From a9aeda49584272c726b11a7ebc0a5cf24af483c0 Mon Sep 17 00:00:00 2001 From: Anthony Scemama Date: Mon, 27 Nov 2017 19:44:29 +0100 Subject: [PATCH] Packed tasks for PT2 stoch --- config/gfortran_debug.cfg | 2 +- ocaml/TaskServer.ml | 4 +- plugins/Full_CI_ZMQ/pt2_stoch_routines.irp.f | 33 ++--- plugins/Full_CI_ZMQ/run_pt2_slave.irp.f | 134 ++++++++----------- plugins/Full_CI_ZMQ/selection.irp.f | 1 + src/ZMQ/utils.irp.f | 85 ++++++------ 6 files changed, 123 insertions(+), 136 deletions(-) diff --git a/config/gfortran_debug.cfg b/config/gfortran_debug.cfg index 8a7ce9b8..dbd99539 100644 --- a/config/gfortran_debug.cfg +++ b/config/gfortran_debug.cfg @@ -13,7 +13,7 @@ FC : gfortran -g -ffree-line-length-none -I . LAPACK_LIB : -lblas -llapack IRPF90 : irpf90 -IRPF90_FLAGS : --ninja --align=32 --assert +IRPF90_FLAGS : --ninja --align=32 --assert # Global options ################ diff --git a/ocaml/TaskServer.ml b/ocaml/TaskServer.ml index 04c3c8a0..a98efd66 100644 --- a/ocaml/TaskServer.ml +++ b/ocaml/TaskServer.ml @@ -427,7 +427,7 @@ let get_tasks msg program_state rep_socket pair_socket = and success () = let rec build_list accu queue = function - | 0 -> queue, accu + | 0 -> queue, (List.rev accu) | n -> let new_queue, task_id, task = Queuing_system.pop_task ~client_id queue @@ -435,7 +435,7 @@ let get_tasks msg program_state rep_socket pair_socket = match (task_id, task) with | Some task_id, Some task -> build_list ( (Some task_id, task)::accu ) new_queue (n-1) - | _ -> queue, (None, "terminate")::accu + | _ -> build_list ( (None, "terminate")::accu ) queue 0 in let new_queue, result = diff --git a/plugins/Full_CI_ZMQ/pt2_stoch_routines.irp.f b/plugins/Full_CI_ZMQ/pt2_stoch_routines.irp.f index 52dc024c..d08fb3f0 100644 --- a/plugins/Full_CI_ZMQ/pt2_stoch_routines.irp.f +++ b/plugins/Full_CI_ZMQ/pt2_stoch_routines.irp.f @@ -188,25 +188,26 @@ subroutine pt2_collector(E, b, tbc, comb, Ncomb, computed, pt2_detail, sumabove, integer(ZMQ_PTR) :: zmq_socket_pull integer :: msg_size, rc, more - integer :: acc, i, j, robin, N, ntask + integer :: acc, i, j, robin, N, n_tasks double precision, allocatable :: val(:) integer(bit_kind), allocatable :: det(:,:,:) integer, allocatable :: task_id(:) - integer :: Nindex integer, allocatable :: index(:) double precision :: time0 double precision :: time, timeLast, Nabove_old double precision, external :: omp_get_wtime - integer :: tooth, firstTBDcomb, orgTBDcomb + integer :: tooth, firstTBDcomb, orgTBDcomb, n_tasks_max integer, allocatable :: parts_to_get(:) logical, allocatable :: actually_computed(:) double precision :: eqt character*(512) :: task Nabove_old = -1.d0 + n_tasks_max = N_det_generators/100+1 allocate(actually_computed(N_det_generators), parts_to_get(N_det_generators), & - pt2_mwen(N_states, N_det_generators) ) - pt2_mwen(1:N_states, 1:N_det_generators) =0.d0 + pt2_mwen(N_states, n_tasks_max) ) + + pt2_mwen(1:N_states, 1:n_tasks_max) = 0.d0 do i=1,N_det_generators actually_computed(i) = computed(i) enddo @@ -227,7 +228,7 @@ subroutine pt2_collector(E, b, tbc, comb, Ncomb, computed, pt2_detail, sumabove, zmq_to_qp_run_socket = new_zmq_to_qp_run_socket() zmq_socket_pull = new_zmq_pull_socket() - allocate(val(b%N), det(N_int, 2, b%N), task_id(N_det_generators), index(1)) + allocate(val(b%N), det(N_int, 2, b%N), task_id(n_tasks_max), index(n_tasks_max)) more = 1 call wall_time(time0) timeLast = time0 @@ -235,26 +236,28 @@ subroutine pt2_collector(E, b, tbc, comb, Ncomb, computed, pt2_detail, sumabove, call get_first_tooth(actually_computed, tooth) Nabove_old = Nabove(tooth) - pullLoop : do while (more == 1) + logical :: loop + loop = .True. + pullLoop : do while (loop) - call pull_pt2_results(zmq_socket_pull, Nindex, index, pt2_mwen, task_id, ntask) - do i=1,Nindex + call pull_pt2_results(zmq_socket_pull, index, pt2_mwen, task_id, n_tasks) + do i=1,n_tasks pt2_detail(1:N_states, index(i)) += pt2_mwen(1:N_states,i) parts_to_get(index(i)) -= 1 if(parts_to_get(index(i)) < 0) then - print *, i, index(i), parts_to_get(index(i)), Nindex + print *, i, index(i), parts_to_get(index(i)) print *, "PARTS ??" print *, parts_to_get stop "PARTS ??" end if if(parts_to_get(index(i)) == 0) actually_computed(index(i)) = .true. - end do + enddo - do i=1, ntask - if(task_id(i) == 0) then - print *, "Error in collector" - endif + do i=1, n_tasks call zmq_delete_task(zmq_to_qp_run_socket,zmq_socket_pull,task_id(i),more) + if (more /= 1) then + loop = .False. + endif end do time = omp_get_wtime() diff --git a/plugins/Full_CI_ZMQ/run_pt2_slave.irp.f b/plugins/Full_CI_ZMQ/run_pt2_slave.irp.f index 86ebcacf..2a743dad 100644 --- a/plugins/Full_CI_ZMQ/run_pt2_slave.irp.f +++ b/plugins/Full_CI_ZMQ/run_pt2_slave.irp.f @@ -8,8 +8,9 @@ subroutine run_pt2_slave(thread,iproc,energy) integer, intent(in) :: thread, iproc integer :: rc, i - integer :: worker_id, task_id(1), ctask, ltask - character*(512) :: task + integer :: worker_id, ctask, ltask + character*(512), allocatable :: task(:) + integer, allocatable :: task_id(:) integer(ZMQ_PTR),external :: new_zmq_to_qp_run_socket integer(ZMQ_PTR) :: zmq_to_qp_run_socket @@ -20,102 +21,78 @@ subroutine run_pt2_slave(thread,iproc,energy) type(selection_buffer) :: buf logical :: done - double precision :: pt2(N_states) - double precision,allocatable :: pt2_detail(:,:) - integer :: index - integer :: Nindex - logical :: buffer_ready + double precision,allocatable :: pt2(:,:) + integer :: n_tasks, k, n_tasks_max + integer, allocatable :: i_generator(:), subset(:) - allocate(pt2_detail(N_states, N_det_generators)) + n_tasks_max = N_det_generators/100+1 + allocate(task_id(n_tasks_max), task(n_tasks_max)) + allocate(pt2(N_states,n_tasks_max), i_generator(n_tasks_max), subset(n_tasks_max)) zmq_to_qp_run_socket = new_zmq_to_qp_run_socket() zmq_socket_push = new_zmq_push_socket(thread) call connect_to_taskserver(zmq_to_qp_run_socket,worker_id,thread) if(worker_id == -1) then - print *, 'WORKER -1' call end_zmq_to_qp_run_socket(zmq_to_qp_run_socket) call end_zmq_push_socket(zmq_socket_push,thread) return end if buf%N = 0 - ctask = 1 - Nindex=1 - pt2 = 0d0 - pt2_detail = 0d0 - buffer_ready = .False. - do - call get_task_from_taskserver(zmq_to_qp_run_socket,worker_id, task_id(ctask), task) + n_tasks = 1 + call create_selection_buffer(1, 2, buf) - done = task_id(ctask) == 0 - if (done) then - ctask = ctask - 1 - else - integer :: i_generator, i_i_generator, subset - read (task,*) subset, index - - if(buf%N == 0) then - ! Only first time - call create_selection_buffer(1, 2, buf) - buffer_ready = .True. - end if - do i_i_generator=1, Nindex - i_generator = index - call select_connected(i_generator,energy,pt2_detail(1, i_i_generator),buf,subset) - pt2(:) += pt2_detail(:, i_generator) - enddo - endif + done = .False. + do while (.not.done) - if(done .or. (ctask == size(task_id)) ) then - if(buf%N == 0 .and. ctask > 0) stop 'uninitialized selection_buffer' - do i=1, ctask - call task_done_to_taskserver(zmq_to_qp_run_socket,worker_id,task_id(i)) - end do - if(ctask > 0) then - call push_pt2_results(zmq_socket_push, Nindex, index, pt2_detail, task_id(1), ctask) - pt2 = 0d0 - pt2_detail(:,:Nindex) = 0d0 + n_tasks = min(n_tasks+1,n_tasks_max) + + call get_tasks_from_taskserver(zmq_to_qp_run_socket,worker_id, task_id, task, n_tasks) + done = task_id(n_tasks) == 0 + if (done) n_tasks = n_tasks-1 + + do k=1,n_tasks + read (task(k),*) subset(k), i_generator(k) + enddo + + do k=1,n_tasks + pt2(:,k) = 0.d0 buf%cur = 0 - end if - ctask = 0 - end if - - if(done) exit - ctask = ctask + 1 + call select_connected(i_generator(k),energy,pt2(1,k),buf,subset(k)) + enddo + do k=1,n_tasks + call task_done_to_taskserver(zmq_to_qp_run_socket,worker_id,task_id(k)) + enddo + call push_pt2_results(zmq_socket_push, i_generator, pt2, task_id, n_tasks) end do call disconnect_from_taskserver(zmq_to_qp_run_socket,zmq_socket_push,worker_id) call end_zmq_to_qp_run_socket(zmq_to_qp_run_socket) call end_zmq_push_socket(zmq_socket_push,thread) - if (buffer_ready) then - call delete_selection_buffer(buf) - endif + call delete_selection_buffer(buf) end subroutine -subroutine push_pt2_results(zmq_socket_push, N, index, pt2_detail, task_id, ntask) +subroutine push_pt2_results(zmq_socket_push, index, pt2, task_id, n_tasks) use f77_zmq use selection_types implicit none integer(ZMQ_PTR), intent(in) :: zmq_socket_push - double precision, intent(in) :: pt2_detail(N_states, N_det_generators) - integer, intent(in) :: ntask, N, index, task_id(*) + double precision, intent(in) :: pt2(N_states,n_tasks) + integer, intent(in) :: n_tasks, index(n_tasks), task_id(n_tasks) integer :: rc - - rc = f77_zmq_send( zmq_socket_push, N, 4, ZMQ_SNDMORE) + rc = f77_zmq_send( zmq_socket_push, n_tasks, 4, ZMQ_SNDMORE) if(rc /= 4) stop 'push' - rc = f77_zmq_send( zmq_socket_push, index, 4, ZMQ_SNDMORE) - if(rc /= 4*N) stop 'push' + + rc = f77_zmq_send( zmq_socket_push, index, 4*n_tasks, ZMQ_SNDMORE) + if(rc /= 4*n_tasks) stop 'push' - rc = f77_zmq_send( zmq_socket_push, pt2_detail, 8*N_states*N, ZMQ_SNDMORE) - if(rc /= 8*N_states*N) stop 'push' + rc = f77_zmq_send( zmq_socket_push, pt2, 8*N_states*n_tasks, ZMQ_SNDMORE) + if(rc /= 8*N_states*n_tasks) stop 'push' - rc = f77_zmq_send( zmq_socket_push, ntask, 4, ZMQ_SNDMORE) - if(rc /= 4) stop 'push' - - rc = f77_zmq_send( zmq_socket_push, task_id, ntask*4, 0) - if(rc /= 4*ntask) stop 'push' + rc = f77_zmq_send( zmq_socket_push, task_id, n_tasks*4, 0) + if(rc /= 4*n_tasks) stop 'push' ! Activate is zmq_socket_push is a REQ IRP_IF ZMQ_PUSH @@ -131,30 +108,27 @@ IRP_ENDIF end subroutine -subroutine pull_pt2_results(zmq_socket_pull, N, index, pt2_detail, task_id, ntask) +subroutine pull_pt2_results(zmq_socket_pull, index, pt2, task_id, n_tasks) use f77_zmq use selection_types implicit none integer(ZMQ_PTR), intent(in) :: zmq_socket_pull - double precision, intent(inout) :: pt2_detail(N_states, N_det_generators) - integer, intent(out) :: index - integer, intent(out) :: N, ntask, task_id(*) + double precision, intent(inout) :: pt2(N_states,*) + integer, intent(out) :: index(*) + integer, intent(out) :: n_tasks, task_id(*) integer :: rc, rn, i - rc = f77_zmq_recv( zmq_socket_pull, N, 4, 0) + rc = f77_zmq_recv( zmq_socket_pull, n_tasks, 4, 0) if(rc /= 4) stop 'pull' - rc = f77_zmq_recv( zmq_socket_pull, index, 4, 0) - if(rc /= 4*N) stop 'pull' + rc = f77_zmq_recv( zmq_socket_pull, index, 4*n_tasks, 0) + if(rc /= 4*n_tasks) stop 'pull' - rc = f77_zmq_recv( zmq_socket_pull, pt2_detail, N_states*8*N, 0) - if(rc /= 8*N_states*N) stop 'pull' + rc = f77_zmq_recv( zmq_socket_pull, pt2, N_states*8*n_tasks, 0) + if(rc /= 8*N_states*n_tasks) stop 'pull' - rc = f77_zmq_recv( zmq_socket_pull, ntask, 4, 0) - if(rc /= 4) stop 'pull' - - rc = f77_zmq_recv( zmq_socket_pull, task_id, ntask*4, 0) - if(rc /= 4*ntask) stop 'pull' + rc = f77_zmq_recv( zmq_socket_pull, task_id, n_tasks*4, 0) + if(rc /= 4*n_tasks) stop 'pull' ! Activate is zmq_socket_pull is a REP IRP_IF ZMQ_PUSH diff --git a/plugins/Full_CI_ZMQ/selection.irp.f b/plugins/Full_CI_ZMQ/selection.irp.f index af898941..cf0cfe18 100644 --- a/plugins/Full_CI_ZMQ/selection.irp.f +++ b/plugins/Full_CI_ZMQ/selection.irp.f @@ -6,6 +6,7 @@ BEGIN_PROVIDER [ integer, fragment_count ] ! Number of fragments for the deterministic part END_DOC fragment_count = (elec_alpha_num-n_core_orb)**2 +! fragment_count = mo_tot_num*mo_tot_num END_PROVIDER diff --git a/src/ZMQ/utils.irp.f b/src/ZMQ/utils.irp.f index d9a619f3..06cff585 100644 --- a/src/ZMQ/utils.irp.f +++ b/src/ZMQ/utils.irp.f @@ -679,6 +679,10 @@ subroutine disconnect_from_taskserver(zmq_to_qp_run_socket, & sze = len(trim(message)) rc = f77_zmq_send(zmq_to_qp_run_socket, trim(message), sze, 0) + if (rc == -1) then + return + endif + if (rc /= sze) then print *, rc, sze print *, irp_here, 'f77_zmq_send(zmq_to_qp_run_socket, trim(message), sze, 0)' @@ -849,42 +853,41 @@ subroutine get_task_from_taskserver(zmq_to_qp_run_socket,worker_id,task_id,task) character*(64) :: reply integer :: rc, sze - call get_tasks_from_taskserver(zmq_to_qp_run_socket,worker_id,task_id,task,1) -! write(message,*) 'get_task '//trim(zmq_state), worker_id -! -! sze = len(trim(message)) -! rc = f77_zmq_send(zmq_to_qp_run_socket, message, sze, 0) -! if (rc /= sze) then -! print *, irp_here, ':f77_zmq_send(zmq_to_qp_run_socket, trim(message), sze, 0)' -! stop 'error' -! endif -! -! message = repeat(' ',512) -! rc = f77_zmq_recv(zmq_to_qp_run_socket, message, 1024, 0) -! rc = min(1024,rc) -! read(message(1:rc),*) reply -! if (trim(reply) == 'get_task_reply') then -! read(message(1:rc),*) reply, task_id -! rc = 15 -! do while (message(rc:rc) == ' ') -! rc += 1 -! enddo -! do while (message(rc:rc) /= ' ') -! rc += 1 -! enddo -! rc += 1 -! task = message(rc:) -! else if (trim(reply) == 'terminate') then -! task_id = 0 -! task = 'terminate' -! else if (trim(message) == 'error No job is running') then -! task_id = 0 -! task = 'terminate' -! else -! print *, 'Unable to get the next task' -! print *, trim(message) -! stop -1 -! endif + write(message,*) 'get_task '//trim(zmq_state), worker_id + + sze = len(trim(message)) + rc = f77_zmq_send(zmq_to_qp_run_socket, message, sze, 0) + if (rc /= sze) then + print *, irp_here, ':f77_zmq_send(zmq_to_qp_run_socket, trim(message), sze, 0)' + stop 'error' + endif + + message = repeat(' ',512) + rc = f77_zmq_recv(zmq_to_qp_run_socket, message, 1024, 0) + rc = min(1024,rc) + read(message(1:rc),*) reply + if (trim(reply) == 'get_task_reply') then + read(message(1:rc),*) reply, task_id + rc = 15 + do while (message(rc:rc) == ' ') + rc += 1 + enddo + do while (message(rc:rc) /= ' ') + rc += 1 + enddo + rc += 1 + task = message(rc:) + else if (trim(reply) == 'terminate') then + task_id = 0 + task = 'terminate' + else if (trim(message) == 'error No job is running') then + task_id = 0 + task = 'terminate' + else + print *, 'Unable to get the next task' + print *, trim(message) + stop -1 + endif end @@ -897,7 +900,7 @@ subroutine get_tasks_from_taskserver(zmq_to_qp_run_socket,worker_id,task_id,task END_DOC integer(ZMQ_PTR), intent(in) :: zmq_to_qp_run_socket integer, intent(in) :: worker_id - integer, intent(in) :: n_tasks + integer, intent(inout) :: n_tasks integer, intent(out) :: task_id(n_tasks) character*(512), intent(out) :: task(n_tasks) @@ -931,12 +934,18 @@ subroutine get_tasks_from_taskserver(zmq_to_qp_run_socket,worker_id,task_id,task print *, ':'//trim(message)//':' stop -1 endif - + + task(:) = repeat(' ',512) do i=1,n_tasks message = repeat(' ',512) rc = f77_zmq_recv(zmq_to_qp_run_socket, message, 1024, 0) rc = min(1024,rc) read(message(1:rc),*) task_id(i) + if (task_id(i) == 0) then + task(i) = 'terminate' + n_tasks = i + exit + endif rc = 1 do while (message(rc:rc) == ' ') rc += 1