1
0
mirror of https://github.com/TREX-CoE/qmckl.git synced 2025-01-08 20:33:40 +01:00

Working on adding TransA and TransB to DGEMM. #32

This commit is contained in:
v1j4y 2021-09-27 18:16:51 +02:00
parent a52d6683cc
commit ff5e7882d0

View File

@ -95,11 +95,36 @@ integer function qmckl_dgemm_f(context, TransA, TransB, m, n, k, alpha, A, LDA,
real*8 , intent(in) :: B(k,n)
integer*8 , intent(in) :: ldc
real*8 , intent(out) :: C(m,n)
real*8, allocatable :: AT(:,:), BT(:,:), CT(:,:)
integer*8 :: i,j,l
integer*8 :: i,j,l, LDA_2, LDB_2
info = QMCKL_SUCCESS
if (TransA) then
allocate(AT(k,m))
do i = 1, m
do j = 1, k
AT(j,i) = A(i,j)
end do
end do
LDA_2 = M
else
LDA_2 = LDA
endif
if (TransB) then
allocate(BT(n,k))
do i = 1, k
do j = 1, n
BT(j,i) = B(i,j)
end do
end do
LDB_2 = K
else
LDB_2 = LDB
endif
if (context == QMCKL_NULL_CONTEXT) then
info = QMCKL_INVALID_CONTEXT
return
@ -120,12 +145,12 @@ integer function qmckl_dgemm_f(context, TransA, TransB, m, n, k, alpha, A, LDA,
return
endif
if (LDA .ne. m) then
if (LDA_2 .ne. m) then
info = QMCKL_INVALID_ARG_9
return
endif
if (LDB .ne. k) then
if (LDB_2 .ne. k) then
info = QMCKL_INVALID_ARG_10
return
endif
@ -135,7 +160,13 @@ integer function qmckl_dgemm_f(context, TransA, TransB, m, n, k, alpha, A, LDA,
return
endif
C = matmul(A,B)
if (TransA) then
C = matmul(AT,B)
else if (TransB) then
C = matmul(A,BT)
else
C = matmul(A,B)
endif
end function qmckl_dgemm_f
#+end_src
@ -153,8 +184,8 @@ end function qmckl_dgemm_f
implicit none
integer (c_int64_t) , intent(in) , value :: context
logical , intent(in) , value :: TransA
logical , intent(in) , value :: TransB
logical*8 , intent(in) , value :: TransA
logical*8 , intent(in) , value :: TransB
integer (c_int64_t) , intent(in) , value :: m
integer (c_int64_t) , intent(in) , value :: n
integer (c_int64_t) , intent(in) , value :: k
@ -188,8 +219,8 @@ end function qmckl_dgemm_f
implicit none
integer (c_int64_t) , intent(in) , value :: context
logical , intent(in) , value :: TransA
logical , intent(in) , value :: TransB
logical*8 , intent(in) , value :: TransA
logical*8 , intent(in) , value :: TransB
integer (c_int64_t) , intent(in) , value :: m
integer (c_int64_t) , intent(in) , value :: n
integer (c_int64_t) , intent(in) , value :: k
@ -216,7 +247,7 @@ integer(qmckl_exit_code) function test_qmckl_dgemm(context) bind(C)
double precision, allocatable :: A(:,:), B(:,:), C(:,:), D(:,:)
integer*8 :: m, n, k, LDA, LDB, LDC
integer*8 :: i,j,l
logical :: TransA, TransB
logical*8 :: TransA, TransB
double precision :: x, alpha, beta
TransA = .False.