diff --git a/src/cipsi/pt2_stoch_routines.irp.f b/src/cipsi/pt2_stoch_routines.irp.f index fa1438dc..4b781dd8 100644 --- a/src/cipsi/pt2_stoch_routines.irp.f +++ b/src/cipsi/pt2_stoch_routines.irp.f @@ -474,10 +474,12 @@ subroutine pt2_collector(zmq_socket_pull, E, relative_error, pt2_data, b, N_) do p=pt2_N_teeth, 1, -1 v = pt2_u_0 + pt2_W_T * (pt2_u(c) + dble(p-1)) i = pt2_find_sample_lr(v, pt2_cW,pt2_n_0(p),pt2_n_0(p+1)) - call pt2_add ( pt2_data_teeth, pt2_W_T / pt2_w(i), pt2_data_I(i) ) + v = pt2_W_T / pt2_w(i) + call pt2_add ( pt2_data_teeth, v, pt2_data_I(i) ) call pt2_add ( pt2_data_S(p), 1.d0, pt2_data_teeth ) call pt2_add2( pt2_data_S2(p), 1.d0, pt2_data_teeth ) - end do + enddo + call pt2_dealloc(pt2_data_teeth) avg = E0 + pt2_data_S(t) % pt2(pt2_stoch_istate) / dble(c) avg2 = v0 + pt2_data_S(t) % variance(pt2_stoch_istate) / dble(c) @@ -546,14 +548,6 @@ subroutine pt2_collector(zmq_socket_pull, E, relative_error, pt2_data, b, N_) print*,'i,index(i),size(pt2_data_I,1) = ',i,index(i),size(pt2_data_I,1) stop -1 endif -! print *, pt2_data_I(index(i))%pt2 -! print *, pt2_data_I(index(i))%variance -! print *, pt2_data_I(index(i))%norm2 -! print *, '' -! print *, pt2_data_task(i)%pt2 -! print *, pt2_data_task(i)%variance -! print *, pt2_data_task(i)%norm2 -! print *, '' call pt2_add(pt2_data_I(index(i)),1.d0,pt2_data_task(i)) f(index(i)) -= 1 end do diff --git a/src/cipsi/pt2_type.irp.f b/src/cipsi/pt2_type.irp.f index af8cf6a7..e6f31799 100644 --- a/src/cipsi/pt2_type.irp.f +++ b/src/cipsi/pt2_type.irp.f @@ -60,16 +60,33 @@ subroutine pt2_add(p1, w, p2) double precision, intent(in) :: w type(pt2_type), intent(in) :: p2 - p1 % pt2(:) = p1 % pt2(:) + w * p2 % pt2(:) - p1 % pt2_err(:) = p1 % pt2_err(:) + w * p2 % pt2_err(:) - p1 % rpt2(:) = p1 % rpt2(:) + w * p2 % rpt2(:) - p1 % rpt2_err(:) = p1 % rpt2_err(:) + w * p2 % rpt2_err(:) - p1 % variance(:) = p1 % variance(:) + w * p2 % variance(:) - p1 % variance_err(:) = p1 % variance_err(:) + w * p2 % variance_err(:) - p1 % norm2(:) = p1 % norm2(:) + w * p2 % norm2(:) - p1 % norm2_err(:) = p1 % norm2_err(:) + w * p2 % norm2_err(:) - p1 % overlap(:,:) = p1 % overlap(:,:) + w * p2 % overlap(:,:) - p1 % overlap_err(:,:) = p1 % overlap_err(:,:) + w * p2 % overlap_err(:,:) + if (w == 1.d0) then + + p1 % pt2(:) = p1 % pt2(:) + p2 % pt2(:) + p1 % pt2_err(:) = p1 % pt2_err(:) + p2 % pt2_err(:) + p1 % rpt2(:) = p1 % rpt2(:) + p2 % rpt2(:) + p1 % rpt2_err(:) = p1 % rpt2_err(:) + p2 % rpt2_err(:) + p1 % variance(:) = p1 % variance(:) + p2 % variance(:) + p1 % variance_err(:) = p1 % variance_err(:) + p2 % variance_err(:) + p1 % norm2(:) = p1 % norm2(:) + p2 % norm2(:) + p1 % norm2_err(:) = p1 % norm2_err(:) + p2 % norm2_err(:) + p1 % overlap(:,:) = p1 % overlap(:,:) + p2 % overlap(:,:) + p1 % overlap_err(:,:) = p1 % overlap_err(:,:) + p2 % overlap_err(:,:) + + else + + p1 % pt2(:) = p1 % pt2(:) + w * p2 % pt2(:) + p1 % pt2_err(:) = p1 % pt2_err(:) + w * p2 % pt2_err(:) + p1 % rpt2(:) = p1 % rpt2(:) + w * p2 % rpt2(:) + p1 % rpt2_err(:) = p1 % rpt2_err(:) + w * p2 % rpt2_err(:) + p1 % variance(:) = p1 % variance(:) + w * p2 % variance(:) + p1 % variance_err(:) = p1 % variance_err(:) + w * p2 % variance_err(:) + p1 % norm2(:) = p1 % norm2(:) + w * p2 % norm2(:) + p1 % norm2_err(:) = p1 % norm2_err(:) + w * p2 % norm2_err(:) + p1 % overlap(:,:) = p1 % overlap(:,:) + w * p2 % overlap(:,:) + p1 % overlap_err(:,:) = p1 % overlap_err(:,:) + w * p2 % overlap_err(:,:) + + endif end subroutine @@ -84,17 +101,83 @@ subroutine pt2_add2(p1, w, p2) double precision, intent(in) :: w type(pt2_type), intent(in) :: p2 - p1 % pt2(:) = p1 % pt2(:) + w * p2 % pt2(:) * p2 % pt2(:) - p1 % pt2_err(:) = p1 % pt2_err(:) + w * p2 % pt2_err(:) * p2 % pt2_err(:) - p1 % rpt2(:) = p1 % rpt2(:) + w * p2 % rpt2(:) * p2 % rpt2(:) - p1 % rpt2_err(:) = p1 % rpt2_err(:) + w * p2 % rpt2_err(:) * p2 % rpt2_err(:) - p1 % variance(:) = p1 % variance(:) + w * p2 % variance(:) * p2 % variance(:) - p1 % variance_err(:) = p1 % variance_err(:) + w * p2 % variance_err(:) * p2 % variance_err(:) - p1 % norm2(:) = p1 % norm2(:) + w * p2 % norm2(:) * p2 % norm2(:) - p1 % norm2_err(:) = p1 % norm2_err(:) + w * p2 % norm2_err(:) * p2 % norm2_err(:) - p1 % overlap(:,:) = p1 % overlap(:,:) + w * p2 % overlap(:,:) * p2 % overlap(:,:) - p1 % overlap_err(:,:) = p1 % overlap_err(:,:) + w * p2 % overlap_err(:,:) * p2 % overlap_err(:,:) + if (w == 1.d0) then + + p1 % pt2(:) = p1 % pt2(:) + p2 % pt2(:) * p2 % pt2(:) + p1 % pt2_err(:) = p1 % pt2_err(:) + p2 % pt2_err(:) * p2 % pt2_err(:) + p1 % rpt2(:) = p1 % rpt2(:) + p2 % rpt2(:) * p2 % rpt2(:) + p1 % rpt2_err(:) = p1 % rpt2_err(:) + p2 % rpt2_err(:) * p2 % rpt2_err(:) + p1 % variance(:) = p1 % variance(:) + p2 % variance(:) * p2 % variance(:) + p1 % variance_err(:) = p1 % variance_err(:) + p2 % variance_err(:) * p2 % variance_err(:) + p1 % norm2(:) = p1 % norm2(:) + p2 % norm2(:) * p2 % norm2(:) + p1 % norm2_err(:) = p1 % norm2_err(:) + p2 % norm2_err(:) * p2 % norm2_err(:) + p1 % overlap(:,:) = p1 % overlap(:,:) + p2 % overlap(:,:) * p2 % overlap(:,:) + p1 % overlap_err(:,:) = p1 % overlap_err(:,:) + p2 % overlap_err(:,:) * p2 % overlap_err(:,:) + + else + + p1 % pt2(:) = p1 % pt2(:) + w * p2 % pt2(:) * p2 % pt2(:) + p1 % pt2_err(:) = p1 % pt2_err(:) + w * p2 % pt2_err(:) * p2 % pt2_err(:) + p1 % rpt2(:) = p1 % rpt2(:) + w * p2 % rpt2(:) * p2 % rpt2(:) + p1 % rpt2_err(:) = p1 % rpt2_err(:) + w * p2 % rpt2_err(:) * p2 % rpt2_err(:) + p1 % variance(:) = p1 % variance(:) + w * p2 % variance(:) * p2 % variance(:) + p1 % variance_err(:) = p1 % variance_err(:) + w * p2 % variance_err(:) * p2 % variance_err(:) + p1 % norm2(:) = p1 % norm2(:) + w * p2 % norm2(:) * p2 % norm2(:) + p1 % norm2_err(:) = p1 % norm2_err(:) + w * p2 % norm2_err(:) * p2 % norm2_err(:) + p1 % overlap(:,:) = p1 % overlap(:,:) + w * p2 % overlap(:,:) * p2 % overlap(:,:) + p1 % overlap_err(:,:) = p1 % overlap_err(:,:) + w * p2 % overlap_err(:,:) * p2 % overlap_err(:,:) + + endif end subroutine +subroutine pt2_serialize(pt2_data, n, x) + implicit none + use selection_types + type(pt2_type), intent(in) :: pt2_data + integer, intent(in) :: n + double precision, intent(out) :: x(*) + + integer :: i,k,n2 + + n2 = n*n + x(1:n) = pt2_data % pt2(1:n) + x(n+1:2*n) = pt2_data % pt2_err(1:n) + x(2*n+1:3*n) = pt2_data % rpt2(1:n) + x(3*n+1:4*n) = pt2_data % rpt2_err(1:n) + x(4*n+1:5*n) = pt2_data % variance(1:n) + x(5*n+1:6*n) = pt2_data % variance_err(1:n) + x(6*n+1:7*n) = pt2_data % norm2(1:n) + x(7*n+1:8*n) = pt2_data % norm2_err(1:n) + k=8*n + x(k+1:k+n2) = reshape(pt2_data % overlap(1:n,1:n), (/ n2 /)) + k=8*n+n2 + x(k+1:k+n2) = reshape(pt2_data % overlap_err(1:n,1:n), (/ n2 /)) + +end + +subroutine pt2_deserialize(pt2_data, n, x) + implicit none + use selection_types + type(pt2_type), intent(inout) :: pt2_data + integer, intent(in) :: n + double precision, intent(in) :: x(*) + + integer :: i,k,n2 + + n2 = n*n + pt2_data % pt2(1:n) = x(1:n) + pt2_data % pt2_err(1:n) = x(n+1:2*n) + pt2_data % rpt2(1:n) = x(2*n+1:3*n) + pt2_data % rpt2_err(1:n) = x(3*n+1:4*n) + pt2_data % variance(1:n) = x(4*n+1:5*n) + pt2_data % variance_err(1:n) = x(5*n+1:6*n) + pt2_data % norm2(1:n) = x(6*n+1:7*n) + pt2_data % norm2_err(1:n) = x(7*n+1:8*n) + k=8*n + pt2_data % overlap(1:n,1:n) = reshape(x(k+1:k+n2), (/ n, n /)) + k=8*n+n2 + pt2_data % overlap_err(1:n,1:n) = reshape(x(k+1:k+n2), (/ n, n /)) + +end diff --git a/src/cipsi/run_pt2_slave.irp.f b/src/cipsi/run_pt2_slave.irp.f index 1859ca88..3c72dac0 100644 --- a/src/cipsi/run_pt2_slave.irp.f +++ b/src/cipsi/run_pt2_slave.irp.f @@ -141,6 +141,7 @@ subroutine run_pt2_slave_small(thread,iproc,energy) ! ! Try to adjust n_tasks around nproc/2 seconds per job n_tasks = min(2*n_tasks,int( dble(n_tasks * nproc/2) / (time1 - time0 + 1.d0))) + n_tasks = min(n_tasks, pt2_n_tasks_max) ! n_tasks = 1 end do @@ -169,8 +170,8 @@ subroutine run_pt2_slave_large(thread,iproc,energy) integer :: rc, i integer :: worker_id, ctask, ltask - character*(512), allocatable :: task(:) - integer, allocatable :: task_id(:) + character*(512) :: task + integer :: task_id integer(ZMQ_PTR),external :: new_zmq_to_qp_run_socket integer(ZMQ_PTR) :: zmq_to_qp_run_socket @@ -181,18 +182,15 @@ subroutine run_pt2_slave_large(thread,iproc,energy) type(selection_buffer) :: b logical :: done, buffer_ready - type(pt2_type), allocatable :: pt2_data(:) + type(pt2_type) :: pt2_data integer :: n_tasks, k, N - integer, allocatable :: i_generator(:), subset(:) + integer :: i_generator, subset integer :: bsize ! Size of selection buffers logical :: sending PROVIDE global_selection_buffer global_selection_buffer_lock - allocate(task_id(pt2_n_tasks_max), task(pt2_n_tasks_max)) - allocate(pt2_data(pt2_n_tasks_max), i_generator(pt2_n_tasks_max), subset(pt2_n_tasks_max)) - zmq_to_qp_run_socket = new_zmq_to_qp_run_socket() integer, external :: connect_to_taskserver @@ -211,22 +209,17 @@ subroutine run_pt2_slave_large(thread,iproc,energy) done = .False. do while (.not.done) - n_tasks = max(1,n_tasks) - n_tasks = min(pt2_n_tasks_max,n_tasks) - integer, external :: get_tasks_from_taskserver if (get_tasks_from_taskserver(zmq_to_qp_run_socket,worker_id, task_id, task, n_tasks) == -1) then exit endif - done = task_id(n_tasks) == 0 + done = task_id == 0 if (done) then n_tasks = n_tasks-1 endif if (n_tasks == 0) exit - do k=1,n_tasks - read (task(k),*) subset(k), i_generator(k), N - enddo + read (task,*) subset, i_generator, N if (b%N == 0) then ! Only first time bsize = min(N, (elec_alpha_num * (mo_num-elec_alpha_num))**2) @@ -238,15 +231,13 @@ subroutine run_pt2_slave_large(thread,iproc,energy) double precision :: time0, time1 call wall_time(time0) - do k=1,n_tasks - call pt2_alloc(pt2_data(k),N_states) - b%cur = 0 + call pt2_alloc(pt2_data,N_states) + b%cur = 0 !double precision :: time2 !call wall_time(time2) - call select_connected(i_generator(k),energy,pt2_data(k),b,subset(k),pt2_F(i_generator(k))) + call select_connected(i_generator,energy,pt2_data,b,subset,pt2_F(i_generator)) !call wall_time(time1) !print *, i_generator(1), time1-time2, n_tasks, pt2_F(i_generator(1)) - enddo call wall_time(time1) !print *, '-->', i_generator(1), time1-time0, n_tasks @@ -270,12 +261,7 @@ subroutine run_pt2_slave_large(thread,iproc,energy) call push_pt2_results_async_send(zmq_socket_push, i_generator, pt2_data, b, task_id, n_tasks,sending) endif - do k=1,n_tasks - call pt2_dealloc(pt2_data(k)) - enddo -! ! Try to adjust n_tasks around nproc/2 seconds per job -! n_tasks = min(2*n_tasks,int( dble(n_tasks * nproc/2) / (time1 - time0 + 1.d0))) - n_tasks = 1 + call pt2_dealloc(pt2_data) end do call push_pt2_results_async_recv(zmq_socket_push,b%mini,sending) @@ -322,8 +308,9 @@ subroutine push_pt2_results_async_send(zmq_socket_push, index, pt2_data, b, task integer, intent(in) :: n_tasks, index(n_tasks), task_id(n_tasks) type(selection_buffer), intent(inout) :: b logical, intent(inout) :: sending - integer :: rc + integer :: rc, i integer*8 :: rc8 + double precision, allocatable :: pt2_serialized(:,:) if (sending) then print *, irp_here, ': sending is true' @@ -351,12 +338,18 @@ subroutine push_pt2_results_async_send(zmq_socket_push, index, pt2_data, b, task endif - rc = f77_zmq_send( zmq_socket_push, pt2_data, pt2_type_size(N_states)*n_tasks, ZMQ_SNDMORE) + allocate(pt2_serialized (pt2_type_size(N_states),n_tasks) ) + do i=1,n_tasks + call pt2_serialize(pt2_data(i),N_states,pt2_serialized(1,i)) + enddo + + rc = f77_zmq_send( zmq_socket_push, pt2_serialized, size(pt2_serialized)*8, ZMQ_SNDMORE) + deallocate(pt2_serialized) if (rc == -1) then print *, irp_here, ': error sending result' stop 3 return - else if(rc /= pt2_type_size(N_states)*n_tasks) then + else if(rc /= size(pt2_serialized)*8) then stop 'push' endif @@ -468,6 +461,7 @@ subroutine pull_pt2_results(zmq_socket_pull, index, pt2_data, task_id, n_tasks, integer, intent(out) :: n_tasks, task_id(*) integer :: rc, rn, i integer*8 :: rc8 + double precision, allocatable :: pt2_serialized(:,:) rc = f77_zmq_recv( zmq_socket_pull, n_tasks, 4, 0) if (rc == -1) then @@ -485,14 +479,20 @@ subroutine pull_pt2_results(zmq_socket_pull, index, pt2_data, task_id, n_tasks, stop 'pull' endif - rc = f77_zmq_recv( zmq_socket_pull, pt2_data, pt2_type_size(N_states)*n_tasks, 0) + allocate(pt2_serialized (pt2_type_size(N_states),n_tasks) ) + rc = f77_zmq_recv( zmq_socket_pull, pt2_serialized, 8*size(pt2_serialized)*n_tasks, 0) if (rc == -1) then n_tasks = 1 task_id(1) = 0 - else if(rc /= pt2_type_size(N_states)*n_tasks) then + else if(rc /= 8*size(pt2_serialized)) then stop 'pull' endif + do i=1,n_tasks + call pt2_deserialize(pt2_data(i),N_states,pt2_serialized(1,i)) + enddo + deallocate(pt2_serialized) + rc = f77_zmq_recv( zmq_socket_pull, task_id, n_tasks*4, 0) if (rc == -1) then n_tasks = 1 diff --git a/src/cipsi/run_selection_slave.irp.f b/src/cipsi/run_selection_slave.irp.f index fe712c45..69a8a4c3 100644 --- a/src/cipsi/run_selection_slave.irp.f +++ b/src/cipsi/run_selection_slave.irp.f @@ -18,9 +18,7 @@ subroutine run_selection_slave(thread,iproc,energy) type(selection_buffer) :: buf, buf2 logical :: done, buffer_ready - double precision :: pt2(N_states) - double precision :: variance(N_states) - double precision :: norm(N_states) + type(pt2_type) :: pt2_data PROVIDE psi_bilinear_matrix_columns_loc psi_det_alpha_unique psi_det_beta_unique PROVIDE psi_bilinear_matrix_rows psi_det_sorted_order psi_bilinear_matrix_order @@ -28,6 +26,7 @@ subroutine run_selection_slave(thread,iproc,energy) PROVIDE psi_bilinear_matrix_transp_order N_int pt2_F pseudo_sym PROVIDE psi_selectors_coef_transp psi_det_sorted weight_selection + call pt2_alloc(pt2_data,N_states) zmq_to_qp_run_socket = new_zmq_to_qp_run_socket() @@ -42,9 +41,6 @@ subroutine run_selection_slave(thread,iproc,energy) buf%N = 0 buffer_ready = .False. ctask = 1 - pt2(:) = 0d0 - variance(:) = 0d0 - norm(:) = 0.d0 do integer, external :: get_task_from_taskserver @@ -69,7 +65,7 @@ subroutine run_selection_slave(thread,iproc,energy) stop '-1' end if end if - call select_connected(i_generator,energy,pt2,variance,norm,buf,subset,pt2_F(i_generator)) + call select_connected(i_generator,energy,pt2_data,buf,subset,pt2_F(i_generator)) endif integer, external :: task_done_to_taskserver @@ -88,12 +84,10 @@ subroutine run_selection_slave(thread,iproc,energy) if(ctask > 0) then call sort_selection_buffer(buf) ! call merge_selection_buffers(buf,buf2) -!print *, task_id(1), pt2(1), buf%cur, ctask - call push_selection_results(zmq_socket_push, pt2, variance, norm, buf, task_id(1), ctask) + call push_selection_results(zmq_socket_push, pt2_data, buf, task_id(1), ctask) + call pt2_dealloc(pt2_data) + call pt2_alloc(pt2_data,N_states) ! buf%mini = buf2%mini - pt2(:) = 0d0 - variance(:) = 0d0 - norm(:) = 0d0 buf%cur = 0 end if ctask = 0 @@ -106,11 +100,9 @@ subroutine run_selection_slave(thread,iproc,energy) if(ctask > 0) then call sort_selection_buffer(buf) ! call merge_selection_buffers(buf,buf2) - call push_selection_results(zmq_socket_push, pt2, variance, norm, buf, task_id(1), ctask) + call push_selection_results(zmq_socket_push, pt2_data, buf, task_id(1), ctask) + call pt2_dealloc(pt2_data) ! buf%mini = buf2%mini - pt2(:) = 0d0 - variance(:) = 0d0 - norm(:) = 0d0 buf%cur = 0 end if ctask = 0 @@ -129,18 +121,17 @@ subroutine run_selection_slave(thread,iproc,energy) end subroutine -subroutine push_selection_results(zmq_socket_push, pt2, variance, norm, b, task_id, ntask) +subroutine push_selection_results(zmq_socket_push, pt2_data, b, task_id, ntask) use f77_zmq use selection_types implicit none integer(ZMQ_PTR), intent(in) :: zmq_socket_push - double precision, intent(in) :: pt2(N_states) - double precision, intent(in) :: variance(N_states) - double precision, intent(in) :: norm(N_states) + type(pt2_type), intent(in) :: pt2_data type(selection_buffer), intent(inout) :: b integer, intent(in) :: ntask, task_id(*) integer :: rc + double precision, allocatable :: pt2_serialized(:,:) rc = f77_zmq_send( zmq_socket_push, b%cur, 4, ZMQ_SNDMORE) if(rc /= 4) then @@ -148,19 +139,19 @@ subroutine push_selection_results(zmq_socket_push, pt2, variance, norm, b, task_ endif - rc = f77_zmq_send( zmq_socket_push, pt2, 8*N_states, ZMQ_SNDMORE) - if(rc /= 8*N_states) then - print *, 'f77_zmq_send( zmq_socket_push, pt2, 8*N_states, ZMQ_SNDMORE)' - endif + allocate(pt2_serialized (pt2_type_size(N_states),n_tasks) ) + do i=1,n_tasks + call pt2_serialize(pt2_data(i),N_states,pt2_serialized(1,i)) + enddo - rc = f77_zmq_send( zmq_socket_push, variance, 8*N_states, ZMQ_SNDMORE) - if(rc /= 8*N_states) then - print *, 'f77_zmq_send( zmq_socket_push, variance, 8*N_states, ZMQ_SNDMORE)' - endif - - rc = f77_zmq_send( zmq_socket_push, norm, 8*N_states, ZMQ_SNDMORE) - if(rc /= 8*N_states) then - print *, 'f77_zmq_send( zmq_socket_push, norm, 8*N_states, ZMQ_SNDMORE)' + rc = f77_zmq_send( zmq_socket_push, pt2_serialized, size(pt2_serialized)*8, ZMQ_SNDMORE) + deallocate(pt2_serialized) + if (rc == -1) then + print *, irp_here, ': error sending result' + stop 3 + return + else if(rc /= size(pt2_serialized)*8) then + stop 'push' endif if (b%cur > 0) then @@ -201,42 +192,36 @@ IRP_ENDIF end subroutine -subroutine pull_selection_results(zmq_socket_pull, pt2, variance, norm, val, det, N, task_id, ntask) +subroutine pull_selection_results(zmq_socket_pull, pt2_data, val, det, N, task_id, ntask) use f77_zmq use selection_types implicit none integer(ZMQ_PTR), intent(in) :: zmq_socket_pull - double precision, intent(inout) :: pt2(N_states) - double precision, intent(inout) :: variance(N_states) - double precision, intent(inout) :: norm(N_states) + type(pt2_type), intent(inout) :: pt2_data double precision, intent(out) :: val(*) integer(bit_kind), intent(out) :: det(N_int, 2, *) integer, intent(out) :: N, ntask, task_id(*) integer :: rc, rn, i + double precision, allocatable :: pt2_serialized(:,:) rc = f77_zmq_recv( zmq_socket_pull, N, 4, 0) if(rc /= 4) then print *, 'f77_zmq_recv( zmq_socket_pull, N, 4, 0)' endif - pt2(:) = 0.d0 - variance(:) = 0.d0 - norm(:) = 0.d0 - - rc = f77_zmq_recv( zmq_socket_pull, pt2, N_states*8, 0) - if(rc /= 8*N_states) then - print *, 'f77_zmq_recv( zmq_socket_pull, pt2, N_states*8, 0)' + allocate(pt2_serialized (pt2_type_size(N_states),n_tasks) ) + rc = f77_zmq_recv( zmq_socket_pull, pt2_serialized, 8*size(pt2_serialized)*n_tasks, 0) + if (rc == -1) then + n_tasks = 1 + task_id(1) = 0 + else if(rc /= 8*size(pt2_serialized)) then + stop 'pull' endif - rc = f77_zmq_recv( zmq_socket_pull, variance, N_states*8, 0) - if(rc /= 8*N_states) then - print *, 'f77_zmq_recv( zmq_socket_pull, variance, N_states*8, 0)' - endif - - rc = f77_zmq_recv( zmq_socket_pull, norm, N_states*8, 0) - if(rc /= 8*N_states) then - print *, 'f77_zmq_recv( zmq_socket_pull, norm, N_states*8, 0)' - endif + do i=1,n_tasks + call pt2_deserialize(pt2_data(i),N_states,pt2_serialized(1,i)) + enddo + deallocate(pt2_serialized) if (N>0) then rc = f77_zmq_recv( zmq_socket_pull, val(1), 8*N, 0) diff --git a/src/cipsi/selection_types.f90 b/src/cipsi/selection_types.f90 index 52b84cf1..eef57aa5 100644 --- a/src/cipsi/selection_types.f90 +++ b/src/cipsi/selection_types.f90 @@ -24,7 +24,7 @@ module selection_types integer function pt2_type_size(N) implicit none integer, intent(in) :: N - pt2_type_size = 8*(8*n + 2*n*n) + pt2_type_size = (8*n + 2*n*n) end function end module diff --git a/src/cipsi/zmq_selection.irp.f b/src/cipsi/zmq_selection.irp.f index 006e6578..789c7a26 100644 --- a/src/cipsi/zmq_selection.irp.f +++ b/src/cipsi/zmq_selection.irp.f @@ -129,12 +129,12 @@ subroutine ZMQ_selection(N_in, pt2_data) call delete_selection_buffer(b) do k=1,N_states - pt2_data % pt2(k) = pt2_data % pt2(k) * f(k) - pt2_data % variance(k) = pt2_data % variance(k) * f(k) - pt2_data % norm2(k) = pt2_data % norm2(k) * f(k) + pt2_data % pt2(k) = pt2_data % pt2(k) * f(k) + pt2_data % variance(k) = pt2_data % variance(k) * f(k) + pt2_data % norm2(k) = pt2_data % norm2(k) * f(k) - pt2_data % rpt2(k) = & - pt2_data % pt2(k)/(1.d0 + pt2_data % norm2(k)) + pt2_data % rpt2(k) = & + pt2_data % pt2(k)/(1.d0 + pt2_data % norm2(k)) enddo call update_pt2_and_variance_weights(pt2_data, N_states) @@ -160,6 +160,7 @@ subroutine selection_collector(zmq_socket_pull, b, N, pt2_data) type(selection_buffer), intent(inout) :: b integer, intent(in) :: N type(pt2_type), intent(inout) :: pt2_data + type(pt2_type) :: pt2_data_tmp double precision :: pt2_mwen(N_states) double precision :: variance_mwen(N_states) @@ -190,15 +191,11 @@ subroutine selection_collector(zmq_socket_pull, b, N, pt2_data) pt2_data % pt2(:) = 0d0 pt2_data % variance(:) = 0.d0 pt2_data % norm2(:) = 0.d0 - pt2_mwen(:) = 0.d0 - variance_mwen(:) = 0.d0 - norm2_mwen(:) = 0.d0 + call pt2_alloc(pt2_data_tmp,N_states) do while (more == 1) - call pull_selection_results(zmq_socket_pull, pt2_mwen, variance_mwen, norm2_mwen, b2%val(1), b2%det(1,1,1), b2%cur, task_id, ntask) + call pull_selection_results(zmq_socket_pull, pt2_data_tmp, b2%val(1), b2%det(1,1,1), b2%cur, task_id, ntask) - pt2_data % pt2(:) += pt2_mwen(:) - pt2_data % variance(:) += variance_mwen(:) - pt2_data % norm2(:) += norm2_mwen(:) + call pt2_add(pt2_data, 1.d0, pt2_data_tmp) do i=1, b2%cur call add_to_selection_buffer(b, b2%det(1,1,i), b2%val(i)) if (b2%val(i) > b%mini) exit @@ -214,6 +211,7 @@ subroutine selection_collector(zmq_socket_pull, b, N, pt2_data) endif end do end do + call pt2_dealloc(pt2_data_tmp) call delete_selection_buffer(b2)