diff --git a/org/qmckl_blas.org b/org/qmckl_blas.org index b05cac5..4bb7384 100644 --- a/org/qmckl_blas.org +++ b/org/qmckl_blas.org @@ -95,11 +95,36 @@ integer function qmckl_dgemm_f(context, TransA, TransB, m, n, k, alpha, A, LDA, real*8 , intent(in) :: B(k,n) integer*8 , intent(in) :: ldc real*8 , intent(out) :: C(m,n) + real*8, allocatable :: AT(:,:), BT(:,:), CT(:,:) - integer*8 :: i,j,l + integer*8 :: i,j,l, LDA_2, LDB_2 info = QMCKL_SUCCESS + if (TransA) then + allocate(AT(k,m)) + do i = 1, m + do j = 1, k + AT(j,i) = A(i,j) + end do + end do + LDA_2 = M + else + LDA_2 = LDA + endif + + if (TransB) then + allocate(BT(n,k)) + do i = 1, k + do j = 1, n + BT(j,i) = B(i,j) + end do + end do + LDB_2 = K + else + LDB_2 = LDB + endif + if (context == QMCKL_NULL_CONTEXT) then info = QMCKL_INVALID_CONTEXT return @@ -120,12 +145,12 @@ integer function qmckl_dgemm_f(context, TransA, TransB, m, n, k, alpha, A, LDA, return endif - if (LDA .ne. m) then + if (LDA_2 .ne. m) then info = QMCKL_INVALID_ARG_9 return endif - if (LDB .ne. k) then + if (LDB_2 .ne. k) then info = QMCKL_INVALID_ARG_10 return endif @@ -135,7 +160,13 @@ integer function qmckl_dgemm_f(context, TransA, TransB, m, n, k, alpha, A, LDA, return endif - C = matmul(A,B) + if (TransA) then + C = matmul(AT,B) + else if (TransB) then + C = matmul(A,BT) + else + C = matmul(A,B) + endif end function qmckl_dgemm_f #+end_src @@ -153,8 +184,8 @@ end function qmckl_dgemm_f implicit none integer (c_int64_t) , intent(in) , value :: context - logical , intent(in) , value :: TransA - logical , intent(in) , value :: TransB + logical*8 , intent(in) , value :: TransA + logical*8 , intent(in) , value :: TransB integer (c_int64_t) , intent(in) , value :: m integer (c_int64_t) , intent(in) , value :: n integer (c_int64_t) , intent(in) , value :: k @@ -188,8 +219,8 @@ end function qmckl_dgemm_f implicit none integer (c_int64_t) , intent(in) , value :: context - logical , intent(in) , value :: TransA - logical , intent(in) , value :: TransB + logical*8 , intent(in) , value :: TransA + logical*8 , intent(in) , value :: TransB integer (c_int64_t) , intent(in) , value :: m integer (c_int64_t) , intent(in) , value :: n integer (c_int64_t) , intent(in) , value :: k @@ -216,7 +247,7 @@ integer(qmckl_exit_code) function test_qmckl_dgemm(context) bind(C) double precision, allocatable :: A(:,:), B(:,:), C(:,:), D(:,:) integer*8 :: m, n, k, LDA, LDB, LDC integer*8 :: i,j,l - logical :: TransA, TransB + logical*8 :: TransA, TransB double precision :: x, alpha, beta TransA = .False. @@ -271,7 +302,7 @@ integer(qmckl_exit_code) function test_qmckl_dgemm(context) bind(C) deallocate(A,B,C,D) end function test_qmckl_dgemm #+end_src - + #+begin_src c :comments link :tangle (eval c_test) qmckl_exit_code test_qmckl_dgemm(qmckl_context context); assert(QMCKL_SUCCESS == test_qmckl_dgemm(context));