diff --git a/plugins/Full_CI_ZMQ/pt2_stoch_routines.irp.f b/plugins/Full_CI_ZMQ/pt2_stoch_routines.irp.f index 70ce056f..41d62eca 100644 --- a/plugins/Full_CI_ZMQ/pt2_stoch_routines.irp.f +++ b/plugins/Full_CI_ZMQ/pt2_stoch_routines.irp.f @@ -63,40 +63,52 @@ subroutine ZMQ_pt2(pt2,relative_error) integer(ZMQ_PTR), external :: new_zmq_to_qp_run_socket integer :: ipos + logical :: tasks + tasks = .False. ipos=1 + do i=1,tbc(0) if(tbc(i) > fragment_first) then - write(task(ipos:ipos+20),'(I9,X,I9,''|'')') 0, i + write(task(ipos:ipos+20),'(I9,X,I9,''|'')') 0, tbc(i) ipos += 20 if (ipos > 64000) then call add_task_to_taskserver(zmq_to_qp_run_socket,trim(task(1:ipos-20))) ipos=1 + tasks = .True. endif else do j=1,fragment_count - write(task(ipos:ipos+20),'(I9,X,I9,''|'')') j, i + write(task(ipos:ipos+20),'(I9,X,I9,''|'')') j, tbc(i) ipos += 20 if (ipos > 64000) then call add_task_to_taskserver(zmq_to_qp_run_socket,trim(task(1:ipos-20))) ipos=1 + tasks = .True. endif end do end if end do if (ipos > 1) then call add_task_to_taskserver(zmq_to_qp_run_socket,trim(task(1:ipos-20))) + tasks = .True. endif - call zmq_set_running(zmq_to_qp_run_socket) - !$OMP PARALLEL DEFAULT(shared) NUM_THREADS(nproc+1) & - !$OMP PRIVATE(i) - i = omp_get_thread_num() - if (i==0) then - call pt2_collector(b, tbc, comb, Ncomb, computed, pt2_detail, sumabove, sum2above, Nabove, relative_error, pt2) - else - call pt2_slave_inproc(i) - endif - !$OMP END PARALLEL + if (tasks) then + call zmq_set_running(zmq_to_qp_run_socket) + + !$OMP PARALLEL DEFAULT(shared) NUM_THREADS(nproc+1) & + !$OMP PRIVATE(i) + i = omp_get_thread_num() + if (i==0) then + call pt2_collector(b, tbc, comb, Ncomb, computed, pt2_detail, sumabove, sum2above, Nabove, relative_error, pt2) + else + call pt2_slave_inproc(i) + endif + !$OMP END PARALLEL + + else + pt2(1) = sum(pt2_detail(1,:)) + endif call end_parallel_job(zmq_to_qp_run_socket, 'pt2') tbc(0) = 0 @@ -105,6 +117,7 @@ subroutine ZMQ_pt2(pt2,relative_error) endif end do + end subroutine @@ -160,7 +173,7 @@ subroutine pt2_collector(b, tbc, comb, Ncomb, computed, pt2_detail, sumabove, su type(selection_buffer), intent(inout) :: b - double precision :: pt2_mwen(N_states, N_det_generators) + double precision, allocatable :: pt2_mwen(:,:) integer(ZMQ_PTR),external :: new_zmq_to_qp_run_socket integer(ZMQ_PTR) :: zmq_to_qp_run_socket @@ -181,7 +194,8 @@ subroutine pt2_collector(b, tbc, comb, Ncomb, computed, pt2_detail, sumabove, su integer, allocatable :: parts_to_get(:) logical, allocatable :: actually_computed(:) - allocate(actually_computed(N_det_generators), parts_to_get(N_det_generators)) + allocate(actually_computed(N_det_generators), parts_to_get(N_det_generators), & + pt2_mwen(N_states, N_det_generators) ) actually_computed(:) = computed(:) parts_to_get(:) = 1 @@ -198,7 +212,7 @@ subroutine pt2_collector(b, tbc, comb, Ncomb, computed, pt2_detail, sumabove, su zmq_to_qp_run_socket = new_zmq_to_qp_run_socket() zmq_socket_pull = new_zmq_pull_socket() - allocate(val(b%N), det(N_int, 2, b%N), task_id(N_det_generators), index(N_det_generators)) + allocate(val(b%N), det(N_int, 2, b%N), task_id(N_det_generators), index(1)) more = 1 if (time0 < 0.d0) then time0 = omp_get_wtime() diff --git a/plugins/Full_CI_ZMQ/run_pt2_slave.irp.f b/plugins/Full_CI_ZMQ/run_pt2_slave.irp.f index f6f41ab3..4dd4374c 100644 --- a/plugins/Full_CI_ZMQ/run_pt2_slave.irp.f +++ b/plugins/Full_CI_ZMQ/run_pt2_slave.irp.f @@ -22,10 +22,10 @@ subroutine run_pt2_slave(thread,iproc,energy) double precision :: pt2(N_states) double precision,allocatable :: pt2_detail(:,:) - integer,allocatable :: index(:) + integer :: index integer :: Nindex - allocate(pt2_detail(N_states, N_det), index(N_det)) + allocate(pt2_detail(N_states, N_det)) zmq_to_qp_run_socket = new_zmq_to_qp_run_socket() zmq_socket_push = new_zmq_push_socket(thread) call connect_to_taskserver(zmq_to_qp_run_socket,worker_id,thread) @@ -37,9 +37,9 @@ subroutine run_pt2_slave(thread,iproc,energy) end if buf%N = 0 ctask = 1 + Nindex=1 pt2 = 0d0 pt2_detail = 0d0 - Nindex=1 do call get_task_from_taskserver(zmq_to_qp_run_socket,worker_id, task_id(ctask), task) @@ -48,8 +48,7 @@ subroutine run_pt2_slave(thread,iproc,energy) ctask = ctask - 1 else integer :: i_generator, i_i_generator, N, subset - read (task,*) Nindex - read (task,*) Nindex, subset, index(:Nindex) + read (task,*) subset, index !!!!! N=1 @@ -62,7 +61,7 @@ subroutine run_pt2_slave(thread,iproc,energy) if(N /= buf%N) stop "N changed... wtf man??" end if do i_i_generator=1, Nindex - i_generator = index(i_i_generator) + i_generator = index call select_connected(i_generator,energy,pt2_detail(1, i_i_generator),buf,subset) pt2(:) += pt2_detail(:, i_generator) enddo @@ -75,7 +74,6 @@ subroutine run_pt2_slave(thread,iproc,energy) end do if(ctask > 0) then call push_pt2_results(zmq_socket_push, Nindex, index, pt2_detail, task_id(1), ctask) - !print *, "pushed ", index(:Nindex) do i=1,buf%cur call add_to_selection_buffer(buf2, buf%det(1,1,i), buf%val(i)) enddo @@ -104,14 +102,14 @@ subroutine push_pt2_results(zmq_socket_push, N, index, pt2_detail, task_id, ntas integer(ZMQ_PTR), intent(in) :: zmq_socket_push double precision, intent(in) :: pt2_detail(N_states, N_det) - integer, intent(in) :: ntask, N, index(N), task_id(*) + integer, intent(in) :: ntask, N, index, task_id(*) integer :: rc rc = f77_zmq_send( zmq_socket_push, N, 4, ZMQ_SNDMORE) if(rc /= 4) stop "push" - rc = f77_zmq_send( zmq_socket_push, index, 4*N, ZMQ_SNDMORE) + rc = f77_zmq_send( zmq_socket_push, index, 4, ZMQ_SNDMORE) if(rc /= 4*N) stop "push" @@ -121,7 +119,7 @@ subroutine push_pt2_results(zmq_socket_push, N, index, pt2_detail, task_id, ntas rc = f77_zmq_send( zmq_socket_push, ntask, 4, ZMQ_SNDMORE) if(rc /= 4) stop "push" - rc = f77_zmq_send( zmq_socket_push, task_id(1), ntask*4, 0) + rc = f77_zmq_send( zmq_socket_push, task_id, ntask*4, 0) if(rc /= 4*ntask) stop "push" ! Activate is zmq_socket_push is a REQ @@ -136,14 +134,14 @@ subroutine pull_pt2_results(zmq_socket_pull, N, index, pt2_detail, task_id, ntas implicit none integer(ZMQ_PTR), intent(in) :: zmq_socket_pull double precision, intent(inout) :: pt2_detail(N_states, N_det) - integer, intent(out) :: index(N_det) + integer, intent(out) :: index integer, intent(out) :: N, ntask, task_id(*) integer :: rc, rn, i rc = f77_zmq_recv( zmq_socket_pull, N, 4, 0) if(rc /= 4) stop "pull" - rc = f77_zmq_recv( zmq_socket_pull, index, 4*N, 0) + rc = f77_zmq_recv( zmq_socket_pull, index, 4, 0) if(rc /= 4*N) stop "pull" rc = f77_zmq_recv( zmq_socket_pull, pt2_detail, N_states*8*N, 0)