9
1
mirror of https://github.com/QuantumPackage/qp2.git synced 2024-12-21 11:03:29 +01:00

Fixed never stopping FCI

This commit is contained in:
Anthony Scemama 2020-04-06 00:03:59 +02:00
parent 9c9b219aba
commit a0e55498da
4 changed files with 38 additions and 26 deletions

View File

@ -32,7 +32,7 @@ OPENMP : 1 ; Append OpenMP flags
# #
[OPT] [OPT]
FC : -traceback FC : -traceback
FCFLAGS : -mavx -axAVX -O2 -ip -ftz -g FCFLAGS : -mavx -O2 -ip -ftz -g
# Profiling flags # Profiling flags
################# #################

View File

@ -11,7 +11,7 @@ END_PROVIDER
implicit none implicit none
logical, external :: testTeethBuilding logical, external :: testTeethBuilding
integer :: i,j integer :: i,j
pt2_n_tasks_max = elec_alpha_num*elec_alpha_num + elec_alpha_num*elec_beta_num - n_core_orb*2 pt2_n_tasks_max = elec_alpha_num*elec_alpha_num + elec_alpha_num*elec_beta_num - n_core_orb*2
pt2_n_tasks_max = min(pt2_n_tasks_max,1+N_det_generators/10000) pt2_n_tasks_max = min(pt2_n_tasks_max,1+N_det_generators/10000)
call write_int(6,pt2_n_tasks_max,'pt2_n_tasks_max') call write_int(6,pt2_n_tasks_max,'pt2_n_tasks_max')
@ -88,7 +88,7 @@ logical function testTeethBuilding(minF, N)
do do
u0 = tilde_cW(n0) u0 = tilde_cW(n0)
r = tilde_cW(n0 + minF) r = tilde_cW(n0 + minF)
Wt = (1d0 - u0) * f Wt = (1d0 - u0) * f
if (dabs(Wt) <= 1.d-3) then if (dabs(Wt) <= 1.d-3) then
exit exit
endif endif
@ -115,6 +115,7 @@ subroutine ZMQ_pt2(E, pt2,relative_error, error, variance, norm, N_in)
integer(ZMQ_PTR) :: zmq_to_qp_run_socket, zmq_socket_pull integer(ZMQ_PTR) :: zmq_to_qp_run_socket, zmq_socket_pull
integer, intent(in) :: N_in integer, intent(in) :: N_in
! integer, intent(inout) :: N_in
double precision, intent(in) :: relative_error, E(N_states) double precision, intent(in) :: relative_error, E(N_states)
double precision, intent(out) :: pt2(N_states),error(N_states) double precision, intent(out) :: pt2(N_states),error(N_states)
double precision, intent(out) :: variance(N_states),norm(N_states) double precision, intent(out) :: variance(N_states),norm(N_states)
@ -136,7 +137,7 @@ subroutine ZMQ_pt2(E, pt2,relative_error, error, variance, norm, N_in)
PROVIDE psi_occ_pattern_hii det_to_occ_pattern PROVIDE psi_occ_pattern_hii det_to_occ_pattern
endif endif
if (N_det <= max(4,N_states)) then if (N_det <= max(4,N_states) .or. pt2_N_teeth < 2) then
pt2=0.d0 pt2=0.d0
variance=0.d0 variance=0.d0
norm=0.d0 norm=0.d0
@ -296,7 +297,7 @@ subroutine ZMQ_pt2(E, pt2,relative_error, error, variance, norm, N_in)
print '(A)', ' Samples Energy Stat. Err Variance Norm Seconds ' print '(A)', ' Samples Energy Stat. Err Variance Norm Seconds '
print '(A)', '========== ================= =========== =============== =============== =================' print '(A)', '========== ================= =========== =============== =============== ================='
PROVIDE global_selection_buffer PROVIDE global_selection_buffer
!$OMP PARALLEL DEFAULT(shared) NUM_THREADS(nproc_target+1) & !$OMP PARALLEL DEFAULT(shared) NUM_THREADS(nproc_target+1) &
!$OMP PRIVATE(i) !$OMP PRIVATE(i)
i = omp_get_thread_num() i = omp_get_thread_num()
@ -346,7 +347,7 @@ subroutine pt2_slave_inproc(i)
implicit none implicit none
integer, intent(in) :: i integer, intent(in) :: i
PROVIDE global_selection_buffer PROVIDE global_selection_buffer
call run_pt2_slave(1,i,pt2_e0_denominator) call run_pt2_slave(1,i,pt2_e0_denominator)
end end
@ -528,8 +529,8 @@ subroutine pt2_collector(zmq_socket_pull, E, relative_error, pt2, error, varianc
print*,'PB !!!' print*,'PB !!!'
print*,'If you see this, send an email to Anthony scemama with the following content' print*,'If you see this, send an email to Anthony scemama with the following content'
print*,irp_here print*,irp_here
print*,'n_tasks,pt2_n_tasks_max = ',n_tasks,pt2_n_tasks_max print*,'n_tasks,pt2_n_tasks_max = ',n_tasks,pt2_n_tasks_max
stop -1 stop -1
endif endif
if (zmq_delete_tasks_async_send(zmq_to_qp_run_socket,task_id,n_tasks,sending) == -1) then if (zmq_delete_tasks_async_send(zmq_to_qp_run_socket,task_id,n_tasks,sending) == -1) then
stop 'PT2: Unable to delete tasks (send)' stop 'PT2: Unable to delete tasks (send)'
@ -540,7 +541,7 @@ subroutine pt2_collector(zmq_socket_pull, E, relative_error, pt2, error, varianc
print*,'If you see this, send an email to Anthony scemama with the following content' print*,'If you see this, send an email to Anthony scemama with the following content'
print*,irp_here print*,irp_here
print*,'i,index(i),size(ei,2) = ',i,index(i),size(ei,2) print*,'i,index(i),size(ei,2) = ',i,index(i),size(ei,2)
stop -1 stop -1
endif endif
eI(1:N_states, index(i)) += eI_task(1:N_states,i) eI(1:N_states, index(i)) += eI_task(1:N_states,i)
vI(1:N_states, index(i)) += vI_task(1:N_states,i) vI(1:N_states, index(i)) += vI_task(1:N_states,i)
@ -731,39 +732,39 @@ END_PROVIDER
double precision, allocatable :: tilde_w(:), tilde_cW(:) double precision, allocatable :: tilde_w(:), tilde_cW(:)
double precision :: r, tooth_width double precision :: r, tooth_width
integer, external :: pt2_find_sample integer, external :: pt2_find_sample
double precision :: rss double precision :: rss
double precision, external :: memory_of_double, memory_of_int double precision, external :: memory_of_double, memory_of_int
rss = memory_of_double(2*N_det_generators+1) rss = memory_of_double(2*N_det_generators+1)
call check_mem(rss,irp_here) call check_mem(rss,irp_here)
if (N_det_generators == 1) then if (N_det_generators == 1) then
pt2_w(1) = 1.d0 pt2_w(1) = 1.d0
pt2_cw(1) = 1.d0 pt2_cw(1) = 1.d0
pt2_u_0 = 1.d0 pt2_u_0 = 1.d0
pt2_W_T = 0.d0 pt2_W_T = 0.d0
pt2_n_0(1) = 0 pt2_n_0(1) = 0
pt2_n_0(2) = 1 pt2_n_0(2) = 1
else else
allocate(tilde_w(N_det_generators), tilde_cW(0:N_det_generators)) allocate(tilde_w(N_det_generators), tilde_cW(0:N_det_generators))
tilde_cW(0) = 0d0 tilde_cW(0) = 0d0
do i=1,N_det_generators do i=1,N_det_generators
tilde_w(i) = psi_coef_sorted_gen(i,pt2_stoch_istate)**2 !+ 1.d-20 tilde_w(i) = psi_coef_sorted_gen(i,pt2_stoch_istate)**2 !+ 1.d-20
enddo enddo
double precision :: norm double precision :: norm
norm = 0.d0 norm = 0.d0
do i=N_det_generators,1,-1 do i=N_det_generators,1,-1
norm += tilde_w(i) norm += tilde_w(i)
enddo enddo
tilde_w(:) = tilde_w(:) / norm tilde_w(:) = tilde_w(:) / norm
tilde_cW(0) = -1.d0 tilde_cW(0) = -1.d0
do i=1,N_det_generators do i=1,N_det_generators
tilde_cW(i) = tilde_cW(i-1) + tilde_w(i) tilde_cW(i) = tilde_cW(i-1) + tilde_w(i)
@ -784,13 +785,13 @@ END_PROVIDER
stop -1 stop -1
end if end if
end do end do
do t=2, pt2_N_teeth do t=2, pt2_N_teeth
r = pt2_u_0 + pt2_W_T * dble(t-1) r = pt2_u_0 + pt2_W_T * dble(t-1)
pt2_n_0(t) = pt2_find_sample(r, tilde_cW) pt2_n_0(t) = pt2_find_sample(r, tilde_cW)
end do end do
pt2_n_0(pt2_N_teeth+1) = N_det_generators pt2_n_0(pt2_N_teeth+1) = N_det_generators
pt2_w(:pt2_n_0(1)) = tilde_w(:pt2_n_0(1)) pt2_w(:pt2_n_0(1)) = tilde_w(:pt2_n_0(1))
do t=1, pt2_N_teeth do t=1, pt2_N_teeth
tooth_width = tilde_cW(pt2_n_0(t+1)) - tilde_cW(pt2_n_0(t)) tooth_width = tilde_cW(pt2_n_0(t+1)) - tilde_cW(pt2_n_0(t))
@ -802,7 +803,7 @@ END_PROVIDER
pt2_w(i) = tilde_w(i) * pt2_W_T / tooth_width pt2_w(i) = tilde_w(i) * pt2_W_T / tooth_width
end do end do
end do end do
pt2_cW(0) = 0d0 pt2_cW(0) = 0d0
do i=1,N_det_generators do i=1,N_det_generators
pt2_cW(i) = pt2_cW(i-1) + pt2_w(i) pt2_cW(i) = pt2_cW(i-1) + pt2_w(i)

View File

@ -99,6 +99,17 @@ subroutine run_selection_slave(thread,iproc,energy)
ctask = ctask + 1 ctask = ctask + 1
end do end do
if(ctask > 0) then
call sort_selection_buffer(buf)
! call merge_selection_buffers(buf,buf2)
call push_selection_results(zmq_socket_push, pt2, variance, norm, buf, task_id(1), ctask)
! buf%mini = buf2%mini
pt2(:) = 0d0
variance(:) = 0d0
norm(:) = 0d0
buf%cur = 0
end if
ctask = 0
integer, external :: disconnect_from_taskserver integer, external :: disconnect_from_taskserver
if (disconnect_from_taskserver(zmq_to_qp_run_socket,worker_id) == -1) then if (disconnect_from_taskserver(zmq_to_qp_run_socket,worker_id) == -1) then

View File

@ -52,7 +52,7 @@ subroutine update_pt2_and_variance_weights(pt2, variance, norm, N_st)
rpt2(k) = pt2(k)/(1.d0 + norm(k)) rpt2(k) = pt2(k)/(1.d0 + norm(k))
enddo enddo
avg = sum(rpt2(1:N_st)) / dble(N_st) avg = sum(rpt2(1:N_st)) / dble(N_st) - 1.d-32 ! Avoid future division by zero
do k=1,N_st do k=1,N_st
element = exp(dt*(rpt2(k)/avg -1.d0)) element = exp(dt*(rpt2(k)/avg -1.d0))
element = min(1.5d0 , element) element = min(1.5d0 , element)
@ -61,7 +61,7 @@ subroutine update_pt2_and_variance_weights(pt2, variance, norm, N_st)
pt2_match_weight(k) = product(memo_pt2(k,:)) pt2_match_weight(k) = product(memo_pt2(k,:))
enddo enddo
avg = sum(variance(1:N_st)) / dble(N_st) avg = sum(variance(1:N_st)) / dble(N_st) + 1.d-32 ! Avoid future division by zero
do k=1,N_st do k=1,N_st
element = exp(dt*(variance(k)/avg -1.d0)) element = exp(dt*(variance(k)/avg -1.d0))
element = min(1.5d0 , element) element = min(1.5d0 , element)
@ -325,7 +325,7 @@ subroutine select_singles_and_doubles(i_generator,hole_mask,particle_mask,fock_d
i = psi_bilinear_matrix_rows(l_a) i = psi_bilinear_matrix_rows(l_a)
if (nt + exc_degree(i) <= 4) then if (nt + exc_degree(i) <= 4) then
idx = psi_det_sorted_order(psi_bilinear_matrix_order(l_a)) idx = psi_det_sorted_order(psi_bilinear_matrix_order(l_a))
if (psi_average_norm_contrib_sorted(idx) > 1.d-12) then if (psi_average_norm_contrib_sorted(idx) > 0.d0) then
indices(k) = idx indices(k) = idx
k=k+1 k=k+1
endif endif
@ -349,7 +349,7 @@ subroutine select_singles_and_doubles(i_generator,hole_mask,particle_mask,fock_d
idx = psi_det_sorted_order( & idx = psi_det_sorted_order( &
psi_bilinear_matrix_order( & psi_bilinear_matrix_order( &
psi_bilinear_matrix_transp_order(l_a))) psi_bilinear_matrix_transp_order(l_a)))
if (psi_average_norm_contrib_sorted(idx) > 1.d-12) then if (psi_average_norm_contrib_sorted(idx) > 0.d0) then
indices(k) = idx indices(k) = idx
k=k+1 k=k+1
endif endif