1
0
mirror of https://github.com/TREX-CoE/qmckl.git synced 2024-07-03 01:46:12 +02:00

Added transa/transb in distances

This commit is contained in:
Anthony Scemama 2020-10-28 20:15:36 +01:00
parent 8147ad22a7
commit 3aaaabfad3
2 changed files with 164 additions and 32 deletions

View File

@ -14,4 +14,5 @@ context.
* Performance info * Performance info
* Benchmark interpolation of basis functions * Benchmark interpolation of basis functions
* Complex numbers * Complex numbers
* Adjustable number for derivatives (1,2,3)

View File

@ -43,12 +43,14 @@ MunitResult test_qmckl_distance() {
Computes the matrix of the squared distances between all pairs of Computes the matrix of the squared distances between all pairs of
points in two sets, one point within each set: points in two sets, one point within each set:
\[ \[
C_{ij} = \sum_{k=1}^3 (A_{i,k}-B_{j,k})^2 C_{ij} = \sum_{k=1}^3 (A_{k,i}-B_{k,j})^2
\] \]
*** Arguments *** Arguments
| =context= | input | Global state | | =context= | input | Global state |
| =transa= | input | Array =A= is =N=: Normal, =T=: Transposed |
| =transb= | input | Array =B= is =N=: Normal, =T=: Transposed |
| =m= | input | Number of points in the first set | | =m= | input | Number of points in the first set |
| =n= | input | Number of points in the second set | | =n= | input | Number of points in the second set |
| =A(lda,3)= | input | Array containing the $m \times 3$ matrix $A$ | | =A(lda,3)= | input | Array containing the $m \times 3$ matrix $A$ |
@ -63,16 +65,24 @@ MunitResult test_qmckl_distance() {
- =context= is not 0 - =context= is not 0
- =m= > 0 - =m= > 0
- =n= > 0 - =n= > 0
- =lda= >= m - =lda= >= 3 if =transa= is =N=
- =ldb= >= n - =lda= >= m if =transa= is =T=
- =ldc= >= m - =ldb= >= 3 if =transb= is =N=
- =ldb= >= n if =transb= is =T=
- =ldc= >= m if =transa= is =
- =A= is allocated with at least $3 \times m \times 8$ bytes - =A= is allocated with at least $3 \times m \times 8$ bytes
- =B= is allocated with at least $3 \times n \times 8$ bytes - =B= is allocated with at least $3 \times n \times 8$ bytes
- =C= is allocated with at least $m \times n \times 8$ bytes - =C= is allocated with at least $m \times n \times 8$ bytes
*** Performance
This function might be more efficient when =A= and =B= are
transposed.
*** Header *** Header
#+BEGIN_SRC C :comments link :tangle qmckl_distance.h #+BEGIN_SRC C :comments link :tangle qmckl_distance.h
qmckl_exit_code qmckl_distance_sq(qmckl_context context, qmckl_exit_code qmckl_distance_sq(qmckl_context context,
char transa, char transb,
int64_t m, int64_t n, int64_t m, int64_t n,
double *A, int64_t lda, double *A, int64_t lda,
double *B, int64_t ldb, double *B, int64_t ldb,
@ -81,19 +91,21 @@ qmckl_exit_code qmckl_distance_sq(qmckl_context context,
*** Source *** Source
#+BEGIN_SRC f90 :comments link :tangle qmckl_distance.f90 #+BEGIN_SRC f90 :comments link :tangle qmckl_distance.f90
integer function qmckl_distance_sq_f(context, m, n, A, LDA, B, LDB, C, LDC) result(info) integer function qmckl_distance_sq_f(context, transa, transb, m, n, A, LDA, B, LDB, C, LDC) result(info)
implicit none implicit none
integer*8 , intent(in) :: context integer*8 , intent(in) :: context
character , intent(in) :: transa, transb
integer*8 , intent(in) :: m, n integer*8 , intent(in) :: m, n
integer*8 , intent(in) :: lda integer*8 , intent(in) :: lda
real*8 , intent(in) :: A(lda,3) real*8 , intent(in) :: A(lda,*)
integer*8 , intent(in) :: ldb integer*8 , intent(in) :: ldb
real*8 , intent(in) :: B(ldb,3) real*8 , intent(in) :: B(ldb,*)
integer*8 , intent(in) :: ldc integer*8 , intent(in) :: ldc
real*8 , intent(out) :: C(ldc,n) real*8 , intent(out) :: C(ldc,*)
integer*8 :: i,j integer*8 :: i,j
real*8 :: x, y, z real*8 :: x, y, z
integer :: transab
info = 0 info = 0
@ -112,40 +124,107 @@ integer function qmckl_distance_sq_f(context, m, n, A, LDA, B, LDB, C, LDC) resu
return return
endif endif
if (LDA < m) then if (transa == 'N' .or. transa == 'n') then
info = -4 transab = 0
return else if (transa == 'T' .or. transa == 't') then
transab = 1
else
transab = -100
endif endif
if (LDB < n) then if (transb == 'N' .or. transb == 'n') then
continue
else if (transa == 'T' .or. transa == 't') then
transab = transab + 2
else
transab = -100
endif
if (transab < 0) then
info = -4
return
endif
if (iand(transab,1) == 0 .and. LDA < 3) then
info = -5 info = -5
return return
endif endif
if (LDC < m) then if (iand(transab,1) == 1 .and. LDA < m) then
info = -6 info = -6
return return
endif endif
do j=1,n if (iand(transab,2) == 0 .and. LDA < 3) then
do i=1,m info = -6
x = A(i,1) - B(j,1) return
y = A(i,2) - B(j,2) endif
z = A(i,3) - B(j,3)
C(i,j) = x*x + y*y + z*z
end do
end do
if (iand(transab,2) == 2 .and. LDA < m) then
info = -7
return
endif
select case (transab)
case(0)
do j=1,n
do i=1,m
x = A(1,i) - B(1,j)
y = A(2,i) - B(2,j)
z = A(3,i) - B(3,j)
C(i,j) = x*x + y*y + z*z
end do
end do
case(1)
do j=1,n
do i=1,m
x = A(i,1) - B(1,j)
y = A(i,2) - B(2,j)
z = A(i,3) - B(3,j)
C(i,j) = x*x + y*y + z*z
end do
end do
case(2)
do j=1,n
do i=1,m
x = A(1,i) - B(j,1)
y = A(2,i) - B(j,2)
z = A(3,i) - B(j,3)
C(i,j) = x*x + y*y + z*z
end do
end do
case(3)
do j=1,n
do i=1,m
x = A(i,1) - B(j,1)
y = A(i,2) - B(j,2)
z = A(i,3) - B(j,3)
C(i,j) = x*x + y*y + z*z
end do
end do
end select
end function qmckl_distance_sq_f end function qmckl_distance_sq_f
#+END_SRC #+END_SRC
*** C interface :noexport: *** C interface :noexport:
#+BEGIN_SRC f90 :comments link :tangle qmckl_distance.f90 #+BEGIN_SRC f90 :comments link :tangle qmckl_distance.f90
integer(c_int32_t) function qmckl_distance_sq(context, m, n, A, LDA, B, LDB, C, LDC) & integer(c_int32_t) function qmckl_distance_sq(context, transa, transb, m, n, A, LDA, B, LDB, C, LDC) &
bind(C) result(info) bind(C) result(info)
use, intrinsic :: iso_c_binding use, intrinsic :: iso_c_binding
implicit none implicit none
integer (c_int64_t) , intent(in) , value :: context integer (c_int64_t) , intent(in) , value :: context
character (c_char) , intent(in) , value :: transa, transb
integer (c_int64_t) , intent(in) , value :: m, n integer (c_int64_t) , intent(in) , value :: m, n
integer (c_int64_t) , intent(in) , value :: lda integer (c_int64_t) , intent(in) , value :: lda
real (c_double) , intent(in) :: A(lda,3) real (c_double) , intent(in) :: A(lda,3)
@ -155,17 +234,18 @@ integer(c_int32_t) function qmckl_distance_sq(context, m, n, A, LDA, B, LDB, C,
real (c_double) , intent(out) :: C(ldc,n) real (c_double) , intent(out) :: C(ldc,n)
integer, external :: qmckl_distance_sq_f integer, external :: qmckl_distance_sq_f
info = qmckl_distance_sq_f(context, m, n, A, LDA, B, LDB, C, LDC) info = qmckl_distance_sq_f(context, transa, transb, m, n, A, LDA, B, LDB, C, LDC)
end function qmckl_distance_sq end function qmckl_distance_sq
#+END_SRC #+END_SRC
#+BEGIN_SRC f90 :comments link :tangle qmckl_distance.fh #+BEGIN_SRC f90 :comments link :tangle qmckl_distance.fh
interface interface
integer(c_int32_t) function qmckl_distance_sq(context, m, n, A, LDA, B, LDB, C, LDC) & integer(c_int32_t) function qmckl_distance_sq(context, transa, transb, m, n, A, LDA, B, LDB, C, LDC) &
bind(C) bind(C)
use, intrinsic :: iso_c_binding use, intrinsic :: iso_c_binding
implicit none implicit none
integer (c_int64_t) , intent(in) , value :: context integer (c_int64_t) , intent(in) , value :: context
character (c_char) , intent(in) , value :: transa, transb
integer (c_int64_t) , intent(in) , value :: m, n integer (c_int64_t) , intent(in) , value :: m, n
integer (c_int64_t) , intent(in) , value :: lda integer (c_int64_t) , intent(in) , value :: lda
integer (c_int64_t) , intent(in) , value :: ldb integer (c_int64_t) , intent(in) , value :: ldb
@ -192,22 +272,30 @@ integer(c_int32_t) function test_qmckl_distance_sq(context) bind(C)
m = 5 m = 5
n = 6 n = 6
LDA = 6 LDA = m
LDB = 10 LDB = n
LDC = 5 LDC = 5
allocate( A(LDA,3), B(LDB,3), C(LDC,n) ) allocate( A(LDA,m), B(LDB,n), C(LDC,n) )
do j=1,3 do j=1,m
do i=1,m do i=1,m
A(i,j) = -10.d0 + dble(i+j) A(i,j) = -10.d0 + dble(i+j)
end do end do
end do
do j=1,n
do i=1,n do i=1,n
B(i,j) = -1.d0 + dble(i*j) B(i,j) = -1.d0 + dble(i*j)
end do end do
end do end do
test_qmckl_distance_sq = qmckl_distance_sq(context, m, n, A, LDA, B, LDB, C, LDC) test_qmckl_distance_sq = qmckl_distance_sq(context, 'X', 't', m, n, A, LDA, B, LDB, C, LDC)
if (test_qmckl_distance_sq == 0) return
test_qmckl_distance_sq = qmckl_distance_sq(context, 't', 'X', m, n, A, LDA, B, LDB, C, LDC)
if (test_qmckl_distance_sq == 0) return
test_qmckl_distance_sq = qmckl_distance_sq(context, 'T', 't', m, n, A, LDA, B, LDB, C, LDC)
if (test_qmckl_distance_sq /= 0) return if (test_qmckl_distance_sq /= 0) return
test_qmckl_distance_sq = -1 test_qmckl_distance_sq = -1
@ -220,6 +308,49 @@ integer(c_int32_t) function test_qmckl_distance_sq(context) bind(C)
if ( dabs(1.d0 - C(i,j)/x) > 1.d-14 ) return if ( dabs(1.d0 - C(i,j)/x) > 1.d-14 ) return
end do end do
end do end do
test_qmckl_distance_sq = qmckl_distance_sq(context, 'n', 'T', m, n, A, LDA, B, LDB, C, LDC)
if (test_qmckl_distance_sq /= 0) return
test_qmckl_distance_sq = -1
do j=1,n
do i=1,m
x = (A(1,i)-B(j,1))**2 + &
(A(2,i)-B(j,2))**2 + &
(A(3,i)-B(j,3))**2
if ( dabs(1.d0 - C(i,j)/x) > 1.d-14 ) return
end do
end do
test_qmckl_distance_sq = qmckl_distance_sq(context, 'T', 'n', m, n, A, LDA, B, LDB, C, LDC)
if (test_qmckl_distance_sq /= 0) return
test_qmckl_distance_sq = -1
do j=1,n
do i=1,m
x = (A(i,1)-B(1,j))**2 + &
(A(i,2)-B(2,j))**2 + &
(A(i,3)-B(3,j))**2
if ( dabs(1.d0 - C(i,j)/x) > 1.d-14 ) return
end do
end do
test_qmckl_distance_sq = qmckl_distance_sq(context, 'n', 'N', m, n, A, LDA, B, LDB, C, LDC)
if (test_qmckl_distance_sq /= 0) return
test_qmckl_distance_sq = -1
do j=1,n
do i=1,m
x = (A(1,i)-B(1,j))**2 + &
(A(2,i)-B(2,j))**2 + &
(A(3,i)-B(3,j))**2
if ( dabs(1.d0 - C(i,j)/x) > 1.d-14 ) return
end do
end do
test_qmckl_distance_sq = 0 test_qmckl_distance_sq = 0
deallocate(A,B,C) deallocate(A,B,C)
@ -230,14 +361,14 @@ end function test_qmckl_distance_sq
int test_qmckl_distance_sq(qmckl_context context); int test_qmckl_distance_sq(qmckl_context context);
munit_assert_int(0, ==, test_qmckl_distance_sq(context)); munit_assert_int(0, ==, test_qmckl_distance_sq(context));
#+END_SRC #+END_SRC
* End of files * End of files :noexport:
*** Header :noexport: *** Header
#+BEGIN_SRC C :comments link :tangle qmckl_distance.h #+BEGIN_SRC C :comments link :tangle qmckl_distance.h
#endif #endif
#+END_SRC #+END_SRC
*** Test :noexport: *** Test
#+BEGIN_SRC C :comments link :tangle test_qmckl_distance.c #+BEGIN_SRC C :comments link :tangle test_qmckl_distance.c
if (qmckl_context_destroy(context) != QMCKL_SUCCESS) if (qmckl_context_destroy(context) != QMCKL_SUCCESS)
return QMCKL_FAILURE; return QMCKL_FAILURE;