10
0
mirror of https://github.com/LCPQ/quantum_package synced 2025-01-12 22:18:31 +01:00

Added variance and norm of PT

This commit is contained in:
Anthony Scemama 2018-10-29 14:35:29 +01:00
parent 7a0a14854a
commit 9a70059a11
7 changed files with 184 additions and 51 deletions

View File

@ -1,12 +1,12 @@
program fci_zmq program fci_zmq
implicit none implicit none
integer :: i,j,k integer :: i,j,k
double precision, allocatable :: pt2(:) double precision, allocatable :: pt2(:), variance(:), norm(:)
integer :: degree integer :: degree
integer :: n_det_before, to_select integer :: n_det_before, to_select
double precision :: threshold_davidson_in double precision :: threshold_davidson_in
allocate (pt2(N_states)) allocate (pt2(N_states), norm(N_states), variance(N_states))
double precision :: hf_energy_ref double precision :: hf_energy_ref
logical :: has logical :: has
@ -16,6 +16,8 @@ program fci_zmq
relative_error=PT2_relative_error relative_error=PT2_relative_error
pt2 = -huge(1.e0) pt2 = -huge(1.e0)
norm = 0.d0
variance = huge(1.e0)
threshold_davidson_in = threshold_davidson threshold_davidson_in = threshold_davidson
threshold_davidson = threshold_davidson_in * 100.d0 threshold_davidson = threshold_davidson_in * 100.d0
SOFT_TOUCH threshold_davidson SOFT_TOUCH threshold_davidson
@ -69,10 +71,12 @@ program fci_zmq
if (do_pt2) then if (do_pt2) then
pt2 = 0.d0 pt2 = 0.d0
variance = 0.d0
norm = 0.d0
threshold_selectors = 1.d0 threshold_selectors = 1.d0
threshold_generators = 1.d0 threshold_generators = 1.d0
SOFT_TOUCH threshold_selectors threshold_generators SOFT_TOUCH threshold_selectors threshold_generators
call ZMQ_pt2(CI_energy(1:N_states), pt2,relative_error,error) ! Stochastic PT2 call ZMQ_pt2(CI_energy(1:N_states),pt2,relative_error,error, variance, norm) ! Stochastic PT2
threshold_selectors = threshold_selectors_save threshold_selectors = threshold_selectors_save
threshold_generators = threshold_generators_save threshold_generators = threshold_generators_save
SOFT_TOUCH threshold_selectors threshold_generators SOFT_TOUCH threshold_selectors threshold_generators
@ -87,7 +91,7 @@ program fci_zmq
call ezfio_set_fci_energy_pt2(CI_energy(1:N_states)+pt2) call ezfio_set_fci_energy_pt2(CI_energy(1:N_states)+pt2)
call write_double(6,correlation_energy_ratio, 'Correlation ratio') call write_double(6,correlation_energy_ratio, 'Correlation ratio')
call print_summary(CI_energy(1:N_states),pt2,error) call print_summary(CI_energy(1:N_states),pt2,error,variance,norm)
call save_iterations(CI_energy(1:N_states),pt2,N_det) call save_iterations(CI_energy(1:N_states),pt2,N_det)
call print_extrapolated_energy(CI_energy(1:N_states),pt2) call print_extrapolated_energy(CI_energy(1:N_states),pt2)
N_iter += 1 N_iter += 1
@ -102,7 +106,7 @@ program fci_zmq
to_select = max(N_det, to_select) to_select = max(N_det, to_select)
to_select = min(to_select, N_det_max-n_det_before) to_select = min(to_select, N_det_max-n_det_before)
endif endif
call ZMQ_selection(to_select, pt2) call ZMQ_selection(to_select, pt2, variance, norm)
PROVIDE psi_coef PROVIDE psi_coef
PROVIDE psi_det PROVIDE psi_det
@ -130,7 +134,7 @@ program fci_zmq
threshold_selectors = 1.d0 threshold_selectors = 1.d0
threshold_generators = 1d0 threshold_generators = 1d0
SOFT_TOUCH threshold_selectors threshold_generators SOFT_TOUCH threshold_selectors threshold_generators
call ZMQ_pt2(CI_energy, pt2,relative_error,error) ! Stochastic PT2 call ZMQ_pt2(CI_energy, pt2,relative_error,error,variance,norm) ! Stochastic PT2
threshold_selectors = threshold_selectors_save threshold_selectors = threshold_selectors_save
threshold_generators = threshold_generators_save threshold_generators = threshold_generators_save
SOFT_TOUCH threshold_selectors threshold_generators SOFT_TOUCH threshold_selectors threshold_generators
@ -144,6 +148,6 @@ program fci_zmq
call save_iterations(CI_energy(1:N_states),pt2,N_det) call save_iterations(CI_energy(1:N_states),pt2,N_det)
call write_double(6,correlation_energy_ratio, 'Correlation ratio') call write_double(6,correlation_energy_ratio, 'Correlation ratio')
call print_summary(CI_energy(1:N_states),pt2,error) call print_summary(CI_energy(1:N_states),pt2,error,variance,norm)
end end

View File

@ -85,7 +85,7 @@ end function
subroutine ZMQ_pt2(E, pt2,relative_error, error) subroutine ZMQ_pt2(E, pt2,relative_error, error, variance, norm)
use f77_zmq use f77_zmq
use selection_types use selection_types
@ -95,17 +95,20 @@ subroutine ZMQ_pt2(E, pt2,relative_error, error)
integer, external :: omp_get_thread_num integer, external :: omp_get_thread_num
double precision, intent(in) :: relative_error, E(N_states) double precision, intent(in) :: relative_error, E(N_states)
double precision, intent(out) :: pt2(N_states),error(N_states) double precision, intent(out) :: pt2(N_states),error(N_states)
double precision, intent(out) :: variance(N_states),norm(N_states)
integer :: i integer :: i
double precision, external :: omp_get_wtime double precision, external :: omp_get_wtime
double precision :: state_average_weight_save(N_states), w(N_states) double precision :: state_average_weight_save(N_states), w(N_states,4)
integer(ZMQ_PTR), external :: new_zmq_to_qp_run_socket integer(ZMQ_PTR), external :: new_zmq_to_qp_run_socket
if (N_det < max(10,N_states)) then if (N_det < max(10,N_states)) then
pt2=0.d0 pt2=0.d0
call ZMQ_selection(0, pt2) variance=0.d0
norm=0.d0
call ZMQ_selection(0, pt2, variance, norm)
error(:) = 0.d0 error(:) = 0.d0
else else
@ -116,11 +119,6 @@ subroutine ZMQ_pt2(E, pt2,relative_error, error)
TOUCH state_average_weight pt2_stoch_istate TOUCH state_average_weight pt2_stoch_istate
provide nproc pt2_F mo_bielec_integrals_in_map mo_mono_elec_integral pt2_w psi_selectors provide nproc pt2_F mo_bielec_integrals_in_map mo_mono_elec_integral pt2_w psi_selectors
print *, '========== ================= ================= ================='
print *, ' Samples Energy Stat. Error Seconds '
print *, '========== ================= ================= ================='
call new_parallel_job(zmq_to_qp_run_socket, zmq_socket_pull, 'pt2') call new_parallel_job(zmq_to_qp_run_socket, zmq_socket_pull, 'pt2')
integer, external :: zmq_put_psi integer, external :: zmq_put_psi
@ -213,13 +211,21 @@ subroutine ZMQ_pt2(E, pt2,relative_error, error)
call omp_set_nested(.false.) call omp_set_nested(.false.)
print *, '========== ================= =========== =============== =============== ================='
print *, ' Samples Energy Stat. Err Variance Norm Seconds '
print *, '========== ================= =========== =============== =============== ================='
!$OMP PARALLEL DEFAULT(shared) NUM_THREADS(nproc_target+1) & !$OMP PARALLEL DEFAULT(shared) NUM_THREADS(nproc_target+1) &
!$OMP PRIVATE(i) !$OMP PRIVATE(i)
i = omp_get_thread_num() i = omp_get_thread_num()
if (i==0) then if (i==0) then
call pt2_collector(zmq_socket_pull, E(pt2_stoch_istate),relative_error, w, error) call pt2_collector(zmq_socket_pull, E(pt2_stoch_istate),relative_error, w(1,1), w(1,2), w(1,3), w(1,4))
pt2(pt2_stoch_istate) = w(pt2_stoch_istate) pt2(pt2_stoch_istate) = w(pt2_stoch_istate,1)
error(pt2_stoch_istate) = w(pt2_stoch_istate,2)
variance(pt2_stoch_istate) = w(pt2_stoch_istate,3)
norm(pt2_stoch_istate) = w(pt2_stoch_istate,4)
else else
call pt2_slave_inproc(i) call pt2_slave_inproc(i)
@ -251,7 +257,7 @@ subroutine pt2_slave_inproc(i)
end end
subroutine pt2_collector(zmq_socket_pull, E, relative_error, pt2, error) subroutine pt2_collector(zmq_socket_pull, E, relative_error, pt2, error, variance, norm)
use f77_zmq use f77_zmq
use selection_types use selection_types
use bitmasks use bitmasks
@ -261,9 +267,12 @@ subroutine pt2_collector(zmq_socket_pull, E, relative_error, pt2, error)
integer(ZMQ_PTR), intent(in) :: zmq_socket_pull integer(ZMQ_PTR), intent(in) :: zmq_socket_pull
double precision, intent(in) :: relative_error, E double precision, intent(in) :: relative_error, E
double precision, intent(out) :: pt2(N_states), error(N_states) double precision, intent(out) :: pt2(N_states), error(N_states)
double precision, intent(out) :: variance(N_states), norm(N_states)
double precision, allocatable :: eI(:,:), eI_task(:,:), S(:), S2(:) double precision, allocatable :: eI(:,:), eI_task(:,:), S(:), S2(:)
double precision, allocatable :: vI(:,:), vI_task(:,:), T2(:)
double precision, allocatable :: nI(:,:), nI_task(:,:), T3(:)
integer(ZMQ_PTR),external :: new_zmq_to_qp_run_socket integer(ZMQ_PTR),external :: new_zmq_to_qp_run_socket
integer(ZMQ_PTR) :: zmq_to_qp_run_socket integer(ZMQ_PTR) :: zmq_to_qp_run_socket
integer, external :: zmq_delete_tasks integer, external :: zmq_delete_tasks
@ -275,7 +284,7 @@ subroutine pt2_collector(zmq_socket_pull, E, relative_error, pt2, error)
integer, allocatable :: index(:) integer, allocatable :: index(:)
double precision, external :: omp_get_wtime double precision, external :: omp_get_wtime
double precision :: v, x, avg, eqt, E0 double precision :: v, x, x2, x3, avg, avg2, avg3, eqt, E0, v0, n0
double precision :: time, time0 double precision :: time, time0
integer, allocatable :: f(:) integer, allocatable :: f(:)
@ -287,19 +296,31 @@ subroutine pt2_collector(zmq_socket_pull, E, relative_error, pt2, error)
allocate(task_id(pt2_n_tasks_max), index(pt2_n_tasks_max), f(N_det_generators)) allocate(task_id(pt2_n_tasks_max), index(pt2_n_tasks_max), f(N_det_generators))
allocate(d(N_det_generators+1)) allocate(d(N_det_generators+1))
allocate(eI(N_states, N_det_generators), eI_task(N_states, pt2_n_tasks_max)) allocate(eI(N_states, N_det_generators), eI_task(N_states, pt2_n_tasks_max))
allocate(vI(N_states, N_det_generators), vI_task(N_states, pt2_n_tasks_max))
allocate(nI(N_states, N_det_generators), nI_task(N_states, pt2_n_tasks_max))
allocate(S(pt2_N_teeth+1), S2(pt2_N_teeth+1)) allocate(S(pt2_N_teeth+1), S2(pt2_N_teeth+1))
allocate(T2(pt2_N_teeth+1), T3(pt2_N_teeth+1))
pt2(:) = -huge(1.) pt2(:) = -huge(1.)
error(:) = huge(1.)
variance(:) = huge(1.)
norm(:) = 0.d0
S(:) = 0d0 S(:) = 0d0
S2(:) = 0d0 S2(:) = 0d0
T2(:) = 0d0
T3(:) = 0d0
n = 1 n = 1
t = 0 t = 0
U = 0 U = 0
eI(:,:) = 0d0 eI(:,:) = 0d0
vI(:,:) = 0d0
nI(:,:) = 0d0
f(:) = pt2_F(:) f(:) = pt2_F(:)
d(:) = .false. d(:) = .false.
n_tasks = 0 n_tasks = 0
E0 = E E0 = E
v0 = 0.d0
n0 = 0.d0
more = 1 more = 1
time0 = omp_get_wtime() time0 = omp_get_wtime()
@ -316,8 +337,12 @@ subroutine pt2_collector(zmq_socket_pull, E, relative_error, pt2, error)
if(U >= pt2_n_0(t+1)) then if(U >= pt2_n_0(t+1)) then
t=t+1 t=t+1
E0 = 0.d0 E0 = 0.d0
v0 = 0.d0
n0 = 0.d0
do i=pt2_n_0(t),1,-1 do i=pt2_n_0(t),1,-1
E0 += eI(pt2_stoch_istate, i) E0 += eI(pt2_stoch_istate, i)
v0 += vI(pt2_stoch_istate, i)
n0 += nI(pt2_stoch_istate, i)
end do end do
else else
exit exit
@ -328,25 +353,35 @@ subroutine pt2_collector(zmq_socket_pull, E, relative_error, pt2, error)
c = pt2_R(n) c = pt2_R(n)
if(c > 0) then if(c > 0) then
x = 0d0 x = 0d0
x2 = 0d0
x3 = 0d0
do p=pt2_N_teeth, 1, -1 do p=pt2_N_teeth, 1, -1
v = pt2_u_0 + pt2_W_T * (pt2_u(c) + dble(p-1)) v = pt2_u_0 + pt2_W_T * (pt2_u(c) + dble(p-1))
i = pt2_find_sample_lr(v, pt2_cW,pt2_n_0(p),pt2_n_0(p+1)) i = pt2_find_sample_lr(v, pt2_cW,pt2_n_0(p),pt2_n_0(p+1))
x += eI(pt2_stoch_istate, i) * pt2_W_T / pt2_w(i) x += eI(pt2_stoch_istate, i) * pt2_W_T / pt2_w(i)
x2 += vI(pt2_stoch_istate, i) * pt2_W_T / pt2_w(i)
x3 += nI(pt2_stoch_istate, i) * pt2_W_T / pt2_w(i)
S(p) += x S(p) += x
S2(p) += x*x S2(p) += x*x
T2(p) += x2
T3(p) += x3
end do end do
avg = E0 + S(t) / dble(c) avg = E0 + S(t) / dble(c)
avg2 = v0 + T2(t) / dble(c)
avg3 = n0 + T3(t) / dble(c)
if ((avg /= 0.d0) .or. (n == N_det_generators) ) then if ((avg /= 0.d0) .or. (n == N_det_generators) ) then
do_exit = .true. do_exit = .true.
endif endif
pt2(pt2_stoch_istate) = avg pt2(pt2_stoch_istate) = avg
variance(pt2_stoch_istate) = avg2 !- avg*avg
norm(pt2_stoch_istate) = avg3
! 1/(N-1.5) : see Brugger, The American Statistician (23) 4 p. 32 (1969) ! 1/(N-1.5) : see Brugger, The American Statistician (23) 4 p. 32 (1969)
if(c > 2) then if(c > 2) then
eqt = dabs((S2(t) / c) - (S(t)/c)**2) ! dabs for numerical stability eqt = dabs((S2(t) / c) - (S(t)/c)**2) ! dabs for numerical stability
eqt = sqrt(eqt / (dble(c) - 1.5d0)) eqt = sqrt(eqt / (dble(c) - 1.5d0))
error(pt2_stoch_istate) = eqt error(pt2_stoch_istate) = eqt
if(mod(c,10)==0 .or. n==N_det_generators) then if(mod(c,10)==0 .or. n==N_det_generators) then
print '(G10.3, 2X, F16.10, 2X, G16.3, 2X, F16.4, A20)', c, avg+E, eqt, time-time0, '' print '(G10.3, 2X, F16.10, 2X, G10.3, 2X, F14.10, 2X, F14.10, 2X, F10.4, A10)', c, avg+E, eqt, avg2, avg3, time-time0, ''
if(do_exit .and. (dabs(error(pt2_stoch_istate)) / (1.d-20 + dabs(pt2(pt2_stoch_istate)) ) <= relative_error)) then if(do_exit .and. (dabs(error(pt2_stoch_istate)) / (1.d-20 + dabs(pt2(pt2_stoch_istate)) ) <= relative_error)) then
if (zmq_abort(zmq_to_qp_run_socket) == -1) then if (zmq_abort(zmq_to_qp_run_socket) == -1) then
call sleep(10) call sleep(10)
@ -363,12 +398,14 @@ subroutine pt2_collector(zmq_socket_pull, E, relative_error, pt2, error)
else if(more == 0) then else if(more == 0) then
exit exit
else else
call pull_pt2_results(zmq_socket_pull, index, eI_task, task_id, n_tasks) call pull_pt2_results(zmq_socket_pull, index, eI_task, vI_task, nI_task, task_id, n_tasks)
if (zmq_delete_tasks(zmq_to_qp_run_socket,zmq_socket_pull,task_id,n_tasks,more) == -1) then if (zmq_delete_tasks(zmq_to_qp_run_socket,zmq_socket_pull,task_id,n_tasks,more) == -1) then
stop 'Unable to delete tasks' stop 'Unable to delete tasks'
endif endif
do i=1,n_tasks do i=1,n_tasks
eI(:, index(i)) += eI_task(:, i) eI(:, index(i)) += eI_task(:,i)
vI(:, index(i)) += vI_task(:,i)
nI(:, index(i)) += nI_task(:,i)
f(index(i)) -= 1 f(index(i)) -= 1
end do end do
end if end if
@ -465,6 +502,10 @@ BEGIN_PROVIDER[ double precision, pt2_u, (N_det_generators)]
U = 0 U = 0
do while(N_j < pt2_n_tasks) do while(N_j < pt2_n_tasks)
if (N_c+ncache > N_det_generators) then
ncache = N_det_generators - N_c
endif
!$OMP PARALLEL DO DEFAULT(SHARED) PRIVATE(dt,v,t,k) !$OMP PARALLEL DO DEFAULT(SHARED) PRIVATE(dt,v,t,k)
do k=1, ncache do k=1, ncache
dt = pt2_u_0 dt = pt2_u_0

View File

@ -21,12 +21,14 @@ subroutine run_pt2_slave(thread,iproc,energy)
type(selection_buffer) :: buf type(selection_buffer) :: buf
logical :: done logical :: done
double precision,allocatable :: pt2(:,:) double precision,allocatable :: pt2(:,:), variance(:,:), norm(:,:)
integer :: n_tasks, k integer :: n_tasks, k
integer, allocatable :: i_generator(:), subset(:) integer, allocatable :: i_generator(:), subset(:)
allocate(task_id(pt2_n_tasks_max), task(pt2_n_tasks_max)) allocate(task_id(pt2_n_tasks_max), task(pt2_n_tasks_max))
allocate(pt2(N_states,pt2_n_tasks_max), i_generator(pt2_n_tasks_max), subset(pt2_n_tasks_max)) allocate(pt2(N_states,pt2_n_tasks_max), i_generator(pt2_n_tasks_max), subset(pt2_n_tasks_max))
allocate(variance(N_states,pt2_n_tasks_max))
allocate(norm(N_states,pt2_n_tasks_max))
zmq_to_qp_run_socket = new_zmq_to_qp_run_socket() zmq_to_qp_run_socket = new_zmq_to_qp_run_socket()
@ -65,10 +67,12 @@ subroutine run_pt2_slave(thread,iproc,energy)
call wall_time(time0) call wall_time(time0)
do k=1,n_tasks do k=1,n_tasks
pt2(:,k) = 0.d0 pt2(:,k) = 0.d0
variance(:,k) = 0.d0
norm(:,k) = 0.d0
buf%cur = 0 buf%cur = 0
!double precision :: time2 !double precision :: time2
!call wall_time(time2) !call wall_time(time2)
call select_connected(i_generator(k),energy,pt2(1,k),buf,subset(k),pt2_F(i_generator(k))) call select_connected(i_generator(k),energy,pt2(1,k),variance(1,k),norm(1,k),buf,subset(k),pt2_F(i_generator(k)))
!call wall_time(time1) !call wall_time(time1)
!print *, i_generator(1), time1-time2, n_tasks, pt2_F(i_generator(1)) !print *, i_generator(1), time1-time2, n_tasks, pt2_F(i_generator(1))
enddo enddo
@ -79,7 +83,7 @@ subroutine run_pt2_slave(thread,iproc,energy)
if (tasks_done_to_taskserver(zmq_to_qp_run_socket,worker_id,task_id,n_tasks) == -1) then if (tasks_done_to_taskserver(zmq_to_qp_run_socket,worker_id,task_id,n_tasks) == -1) then
done = .true. done = .true.
endif endif
call push_pt2_results(zmq_socket_push, i_generator, pt2, task_id, n_tasks) call push_pt2_results(zmq_socket_push, i_generator, pt2, variance, norm, task_id, n_tasks)
! Try to adjust n_tasks around nproc/8 seconds per job ! Try to adjust n_tasks around nproc/8 seconds per job
n_tasks = min(2*n_tasks,int( dble(n_tasks * nproc/8) / (time1 - time0 + 1.d0))) n_tasks = min(2*n_tasks,int( dble(n_tasks * nproc/8) / (time1 - time0 + 1.d0)))
@ -98,13 +102,15 @@ subroutine run_pt2_slave(thread,iproc,energy)
end subroutine end subroutine
subroutine push_pt2_results(zmq_socket_push, index, pt2, task_id, n_tasks) subroutine push_pt2_results(zmq_socket_push, index, pt2, variance, norm, task_id, n_tasks)
use f77_zmq use f77_zmq
use selection_types use selection_types
implicit none implicit none
integer(ZMQ_PTR), intent(in) :: zmq_socket_push integer(ZMQ_PTR), intent(in) :: zmq_socket_push
double precision, intent(in) :: pt2(N_states,n_tasks) double precision, intent(in) :: pt2(N_states,n_tasks)
double precision, intent(in) :: variance(N_states,n_tasks)
double precision, intent(in) :: norm(N_states,n_tasks)
integer, intent(in) :: n_tasks, index(n_tasks), task_id(n_tasks) integer, intent(in) :: n_tasks, index(n_tasks), task_id(n_tasks)
integer :: rc integer :: rc
@ -128,6 +134,18 @@ subroutine push_pt2_results(zmq_socket_push, index, pt2, task_id, n_tasks)
endif endif
if(rc /= 8*N_states*n_tasks) stop 'push' if(rc /= 8*N_states*n_tasks) stop 'push'
rc = f77_zmq_send( zmq_socket_push, variance, 8*N_states*n_tasks, ZMQ_SNDMORE)
if (rc == -1) then
return
endif
if(rc /= 8*N_states*n_tasks) stop 'push'
rc = f77_zmq_send( zmq_socket_push, norm, 8*N_states*n_tasks, ZMQ_SNDMORE)
if (rc == -1) then
return
endif
if(rc /= 8*N_states*n_tasks) stop 'push'
rc = f77_zmq_send( zmq_socket_push, task_id, n_tasks*4, 0) rc = f77_zmq_send( zmq_socket_push, task_id, n_tasks*4, 0)
if (rc == -1) then if (rc == -1) then
return return
@ -151,12 +169,14 @@ IRP_ENDIF
end subroutine end subroutine
subroutine pull_pt2_results(zmq_socket_pull, index, pt2, task_id, n_tasks) subroutine pull_pt2_results(zmq_socket_pull, index, pt2, variance, norm, task_id, n_tasks)
use f77_zmq use f77_zmq
use selection_types use selection_types
implicit none implicit none
integer(ZMQ_PTR), intent(in) :: zmq_socket_pull integer(ZMQ_PTR), intent(in) :: zmq_socket_pull
double precision, intent(inout) :: pt2(N_states,*) double precision, intent(inout) :: pt2(N_states,*)
double precision, intent(inout) :: variance(N_states,*)
double precision, intent(inout) :: norm(N_states,*)
integer, intent(out) :: index(*) integer, intent(out) :: index(*)
integer, intent(out) :: n_tasks, task_id(*) integer, intent(out) :: n_tasks, task_id(*)
integer :: rc, rn, i integer :: rc, rn, i
@ -182,6 +202,20 @@ subroutine pull_pt2_results(zmq_socket_pull, index, pt2, task_id, n_tasks)
endif endif
if(rc /= 8*N_states*n_tasks) stop 'pull' if(rc /= 8*N_states*n_tasks) stop 'pull'
rc = f77_zmq_recv( zmq_socket_pull, variance, N_states*8*n_tasks, 0)
if (rc == -1) then
n_tasks = 1
task_id(1) = 0
endif
if(rc /= 8*N_states*n_tasks) stop 'pull'
rc = f77_zmq_recv( zmq_socket_pull, norm, N_states*8*n_tasks, 0)
if (rc == -1) then
n_tasks = 1
task_id(1) = 0
endif
if(rc /= 8*N_states*n_tasks) stop 'pull'
rc = f77_zmq_recv( zmq_socket_pull, task_id, n_tasks*4, 0) rc = f77_zmq_recv( zmq_socket_pull, task_id, n_tasks*4, 0)
if (rc == -1) then if (rc == -1) then
n_tasks = 1 n_tasks = 1
@ -205,5 +239,3 @@ IRP_ENDIF
end subroutine end subroutine

View File

@ -19,6 +19,8 @@ subroutine run_selection_slave(thread,iproc,energy)
type(selection_buffer) :: buf, buf2 type(selection_buffer) :: buf, buf2
logical :: done, buffer_ready logical :: done, buffer_ready
double precision :: pt2(N_states) double precision :: pt2(N_states)
double precision :: variance(N_states)
double precision :: norm(N_states)
PROVIDE psi_bilinear_matrix_columns_loc psi_det_alpha_unique psi_det_beta_unique PROVIDE psi_bilinear_matrix_columns_loc psi_det_alpha_unique psi_det_beta_unique
PROVIDE psi_bilinear_matrix_rows psi_det_sorted_order psi_bilinear_matrix_order PROVIDE psi_bilinear_matrix_rows psi_det_sorted_order psi_bilinear_matrix_order
@ -39,6 +41,8 @@ subroutine run_selection_slave(thread,iproc,energy)
buffer_ready = .False. buffer_ready = .False.
ctask = 1 ctask = 1
pt2(:) = 0d0 pt2(:) = 0d0
variance(:) = 0d0
norm(:) = 0.d0
do do
integer, external :: get_task_from_taskserver integer, external :: get_task_from_taskserver
@ -59,7 +63,7 @@ subroutine run_selection_slave(thread,iproc,energy)
else else
ASSERT (N == buf%N) ASSERT (N == buf%N)
end if end if
call select_connected(i_generator,energy,pt2,buf,subset,pt2_F(i_generator)) call select_connected(i_generator,energy,pt2,variance,norm,buf,subset,pt2_F(i_generator))
endif endif
integer, external :: task_done_to_taskserver integer, external :: task_done_to_taskserver
@ -78,9 +82,11 @@ subroutine run_selection_slave(thread,iproc,energy)
if(ctask > 0) then if(ctask > 0) then
call sort_selection_buffer(buf) call sort_selection_buffer(buf)
call merge_selection_buffers(buf,buf2) call merge_selection_buffers(buf,buf2)
call push_selection_results(zmq_socket_push, pt2, buf, task_id(1), ctask) call push_selection_results(zmq_socket_push, pt2, variance, norm, buf, task_id(1), ctask)
buf%mini = buf2%mini buf%mini = buf2%mini
pt2(:) = 0d0 pt2(:) = 0d0
variance(:) = 0d0
norm(:) = 0d0
buf%cur = 0 buf%cur = 0
end if end if
ctask = 0 ctask = 0
@ -105,13 +111,15 @@ subroutine run_selection_slave(thread,iproc,energy)
end subroutine end subroutine
subroutine push_selection_results(zmq_socket_push, pt2, b, task_id, ntask) subroutine push_selection_results(zmq_socket_push, pt2, variance, norm, b, task_id, ntask)
use f77_zmq use f77_zmq
use selection_types use selection_types
implicit none implicit none
integer(ZMQ_PTR), intent(in) :: zmq_socket_push integer(ZMQ_PTR), intent(in) :: zmq_socket_push
double precision, intent(in) :: pt2(N_states) double precision, intent(in) :: pt2(N_states)
double precision, intent(in) :: variance(N_states)
double precision, intent(in) :: norm(N_states)
type(selection_buffer), intent(inout) :: b type(selection_buffer), intent(inout) :: b
integer, intent(in) :: ntask, task_id(*) integer, intent(in) :: ntask, task_id(*)
integer :: rc integer :: rc
@ -128,6 +136,16 @@ subroutine push_selection_results(zmq_socket_push, pt2, b, task_id, ntask)
print *, 'f77_zmq_send( zmq_socket_push, pt2, 8*N_states, ZMQ_SNDMORE)' print *, 'f77_zmq_send( zmq_socket_push, pt2, 8*N_states, ZMQ_SNDMORE)'
endif endif
rc = f77_zmq_send( zmq_socket_push, variance, 8*N_states, ZMQ_SNDMORE)
if(rc /= 8*N_states) then
print *, 'f77_zmq_send( zmq_socket_push, variance, 8*N_states, ZMQ_SNDMORE)'
endif
rc = f77_zmq_send( zmq_socket_push, norm, 8*N_states, ZMQ_SNDMORE)
if(rc /= 8*N_states) then
print *, 'f77_zmq_send( zmq_socket_push, norm, 8*N_states, ZMQ_SNDMORE)'
endif
rc = f77_zmq_send( zmq_socket_push, b%val(1), 8*b%cur, ZMQ_SNDMORE) rc = f77_zmq_send( zmq_socket_push, b%val(1), 8*b%cur, ZMQ_SNDMORE)
if(rc /= 8*b%cur) then if(rc /= 8*b%cur) then
print *, 'f77_zmq_send( zmq_socket_push, b%val(1), 8*b%cur, ZMQ_SNDMORE)' print *, 'f77_zmq_send( zmq_socket_push, b%val(1), 8*b%cur, ZMQ_SNDMORE)'
@ -164,12 +182,14 @@ IRP_ENDIF
end subroutine end subroutine
subroutine pull_selection_results(zmq_socket_pull, pt2, val, det, N, task_id, ntask) subroutine pull_selection_results(zmq_socket_pull, pt2, variance, norm, val, det, N, task_id, ntask)
use f77_zmq use f77_zmq
use selection_types use selection_types
implicit none implicit none
integer(ZMQ_PTR), intent(in) :: zmq_socket_pull integer(ZMQ_PTR), intent(in) :: zmq_socket_pull
double precision, intent(inout) :: pt2(N_states) double precision, intent(inout) :: pt2(N_states)
double precision, intent(inout) :: variance(N_states)
double precision, intent(inout) :: norm(N_states)
double precision, intent(out) :: val(*) double precision, intent(out) :: val(*)
integer(bit_kind), intent(out) :: det(N_int, 2, *) integer(bit_kind), intent(out) :: det(N_int, 2, *)
integer, intent(out) :: N, ntask, task_id(*) integer, intent(out) :: N, ntask, task_id(*)
@ -186,6 +206,16 @@ subroutine pull_selection_results(zmq_socket_pull, pt2, val, det, N, task_id, nt
print *, 'f77_zmq_recv( zmq_socket_pull, pt2, N_states*8, 0)' print *, 'f77_zmq_recv( zmq_socket_pull, pt2, N_states*8, 0)'
endif endif
rc = f77_zmq_recv( zmq_socket_pull, variance, N_states*8, 0)
if(rc /= 8*N_states) then
print *, 'f77_zmq_recv( zmq_socket_pull, variance, N_states*8, 0)'
endif
rc = f77_zmq_recv( zmq_socket_pull, norm, N_states*8, 0)
if(rc /= 8*N_states) then
print *, 'f77_zmq_recv( zmq_socket_pull, norm, N_states*8, 0)'
endif
rc = f77_zmq_recv( zmq_socket_pull, val(1), 8*N, 0) rc = f77_zmq_recv( zmq_socket_pull, val(1), 8*N, 0)
if(rc /= 8*N) then if(rc /= 8*N) then
print *, 'f77_zmq_recv( zmq_socket_pull, val(1), 8*N, 0)' print *, 'f77_zmq_recv( zmq_socket_pull, val(1), 8*N, 0)'

View File

@ -26,13 +26,15 @@ subroutine get_mask_phase(det1, pm, Nint)
end subroutine end subroutine
subroutine select_connected(i_generator,E0,pt2,b,subset,csubset) subroutine select_connected(i_generator,E0,pt2,variance,norm,b,subset,csubset)
use bitmasks use bitmasks
use selection_types use selection_types
implicit none implicit none
integer, intent(in) :: i_generator, subset, csubset integer, intent(in) :: i_generator, subset, csubset
type(selection_buffer), intent(inout) :: b type(selection_buffer), intent(inout) :: b
double precision, intent(inout) :: pt2(N_states) double precision, intent(inout) :: pt2(N_states)
double precision, intent(inout) :: variance(N_states)
double precision, intent(inout) :: norm(N_states)
integer :: k,l integer :: k,l
double precision, intent(in) :: E0(N_states) double precision, intent(in) :: E0(N_states)
@ -51,7 +53,7 @@ subroutine select_connected(i_generator,E0,pt2,b,subset,csubset)
particle_mask(k,1) = iand(generators_bitmask(k,1,s_part,l), not(psi_det_generators(k,1,i_generator)) ) particle_mask(k,1) = iand(generators_bitmask(k,1,s_part,l), not(psi_det_generators(k,1,i_generator)) )
particle_mask(k,2) = iand(generators_bitmask(k,2,s_part,l), not(psi_det_generators(k,2,i_generator)) ) particle_mask(k,2) = iand(generators_bitmask(k,2,s_part,l), not(psi_det_generators(k,2,i_generator)) )
enddo enddo
call select_singles_and_doubles(i_generator,hole_mask,particle_mask,fock_diag_tmp,E0,pt2,b,subset,csubset) call select_singles_and_doubles(i_generator,hole_mask,particle_mask,fock_diag_tmp,E0,pt2,variance,norm,b,subset,csubset)
enddo enddo
deallocate(fock_diag_tmp) deallocate(fock_diag_tmp)
end subroutine end subroutine
@ -287,7 +289,7 @@ subroutine get_m0(gen, phasemask, bannedOrb, vect, mask, h, p, sp, coefs)
end end
subroutine select_singles_and_doubles(i_generator,hole_mask,particle_mask,fock_diag_tmp,E0,pt2,buf,subset,csubset) subroutine select_singles_and_doubles(i_generator,hole_mask,particle_mask,fock_diag_tmp,E0,pt2,variance,norm,buf,subset,csubset)
use bitmasks use bitmasks
use selection_types use selection_types
implicit none implicit none
@ -300,6 +302,8 @@ subroutine select_singles_and_doubles(i_generator,hole_mask,particle_mask,fock_d
double precision, intent(in) :: fock_diag_tmp(mo_tot_num) double precision, intent(in) :: fock_diag_tmp(mo_tot_num)
double precision, intent(in) :: E0(N_states) double precision, intent(in) :: E0(N_states)
double precision, intent(inout) :: pt2(N_states) double precision, intent(inout) :: pt2(N_states)
double precision, intent(inout) :: variance(N_states)
double precision, intent(inout) :: norm(N_states)
type(selection_buffer), intent(inout) :: buf type(selection_buffer), intent(inout) :: buf
integer :: h1,h2,s1,s2,s3,i1,i2,ib,sp,k,i,j,nt,ii integer :: h1,h2,s1,s2,s3,i1,i2,ib,sp,k,i,j,nt,ii
@ -622,7 +626,7 @@ subroutine select_singles_and_doubles(i_generator,hole_mask,particle_mask,fock_d
call splash_pq(mask, sp, minilist, i_generator, interesting(0), bannedOrb, banned, mat, interesting) call splash_pq(mask, sp, minilist, i_generator, interesting(0), bannedOrb, banned, mat, interesting)
call fill_buffer_double(i_generator, sp, h1, h2, bannedOrb, banned, fock_diag_tmp, E0, pt2, mat, buf) call fill_buffer_double(i_generator, sp, h1, h2, bannedOrb, banned, fock_diag_tmp, E0, pt2, variance, norm, mat, buf)
end if end if
enddo enddo
if(s1 /= s2) monoBdo = .false. if(s1 /= s2) monoBdo = .false.
@ -635,7 +639,7 @@ end subroutine
subroutine fill_buffer_double(i_generator, sp, h1, h2, bannedOrb, banned, fock_diag_tmp, E0, pt2, mat, buf) subroutine fill_buffer_double(i_generator, sp, h1, h2, bannedOrb, banned, fock_diag_tmp, E0, pt2, variance, norm, mat, buf)
use bitmasks use bitmasks
use selection_types use selection_types
implicit none implicit none
@ -646,11 +650,13 @@ subroutine fill_buffer_double(i_generator, sp, h1, h2, bannedOrb, banned, fock_d
double precision, intent(in) :: fock_diag_tmp(mo_tot_num) double precision, intent(in) :: fock_diag_tmp(mo_tot_num)
double precision, intent(in) :: E0(N_states) double precision, intent(in) :: E0(N_states)
double precision, intent(inout) :: pt2(N_states) double precision, intent(inout) :: pt2(N_states)
double precision, intent(inout) :: variance(N_states)
double precision, intent(inout) :: norm(N_states)
type(selection_buffer), intent(inout) :: buf type(selection_buffer), intent(inout) :: buf
logical :: ok logical :: ok
integer :: s1, s2, p1, p2, ib, j, istate integer :: s1, s2, p1, p2, ib, j, istate
integer(bit_kind) :: mask(N_int, 2), det(N_int, 2) integer(bit_kind) :: mask(N_int, 2), det(N_int, 2)
double precision :: e_pert, delta_E, val, Hii, sum_e_pert, tmp double precision :: e_pert, delta_E, val, Hii, sum_e_pert, tmp, alpha_h_psi, coef
double precision, external :: diag_H_mat_elem_fock double precision, external :: diag_H_mat_elem_fock
logical, external :: detEq logical, external :: detEq
@ -681,14 +687,21 @@ subroutine fill_buffer_double(i_generator, sp, h1, h2, bannedOrb, banned, fock_d
do istate=1,N_states do istate=1,N_states
delta_E = E0(istate) - Hii delta_E = E0(istate) - Hii
val = mat(istate, p1, p2) + mat(istate, p1, p2) alpha_h_psi = mat(istate, p1, p2)
val = alpha_h_psi + alpha_h_psi
tmp = dsqrt(delta_E * delta_E + val * val) tmp = dsqrt(delta_E * delta_E + val * val)
if (delta_E < 0.d0) then if (delta_E < 0.d0) then
tmp = -tmp tmp = -tmp
endif endif
e_pert = 0.5d0 * (tmp - delta_E) e_pert = 0.5d0 * (tmp - delta_E)
coef = alpha_h_psi / delta_E
pt2(istate) = pt2(istate) + e_pert pt2(istate) = pt2(istate) + e_pert
sum_e_pert = sum_e_pert + e_pert * state_average_weight(istate) sum_e_pert = sum_e_pert + e_pert * state_average_weight(istate)
variance(istate) = variance(istate) + alpha_h_psi * alpha_h_psi * state_average_weight(istate)
norm(istate) = norm(istate) + coef * coef * state_average_weight(istate)
end do end do
if(sum_e_pert <= buf%mini) then if(sum_e_pert <= buf%mini) then

View File

@ -1,4 +1,4 @@
subroutine ZMQ_selection(N_in, pt2) subroutine ZMQ_selection(N_in, pt2, variance, norm)
use f77_zmq use f77_zmq
use selection_types use selection_types
@ -10,6 +10,8 @@ subroutine ZMQ_selection(N_in, pt2)
integer :: i, N integer :: i, N
integer, external :: omp_get_thread_num integer, external :: omp_get_thread_num
double precision, intent(out) :: pt2(N_states) double precision, intent(out) :: pt2(N_states)
double precision, intent(out) :: variance(N_states)
double precision, intent(out) :: norm(N_states)
N = max(N_in,1) N = max(N_in,1)
@ -103,10 +105,10 @@ subroutine ZMQ_selection(N_in, pt2)
f(:) = 1.d0 f(:) = 1.d0
endif endif
!$OMP PARALLEL DEFAULT(shared) SHARED(b, pt2) PRIVATE(i) NUM_THREADS(nproc_target+1) !$OMP PARALLEL DEFAULT(shared) SHARED(b, pt2, variance, norm) PRIVATE(i) NUM_THREADS(nproc_target+1)
i = omp_get_thread_num() i = omp_get_thread_num()
if (i==0) then if (i==0) then
call selection_collector(zmq_socket_pull, b, N, pt2) call selection_collector(zmq_socket_pull, b, N, pt2, variance, norm)
else else
call selection_slave_inproc(i) call selection_slave_inproc(i)
endif endif
@ -114,6 +116,8 @@ subroutine ZMQ_selection(N_in, pt2)
call end_parallel_job(zmq_to_qp_run_socket, zmq_socket_pull, 'selection') call end_parallel_job(zmq_to_qp_run_socket, zmq_socket_pull, 'selection')
do i=N_det+1,N_states do i=N_det+1,N_states
pt2(i) = 0.d0 pt2(i) = 0.d0
variance(i) = 0.d0
norm(i) = 0.d0
enddo enddo
if (N_in > 0) then if (N_in > 0) then
call fill_H_apply_buffer_no_selection(b%cur,b%det,N_int,0) call fill_H_apply_buffer_no_selection(b%cur,b%det,N_int,0)
@ -126,7 +130,10 @@ subroutine ZMQ_selection(N_in, pt2)
call delete_selection_buffer(b) call delete_selection_buffer(b)
do k=1,N_states do k=1,N_states
pt2(k) = pt2(k) * f(k) pt2(k) = pt2(k) * f(k)
variance(k) = variance(k) * f(k)
norm(k) = norm(k) * f(k)
enddo enddo
! variance = variance - pt2*pt2
end subroutine end subroutine
@ -138,7 +145,7 @@ subroutine selection_slave_inproc(i)
call run_selection_slave(1,i,pt2_e0_denominator) call run_selection_slave(1,i,pt2_e0_denominator)
end end
subroutine selection_collector(zmq_socket_pull, b, N, pt2) subroutine selection_collector(zmq_socket_pull, b, N, pt2, variance, norm)
use f77_zmq use f77_zmq
use selection_types use selection_types
use bitmasks use bitmasks
@ -150,6 +157,8 @@ subroutine selection_collector(zmq_socket_pull, b, N, pt2)
integer, intent(in) :: N integer, intent(in) :: N
double precision, intent(out) :: pt2(N_states) double precision, intent(out) :: pt2(N_states)
double precision :: pt2_mwen(N_states) double precision :: pt2_mwen(N_states)
double precision :: variance(N_states)
double precision :: norm(N_states)
integer(ZMQ_PTR),external :: new_zmq_to_qp_run_socket integer(ZMQ_PTR),external :: new_zmq_to_qp_run_socket
integer(ZMQ_PTR) :: zmq_to_qp_run_socket integer(ZMQ_PTR) :: zmq_to_qp_run_socket
@ -172,9 +181,11 @@ subroutine selection_collector(zmq_socket_pull, b, N, pt2)
pt2(:) = 0d0 pt2(:) = 0d0
pt2_mwen(:) = 0.d0 pt2_mwen(:) = 0.d0
do while (more == 1) do while (more == 1)
call pull_selection_results(zmq_socket_pull, pt2_mwen, b2%val(1), b2%det(1,1,1), b2%cur, task_id, ntask) call pull_selection_results(zmq_socket_pull, pt2_mwen, variance, norm, b2%val(1), b2%det(1,1,1), b2%cur, task_id, ntask)
pt2(:) += pt2_mwen(:) pt2(:) += pt2_mwen(:)
variance(:) += variance(:)
norm(:) += norm(:)
do i=1, b2%cur do i=1, b2%cur
call add_to_selection_buffer(b, b2%det(1,1,i), b2%val(i)) call add_to_selection_buffer(b, b2%det(1,1,i), b2%val(i))
if (b2%val(i) > b%mini) exit if (b2%val(i) > b%mini) exit

View File

@ -1,10 +1,10 @@
subroutine print_summary(e_,pt2_,error_) subroutine print_summary(e_,pt2_,error_,variance_,norm_)
implicit none implicit none
BEGIN_DOC BEGIN_DOC
! Print the extrapolated energy in the output ! Print the extrapolated energy in the output
END_DOC END_DOC
double precision, intent(in) :: e_(N_states), pt2_(N_states), error_(N_states) double precision, intent(in) :: e_(N_states), pt2_(N_states), variance_(N_states), norm_(N_states), error_(N_states)
integer :: i, k integer :: i, k
integer :: N_states_p integer :: N_states_p
character*(8) :: pt2_string character*(8) :: pt2_string
@ -55,6 +55,8 @@ subroutine print_summary(e_,pt2_,error_)
do k=1, N_states_p do k=1, N_states_p
print*,'State ',k print*,'State ',k
print *, 'Variance = ', variance_(k)
print *, 'PT norm = ', norm_(k)
print *, 'PT2 = ', pt2_(k) print *, 'PT2 = ', pt2_(k)
print *, 'E = ', e_(k) print *, 'E = ', e_(k)
print *, 'E+PT2'//pt2_string//' = ', e_(k)+pt2_(k), ' +/- ', error_(k) print *, 'E+PT2'//pt2_string//' = ', e_(k)+pt2_(k), ' +/- ', error_(k)