10
0
mirror of https://github.com/QuantumPackage/qp2.git synced 2025-01-07 03:43:14 +01:00

Accelerated BH Jastrow

This commit is contained in:
Anthony Scemama 2024-03-15 18:19:00 +01:00
parent cfdaf722df
commit 0a8d57abd9

View File

@ -167,9 +167,9 @@ subroutine grad1_j12_r1_seq(r1, n_grid2, gradx, grady, gradz)
integer :: jpoint integer :: jpoint
integer :: i_nucl, p, mpA, npA, opA integer :: i_nucl, p, mpA, npA, opA
double precision :: r2(3) double precision :: r2(3)
double precision :: dx, dy, dz, r12, tmp double precision :: dx, dy, dz, r12, tmp, r12_inv
double precision :: mu_val, mu_tmp, mu_der(3) double precision :: mu_val, mu_tmp, mu_der(3)
double precision :: rn(3), f1A, gard1_f1A(3), f2A, gard2_f2A(3), g12, gard1_g12(3) double precision :: rn(3), f1A, grad1_f1A(3), f2A, grad2_f2A(3), g12, grad1_g12(3)
double precision :: tmp1, tmp2 double precision :: tmp1, tmp2
@ -191,15 +191,19 @@ subroutine grad1_j12_r1_seq(r1, n_grid2, gradx, grady, gradz)
dy = r1(2) - r2(2) dy = r1(2) - r2(2)
dz = r1(3) - r2(3) dz = r1(3) - r2(3)
r12 = dsqrt(dx * dx + dy * dy + dz * dz) r12 = dx * dx + dy * dy + dz * dz
if(r12 .lt. 1d-10) then
if(r12 .lt. 1d-20) then
gradx(jpoint) = 0.d0 gradx(jpoint) = 0.d0
grady(jpoint) = 0.d0 grady(jpoint) = 0.d0
gradz(jpoint) = 0.d0 gradz(jpoint) = 0.d0
cycle cycle
endif endif
tmp = 0.5d0 * (1.d0 - derf(mu_erf * r12)) / r12 r12_inv = 1.d0/dsqrt(r12)
r12 = r12*r12_inv
tmp = 0.5d0 * (1.d0 - derf(mu_erf * r12)) * r12_inv
gradx(jpoint) = tmp * dx gradx(jpoint) = tmp * dx
grady(jpoint) = tmp * dy grady(jpoint) = tmp * dy
@ -220,23 +224,29 @@ subroutine grad1_j12_r1_seq(r1, n_grid2, gradx, grady, gradz)
dx = r1(1) - r2(1) dx = r1(1) - r2(1)
dy = r1(2) - r2(2) dy = r1(2) - r2(2)
dz = r1(3) - r2(3) dz = r1(3) - r2(3)
r12 = dsqrt(dx * dx + dy * dy + dz * dz)
call mu_r_val_and_grad(r1, r2, mu_val, mu_der) r12 = dx * dx + dy * dy + dz * dz
mu_tmp = mu_val * r12
tmp = inv_sq_pi_2 * dexp(-mu_tmp*mu_tmp) / (mu_val * mu_val)
gradx(jpoint) = tmp * mu_der(1)
grady(jpoint) = tmp * mu_der(2)
gradz(jpoint) = tmp * mu_der(3)
if(r12 .lt. 1d-10) then if(r12 .lt. 1d-20) then
gradx(jpoint) = 0.d0 gradx(jpoint) = 0.d0
grady(jpoint) = 0.d0 grady(jpoint) = 0.d0
gradz(jpoint) = 0.d0 gradz(jpoint) = 0.d0
cycle cycle
endif endif
tmp = 0.5d0 * (1.d0 - derf(mu_tmp)) / r12 r12_inv = 1.d0/dsqrt(r12)
r12 = r12*r12_inv
call mu_r_val_and_grad(r1, r2, mu_val, mu_der)
mu_tmp = mu_val * r12
tmp = inv_sq_pi_2 * dexp(-mu_tmp*mu_tmp) / (mu_val * mu_val)
gradx(jpoint) = tmp * mu_der(1)
grady(jpoint) = tmp * mu_der(2)
gradz(jpoint) = tmp * mu_der(3)
tmp = 0.5d0 * (1.d0 - derf(mu_tmp)) * r12_inv
gradx(jpoint) = gradx(jpoint) + tmp * dx gradx(jpoint) = gradx(jpoint) + tmp * dx
grady(jpoint) = grady(jpoint) + tmp * dy grady(jpoint) = grady(jpoint) + tmp * dy
@ -263,7 +273,8 @@ subroutine grad1_j12_r1_seq(r1, n_grid2, gradx, grady, gradz)
dx = r1(1) - r2(1) dx = r1(1) - r2(1)
dy = r1(2) - r2(2) dy = r1(2) - r2(2)
dz = r1(3) - r2(3) dz = r1(3) - r2(3)
r12 = dsqrt(dx * dx + dy * dy + dz * dz) r12 = dx * dx + dy * dy + dz * dz
if(r12 .lt. 1d-10) then if(r12 .lt. 1d-10) then
gradx(jpoint) = 0.d0 gradx(jpoint) = 0.d0
grady(jpoint) = 0.d0 grady(jpoint) = 0.d0
@ -271,6 +282,8 @@ subroutine grad1_j12_r1_seq(r1, n_grid2, gradx, grady, gradz)
cycle cycle
endif endif
r12 = dsqrt(r12)
tmp = 1.d0 + a_boys * r12 tmp = 1.d0 + a_boys * r12
tmp = 0.5d0 / (r12 * tmp * tmp) tmp = 0.5d0 / (r12 * tmp * tmp)
@ -281,6 +294,24 @@ subroutine grad1_j12_r1_seq(r1, n_grid2, gradx, grady, gradz)
elseif(j2e_type .eq. "Boys_Handy") then elseif(j2e_type .eq. "Boys_Handy") then
integer :: powmax
powmax = max(maxval(jBH_m),maxval(jBH_n))
double precision, allocatable :: f1A_power(:), f2A_power(:), double_p(:), g12_power(:)
allocate (f1A_power(-1:powmax), f2A_power(-1:powmax), g12_power(-1:powmax), double_p(0:powmax))
do p=0,powmax
double_p(p) = dble(p)
enddo
f1A_power(-1) = 0.d0
f2A_power(-1) = 0.d0
g12_power(-1) = 0.d0
f1A_power(0) = 1.d0
f2A_power(0) = 1.d0
g12_power(0) = 1.d0
do jpoint = 1, n_points_extra_final_grid ! r2 do jpoint = 1, n_points_extra_final_grid ! r2
r2(1) = final_grid_points_extra(1,jpoint) r2(1) = final_grid_points_extra(1,jpoint)
@ -290,15 +321,33 @@ subroutine grad1_j12_r1_seq(r1, n_grid2, gradx, grady, gradz)
gradx(jpoint) = 0.d0 gradx(jpoint) = 0.d0
grady(jpoint) = 0.d0 grady(jpoint) = 0.d0
gradz(jpoint) = 0.d0 gradz(jpoint) = 0.d0
do i_nucl = 1, nucl_num do i_nucl = 1, nucl_num
rn(1) = nucl_coord(i_nucl,1) rn(1) = nucl_coord(i_nucl,1)
rn(2) = nucl_coord(i_nucl,2) rn(2) = nucl_coord(i_nucl,2)
rn(3) = nucl_coord(i_nucl,3) rn(3) = nucl_coord(i_nucl,3)
call jBH_elem_fct_grad(jBH_en(i_nucl), r1, rn, f1A, gard1_f1A) call jBH_elem_fct_grad(jBH_en(i_nucl), r1, rn, f1A, grad1_f1A)
call jBH_elem_fct_grad(jBH_en(i_nucl), r2, rn, f2A, gard2_f2A) call jBH_elem_fct_grad(jBH_en(i_nucl), r2, rn, f2A, grad2_f2A)
call jBH_elem_fct_grad(jBH_ee(i_nucl), r1, r2, g12, gard1_g12) call jBH_elem_fct_grad(jBH_ee(i_nucl), r1, r2, g12, grad1_g12)
! Compute powers of f1A and f2A
do p = 1, maxval(jBH_m(:,i_nucl))
f1A_power(p) = f1A_power(p-1) * f1A
enddo
do p = 1, maxval(jBH_n(:,i_nucl))
f2A_power(p) = f2A_power(p-1) * f2A
enddo
do p = 1, maxval(jBH_o(:,i_nucl))
g12_power(p) = g12_power(p-1) * g12
enddo
do p = 1, jBH_size do p = 1, jBH_size
mpA = jBH_m(p,i_nucl) mpA = jBH_m(p,i_nucl)
@ -309,23 +358,31 @@ subroutine grad1_j12_r1_seq(r1, n_grid2, gradx, grady, gradz)
tmp = tmp * 0.5d0 tmp = tmp * 0.5d0
endif endif
tmp1 = 0.d0 !TODO : Powers to optimize here
if(mpA .gt. 0) then
tmp1 = tmp1 + dble(mpA) * f1A**dble(mpA-1) * f2A**dble(npA)
endif
if(npA .gt. 0) then
tmp1 = tmp1 + dble(npA) * f1A**dble(npA-1) * f2A**dble(mpA)
endif
tmp1 = tmp1 * g12**dble(opA)
tmp2 = 0.d0 ! tmp1 = 0.d0
if(opA .gt. 0) then ! if(mpA .gt. 0) then
tmp2 = tmp2 + dble(opA) * g12**dble(opA-1) * (f1A**dble(mpA) * f2A**dble(npA) + f1A**dble(npA) * f2A**dble(mpA)) ! tmp1 = tmp1 + dble(mpA) * f1A**(mpA-1) * f2A**npA
endif ! endif
! if(npA .gt. 0) then
! tmp1 = tmp1 + dble(npA) * f1A**(npA-1) * f2A**mpA
! endif
! tmp1 = tmp1 * g12**(opA)
!
! tmp2 = 0.d0
! if(opA .gt. 0) then
! tmp2 = tmp2 + dble(opA) * g12**(opA-1) * (f1A**(mpA) * f2A**(npA) + f1A**(npA) * f2A**(mpA))
! endif
gradx(jpoint) = gradx(jpoint) + tmp * (tmp1 * gard1_f1A(1) + tmp2 * gard1_g12(1)) tmp1 = double_p(mpA) * f1A_power(mpA-1) * f2A_power(npA) + double_p(npA) * f1A_power(npA-1) * f2A_power(mpA)
grady(jpoint) = grady(jpoint) + tmp * (tmp1 * gard1_f1A(2) + tmp2 * gard1_g12(2)) tmp1 = tmp1 * g12_power(opA)
gradz(jpoint) = gradz(jpoint) + tmp * (tmp1 * gard1_f1A(3) + tmp2 * gard1_g12(3))
tmp2 = double_p(opA) * g12_power(opA-1) * (f1A_power(mpA) * f2A_power(npA) + f1A_power(npA) * f2A_power(mpA))
gradx(jpoint) = gradx(jpoint) + tmp * (tmp1 * grad1_f1A(1) + tmp2 * grad1_g12(1))
grady(jpoint) = grady(jpoint) + tmp * (tmp1 * grad1_f1A(2) + tmp2 * grad1_g12(2))
gradz(jpoint) = gradz(jpoint) + tmp * (tmp1 * grad1_f1A(3) + tmp2 * grad1_g12(3))
enddo ! p enddo ! p
enddo ! i_nucl enddo ! i_nucl
enddo ! jpoint enddo ! jpoint
@ -361,7 +418,7 @@ subroutine grad1_jmu_r1_seq(mu, r1, n_grid2, gradx, grady, gradz)
integer :: jpoint integer :: jpoint
double precision :: r2(3) double precision :: r2(3)
double precision :: dx, dy, dz, r12, tmp double precision :: dx, dy, dz, r12, r12_inv, tmp
do jpoint = 1, n_points_extra_final_grid ! r2 do jpoint = 1, n_points_extra_final_grid ! r2
@ -374,15 +431,19 @@ subroutine grad1_jmu_r1_seq(mu, r1, n_grid2, gradx, grady, gradz)
dy = r1(2) - r2(2) dy = r1(2) - r2(2)
dz = r1(3) - r2(3) dz = r1(3) - r2(3)
r12 = dsqrt(dx * dx + dy * dy + dz * dz) r12 = dx * dx + dy * dy + dz * dz
if(r12 .lt. 1d-10) then
if(r12 .lt. 1d-20) then
gradx(jpoint) = 0.d0 gradx(jpoint) = 0.d0
grady(jpoint) = 0.d0 grady(jpoint) = 0.d0
gradz(jpoint) = 0.d0 gradz(jpoint) = 0.d0
cycle cycle
endif endif
tmp = 0.5d0 * (1.d0 - derf(mu * r12)) / r12 r12_inv = 1.d0 / dsqrt(r12)
r12 = r12 * r12_inv
tmp = 0.5d0 * (1.d0 - derf(mu * r12)) * r12_inv
gradx(jpoint) = tmp * dx gradx(jpoint) = tmp * dx
grady(jpoint) = tmp * dy grady(jpoint) = tmp * dy
@ -406,7 +467,7 @@ subroutine j12_r1_seq(r1, n_grid2, res)
integer :: jpoint integer :: jpoint
double precision :: r2(3) double precision :: r2(3)
double precision :: dx, dy, dz double precision :: dx, dy, dz
double precision :: mu_tmp, r12 double precision :: mu_tmp, r12, mu_erf_inv
PROVIDE final_grid_points_extra PROVIDE final_grid_points_extra
@ -414,6 +475,7 @@ subroutine j12_r1_seq(r1, n_grid2, res)
PROVIDE mu_erf PROVIDE mu_erf
mu_erf_inv = 1.d0 / mu_erf
do jpoint = 1, n_points_extra_final_grid ! r2 do jpoint = 1, n_points_extra_final_grid ! r2
r2(1) = final_grid_points_extra(1,jpoint) r2(1) = final_grid_points_extra(1,jpoint)
@ -427,7 +489,7 @@ subroutine j12_r1_seq(r1, n_grid2, res)
mu_tmp = mu_erf * r12 mu_tmp = mu_erf * r12
res(jpoint) = 0.5d0 * r12 * (1.d0 - derf(mu_tmp)) - inv_sq_pi_2 * dexp(-mu_tmp*mu_tmp) / mu_erf res(jpoint) = 0.5d0 * r12 * (1.d0 - derf(mu_tmp)) - inv_sq_pi_2 * dexp(-mu_tmp*mu_tmp) * mu_erf_inv
enddo enddo
elseif(j2e_type .eq. "Boys") then elseif(j2e_type .eq. "Boys") then
@ -820,11 +882,11 @@ end
! --- ! ---
subroutine jBH_elem_fct_grad(alpha, r1, r2, fct, gard1_fct) subroutine jBH_elem_fct_grad(alpha, r1, r2, fct, grad1_fct)
implicit none implicit none
double precision, intent(in) :: alpha, r1(3), r2(3) double precision, intent(in) :: alpha, r1(3), r2(3)
double precision, intent(out) :: fct, gard1_fct(3) double precision, intent(out) :: fct, grad1_fct(3)
double precision :: dist, tmp1, tmp2 double precision :: dist, tmp1, tmp2
dist = dsqrt( (r1(1) - r2(1)) * (r1(1) - r2(1)) & dist = dsqrt( (r1(1) - r2(1)) * (r1(1) - r2(1)) &
@ -836,14 +898,14 @@ subroutine jBH_elem_fct_grad(alpha, r1, r2, fct, gard1_fct)
fct = alpha * dist * tmp1 fct = alpha * dist * tmp1
if(dist .lt. 1d-10) then if(dist .lt. 1d-10) then
gard1_fct(1) = 0.d0 grad1_fct(1) = 0.d0
gard1_fct(2) = 0.d0 grad1_fct(2) = 0.d0
gard1_fct(3) = 0.d0 grad1_fct(3) = 0.d0
else else
tmp2 = alpha * tmp1 * tmp1 / dist tmp2 = alpha * tmp1 * tmp1 / dist
gard1_fct(1) = tmp2 * (r1(1) - r2(1)) grad1_fct(1) = tmp2 * (r1(1) - r2(1))
gard1_fct(2) = tmp2 * (r1(2) - r2(2)) grad1_fct(2) = tmp2 * (r1(2) - r2(2))
gard1_fct(3) = tmp2 * (r1(3) - r2(3)) grad1_fct(3) = tmp2 * (r1(3) - r2(3))
endif endif
return return