mirror of
https://github.com/TREX-CoE/qmckl.git
synced 2025-01-09 12:44:12 +01:00
Working on adding TransA and TransB to DGEMM. #32
This commit is contained in:
parent
a52d6683cc
commit
ff5e7882d0
@ -95,11 +95,36 @@ integer function qmckl_dgemm_f(context, TransA, TransB, m, n, k, alpha, A, LDA,
|
|||||||
real*8 , intent(in) :: B(k,n)
|
real*8 , intent(in) :: B(k,n)
|
||||||
integer*8 , intent(in) :: ldc
|
integer*8 , intent(in) :: ldc
|
||||||
real*8 , intent(out) :: C(m,n)
|
real*8 , intent(out) :: C(m,n)
|
||||||
|
real*8, allocatable :: AT(:,:), BT(:,:), CT(:,:)
|
||||||
|
|
||||||
integer*8 :: i,j,l
|
integer*8 :: i,j,l, LDA_2, LDB_2
|
||||||
|
|
||||||
info = QMCKL_SUCCESS
|
info = QMCKL_SUCCESS
|
||||||
|
|
||||||
|
if (TransA) then
|
||||||
|
allocate(AT(k,m))
|
||||||
|
do i = 1, m
|
||||||
|
do j = 1, k
|
||||||
|
AT(j,i) = A(i,j)
|
||||||
|
end do
|
||||||
|
end do
|
||||||
|
LDA_2 = M
|
||||||
|
else
|
||||||
|
LDA_2 = LDA
|
||||||
|
endif
|
||||||
|
|
||||||
|
if (TransB) then
|
||||||
|
allocate(BT(n,k))
|
||||||
|
do i = 1, k
|
||||||
|
do j = 1, n
|
||||||
|
BT(j,i) = B(i,j)
|
||||||
|
end do
|
||||||
|
end do
|
||||||
|
LDB_2 = K
|
||||||
|
else
|
||||||
|
LDB_2 = LDB
|
||||||
|
endif
|
||||||
|
|
||||||
if (context == QMCKL_NULL_CONTEXT) then
|
if (context == QMCKL_NULL_CONTEXT) then
|
||||||
info = QMCKL_INVALID_CONTEXT
|
info = QMCKL_INVALID_CONTEXT
|
||||||
return
|
return
|
||||||
@ -120,12 +145,12 @@ integer function qmckl_dgemm_f(context, TransA, TransB, m, n, k, alpha, A, LDA,
|
|||||||
return
|
return
|
||||||
endif
|
endif
|
||||||
|
|
||||||
if (LDA .ne. m) then
|
if (LDA_2 .ne. m) then
|
||||||
info = QMCKL_INVALID_ARG_9
|
info = QMCKL_INVALID_ARG_9
|
||||||
return
|
return
|
||||||
endif
|
endif
|
||||||
|
|
||||||
if (LDB .ne. k) then
|
if (LDB_2 .ne. k) then
|
||||||
info = QMCKL_INVALID_ARG_10
|
info = QMCKL_INVALID_ARG_10
|
||||||
return
|
return
|
||||||
endif
|
endif
|
||||||
@ -135,7 +160,13 @@ integer function qmckl_dgemm_f(context, TransA, TransB, m, n, k, alpha, A, LDA,
|
|||||||
return
|
return
|
||||||
endif
|
endif
|
||||||
|
|
||||||
C = matmul(A,B)
|
if (TransA) then
|
||||||
|
C = matmul(AT,B)
|
||||||
|
else if (TransB) then
|
||||||
|
C = matmul(A,BT)
|
||||||
|
else
|
||||||
|
C = matmul(A,B)
|
||||||
|
endif
|
||||||
end function qmckl_dgemm_f
|
end function qmckl_dgemm_f
|
||||||
#+end_src
|
#+end_src
|
||||||
|
|
||||||
@ -153,8 +184,8 @@ end function qmckl_dgemm_f
|
|||||||
implicit none
|
implicit none
|
||||||
|
|
||||||
integer (c_int64_t) , intent(in) , value :: context
|
integer (c_int64_t) , intent(in) , value :: context
|
||||||
logical , intent(in) , value :: TransA
|
logical*8 , intent(in) , value :: TransA
|
||||||
logical , intent(in) , value :: TransB
|
logical*8 , intent(in) , value :: TransB
|
||||||
integer (c_int64_t) , intent(in) , value :: m
|
integer (c_int64_t) , intent(in) , value :: m
|
||||||
integer (c_int64_t) , intent(in) , value :: n
|
integer (c_int64_t) , intent(in) , value :: n
|
||||||
integer (c_int64_t) , intent(in) , value :: k
|
integer (c_int64_t) , intent(in) , value :: k
|
||||||
@ -188,8 +219,8 @@ end function qmckl_dgemm_f
|
|||||||
implicit none
|
implicit none
|
||||||
|
|
||||||
integer (c_int64_t) , intent(in) , value :: context
|
integer (c_int64_t) , intent(in) , value :: context
|
||||||
logical , intent(in) , value :: TransA
|
logical*8 , intent(in) , value :: TransA
|
||||||
logical , intent(in) , value :: TransB
|
logical*8 , intent(in) , value :: TransB
|
||||||
integer (c_int64_t) , intent(in) , value :: m
|
integer (c_int64_t) , intent(in) , value :: m
|
||||||
integer (c_int64_t) , intent(in) , value :: n
|
integer (c_int64_t) , intent(in) , value :: n
|
||||||
integer (c_int64_t) , intent(in) , value :: k
|
integer (c_int64_t) , intent(in) , value :: k
|
||||||
@ -216,7 +247,7 @@ integer(qmckl_exit_code) function test_qmckl_dgemm(context) bind(C)
|
|||||||
double precision, allocatable :: A(:,:), B(:,:), C(:,:), D(:,:)
|
double precision, allocatable :: A(:,:), B(:,:), C(:,:), D(:,:)
|
||||||
integer*8 :: m, n, k, LDA, LDB, LDC
|
integer*8 :: m, n, k, LDA, LDB, LDC
|
||||||
integer*8 :: i,j,l
|
integer*8 :: i,j,l
|
||||||
logical :: TransA, TransB
|
logical*8 :: TransA, TransB
|
||||||
double precision :: x, alpha, beta
|
double precision :: x, alpha, beta
|
||||||
|
|
||||||
TransA = .False.
|
TransA = .False.
|
||||||
|
Loading…
Reference in New Issue
Block a user