10
0
mirror of https://github.com/QuantumPackage/qp2.git synced 2025-01-10 13:08:19 +01:00

Fixed PT2

This commit is contained in:
Anthony Scemama 2020-08-30 22:16:39 +02:00
parent 061e7100ca
commit 93fc49000c
6 changed files with 185 additions and 125 deletions

View File

@ -474,10 +474,12 @@ subroutine pt2_collector(zmq_socket_pull, E, relative_error, pt2_data, b, N_)
do p=pt2_N_teeth, 1, -1 do p=pt2_N_teeth, 1, -1
v = pt2_u_0 + pt2_W_T * (pt2_u(c) + dble(p-1)) v = pt2_u_0 + pt2_W_T * (pt2_u(c) + dble(p-1))
i = pt2_find_sample_lr(v, pt2_cW,pt2_n_0(p),pt2_n_0(p+1)) i = pt2_find_sample_lr(v, pt2_cW,pt2_n_0(p),pt2_n_0(p+1))
call pt2_add ( pt2_data_teeth, pt2_W_T / pt2_w(i), pt2_data_I(i) ) v = pt2_W_T / pt2_w(i)
call pt2_add ( pt2_data_teeth, v, pt2_data_I(i) )
call pt2_add ( pt2_data_S(p), 1.d0, pt2_data_teeth ) call pt2_add ( pt2_data_S(p), 1.d0, pt2_data_teeth )
call pt2_add2( pt2_data_S2(p), 1.d0, pt2_data_teeth ) call pt2_add2( pt2_data_S2(p), 1.d0, pt2_data_teeth )
end do enddo
call pt2_dealloc(pt2_data_teeth) call pt2_dealloc(pt2_data_teeth)
avg = E0 + pt2_data_S(t) % pt2(pt2_stoch_istate) / dble(c) avg = E0 + pt2_data_S(t) % pt2(pt2_stoch_istate) / dble(c)
avg2 = v0 + pt2_data_S(t) % variance(pt2_stoch_istate) / dble(c) avg2 = v0 + pt2_data_S(t) % variance(pt2_stoch_istate) / dble(c)
@ -546,14 +548,6 @@ subroutine pt2_collector(zmq_socket_pull, E, relative_error, pt2_data, b, N_)
print*,'i,index(i),size(pt2_data_I,1) = ',i,index(i),size(pt2_data_I,1) print*,'i,index(i),size(pt2_data_I,1) = ',i,index(i),size(pt2_data_I,1)
stop -1 stop -1
endif endif
! print *, pt2_data_I(index(i))%pt2
! print *, pt2_data_I(index(i))%variance
! print *, pt2_data_I(index(i))%norm2
! print *, ''
! print *, pt2_data_task(i)%pt2
! print *, pt2_data_task(i)%variance
! print *, pt2_data_task(i)%norm2
! print *, ''
call pt2_add(pt2_data_I(index(i)),1.d0,pt2_data_task(i)) call pt2_add(pt2_data_I(index(i)),1.d0,pt2_data_task(i))
f(index(i)) -= 1 f(index(i)) -= 1
end do end do

View File

@ -60,16 +60,33 @@ subroutine pt2_add(p1, w, p2)
double precision, intent(in) :: w double precision, intent(in) :: w
type(pt2_type), intent(in) :: p2 type(pt2_type), intent(in) :: p2
p1 % pt2(:) = p1 % pt2(:) + w * p2 % pt2(:) if (w == 1.d0) then
p1 % pt2_err(:) = p1 % pt2_err(:) + w * p2 % pt2_err(:)
p1 % rpt2(:) = p1 % rpt2(:) + w * p2 % rpt2(:) p1 % pt2(:) = p1 % pt2(:) + p2 % pt2(:)
p1 % rpt2_err(:) = p1 % rpt2_err(:) + w * p2 % rpt2_err(:) p1 % pt2_err(:) = p1 % pt2_err(:) + p2 % pt2_err(:)
p1 % variance(:) = p1 % variance(:) + w * p2 % variance(:) p1 % rpt2(:) = p1 % rpt2(:) + p2 % rpt2(:)
p1 % variance_err(:) = p1 % variance_err(:) + w * p2 % variance_err(:) p1 % rpt2_err(:) = p1 % rpt2_err(:) + p2 % rpt2_err(:)
p1 % norm2(:) = p1 % norm2(:) + w * p2 % norm2(:) p1 % variance(:) = p1 % variance(:) + p2 % variance(:)
p1 % norm2_err(:) = p1 % norm2_err(:) + w * p2 % norm2_err(:) p1 % variance_err(:) = p1 % variance_err(:) + p2 % variance_err(:)
p1 % overlap(:,:) = p1 % overlap(:,:) + w * p2 % overlap(:,:) p1 % norm2(:) = p1 % norm2(:) + p2 % norm2(:)
p1 % overlap_err(:,:) = p1 % overlap_err(:,:) + w * p2 % overlap_err(:,:) p1 % norm2_err(:) = p1 % norm2_err(:) + p2 % norm2_err(:)
p1 % overlap(:,:) = p1 % overlap(:,:) + p2 % overlap(:,:)
p1 % overlap_err(:,:) = p1 % overlap_err(:,:) + p2 % overlap_err(:,:)
else
p1 % pt2(:) = p1 % pt2(:) + w * p2 % pt2(:)
p1 % pt2_err(:) = p1 % pt2_err(:) + w * p2 % pt2_err(:)
p1 % rpt2(:) = p1 % rpt2(:) + w * p2 % rpt2(:)
p1 % rpt2_err(:) = p1 % rpt2_err(:) + w * p2 % rpt2_err(:)
p1 % variance(:) = p1 % variance(:) + w * p2 % variance(:)
p1 % variance_err(:) = p1 % variance_err(:) + w * p2 % variance_err(:)
p1 % norm2(:) = p1 % norm2(:) + w * p2 % norm2(:)
p1 % norm2_err(:) = p1 % norm2_err(:) + w * p2 % norm2_err(:)
p1 % overlap(:,:) = p1 % overlap(:,:) + w * p2 % overlap(:,:)
p1 % overlap_err(:,:) = p1 % overlap_err(:,:) + w * p2 % overlap_err(:,:)
endif
end subroutine end subroutine
@ -84,17 +101,83 @@ subroutine pt2_add2(p1, w, p2)
double precision, intent(in) :: w double precision, intent(in) :: w
type(pt2_type), intent(in) :: p2 type(pt2_type), intent(in) :: p2
p1 % pt2(:) = p1 % pt2(:) + w * p2 % pt2(:) * p2 % pt2(:) if (w == 1.d0) then
p1 % pt2_err(:) = p1 % pt2_err(:) + w * p2 % pt2_err(:) * p2 % pt2_err(:)
p1 % rpt2(:) = p1 % rpt2(:) + w * p2 % rpt2(:) * p2 % rpt2(:) p1 % pt2(:) = p1 % pt2(:) + p2 % pt2(:) * p2 % pt2(:)
p1 % rpt2_err(:) = p1 % rpt2_err(:) + w * p2 % rpt2_err(:) * p2 % rpt2_err(:) p1 % pt2_err(:) = p1 % pt2_err(:) + p2 % pt2_err(:) * p2 % pt2_err(:)
p1 % variance(:) = p1 % variance(:) + w * p2 % variance(:) * p2 % variance(:) p1 % rpt2(:) = p1 % rpt2(:) + p2 % rpt2(:) * p2 % rpt2(:)
p1 % variance_err(:) = p1 % variance_err(:) + w * p2 % variance_err(:) * p2 % variance_err(:) p1 % rpt2_err(:) = p1 % rpt2_err(:) + p2 % rpt2_err(:) * p2 % rpt2_err(:)
p1 % norm2(:) = p1 % norm2(:) + w * p2 % norm2(:) * p2 % norm2(:) p1 % variance(:) = p1 % variance(:) + p2 % variance(:) * p2 % variance(:)
p1 % norm2_err(:) = p1 % norm2_err(:) + w * p2 % norm2_err(:) * p2 % norm2_err(:) p1 % variance_err(:) = p1 % variance_err(:) + p2 % variance_err(:) * p2 % variance_err(:)
p1 % overlap(:,:) = p1 % overlap(:,:) + w * p2 % overlap(:,:) * p2 % overlap(:,:) p1 % norm2(:) = p1 % norm2(:) + p2 % norm2(:) * p2 % norm2(:)
p1 % overlap_err(:,:) = p1 % overlap_err(:,:) + w * p2 % overlap_err(:,:) * p2 % overlap_err(:,:) p1 % norm2_err(:) = p1 % norm2_err(:) + p2 % norm2_err(:) * p2 % norm2_err(:)
p1 % overlap(:,:) = p1 % overlap(:,:) + p2 % overlap(:,:) * p2 % overlap(:,:)
p1 % overlap_err(:,:) = p1 % overlap_err(:,:) + p2 % overlap_err(:,:) * p2 % overlap_err(:,:)
else
p1 % pt2(:) = p1 % pt2(:) + w * p2 % pt2(:) * p2 % pt2(:)
p1 % pt2_err(:) = p1 % pt2_err(:) + w * p2 % pt2_err(:) * p2 % pt2_err(:)
p1 % rpt2(:) = p1 % rpt2(:) + w * p2 % rpt2(:) * p2 % rpt2(:)
p1 % rpt2_err(:) = p1 % rpt2_err(:) + w * p2 % rpt2_err(:) * p2 % rpt2_err(:)
p1 % variance(:) = p1 % variance(:) + w * p2 % variance(:) * p2 % variance(:)
p1 % variance_err(:) = p1 % variance_err(:) + w * p2 % variance_err(:) * p2 % variance_err(:)
p1 % norm2(:) = p1 % norm2(:) + w * p2 % norm2(:) * p2 % norm2(:)
p1 % norm2_err(:) = p1 % norm2_err(:) + w * p2 % norm2_err(:) * p2 % norm2_err(:)
p1 % overlap(:,:) = p1 % overlap(:,:) + w * p2 % overlap(:,:) * p2 % overlap(:,:)
p1 % overlap_err(:,:) = p1 % overlap_err(:,:) + w * p2 % overlap_err(:,:) * p2 % overlap_err(:,:)
endif
end subroutine end subroutine
subroutine pt2_serialize(pt2_data, n, x)
implicit none
use selection_types
type(pt2_type), intent(in) :: pt2_data
integer, intent(in) :: n
double precision, intent(out) :: x(*)
integer :: i,k,n2
n2 = n*n
x(1:n) = pt2_data % pt2(1:n)
x(n+1:2*n) = pt2_data % pt2_err(1:n)
x(2*n+1:3*n) = pt2_data % rpt2(1:n)
x(3*n+1:4*n) = pt2_data % rpt2_err(1:n)
x(4*n+1:5*n) = pt2_data % variance(1:n)
x(5*n+1:6*n) = pt2_data % variance_err(1:n)
x(6*n+1:7*n) = pt2_data % norm2(1:n)
x(7*n+1:8*n) = pt2_data % norm2_err(1:n)
k=8*n
x(k+1:k+n2) = reshape(pt2_data % overlap(1:n,1:n), (/ n2 /))
k=8*n+n2
x(k+1:k+n2) = reshape(pt2_data % overlap_err(1:n,1:n), (/ n2 /))
end
subroutine pt2_deserialize(pt2_data, n, x)
implicit none
use selection_types
type(pt2_type), intent(inout) :: pt2_data
integer, intent(in) :: n
double precision, intent(in) :: x(*)
integer :: i,k,n2
n2 = n*n
pt2_data % pt2(1:n) = x(1:n)
pt2_data % pt2_err(1:n) = x(n+1:2*n)
pt2_data % rpt2(1:n) = x(2*n+1:3*n)
pt2_data % rpt2_err(1:n) = x(3*n+1:4*n)
pt2_data % variance(1:n) = x(4*n+1:5*n)
pt2_data % variance_err(1:n) = x(5*n+1:6*n)
pt2_data % norm2(1:n) = x(6*n+1:7*n)
pt2_data % norm2_err(1:n) = x(7*n+1:8*n)
k=8*n
pt2_data % overlap(1:n,1:n) = reshape(x(k+1:k+n2), (/ n, n /))
k=8*n+n2
pt2_data % overlap_err(1:n,1:n) = reshape(x(k+1:k+n2), (/ n, n /))
end

View File

@ -141,6 +141,7 @@ subroutine run_pt2_slave_small(thread,iproc,energy)
! ! Try to adjust n_tasks around nproc/2 seconds per job ! ! Try to adjust n_tasks around nproc/2 seconds per job
n_tasks = min(2*n_tasks,int( dble(n_tasks * nproc/2) / (time1 - time0 + 1.d0))) n_tasks = min(2*n_tasks,int( dble(n_tasks * nproc/2) / (time1 - time0 + 1.d0)))
n_tasks = min(n_tasks, pt2_n_tasks_max)
! n_tasks = 1 ! n_tasks = 1
end do end do
@ -169,8 +170,8 @@ subroutine run_pt2_slave_large(thread,iproc,energy)
integer :: rc, i integer :: rc, i
integer :: worker_id, ctask, ltask integer :: worker_id, ctask, ltask
character*(512), allocatable :: task(:) character*(512) :: task
integer, allocatable :: task_id(:) integer :: task_id
integer(ZMQ_PTR),external :: new_zmq_to_qp_run_socket integer(ZMQ_PTR),external :: new_zmq_to_qp_run_socket
integer(ZMQ_PTR) :: zmq_to_qp_run_socket integer(ZMQ_PTR) :: zmq_to_qp_run_socket
@ -181,18 +182,15 @@ subroutine run_pt2_slave_large(thread,iproc,energy)
type(selection_buffer) :: b type(selection_buffer) :: b
logical :: done, buffer_ready logical :: done, buffer_ready
type(pt2_type), allocatable :: pt2_data(:) type(pt2_type) :: pt2_data
integer :: n_tasks, k, N integer :: n_tasks, k, N
integer, allocatable :: i_generator(:), subset(:) integer :: i_generator, subset
integer :: bsize ! Size of selection buffers integer :: bsize ! Size of selection buffers
logical :: sending logical :: sending
PROVIDE global_selection_buffer global_selection_buffer_lock PROVIDE global_selection_buffer global_selection_buffer_lock
allocate(task_id(pt2_n_tasks_max), task(pt2_n_tasks_max))
allocate(pt2_data(pt2_n_tasks_max), i_generator(pt2_n_tasks_max), subset(pt2_n_tasks_max))
zmq_to_qp_run_socket = new_zmq_to_qp_run_socket() zmq_to_qp_run_socket = new_zmq_to_qp_run_socket()
integer, external :: connect_to_taskserver integer, external :: connect_to_taskserver
@ -211,22 +209,17 @@ subroutine run_pt2_slave_large(thread,iproc,energy)
done = .False. done = .False.
do while (.not.done) do while (.not.done)
n_tasks = max(1,n_tasks)
n_tasks = min(pt2_n_tasks_max,n_tasks)
integer, external :: get_tasks_from_taskserver integer, external :: get_tasks_from_taskserver
if (get_tasks_from_taskserver(zmq_to_qp_run_socket,worker_id, task_id, task, n_tasks) == -1) then if (get_tasks_from_taskserver(zmq_to_qp_run_socket,worker_id, task_id, task, n_tasks) == -1) then
exit exit
endif endif
done = task_id(n_tasks) == 0 done = task_id == 0
if (done) then if (done) then
n_tasks = n_tasks-1 n_tasks = n_tasks-1
endif endif
if (n_tasks == 0) exit if (n_tasks == 0) exit
do k=1,n_tasks read (task,*) subset, i_generator, N
read (task(k),*) subset(k), i_generator(k), N
enddo
if (b%N == 0) then if (b%N == 0) then
! Only first time ! Only first time
bsize = min(N, (elec_alpha_num * (mo_num-elec_alpha_num))**2) bsize = min(N, (elec_alpha_num * (mo_num-elec_alpha_num))**2)
@ -238,15 +231,13 @@ subroutine run_pt2_slave_large(thread,iproc,energy)
double precision :: time0, time1 double precision :: time0, time1
call wall_time(time0) call wall_time(time0)
do k=1,n_tasks call pt2_alloc(pt2_data,N_states)
call pt2_alloc(pt2_data(k),N_states) b%cur = 0
b%cur = 0
!double precision :: time2 !double precision :: time2
!call wall_time(time2) !call wall_time(time2)
call select_connected(i_generator(k),energy,pt2_data(k),b,subset(k),pt2_F(i_generator(k))) call select_connected(i_generator,energy,pt2_data,b,subset,pt2_F(i_generator))
!call wall_time(time1) !call wall_time(time1)
!print *, i_generator(1), time1-time2, n_tasks, pt2_F(i_generator(1)) !print *, i_generator(1), time1-time2, n_tasks, pt2_F(i_generator(1))
enddo
call wall_time(time1) call wall_time(time1)
!print *, '-->', i_generator(1), time1-time0, n_tasks !print *, '-->', i_generator(1), time1-time0, n_tasks
@ -270,12 +261,7 @@ subroutine run_pt2_slave_large(thread,iproc,energy)
call push_pt2_results_async_send(zmq_socket_push, i_generator, pt2_data, b, task_id, n_tasks,sending) call push_pt2_results_async_send(zmq_socket_push, i_generator, pt2_data, b, task_id, n_tasks,sending)
endif endif
do k=1,n_tasks call pt2_dealloc(pt2_data)
call pt2_dealloc(pt2_data(k))
enddo
! ! Try to adjust n_tasks around nproc/2 seconds per job
! n_tasks = min(2*n_tasks,int( dble(n_tasks * nproc/2) / (time1 - time0 + 1.d0)))
n_tasks = 1
end do end do
call push_pt2_results_async_recv(zmq_socket_push,b%mini,sending) call push_pt2_results_async_recv(zmq_socket_push,b%mini,sending)
@ -322,8 +308,9 @@ subroutine push_pt2_results_async_send(zmq_socket_push, index, pt2_data, b, task
integer, intent(in) :: n_tasks, index(n_tasks), task_id(n_tasks) integer, intent(in) :: n_tasks, index(n_tasks), task_id(n_tasks)
type(selection_buffer), intent(inout) :: b type(selection_buffer), intent(inout) :: b
logical, intent(inout) :: sending logical, intent(inout) :: sending
integer :: rc integer :: rc, i
integer*8 :: rc8 integer*8 :: rc8
double precision, allocatable :: pt2_serialized(:,:)
if (sending) then if (sending) then
print *, irp_here, ': sending is true' print *, irp_here, ': sending is true'
@ -351,12 +338,18 @@ subroutine push_pt2_results_async_send(zmq_socket_push, index, pt2_data, b, task
endif endif
rc = f77_zmq_send( zmq_socket_push, pt2_data, pt2_type_size(N_states)*n_tasks, ZMQ_SNDMORE) allocate(pt2_serialized (pt2_type_size(N_states),n_tasks) )
do i=1,n_tasks
call pt2_serialize(pt2_data(i),N_states,pt2_serialized(1,i))
enddo
rc = f77_zmq_send( zmq_socket_push, pt2_serialized, size(pt2_serialized)*8, ZMQ_SNDMORE)
deallocate(pt2_serialized)
if (rc == -1) then if (rc == -1) then
print *, irp_here, ': error sending result' print *, irp_here, ': error sending result'
stop 3 stop 3
return return
else if(rc /= pt2_type_size(N_states)*n_tasks) then else if(rc /= size(pt2_serialized)*8) then
stop 'push' stop 'push'
endif endif
@ -468,6 +461,7 @@ subroutine pull_pt2_results(zmq_socket_pull, index, pt2_data, task_id, n_tasks,
integer, intent(out) :: n_tasks, task_id(*) integer, intent(out) :: n_tasks, task_id(*)
integer :: rc, rn, i integer :: rc, rn, i
integer*8 :: rc8 integer*8 :: rc8
double precision, allocatable :: pt2_serialized(:,:)
rc = f77_zmq_recv( zmq_socket_pull, n_tasks, 4, 0) rc = f77_zmq_recv( zmq_socket_pull, n_tasks, 4, 0)
if (rc == -1) then if (rc == -1) then
@ -485,14 +479,20 @@ subroutine pull_pt2_results(zmq_socket_pull, index, pt2_data, task_id, n_tasks,
stop 'pull' stop 'pull'
endif endif
rc = f77_zmq_recv( zmq_socket_pull, pt2_data, pt2_type_size(N_states)*n_tasks, 0) allocate(pt2_serialized (pt2_type_size(N_states),n_tasks) )
rc = f77_zmq_recv( zmq_socket_pull, pt2_serialized, 8*size(pt2_serialized)*n_tasks, 0)
if (rc == -1) then if (rc == -1) then
n_tasks = 1 n_tasks = 1
task_id(1) = 0 task_id(1) = 0
else if(rc /= pt2_type_size(N_states)*n_tasks) then else if(rc /= 8*size(pt2_serialized)) then
stop 'pull' stop 'pull'
endif endif
do i=1,n_tasks
call pt2_deserialize(pt2_data(i),N_states,pt2_serialized(1,i))
enddo
deallocate(pt2_serialized)
rc = f77_zmq_recv( zmq_socket_pull, task_id, n_tasks*4, 0) rc = f77_zmq_recv( zmq_socket_pull, task_id, n_tasks*4, 0)
if (rc == -1) then if (rc == -1) then
n_tasks = 1 n_tasks = 1

View File

@ -18,9 +18,7 @@ subroutine run_selection_slave(thread,iproc,energy)
type(selection_buffer) :: buf, buf2 type(selection_buffer) :: buf, buf2
logical :: done, buffer_ready logical :: done, buffer_ready
double precision :: pt2(N_states) type(pt2_type) :: pt2_data
double precision :: variance(N_states)
double precision :: norm(N_states)
PROVIDE psi_bilinear_matrix_columns_loc psi_det_alpha_unique psi_det_beta_unique PROVIDE psi_bilinear_matrix_columns_loc psi_det_alpha_unique psi_det_beta_unique
PROVIDE psi_bilinear_matrix_rows psi_det_sorted_order psi_bilinear_matrix_order PROVIDE psi_bilinear_matrix_rows psi_det_sorted_order psi_bilinear_matrix_order
@ -28,6 +26,7 @@ subroutine run_selection_slave(thread,iproc,energy)
PROVIDE psi_bilinear_matrix_transp_order N_int pt2_F pseudo_sym PROVIDE psi_bilinear_matrix_transp_order N_int pt2_F pseudo_sym
PROVIDE psi_selectors_coef_transp psi_det_sorted weight_selection PROVIDE psi_selectors_coef_transp psi_det_sorted weight_selection
call pt2_alloc(pt2_data,N_states)
zmq_to_qp_run_socket = new_zmq_to_qp_run_socket() zmq_to_qp_run_socket = new_zmq_to_qp_run_socket()
@ -42,9 +41,6 @@ subroutine run_selection_slave(thread,iproc,energy)
buf%N = 0 buf%N = 0
buffer_ready = .False. buffer_ready = .False.
ctask = 1 ctask = 1
pt2(:) = 0d0
variance(:) = 0d0
norm(:) = 0.d0
do do
integer, external :: get_task_from_taskserver integer, external :: get_task_from_taskserver
@ -69,7 +65,7 @@ subroutine run_selection_slave(thread,iproc,energy)
stop '-1' stop '-1'
end if end if
end if end if
call select_connected(i_generator,energy,pt2,variance,norm,buf,subset,pt2_F(i_generator)) call select_connected(i_generator,energy,pt2_data,buf,subset,pt2_F(i_generator))
endif endif
integer, external :: task_done_to_taskserver integer, external :: task_done_to_taskserver
@ -88,12 +84,10 @@ subroutine run_selection_slave(thread,iproc,energy)
if(ctask > 0) then if(ctask > 0) then
call sort_selection_buffer(buf) call sort_selection_buffer(buf)
! call merge_selection_buffers(buf,buf2) ! call merge_selection_buffers(buf,buf2)
!print *, task_id(1), pt2(1), buf%cur, ctask call push_selection_results(zmq_socket_push, pt2_data, buf, task_id(1), ctask)
call push_selection_results(zmq_socket_push, pt2, variance, norm, buf, task_id(1), ctask) call pt2_dealloc(pt2_data)
call pt2_alloc(pt2_data,N_states)
! buf%mini = buf2%mini ! buf%mini = buf2%mini
pt2(:) = 0d0
variance(:) = 0d0
norm(:) = 0d0
buf%cur = 0 buf%cur = 0
end if end if
ctask = 0 ctask = 0
@ -106,11 +100,9 @@ subroutine run_selection_slave(thread,iproc,energy)
if(ctask > 0) then if(ctask > 0) then
call sort_selection_buffer(buf) call sort_selection_buffer(buf)
! call merge_selection_buffers(buf,buf2) ! call merge_selection_buffers(buf,buf2)
call push_selection_results(zmq_socket_push, pt2, variance, norm, buf, task_id(1), ctask) call push_selection_results(zmq_socket_push, pt2_data, buf, task_id(1), ctask)
call pt2_dealloc(pt2_data)
! buf%mini = buf2%mini ! buf%mini = buf2%mini
pt2(:) = 0d0
variance(:) = 0d0
norm(:) = 0d0
buf%cur = 0 buf%cur = 0
end if end if
ctask = 0 ctask = 0
@ -129,18 +121,17 @@ subroutine run_selection_slave(thread,iproc,energy)
end subroutine end subroutine
subroutine push_selection_results(zmq_socket_push, pt2, variance, norm, b, task_id, ntask) subroutine push_selection_results(zmq_socket_push, pt2_data, b, task_id, ntask)
use f77_zmq use f77_zmq
use selection_types use selection_types
implicit none implicit none
integer(ZMQ_PTR), intent(in) :: zmq_socket_push integer(ZMQ_PTR), intent(in) :: zmq_socket_push
double precision, intent(in) :: pt2(N_states) type(pt2_type), intent(in) :: pt2_data
double precision, intent(in) :: variance(N_states)
double precision, intent(in) :: norm(N_states)
type(selection_buffer), intent(inout) :: b type(selection_buffer), intent(inout) :: b
integer, intent(in) :: ntask, task_id(*) integer, intent(in) :: ntask, task_id(*)
integer :: rc integer :: rc
double precision, allocatable :: pt2_serialized(:,:)
rc = f77_zmq_send( zmq_socket_push, b%cur, 4, ZMQ_SNDMORE) rc = f77_zmq_send( zmq_socket_push, b%cur, 4, ZMQ_SNDMORE)
if(rc /= 4) then if(rc /= 4) then
@ -148,19 +139,19 @@ subroutine push_selection_results(zmq_socket_push, pt2, variance, norm, b, task_
endif endif
rc = f77_zmq_send( zmq_socket_push, pt2, 8*N_states, ZMQ_SNDMORE) allocate(pt2_serialized (pt2_type_size(N_states),n_tasks) )
if(rc /= 8*N_states) then do i=1,n_tasks
print *, 'f77_zmq_send( zmq_socket_push, pt2, 8*N_states, ZMQ_SNDMORE)' call pt2_serialize(pt2_data(i),N_states,pt2_serialized(1,i))
endif enddo
rc = f77_zmq_send( zmq_socket_push, variance, 8*N_states, ZMQ_SNDMORE) rc = f77_zmq_send( zmq_socket_push, pt2_serialized, size(pt2_serialized)*8, ZMQ_SNDMORE)
if(rc /= 8*N_states) then deallocate(pt2_serialized)
print *, 'f77_zmq_send( zmq_socket_push, variance, 8*N_states, ZMQ_SNDMORE)' if (rc == -1) then
endif print *, irp_here, ': error sending result'
stop 3
rc = f77_zmq_send( zmq_socket_push, norm, 8*N_states, ZMQ_SNDMORE) return
if(rc /= 8*N_states) then else if(rc /= size(pt2_serialized)*8) then
print *, 'f77_zmq_send( zmq_socket_push, norm, 8*N_states, ZMQ_SNDMORE)' stop 'push'
endif endif
if (b%cur > 0) then if (b%cur > 0) then
@ -201,42 +192,36 @@ IRP_ENDIF
end subroutine end subroutine
subroutine pull_selection_results(zmq_socket_pull, pt2, variance, norm, val, det, N, task_id, ntask) subroutine pull_selection_results(zmq_socket_pull, pt2_data, val, det, N, task_id, ntask)
use f77_zmq use f77_zmq
use selection_types use selection_types
implicit none implicit none
integer(ZMQ_PTR), intent(in) :: zmq_socket_pull integer(ZMQ_PTR), intent(in) :: zmq_socket_pull
double precision, intent(inout) :: pt2(N_states) type(pt2_type), intent(inout) :: pt2_data
double precision, intent(inout) :: variance(N_states)
double precision, intent(inout) :: norm(N_states)
double precision, intent(out) :: val(*) double precision, intent(out) :: val(*)
integer(bit_kind), intent(out) :: det(N_int, 2, *) integer(bit_kind), intent(out) :: det(N_int, 2, *)
integer, intent(out) :: N, ntask, task_id(*) integer, intent(out) :: N, ntask, task_id(*)
integer :: rc, rn, i integer :: rc, rn, i
double precision, allocatable :: pt2_serialized(:,:)
rc = f77_zmq_recv( zmq_socket_pull, N, 4, 0) rc = f77_zmq_recv( zmq_socket_pull, N, 4, 0)
if(rc /= 4) then if(rc /= 4) then
print *, 'f77_zmq_recv( zmq_socket_pull, N, 4, 0)' print *, 'f77_zmq_recv( zmq_socket_pull, N, 4, 0)'
endif endif
pt2(:) = 0.d0 allocate(pt2_serialized (pt2_type_size(N_states),n_tasks) )
variance(:) = 0.d0 rc = f77_zmq_recv( zmq_socket_pull, pt2_serialized, 8*size(pt2_serialized)*n_tasks, 0)
norm(:) = 0.d0 if (rc == -1) then
n_tasks = 1
rc = f77_zmq_recv( zmq_socket_pull, pt2, N_states*8, 0) task_id(1) = 0
if(rc /= 8*N_states) then else if(rc /= 8*size(pt2_serialized)) then
print *, 'f77_zmq_recv( zmq_socket_pull, pt2, N_states*8, 0)' stop 'pull'
endif endif
rc = f77_zmq_recv( zmq_socket_pull, variance, N_states*8, 0) do i=1,n_tasks
if(rc /= 8*N_states) then call pt2_deserialize(pt2_data(i),N_states,pt2_serialized(1,i))
print *, 'f77_zmq_recv( zmq_socket_pull, variance, N_states*8, 0)' enddo
endif deallocate(pt2_serialized)
rc = f77_zmq_recv( zmq_socket_pull, norm, N_states*8, 0)
if(rc /= 8*N_states) then
print *, 'f77_zmq_recv( zmq_socket_pull, norm, N_states*8, 0)'
endif
if (N>0) then if (N>0) then
rc = f77_zmq_recv( zmq_socket_pull, val(1), 8*N, 0) rc = f77_zmq_recv( zmq_socket_pull, val(1), 8*N, 0)

View File

@ -24,7 +24,7 @@ module selection_types
integer function pt2_type_size(N) integer function pt2_type_size(N)
implicit none implicit none
integer, intent(in) :: N integer, intent(in) :: N
pt2_type_size = 8*(8*n + 2*n*n) pt2_type_size = (8*n + 2*n*n)
end function end function
end module end module

View File

@ -129,12 +129,12 @@ subroutine ZMQ_selection(N_in, pt2_data)
call delete_selection_buffer(b) call delete_selection_buffer(b)
do k=1,N_states do k=1,N_states
pt2_data % pt2(k) = pt2_data % pt2(k) * f(k) pt2_data % pt2(k) = pt2_data % pt2(k) * f(k)
pt2_data % variance(k) = pt2_data % variance(k) * f(k) pt2_data % variance(k) = pt2_data % variance(k) * f(k)
pt2_data % norm2(k) = pt2_data % norm2(k) * f(k) pt2_data % norm2(k) = pt2_data % norm2(k) * f(k)
pt2_data % rpt2(k) = & pt2_data % rpt2(k) = &
pt2_data % pt2(k)/(1.d0 + pt2_data % norm2(k)) pt2_data % pt2(k)/(1.d0 + pt2_data % norm2(k))
enddo enddo
call update_pt2_and_variance_weights(pt2_data, N_states) call update_pt2_and_variance_weights(pt2_data, N_states)
@ -160,6 +160,7 @@ subroutine selection_collector(zmq_socket_pull, b, N, pt2_data)
type(selection_buffer), intent(inout) :: b type(selection_buffer), intent(inout) :: b
integer, intent(in) :: N integer, intent(in) :: N
type(pt2_type), intent(inout) :: pt2_data type(pt2_type), intent(inout) :: pt2_data
type(pt2_type) :: pt2_data_tmp
double precision :: pt2_mwen(N_states) double precision :: pt2_mwen(N_states)
double precision :: variance_mwen(N_states) double precision :: variance_mwen(N_states)
@ -190,15 +191,11 @@ subroutine selection_collector(zmq_socket_pull, b, N, pt2_data)
pt2_data % pt2(:) = 0d0 pt2_data % pt2(:) = 0d0
pt2_data % variance(:) = 0.d0 pt2_data % variance(:) = 0.d0
pt2_data % norm2(:) = 0.d0 pt2_data % norm2(:) = 0.d0
pt2_mwen(:) = 0.d0 call pt2_alloc(pt2_data_tmp,N_states)
variance_mwen(:) = 0.d0
norm2_mwen(:) = 0.d0
do while (more == 1) do while (more == 1)
call pull_selection_results(zmq_socket_pull, pt2_mwen, variance_mwen, norm2_mwen, b2%val(1), b2%det(1,1,1), b2%cur, task_id, ntask) call pull_selection_results(zmq_socket_pull, pt2_data_tmp, b2%val(1), b2%det(1,1,1), b2%cur, task_id, ntask)
pt2_data % pt2(:) += pt2_mwen(:) call pt2_add(pt2_data, 1.d0, pt2_data_tmp)
pt2_data % variance(:) += variance_mwen(:)
pt2_data % norm2(:) += norm2_mwen(:)
do i=1, b2%cur do i=1, b2%cur
call add_to_selection_buffer(b, b2%det(1,1,i), b2%val(i)) call add_to_selection_buffer(b, b2%det(1,1,i), b2%val(i))
if (b2%val(i) > b%mini) exit if (b2%val(i) > b%mini) exit
@ -214,6 +211,7 @@ subroutine selection_collector(zmq_socket_pull, b, N, pt2_data)
endif endif
end do end do
end do end do
call pt2_dealloc(pt2_data_tmp)
call delete_selection_buffer(b2) call delete_selection_buffer(b2)