1
0
mirror of https://github.com/TREX-CoE/qmckl.git synced 2025-01-03 18:16:28 +01:00

Tests pass for qmckl_dgemm. #32.

This commit is contained in:
v1j4y 2021-09-15 16:21:42 +02:00
parent cf3550b6b7
commit eaede28a73

View File

@ -90,13 +90,13 @@ integer function qmckl_dgemm_f(context, TransA, TransB, m, n, k, alpha, A, LDA,
integer*8 , intent(in) :: m, n, k integer*8 , intent(in) :: m, n, k
real*8 , intent(in) :: alpha, beta real*8 , intent(in) :: alpha, beta
integer*8 , intent(in) :: lda integer*8 , intent(in) :: lda
real*8 , intent(in) :: A(m,n) real*8 , intent(in) :: A(m,k)
integer*8 , intent(in) :: ldb integer*8 , intent(in) :: ldb
real*8 , intent(in) :: B(n,k) real*8 , intent(in) :: B(k,n)
integer*8 , intent(in) :: ldc integer*8 , intent(in) :: ldc
real*8 , intent(out) :: C(m,n) real*8 , intent(out) :: C(m,n)
integer*8 :: i,j integer*8 :: i,j,l
info = QMCKL_SUCCESS info = QMCKL_SUCCESS
@ -120,17 +120,17 @@ integer function qmckl_dgemm_f(context, TransA, TransB, m, n, k, alpha, A, LDA,
return return
endif endif
if (LDA < m) then if (LDA .ne. m) then
info = QMCKL_INVALID_ARG_9 info = QMCKL_INVALID_ARG_9
return return
endif endif
if (LDB < n) then if (LDB .ne. k) then
info = QMCKL_INVALID_ARG_10 info = QMCKL_INVALID_ARG_10
return return
endif endif
if (LDB < n) then if (LDC .ne. m) then
info = QMCKL_INVALID_ARG_13 info = QMCKL_INVALID_ARG_13
return return
endif endif
@ -153,8 +153,8 @@ end function qmckl_dgemm_f
implicit none implicit none
integer (c_int64_t) , intent(in) , value :: context integer (c_int64_t) , intent(in) , value :: context
logical*8 , intent(in) , value :: TransA logical , intent(in) , value :: TransA
logical*8 , intent(in) , value :: TransB logical , intent(in) , value :: TransB
integer (c_int64_t) , intent(in) , value :: m integer (c_int64_t) , intent(in) , value :: m
integer (c_int64_t) , intent(in) , value :: n integer (c_int64_t) , intent(in) , value :: n
integer (c_int64_t) , intent(in) , value :: k integer (c_int64_t) , intent(in) , value :: k
@ -188,17 +188,17 @@ end function qmckl_dgemm_f
implicit none implicit none
integer (c_int64_t) , intent(in) , value :: context integer (c_int64_t) , intent(in) , value :: context
logical*8 , intent(in) , value :: TransA logical , intent(in) , value :: TransA
logical*8 , intent(in) , value :: TransB logical , intent(in) , value :: TransB
integer (c_int64_t) , intent(in) , value :: m integer (c_int64_t) , intent(in) , value :: m
integer (c_int64_t) , intent(in) , value :: n integer (c_int64_t) , intent(in) , value :: n
integer (c_int64_t) , intent(in) , value :: k integer (c_int64_t) , intent(in) , value :: k
real (c_double ) , intent(in) :: alpha real (c_double ) , intent(in) , value :: alpha
integer (c_int64_t) , intent(in) , value :: lda integer (c_int64_t) , intent(in) , value :: lda
real (c_double ) , intent(in) :: A(lda,*) real (c_double ) , intent(in) :: A(lda,*)
integer (c_int64_t) , intent(in) , value :: ldb integer (c_int64_t) , intent(in) , value :: ldb
real (c_double ) , intent(in) :: B(ldb,*) real (c_double ) , intent(in) :: B(ldb,*)
real (c_double ) , intent(in) :: beta real (c_double ) , intent(in) , value :: beta
integer (c_int64_t) , intent(in) , value :: ldc integer (c_int64_t) , intent(in) , value :: ldc
real (c_double ) , intent(out) :: C(ldc,*) real (c_double ) , intent(out) :: C(ldc,*)
@ -216,11 +216,11 @@ integer(qmckl_exit_code) function test_qmckl_dgemm(context) bind(C)
double precision, allocatable :: A(:,:), B(:,:), C(:,:), D(:,:) double precision, allocatable :: A(:,:), B(:,:), C(:,:), D(:,:)
integer*8 :: m, n, k, LDA, LDB, LDC integer*8 :: m, n, k, LDA, LDB, LDC
integer*8 :: i,j,l integer*8 :: i,j,l
logical*8 :: TransA, TransB logical :: TransA, TransB
double precision :: x, alpha, beta double precision :: x, alpha, beta
TransA = .False. TransA = .False.
TransB = .False. TransB = .False.
m = 5_8 m = 5_8
k = 4_8 k = 4_8
n = 6_8 n = 6_8
@ -234,14 +234,16 @@ integer(qmckl_exit_code) function test_qmckl_dgemm(context) bind(C)
B = 0.d0 B = 0.d0
C = 0.d0 C = 0.d0
D = 0.d0 D = 0.d0
do j=1,m alpha = 1.0d0
do i=1,k beta = 0.0d0
do j=1,k
do i=1,m
A(i,j) = -10.d0 + dble(i+j) A(i,j) = -10.d0 + dble(i+j)
end do end do
end do end do
do j=1,k do j=1,n
do i=1,n do i=1,k
B(i,j) = -10.d0 + dble(i+j) B(i,j) = -10.d0 + dble(i+j)
end do end do
end do end do
@ -253,21 +255,20 @@ integer(qmckl_exit_code) function test_qmckl_dgemm(context) bind(C)
test_qmckl_dgemm = QMCKL_FAILURE test_qmckl_dgemm = QMCKL_FAILURE
x = 0.d0 x = 0.d0
do j=1,m do j=1,n
do i=l,n do i=1,m
do l=1,k do l=1,k
D(i,j) = D(i,j) + A(i,k)*B(k,j) D(i,j) = D(i,j) + A(i,l)*B(l,j)
end do end do
x = x + (D(i,j) - C(i,j))**2 x = x + (D(i,j) - C(i,j))**2
end do end do
end do end do
print *,"DABS(X)= ",dabs(x)
if (dabs(x) <= 1.d-15) then if (dabs(x) <= 1.d-15) then
test_qmckl_dgemm = QMCKL_SUCCESS test_qmckl_dgemm = QMCKL_SUCCESS
endif endif
deallocate(A,B) deallocate(A,B,C,D)
end function test_qmckl_dgemm end function test_qmckl_dgemm
#+end_src #+end_src