From 3024ffdad9414be985cd8b48a3cf1404b2d9407b Mon Sep 17 00:00:00 2001 From: Anthony Scemama Date: Mon, 3 Mar 2025 14:45:05 +0100 Subject: [PATCH] dgemm in qmckl_compute_jastrow_champ_delta_p_g_hpc --- org/qmckl_jastrow_champ_single.org | 176 +++++++++++++++++++++++++---- 1 file changed, 151 insertions(+), 25 deletions(-) diff --git a/org/qmckl_jastrow_champ_single.org b/org/qmckl_jastrow_champ_single.org index a22eeb1..86bed1e 100644 --- a/org/qmckl_jastrow_champ_single.org +++ b/org/qmckl_jastrow_champ_single.org @@ -3846,6 +3846,124 @@ end do end function qmckl_compute_jastrow_champ_factor_single_een_gl_doc #+end_src + #+begin_src f90 :comments org :tangle (eval f) :noweb yes +integer(qmckl_exit_code) function qmckl_compute_jastrow_champ_factor_single_een_gl_hpc( & + context, num_in, walk_num, elec_num, nucl_num, cord_num, & + dim_c_vector, c_vector_full, lkpm_combined_index, & + tmp_c, dtmp_c, delta_p, delta_p_gl, een_rescaled_n, een_rescaled_single_n, & + een_rescaled_n_gl, een_rescaled_single_n_gl, delta_een_gl) & + result(info) bind(C) + use, intrinsic :: iso_c_binding + use qmckl + implicit none + integer(qmckl_context), intent(in) :: context + integer(c_int64_t) , intent(in), value :: num_in, walk_num, elec_num, cord_num, nucl_num, dim_c_vector + integer(c_int64_t) , intent(in) :: lkpm_combined_index(dim_c_vector,4) + real(c_double) , intent(in) :: c_vector_full(nucl_num, dim_c_vector) + real(c_double) , intent(in) :: tmp_c(elec_num, nucl_num,0:cord_num, 0:cord_num-1, walk_num) + real(c_double) , intent(in) :: dtmp_c(elec_num, 4, nucl_num,0:cord_num, 0:cord_num-1, walk_num) + real(c_double) , intent(in) :: delta_p(elec_num, nucl_num,0:cord_num, 0:cord_num-1, walk_num) + real(c_double) , intent(in) :: delta_p_gl(elec_num, nucl_num, 4, 0:cord_num, 0:cord_num-1, walk_num) + real(c_double) , intent(in) :: een_rescaled_n(elec_num, nucl_num, 0:cord_num, walk_num) + real(c_double) , intent(in) :: een_rescaled_single_n(nucl_num, 0:cord_num, walk_num) + real(c_double) , intent(in) :: een_rescaled_n_gl(elec_num, 4, nucl_num, 0:cord_num, walk_num) + real(c_double) , intent(in) :: een_rescaled_single_n_gl(4, nucl_num, 0:cord_num, walk_num) + real(c_double) , intent(out) :: delta_een_gl(elec_num, 4, walk_num) + + integer*8 :: i, a, j, l, k, p, m, n, nw, kk, num + double precision :: accu, accu2, cn + integer*8 :: LDA, LDB, LDC + + double precision :: een_rescaled_delta_n_gl(4, nucl_num, 0:cord_num, walk_num) + double precision :: een_rescaled_delta_n(nucl_num, 0:cord_num, walk_num) + double precision :: dpg1_m, dpg1_ml, dp_m, dp_ml, een_r_m, een_r_ml, een_r_gl_m, een_r_gl_ml + num = num_in + 1 + + info = QMCKL_SUCCESS + + if (context == QMCKL_NULL_CONTEXT) info = QMCKL_INVALID_CONTEXT + if (walk_num <= 0) info = QMCKL_INVALID_ARG_3 + if (elec_num <= 0) info = QMCKL_INVALID_ARG_4 + if (nucl_num <= 0) info = QMCKL_INVALID_ARG_5 + if (cord_num < 0) info = QMCKL_INVALID_ARG_6 + if (info /= QMCKL_SUCCESS) return + + delta_een_gl = 0.0d0 + + if (cord_num == 0) return + + een_rescaled_delta_n(:,:,:) = een_rescaled_single_n(:,:,:) - een_rescaled_n(num, :, :, :) + een_rescaled_delta_n_gl(:,:,:,:) = een_rescaled_single_n_gl(:,:,:,:) - een_rescaled_n_gl(num, :,:,:,:) + + + do nw =1, walk_num + do n = 1, dim_c_vector + l = lkpm_combined_index(n, 1) + k = lkpm_combined_index(n, 2) + p = lkpm_combined_index(n, 3) + m = lkpm_combined_index(n, 4) + + do kk = 1, 4 + do a = 1, nucl_num + cn = c_vector_full(a, n) + if(cn == 0.d0) cycle + !do i = 1, elec_num + ! delta_een_gl(i,kk,nw) = delta_een_gl(i,kk,nw) + ( & + ! delta_p_gl(i,a,kk,m ,k,nw) * een_rescaled_n(i,a,m+l,nw) + & + ! delta_p_gl(i,a,kk,m+l,k,nw) * een_rescaled_n(i,a,m ,nw) + & + ! delta_p(i,a,m ,k,nw) * een_rescaled_n_gl(i,kk,a,m+l,nw) + & + ! delta_p(i,a,m+l,k,nw) * een_rescaled_n_gl(i,kk,a,m ,nw) ) * cn + !end do + do i = 1, elec_num + ! Cache repeated accesses + dpg1_m = delta_p_gl(i,a,kk,m ,k,nw) + dpg1_ml = delta_p_gl(i,a,kk,m+l,k,nw) + dp_m = delta_p(i,a,m ,k,nw) + dp_ml = delta_p(i,a,m+l,k,nw) + + een_r_m = een_rescaled_n(i,a,m ,nw) + een_r_ml = een_rescaled_n(i,a,m+l,nw) + een_r_gl_m = een_rescaled_n_gl(i,kk,a,m ,nw) + een_r_gl_ml = een_rescaled_n_gl(i,kk,a,m+l,nw) + + delta_een_gl(i,kk,nw) = delta_een_gl(i,kk,nw) + cn * & + (dpg1_m * een_r_ml + dpg1_ml * een_r_m + dp_m * een_r_gl_ml + dp_ml * een_r_gl_m) + end do + + delta_een_gl(num,kk,nw) = delta_een_gl(num,kk,nw) + ( & + (dtmp_c(num,kk,a,m ,k,nw) + delta_p_gl(num,a,kk,m ,k,nw)) * een_rescaled_delta_n(a,m+l,nw) + & + (dtmp_c(num,kk,a,m+l,k,nw) + delta_p_gl(num,a,kk,m+l,k,nw)) * een_rescaled_delta_n(a,m ,nw) + & + (tmp_c(num,a,m ,k,nw) + delta_p(num,a,m ,k,nw)) * een_rescaled_delta_n_gl(kk,a,m+l,nw) + & + (tmp_c(num,a,m+l,k,nw) + delta_p(num,a,m+l,k,nw)) * een_rescaled_delta_n_gl(kk,a,m ,nw) )* cn + end do + end do + do a = 1, nucl_num + cn = c_vector_full(a, n) + if(cn == 0.d0) cycle + cn = cn + cn + do i = 1, elec_num + delta_een_gl(i,4,nw) = delta_een_gl(i,4,nw) + ( & + delta_p_gl(i,a,1,m ,k,nw) * een_rescaled_n_gl(i,1,a,m+l,nw) + & + delta_p_gl(i,a,1,m+l,k,nw) * een_rescaled_n_gl(i,1,a,m ,nw) + & + delta_p_gl(i,a,2,m ,k,nw) * een_rescaled_n_gl(i,2,a,m+l,nw) + & + delta_p_gl(i,a,2,m+l,k,nw) * een_rescaled_n_gl(i,2,a,m ,nw) + & + delta_p_gl(i,a,3,m ,k,nw) * een_rescaled_n_gl(i,3,a,m+l,nw) + & + delta_p_gl(i,a,3,m+l,k,nw) * een_rescaled_n_gl(i,3,a,m ,nw) ) * cn + end do + delta_een_gl(num,4,nw) = delta_een_gl(num,4,nw) + ( & + (delta_p_gl(num,a,1,m ,k,nw) + dtmp_c(num,1,a,m ,k,nw)) * een_rescaled_delta_n_gl(1,a,m+l,nw) + & + (delta_p_gl(num,a,1,m+l,k,nw) + dtmp_c(num,1,a,m+l,k,nw)) * een_rescaled_delta_n_gl(1,a,m ,nw) + & + (delta_p_gl(num,a,2,m ,k,nw) + dtmp_c(num,2,a,m ,k,nw)) * een_rescaled_delta_n_gl(2,a,m+l,nw) + & + (delta_p_gl(num,a,2,m+l,k,nw) + dtmp_c(num,2,a,m+l,k,nw)) * een_rescaled_delta_n_gl(2,a,m ,nw) + & + (delta_p_gl(num,a,3,m ,k,nw) + dtmp_c(num,3,a,m ,k,nw)) * een_rescaled_delta_n_gl(3,a,m+l,nw) + & + (delta_p_gl(num,a,3,m+l,k,nw) + dtmp_c(num,3,a,m+l,k,nw)) * een_rescaled_delta_n_gl(3,a,m ,nw) ) * cn + end do + end do + end do + +end function qmckl_compute_jastrow_champ_factor_single_een_gl_hpc + #+end_src + #+begin_src c :comments org :tangle (eval h_private_func) :noweb yes :exports none qmckl_exit_code qmckl_compute_jastrow_champ_factor_single_een_gl_doc (const qmckl_context context, @@ -3868,7 +3986,26 @@ qmckl_compute_jastrow_champ_factor_single_een_gl_doc (const qmckl_context contex double* const delta_een_gl ); qmckl_exit_code +qmckl_compute_jastrow_champ_factor_single_een_gl_hpc (const qmckl_context context, + const int64_t num, + const int64_t walk_num, + const int64_t elec_num, + const int64_t nucl_num, + const int64_t cord_num, + const int64_t dim_c_vector, + const double* c_vector_full, + const int64_t* lkpm_combined_index, + const double* tmp_c, + const double* dtmp_c, + const double* delta_p, + const double* delta_p_gl, + const double* een_rescaled_n, + const double* een_rescaled_single_n, + const double* een_rescaled_n_gl, + const double* een_rescaled_single_n_gl, + double* const delta_een_gl ); +qmckl_exit_code qmckl_compute_jastrow_champ_factor_single_een_gl (const qmckl_context context, const int64_t num, const int64_t walk_num, @@ -3911,7 +4048,7 @@ qmckl_compute_jastrow_champ_factor_single_een_gl (const qmckl_context context, double* const delta_een_gl ) { #ifdef HAVE_HPC - return qmckl_compute_jastrow_champ_factor_single_een_gl_doc + return qmckl_compute_jastrow_champ_factor_single_een_gl_hpc #else return qmckl_compute_jastrow_champ_factor_single_een_gl_doc #endif @@ -4258,7 +4395,7 @@ integer(qmckl_exit_code) function qmckl_compute_jastrow_champ_delta_p_g_hpc( & double precision :: een_rescaled_delta_n, een_re_n, een_re_single_n integer*8 :: i, a, j, l, k, p, m, n, nw, num - double precision :: tmp, cummu(4) + double precision, allocatable :: tmp(:,:,:) integer*8 :: LDA, LDB, LDC num = num_in + 1 @@ -4279,34 +4416,23 @@ integer(qmckl_exit_code) function qmckl_compute_jastrow_champ_delta_p_g_hpc( & return endif + allocate( tmp(3,nucl_num,0:cord_num) ) do nw=1, walk_num do m=1, cord_num-1 - do j = 1, elec_num - do k = 1, 3 - delta_e_gl(k,j) = een_rescaled_single_e_gl(k,j,m,nw) - een_rescaled_e_gl(num, k, j, m, nw) - end do - end do - do k = 1, 3 - delta_e_gl(k,num) = 0.0d0 - end do + delta_e_gl(1:3,1:elec_num) = een_rescaled_single_e_gl(1:3,1:elec_num,m,nw) - & + een_rescaled_e_gl(num, 1:3, 1:elec_num, m, nw) + delta_e_gl(1:3,num) = 0.0d0 - do l=0, cord_num - do a = 1, nucl_num - do k = 1, 3 - cummu(k) = 0.0d0 - enddo - do i = 1, elec_num - do k = 1, 3 - cummu(k) = cummu(k) + delta_e_gl(k,i) * een_rescaled_n(i,a,l,nw) - end do - enddo - do k = 1, 3 - delta_p_g(num,a,k,l,m,nw) = cummu(k) - end do - end do - end do + call dgemm('N','N', 3, nucl_num*(cord_num+1), elec_num, 1.d0, & + delta_e_gl(1,1), 3, een_rescaled_n(1,1,0,nw), elec_num, 0.d0, & + tmp, 3) + + delta_p_g(num,1:nucl_num,1,0:cord_num,m,nw) = tmp(1,1:nucl_num,0:cord_num) + delta_p_g(num,1:nucl_num,2,0:cord_num,m,nw) = tmp(2,1:nucl_num,0:cord_num) + delta_p_g(num,1:nucl_num,3,0:cord_num,m,nw) = tmp(3,1:nucl_num,0:cord_num) end do end do + deallocate(tmp) end function qmckl_compute_jastrow_champ_delta_p_g_hpc #+end_src