10
0
mirror of https://github.com/LCPQ/quantum_package synced 2024-09-27 03:51:01 +02:00

Stochastic selection work

This commit is contained in:
Anthony Scemama 2019-01-06 16:14:45 +01:00
parent b8bd99d449
commit e3a0722796
6 changed files with 207 additions and 81 deletions

View File

@ -66,17 +66,19 @@ program fci
write(*,'(A)') '--------------------------------------------------------------------------------'
if (do_pt2) then
pt2 = 0.d0
variance = 0.d0
norm = 0.d0
threshold_generators = 1.d0
SOFT_TOUCH threshold_generators
call ZMQ_pt2(psi_energy_with_nucl_rep,pt2,relative_error,error, variance, norm) ! Stochastic PT2
threshold_generators = threshold_generators_save
SOFT_TOUCH threshold_generators
endif
n_det_before = N_det
to_select = N_det
to_select = max(N_states_diag, to_select)
pt2 = 0.d0
variance = 0.d0
norm = 0.d0
threshold_generators = 1.d0
SOFT_TOUCH threshold_generators
call ZMQ_pt2(psi_energy_with_nucl_rep,pt2,relative_error,error, variance, &
norm, to_select) ! Stochastic PT2
threshold_generators = threshold_generators_save
SOFT_TOUCH threshold_generators
correlation_energy_ratio = (psi_energy_with_nucl_rep(1) - hf_energy_ref) / &
(psi_energy_with_nucl_rep(1) + pt2(1) - hf_energy_ref)
@ -94,11 +96,7 @@ program fci
call print_extrapolated_energy(psi_energy_with_nucl_rep(1:N_states),rpt2)
N_iter += 1
n_det_before = N_det
to_select = N_det
to_select = max(N_states_diag, to_select)
! to_select = min(to_select, N_det_max-n_det_before)
call ZMQ_selection(to_select, pt2, variance, norm)
! call ZMQ_selection(to_select, pt2, variance, norm)
PROVIDE psi_coef
PROVIDE psi_det
@ -116,23 +114,17 @@ program fci
call ezfio_set_fci_energy_pt2(psi_energy_with_nucl_rep(1:N_states)+pt2)
endif
if (do_pt2) then
pt2 = 0.d0
variance = 0.d0
norm = 0.d0
threshold_generators = 1d0
SOFT_TOUCH threshold_generators
call ZMQ_pt2(psi_energy_with_nucl_rep, pt2,relative_error,error,variance,norm) ! Stochastic PT2
threshold_generators = threshold_generators_save
SOFT_TOUCH threshold_generators
call ezfio_set_fci_energy(psi_energy_with_nucl_rep(1:N_states))
call ezfio_set_fci_energy_pt2(psi_energy_with_nucl_rep(1:N_states)+pt2)
endif
print *, 'N_det = ', N_det
print *, 'N_sop = ', N_occ_pattern
print *, 'N_states = ', N_states
print*, 'correlation_ratio = ', correlation_energy_ratio
pt2 = 0.d0
variance = 0.d0
norm = 0.d0
threshold_generators = 1d0
SOFT_TOUCH threshold_generators
call ZMQ_pt2(psi_energy_with_nucl_rep, pt2,relative_error,error,variance, &
norm,0) ! Stochastic PT2
threshold_generators = threshold_generators_save
SOFT_TOUCH threshold_generators
call ezfio_set_fci_energy(psi_energy_with_nucl_rep(1:N_states))
call ezfio_set_fci_energy_pt2(psi_energy_with_nucl_rep(1:N_states)+pt2)
do k=1,N_states
rpt2(:) = pt2(:)/(1.d0 + norm(k))

View File

@ -29,7 +29,8 @@ subroutine run
E_CI_before(:) = psi_energy(:) + nuclear_repulsion
relative_error=PT2_relative_error
call ZMQ_pt2(psi_energy_with_nucl_rep,pt2,relative_error,error, variance, norm) ! Stochastic PT2
call ZMQ_pt2(psi_energy_with_nucl_rep,pt2,relative_error,error, variance, &
norm,0) ! Stochastic PT2
do k=1,N_states
rpt2(:) = pt2(:)/(1.d0 + norm(k))
enddo

View File

@ -95,34 +95,42 @@ end function
subroutine ZMQ_pt2(E, pt2,relative_error, error, variance, norm)
subroutine ZMQ_pt2(E, pt2,relative_error, error, variance, norm, N_in)
use f77_zmq
use selection_types
implicit none
integer(ZMQ_PTR) :: zmq_to_qp_run_socket, zmq_socket_pull
integer, intent(in) :: N_in
integer, external :: omp_get_thread_num
double precision, intent(in) :: relative_error, E(N_states)
double precision, intent(out) :: pt2(N_states),error(N_states)
double precision, intent(out) :: variance(N_states),norm(N_states)
integer :: i
integer :: i, N
double precision, external :: omp_get_wtime
double precision :: state_average_weight_save(N_states), w(N_states,4)
integer(ZMQ_PTR), external :: new_zmq_to_qp_run_socket
type(selection_buffer) :: b
if (N_det < max(10,N_states)) then
pt2=0.d0
variance=0.d0
norm=0.d0
call ZMQ_selection(0, pt2, variance, norm)
call ZMQ_selection(N_in, pt2, variance, norm)
error(:) = 0.d0
else
N = max(N_in,1)
state_average_weight_save(:) = state_average_weight(:)
call create_selection_buffer(N, N*2, b)
ASSERT (associated(b%det))
ASSERT (associated(b%val))
do pt2_stoch_istate=1,N_states
state_average_weight(:) = 0.d0
state_average_weight(pt2_stoch_istate) = 1.d0
@ -159,9 +167,8 @@ subroutine ZMQ_pt2(E, pt2,relative_error, error, variance, norm)
endif
integer, external :: add_task_to_taskserver
character(400000) :: task
character(300000) :: task
integer :: j,k,ipos,ifirst
ifirst=0
@ -178,9 +185,9 @@ subroutine ZMQ_pt2(E, pt2,relative_error, error, variance, norm)
ipos=1
do i= 1, N_det_generators
do j=1,pt2_F(pt2_J(i))
write(task(ipos:ipos+20),'(I9,1X,I9,''|'')') j, pt2_J(i)
ipos += 20
if (ipos > 400000-20) then
write(task(ipos:ipos+30),'(I9,1X,I9,1X,I9,''|'')') j, pt2_J(i), N
ipos += 30
if (ipos > 300000-30) then
if (add_task_to_taskserver(zmq_to_qp_run_socket,trim(task(1:ipos))) == -1) then
stop 'Unable to add task to task server'
endif
@ -199,7 +206,7 @@ subroutine ZMQ_pt2(E, pt2,relative_error, error, variance, norm)
stop 'Unable to add task to task server'
endif
endif
integer, external :: zmq_set_running
if (zmq_set_running(zmq_to_qp_run_socket) == -1) then
print *, irp_here, ': Failed in zmq_set_running'
@ -228,7 +235,7 @@ subroutine ZMQ_pt2(E, pt2,relative_error, error, variance, norm)
i = omp_get_thread_num()
if (i==0) then
call pt2_collector(zmq_socket_pull, E(pt2_stoch_istate),relative_error, w(1,1), w(1,2), w(1,3), w(1,4))
call pt2_collector(zmq_socket_pull, E(pt2_stoch_istate),relative_error, w(1,1), w(1,2), w(1,3), w(1,4), b, N)
pt2(pt2_stoch_istate) = w(pt2_stoch_istate,1)
error(pt2_stoch_istate) = w(pt2_stoch_istate,2)
variance(pt2_stoch_istate) = w(pt2_stoch_istate,3)
@ -243,9 +250,18 @@ subroutine ZMQ_pt2(E, pt2,relative_error, error, variance, norm)
print '(A)', '========== ================= =========== =============== =============== ================='
enddo
! call omp_set_nested(.false.)
FREE pt2_stoch_istate
if (N_in > 0) then
if (s2_eig) then
call make_selection_buffer_s2(b)
endif
call fill_H_apply_buffer_no_selection(b%cur,b%det,N_int,0)
call copy_H_apply_buffer_to_wf()
call save_wavefunction
endif
call delete_selection_buffer(b)
state_average_weight(:) = state_average_weight_save(:)
TOUCH state_average_weight
endif
@ -264,7 +280,8 @@ subroutine pt2_slave_inproc(i)
end
subroutine pt2_collector(zmq_socket_pull, E, relative_error, pt2, error, variance, norm)
subroutine pt2_collector(zmq_socket_pull, E, relative_error, pt2, error, &
variance, norm, b, N_)
use f77_zmq
use selection_types
use bitmasks
@ -272,9 +289,11 @@ subroutine pt2_collector(zmq_socket_pull, E, relative_error, pt2, error, varianc
integer(ZMQ_PTR), intent(in) :: zmq_socket_pull
double precision, intent(in) :: relative_error, E
double precision, intent(in) :: relative_error, E
double precision, intent(out) :: pt2(N_states), error(N_states)
double precision, intent(out) :: variance(N_states), norm(N_states)
type(selection_buffer), intent(inout) :: b
integer, intent(in) :: N_
double precision, allocatable :: eI(:,:), eI_task(:,:), S(:), S2(:)
@ -297,6 +316,7 @@ subroutine pt2_collector(zmq_socket_pull, E, relative_error, pt2, error, varianc
integer, allocatable :: f(:)
logical, allocatable :: d(:)
logical :: do_exit
type(selection_buffer) :: b2
double precision :: rss
@ -319,6 +339,8 @@ subroutine pt2_collector(zmq_socket_pull, E, relative_error, pt2, error, varianc
zmq_to_qp_run_socket = new_zmq_to_qp_run_socket()
call create_selection_buffer(N_, N_*2, b2)
pt2(:) = -huge(1.)
error(:) = huge(1.)
@ -417,7 +439,7 @@ subroutine pt2_collector(zmq_socket_pull, E, relative_error, pt2, error, varianc
else if(more == 0) then
exit
else
call pull_pt2_results(zmq_socket_pull, index, eI_task, vI_task, nI_task, task_id, n_tasks)
call pull_pt2_results(zmq_socket_pull, index, eI_task, vI_task, nI_task, task_id, n_tasks, b2)
if (zmq_delete_tasks(zmq_to_qp_run_socket,zmq_socket_pull,task_id,n_tasks,more) == -1) then
stop 'Unable to delete tasks'
endif
@ -427,9 +449,16 @@ subroutine pt2_collector(zmq_socket_pull, E, relative_error, pt2, error, varianc
nI(:, index(i)) += nI_task(:,i)
f(index(i)) -= 1
end do
do i=1, b2%cur
call add_to_selection_buffer(b, b2%det(1,1,i), b2%val(i))
if (b2%val(i) > b%mini) exit
end do
end if
end do
call delete_selection_buffer(b2)
call sort_selection_buffer(b)
call end_zmq_to_qp_run_socket(zmq_to_qp_run_socket)
end subroutine

View File

@ -18,11 +18,11 @@ subroutine run_pt2_slave(thread,iproc,energy)
integer(ZMQ_PTR), external :: new_zmq_push_socket
integer(ZMQ_PTR) :: zmq_socket_push
type(selection_buffer) :: buf
logical :: done
type(selection_buffer) :: b, b2
logical :: done, buffer_ready
double precision,allocatable :: pt2(:,:), variance(:,:), norm(:,:)
integer :: n_tasks, k
integer :: n_tasks, k, N
integer, allocatable :: i_generator(:), subset(:)
double precision :: rss
@ -46,9 +46,9 @@ subroutine run_pt2_slave(thread,iproc,energy)
zmq_socket_push = new_zmq_push_socket(thread)
buf%N = 0
b%N = 0
buffer_ready = .False.
n_tasks = 1
call create_selection_buffer(0, 0, buf)
done = .False.
n_tasks = 1
@ -62,12 +62,22 @@ subroutine run_pt2_slave(thread,iproc,energy)
exit
endif
done = task_id(n_tasks) == 0
if (done) n_tasks = n_tasks-1
if (done) then
n_tasks = n_tasks-1
endif
if (n_tasks == 0) exit
do k=1,n_tasks
read (task(k),*) subset(k), i_generator(k)
read (task(k),*) subset(k), i_generator(k), N
enddo
if (b%N == 0) then
! Only first time
call create_selection_buffer(N, N*2, b)
call create_selection_buffer(N, N*2, b2)
buffer_ready = .True.
else
ASSERT (N == b%N)
endif
double precision :: time0, time1
call wall_time(time0)
@ -75,10 +85,10 @@ subroutine run_pt2_slave(thread,iproc,energy)
pt2(:,k) = 0.d0
variance(:,k) = 0.d0
norm(:,k) = 0.d0
buf%cur = 0
b%cur = 0
!double precision :: time2
!call wall_time(time2)
call select_connected(i_generator(k),energy,pt2(1,k),variance(1,k),norm(1,k),buf,subset(k),pt2_F(i_generator(k)))
call select_connected(i_generator(k),energy,pt2(1,k),variance(1,k),norm(1,k),b,subset(k),pt2_F(i_generator(k)))
!call wall_time(time1)
!print *, i_generator(1), time1-time2, n_tasks, pt2_F(i_generator(1))
enddo
@ -89,7 +99,11 @@ subroutine run_pt2_slave(thread,iproc,energy)
if (tasks_done_to_taskserver(zmq_to_qp_run_socket,worker_id,task_id,n_tasks) == -1) then
done = .true.
endif
call push_pt2_results(zmq_socket_push, i_generator, pt2, variance, norm, task_id, n_tasks)
call sort_selection_buffer(b)
call merge_selection_buffers(b,b2)
call push_pt2_results(zmq_socket_push, i_generator, pt2, variance, norm, b, task_id, n_tasks)
b%mini = b2%mini
b%cur=0
! Try to adjust n_tasks around nproc/8 seconds per job
n_tasks = min(2*n_tasks,int( dble(n_tasks * nproc/8) / (time1 - time0 + 1.d0)))
@ -104,11 +118,14 @@ subroutine run_pt2_slave(thread,iproc,energy)
call end_zmq_push_socket(zmq_socket_push,thread)
call end_zmq_to_qp_run_socket(zmq_to_qp_run_socket)
call delete_selection_buffer(buf)
if (buffer_ready) then
call delete_selection_buffer(b)
call delete_selection_buffer(b2)
endif
end subroutine
subroutine push_pt2_results(zmq_socket_push, index, pt2, variance, norm, task_id, n_tasks)
subroutine push_pt2_results(zmq_socket_push, index, pt2, variance, norm, b, task_id, n_tasks)
use f77_zmq
use selection_types
implicit none
@ -118,45 +135,80 @@ subroutine push_pt2_results(zmq_socket_push, index, pt2, variance, norm, task_id
double precision, intent(in) :: variance(N_states,n_tasks)
double precision, intent(in) :: norm(N_states,n_tasks)
integer, intent(in) :: n_tasks, index(n_tasks), task_id(n_tasks)
type(selection_buffer), intent(inout) :: b
integer :: rc
rc = f77_zmq_send( zmq_socket_push, n_tasks, 4, ZMQ_SNDMORE)
if (rc == -1) then
return
else if(rc /= 4) then
stop 'push'
endif
if(rc /= 4) stop 'push'
rc = f77_zmq_send( zmq_socket_push, index, 4*n_tasks, ZMQ_SNDMORE)
if (rc == -1) then
return
else if(rc /= 4*n_tasks) then
stop 'push'
endif
if(rc /= 4*n_tasks) stop 'push'
rc = f77_zmq_send( zmq_socket_push, pt2, 8*N_states*n_tasks, ZMQ_SNDMORE)
if (rc == -1) then
return
else if(rc /= 8*N_states*n_tasks) then
stop 'push'
endif
if(rc /= 8*N_states*n_tasks) stop 'push'
rc = f77_zmq_send( zmq_socket_push, variance, 8*N_states*n_tasks, ZMQ_SNDMORE)
if (rc == -1) then
return
else if(rc /= 8*N_states*n_tasks) then
stop 'push'
endif
if(rc /= 8*N_states*n_tasks) stop 'push'
rc = f77_zmq_send( zmq_socket_push, norm, 8*N_states*n_tasks, ZMQ_SNDMORE)
if (rc == -1) then
return
else if(rc /= 8*N_states*n_tasks) then
stop 'push'
endif
if(rc /= 8*N_states*n_tasks) stop 'push'
rc = f77_zmq_send( zmq_socket_push, task_id, n_tasks*4, 0)
rc = f77_zmq_send( zmq_socket_push, task_id, n_tasks*4, ZMQ_SNDMORE)
if (rc == -1) then
return
else if(rc /= 4*n_tasks) then
stop 'push'
endif
if(rc /= 4*n_tasks) stop 'push'
rc = f77_zmq_send( zmq_socket_push, b%cur, 4, ZMQ_SNDMORE)
if (rc == -1) then
return
else if(rc /= 4) then
stop 'push'
endif
rc = f77_zmq_send( zmq_socket_push, b%val, 8*b%cur, ZMQ_SNDMORE)
if (rc == -1) then
return
else if(rc /= 8*b%cur) then
stop 'push'
endif
rc = f77_zmq_send( zmq_socket_push, b%det, bit_kind*N_int*2*b%cur, 0)
if (rc == -1) then
return
else if(rc /= N_int*2*8*b%cur) then
stop 'push'
endif
! Activate is zmq_socket_push is a REQ
IRP_IF ZMQ_PUSH
@ -165,8 +217,7 @@ IRP_ELSE
rc = f77_zmq_recv( zmq_socket_push, ok, 2, 0)
if (rc == -1) then
return
endif
if ((rc /= 2).and.(ok(1:2) /= 'ok')) then
else if ((rc /= 2).and.(ok(1:2) /= 'ok')) then
print *, irp_here//': error in receiving ok'
stop -1
endif
@ -175,7 +226,7 @@ IRP_ENDIF
end subroutine
subroutine pull_pt2_results(zmq_socket_pull, index, pt2, variance, norm, task_id, n_tasks)
subroutine pull_pt2_results(zmq_socket_pull, index, pt2, variance, norm, task_id, n_tasks, b)
use f77_zmq
use selection_types
implicit none
@ -183,6 +234,7 @@ subroutine pull_pt2_results(zmq_socket_pull, index, pt2, variance, norm, task_id
double precision, intent(inout) :: pt2(N_states,*)
double precision, intent(inout) :: variance(N_states,*)
double precision, intent(inout) :: norm(N_states,*)
type(selection_buffer), intent(inout) :: b
integer, intent(out) :: index(*)
integer, intent(out) :: n_tasks, task_id(*)
integer :: rc, rn, i
@ -191,43 +243,74 @@ subroutine pull_pt2_results(zmq_socket_pull, index, pt2, variance, norm, task_id
if (rc == -1) then
n_tasks = 1
task_id(1) = 0
else if(rc /= 4) then
stop 'pull'
endif
if(rc /= 4) stop 'pull'
rc = f77_zmq_recv( zmq_socket_pull, index, 4*n_tasks, 0)
if (rc == -1) then
n_tasks = 1
task_id(1) = 0
else if(rc /= 4*n_tasks) then
stop 'pull'
endif
if(rc /= 4*n_tasks) stop 'pull'
rc = f77_zmq_recv( zmq_socket_pull, pt2, N_states*8*n_tasks, 0)
if (rc == -1) then
n_tasks = 1
task_id(1) = 0
else if(rc /= 8*N_states*n_tasks) then
stop 'pull'
endif
if(rc /= 8*N_states*n_tasks) stop 'pull'
rc = f77_zmq_recv( zmq_socket_pull, variance, N_states*8*n_tasks, 0)
if (rc == -1) then
n_tasks = 1
task_id(1) = 0
else if(rc /= 8*N_states*n_tasks) then
stop 'pull'
endif
if(rc /= 8*N_states*n_tasks) stop 'pull'
rc = f77_zmq_recv( zmq_socket_pull, norm, N_states*8*n_tasks, 0)
if (rc == -1) then
n_tasks = 1
task_id(1) = 0
else if(rc /= 8*N_states*n_tasks) then
stop 'pull'
endif
if(rc /= 8*N_states*n_tasks) stop 'pull'
rc = f77_zmq_recv( zmq_socket_pull, task_id, n_tasks*4, 0)
if (rc == -1) then
n_tasks = 1
task_id(1) = 0
else if(rc /= 4*n_tasks) then
stop 'pull'
endif
if(rc /= 4*n_tasks) stop 'pull'
rc = f77_zmq_recv( zmq_socket_pull, b%cur, 4, 0)
if (rc == -1) then
n_tasks = 1
task_id(1) = 0
else if(rc /= 4) then
stop 'pull'
endif
rc = f77_zmq_recv( zmq_socket_pull, b%val, 8*b%cur, 0)
if (rc == -1) then
n_tasks = 1
task_id(1) = 0
else if(rc /= 8*b%cur) then
stop 'pull'
endif
rc = f77_zmq_recv( zmq_socket_pull, b%det, bit_kind*N_int*2*b%cur, 0)
if (rc == -1) then
n_tasks = 1
task_id(1) = 0
else if(rc /= N_int*2*8*b%cur) then
stop 'pull'
endif
! Activate is zmq_socket_pull is a REP
IRP_IF ZMQ_PUSH
@ -236,8 +319,7 @@ IRP_ELSE
if (rc == -1) then
n_tasks = 1
task_id(1) = 0
endif
if (rc /= 2) then
else if (rc /= 2) then
print *, irp_here//': error in sending ok'
stop -1
endif

View File

@ -1,5 +1,27 @@
use bitmasks
BEGIN_PROVIDER [ double precision, selection_weight, (N_states) ]
implicit none
BEGIN_DOC
! Weights in the state-average calculation of the density matrix
END_DOC
logical :: exists
selection_weight(:) = 1.d0
if (used_weight == 0) then
selection_weight(:) = c0_weight(:)
else if (used_weight == 1) then
selection_weight(:) = 1./N_states
else
call ezfio_has_determinants_state_average_weight(exists)
if (exists) then
call ezfio_get_determinants_state_average_weight(selection_weight)
endif
endif
selection_weight(:) = selection_weight(:)+1.d-31
selection_weight(:) = selection_weight(:)/(sum(selection_weight(:)))
END_PROVIDER
subroutine get_mask_phase(det1, pm, Nint)
use bitmasks
@ -719,9 +741,9 @@ subroutine fill_buffer_double(i_generator, sp, h1, h2, bannedOrb, banned, fock_d
norm(istate) = norm(istate) + coef * coef
if (h0_type == "Variance") then
sum_e_pert = sum_e_pert - alpha_h_psi * alpha_h_psi * state_average_weight(istate)
sum_e_pert = sum_e_pert - alpha_h_psi * alpha_h_psi * selection_weight(istate)
else
sum_e_pert = sum_e_pert + e_pert * state_average_weight(istate)
sum_e_pert = sum_e_pert + e_pert * selection_weight(istate)
endif
end do

View File

@ -150,7 +150,7 @@ subroutine selection_collector(zmq_socket_pull, b, N, pt2, variance, norm)
integer(ZMQ_PTR), intent(in) :: zmq_socket_pull
type(selection_buffer), intent(inout) :: b
integer, intent(in) :: N
integer, intent(in) :: N
double precision, intent(inout) :: pt2(N_states)
double precision, intent(inout) :: variance(N_states)
double precision, intent(inout) :: norm(N_states)