diff --git a/TODO.org b/TODO.org index 250dd79..1fad144 100644 --- a/TODO.org +++ b/TODO.org @@ -14,4 +14,5 @@ context. * Performance info * Benchmark interpolation of basis functions * Complex numbers +* Adjustable number for derivatives (1,2,3) diff --git a/src/qmckl_distance.org b/src/qmckl_distance.org index 5bdbb53..6795b98 100644 --- a/src/qmckl_distance.org +++ b/src/qmckl_distance.org @@ -43,12 +43,14 @@ MunitResult test_qmckl_distance() { Computes the matrix of the squared distances between all pairs of points in two sets, one point within each set: \[ - C_{ij} = \sum_{k=1}^3 (A_{i,k}-B_{j,k})^2 + C_{ij} = \sum_{k=1}^3 (A_{k,i}-B_{k,j})^2 \] *** Arguments | =context= | input | Global state | + | =transa= | input | Array =A= is =N=: Normal, =T=: Transposed | + | =transb= | input | Array =B= is =N=: Normal, =T=: Transposed | | =m= | input | Number of points in the first set | | =n= | input | Number of points in the second set | | =A(lda,3)= | input | Array containing the $m \times 3$ matrix $A$ | @@ -63,16 +65,24 @@ MunitResult test_qmckl_distance() { - =context= is not 0 - =m= > 0 - =n= > 0 - - =lda= >= m - - =ldb= >= n - - =ldc= >= m + - =lda= >= 3 if =transa= is =N= + - =lda= >= m if =transa= is =T= + - =ldb= >= 3 if =transb= is =N= + - =ldb= >= n if =transb= is =T= + - =ldc= >= m if =transa= is = - =A= is allocated with at least $3 \times m \times 8$ bytes - =B= is allocated with at least $3 \times n \times 8$ bytes - =C= is allocated with at least $m \times n \times 8$ bytes +*** Performance + + This function might be more efficient when =A= and =B= are + transposed. + *** Header #+BEGIN_SRC C :comments link :tangle qmckl_distance.h qmckl_exit_code qmckl_distance_sq(qmckl_context context, + char transa, char transb, int64_t m, int64_t n, double *A, int64_t lda, double *B, int64_t ldb, @@ -81,19 +91,21 @@ qmckl_exit_code qmckl_distance_sq(qmckl_context context, *** Source #+BEGIN_SRC f90 :comments link :tangle qmckl_distance.f90 -integer function qmckl_distance_sq_f(context, m, n, A, LDA, B, LDB, C, LDC) result(info) +integer function qmckl_distance_sq_f(context, transa, transb, m, n, A, LDA, B, LDB, C, LDC) result(info) implicit none integer*8 , intent(in) :: context + character , intent(in) :: transa, transb integer*8 , intent(in) :: m, n integer*8 , intent(in) :: lda - real*8 , intent(in) :: A(lda,3) + real*8 , intent(in) :: A(lda,*) integer*8 , intent(in) :: ldb - real*8 , intent(in) :: B(ldb,3) + real*8 , intent(in) :: B(ldb,*) integer*8 , intent(in) :: ldc - real*8 , intent(out) :: C(ldc,n) + real*8 , intent(out) :: C(ldc,*) integer*8 :: i,j real*8 :: x, y, z + integer :: transab info = 0 @@ -112,40 +124,107 @@ integer function qmckl_distance_sq_f(context, m, n, A, LDA, B, LDB, C, LDC) resu return endif - if (LDA < m) then - info = -4 - return + if (transa == 'N' .or. transa == 'n') then + transab = 0 + else if (transa == 'T' .or. transa == 't') then + transab = 1 + else + transab = -100 endif - if (LDB < n) then + if (transb == 'N' .or. transb == 'n') then + continue + else if (transa == 'T' .or. transa == 't') then + transab = transab + 2 + else + transab = -100 + endif + + if (transab < 0) then + info = -4 + return + endif + + if (iand(transab,1) == 0 .and. LDA < 3) then info = -5 return endif - if (LDC < m) then + if (iand(transab,1) == 1 .and. LDA < m) then info = -6 return endif - do j=1,n - do i=1,m - x = A(i,1) - B(j,1) - y = A(i,2) - B(j,2) - z = A(i,3) - B(j,3) - C(i,j) = x*x + y*y + z*z - end do - end do + if (iand(transab,2) == 0 .and. LDA < 3) then + info = -6 + return + endif + if (iand(transab,2) == 2 .and. LDA < m) then + info = -7 + return + endif + + + select case (transab) + + case(0) + + do j=1,n + do i=1,m + x = A(1,i) - B(1,j) + y = A(2,i) - B(2,j) + z = A(3,i) - B(3,j) + C(i,j) = x*x + y*y + z*z + end do + end do + + case(1) + + do j=1,n + do i=1,m + x = A(i,1) - B(1,j) + y = A(i,2) - B(2,j) + z = A(i,3) - B(3,j) + C(i,j) = x*x + y*y + z*z + end do + end do + + case(2) + + do j=1,n + do i=1,m + x = A(1,i) - B(j,1) + y = A(2,i) - B(j,2) + z = A(3,i) - B(j,3) + C(i,j) = x*x + y*y + z*z + end do + end do + + case(3) + + do j=1,n + do i=1,m + x = A(i,1) - B(j,1) + y = A(i,2) - B(j,2) + z = A(i,3) - B(j,3) + C(i,j) = x*x + y*y + z*z + end do + end do + + end select + end function qmckl_distance_sq_f #+END_SRC *** C interface :noexport: #+BEGIN_SRC f90 :comments link :tangle qmckl_distance.f90 -integer(c_int32_t) function qmckl_distance_sq(context, m, n, A, LDA, B, LDB, C, LDC) & +integer(c_int32_t) function qmckl_distance_sq(context, transa, transb, m, n, A, LDA, B, LDB, C, LDC) & bind(C) result(info) use, intrinsic :: iso_c_binding implicit none integer (c_int64_t) , intent(in) , value :: context + character (c_char) , intent(in) , value :: transa, transb integer (c_int64_t) , intent(in) , value :: m, n integer (c_int64_t) , intent(in) , value :: lda real (c_double) , intent(in) :: A(lda,3) @@ -155,17 +234,18 @@ integer(c_int32_t) function qmckl_distance_sq(context, m, n, A, LDA, B, LDB, C, real (c_double) , intent(out) :: C(ldc,n) integer, external :: qmckl_distance_sq_f - info = qmckl_distance_sq_f(context, m, n, A, LDA, B, LDB, C, LDC) + info = qmckl_distance_sq_f(context, transa, transb, m, n, A, LDA, B, LDB, C, LDC) end function qmckl_distance_sq #+END_SRC #+BEGIN_SRC f90 :comments link :tangle qmckl_distance.fh interface - integer(c_int32_t) function qmckl_distance_sq(context, m, n, A, LDA, B, LDB, C, LDC) & + integer(c_int32_t) function qmckl_distance_sq(context, transa, transb, m, n, A, LDA, B, LDB, C, LDC) & bind(C) use, intrinsic :: iso_c_binding implicit none integer (c_int64_t) , intent(in) , value :: context + character (c_char) , intent(in) , value :: transa, transb integer (c_int64_t) , intent(in) , value :: m, n integer (c_int64_t) , intent(in) , value :: lda integer (c_int64_t) , intent(in) , value :: ldb @@ -192,22 +272,30 @@ integer(c_int32_t) function test_qmckl_distance_sq(context) bind(C) m = 5 n = 6 - LDA = 6 - LDB = 10 + LDA = m + LDB = n LDC = 5 - allocate( A(LDA,3), B(LDB,3), C(LDC,n) ) + allocate( A(LDA,m), B(LDB,n), C(LDC,n) ) - do j=1,3 + do j=1,m do i=1,m A(i,j) = -10.d0 + dble(i+j) end do + end do + do j=1,n do i=1,n B(i,j) = -1.d0 + dble(i*j) end do end do - test_qmckl_distance_sq = qmckl_distance_sq(context, m, n, A, LDA, B, LDB, C, LDC) + test_qmckl_distance_sq = qmckl_distance_sq(context, 'X', 't', m, n, A, LDA, B, LDB, C, LDC) + if (test_qmckl_distance_sq == 0) return + + test_qmckl_distance_sq = qmckl_distance_sq(context, 't', 'X', m, n, A, LDA, B, LDB, C, LDC) + if (test_qmckl_distance_sq == 0) return + + test_qmckl_distance_sq = qmckl_distance_sq(context, 'T', 't', m, n, A, LDA, B, LDB, C, LDC) if (test_qmckl_distance_sq /= 0) return test_qmckl_distance_sq = -1 @@ -220,6 +308,49 @@ integer(c_int32_t) function test_qmckl_distance_sq(context) bind(C) if ( dabs(1.d0 - C(i,j)/x) > 1.d-14 ) return end do end do + + test_qmckl_distance_sq = qmckl_distance_sq(context, 'n', 'T', m, n, A, LDA, B, LDB, C, LDC) + if (test_qmckl_distance_sq /= 0) return + + test_qmckl_distance_sq = -1 + + do j=1,n + do i=1,m + x = (A(1,i)-B(j,1))**2 + & + (A(2,i)-B(j,2))**2 + & + (A(3,i)-B(j,3))**2 + if ( dabs(1.d0 - C(i,j)/x) > 1.d-14 ) return + end do + end do + + test_qmckl_distance_sq = qmckl_distance_sq(context, 'T', 'n', m, n, A, LDA, B, LDB, C, LDC) + if (test_qmckl_distance_sq /= 0) return + + test_qmckl_distance_sq = -1 + + do j=1,n + do i=1,m + x = (A(i,1)-B(1,j))**2 + & + (A(i,2)-B(2,j))**2 + & + (A(i,3)-B(3,j))**2 + if ( dabs(1.d0 - C(i,j)/x) > 1.d-14 ) return + end do + end do + + test_qmckl_distance_sq = qmckl_distance_sq(context, 'n', 'N', m, n, A, LDA, B, LDB, C, LDC) + if (test_qmckl_distance_sq /= 0) return + + test_qmckl_distance_sq = -1 + + do j=1,n + do i=1,m + x = (A(1,i)-B(1,j))**2 + & + (A(2,i)-B(2,j))**2 + & + (A(3,i)-B(3,j))**2 + if ( dabs(1.d0 - C(i,j)/x) > 1.d-14 ) return + end do + end do + test_qmckl_distance_sq = 0 deallocate(A,B,C) @@ -230,14 +361,14 @@ end function test_qmckl_distance_sq int test_qmckl_distance_sq(qmckl_context context); munit_assert_int(0, ==, test_qmckl_distance_sq(context)); #+END_SRC -* End of files +* End of files :noexport: -*** Header :noexport: +*** Header #+BEGIN_SRC C :comments link :tangle qmckl_distance.h #endif #+END_SRC -*** Test :noexport: +*** Test #+BEGIN_SRC C :comments link :tangle test_qmckl_distance.c if (qmckl_context_destroy(context) != QMCKL_SUCCESS) return QMCKL_FAILURE;