From 4a043229b7aad9f48a07bdc8275d484b97a1fec3 Mon Sep 17 00:00:00 2001 From: Anthony Scemama Date: Fri, 5 May 2017 11:32:17 +0200 Subject: [PATCH] Killed sort bottleneck in selection --- plugins/Full_CI_ZMQ/pt2_stoch_routines.irp.f | 1 + plugins/Full_CI_ZMQ/run_selection_slave.irp.f | 10 +-- plugins/Full_CI_ZMQ/selection_buffer.irp.f | 64 ++++++++++++++++++- plugins/Full_CI_ZMQ/zmq_selection.irp.f | 28 +++++--- 4 files changed, 83 insertions(+), 20 deletions(-) diff --git a/plugins/Full_CI_ZMQ/pt2_stoch_routines.irp.f b/plugins/Full_CI_ZMQ/pt2_stoch_routines.irp.f index eb64fc2f..d3791832 100644 --- a/plugins/Full_CI_ZMQ/pt2_stoch_routines.irp.f +++ b/plugins/Full_CI_ZMQ/pt2_stoch_routines.irp.f @@ -105,6 +105,7 @@ subroutine ZMQ_pt2(E, pt2,relative_error) call pt2_slave_inproc(i) endif !$OMP END PARALLEL + call delete_selection_buffer(b) call end_parallel_job(zmq_to_qp_run_socket, 'pt2') else diff --git a/plugins/Full_CI_ZMQ/run_selection_slave.irp.f b/plugins/Full_CI_ZMQ/run_selection_slave.irp.f index bfc099e2..82c14cc6 100644 --- a/plugins/Full_CI_ZMQ/run_selection_slave.irp.f +++ b/plugins/Full_CI_ZMQ/run_selection_slave.irp.f @@ -58,13 +58,9 @@ subroutine run_selection_slave(thread,iproc,energy) call task_done_to_taskserver(zmq_to_qp_run_socket,worker_id,task_id(i)) end do if(ctask > 0) then + call sort_selection_buffer(buf) call push_selection_results(zmq_socket_push, pt2, buf, task_id(1), ctask) - do i=1,buf%cur - call add_to_selection_buffer(buf2, buf%det(1,1,i), buf%val(i)) - if (buf2%cur == buf2%N) then - call sort_selection_buffer(buf2) - endif - enddo + call merge_selection_buffers(buf,buf2) buf%mini = buf2%mini pt2 = 0d0 buf%cur = 0 @@ -92,8 +88,6 @@ subroutine push_selection_results(zmq_socket_push, pt2, b, task_id, ntask) integer, intent(in) :: ntask, task_id(*) integer :: rc - call sort_selection_buffer(b) - rc = f77_zmq_send( zmq_socket_push, b%cur, 4, ZMQ_SNDMORE) if(rc /= 4) stop "push" rc = f77_zmq_send( zmq_socket_push, pt2, 8*N_states, ZMQ_SNDMORE) diff --git a/plugins/Full_CI_ZMQ/selection_buffer.irp.f b/plugins/Full_CI_ZMQ/selection_buffer.irp.f index d0bc05dd..6b354cd4 100644 --- a/plugins/Full_CI_ZMQ/selection_buffer.irp.f +++ b/plugins/Full_CI_ZMQ/selection_buffer.irp.f @@ -15,6 +15,18 @@ subroutine create_selection_buffer(N, siz, res) res%cur = 0 end subroutine +subroutine delete_selection_buffer(b) + use selection_types + implicit none + type(selection_buffer), intent(inout) :: b + if (associated(b%det)) then + deallocate(b%det) + endif + if (associated(b%val)) then + deallocate(b%val) + endif +end + subroutine add_to_selection_buffer(b, det, val) use selection_types @@ -35,6 +47,55 @@ subroutine add_to_selection_buffer(b, det, val) end if end subroutine +subroutine merge_selection_buffers(b1, b2) + use selection_types + implicit none + BEGIN_DOC +! Merges the selection buffers b1 and b2 into b2 + END_DOC + type(selection_buffer), intent(in) :: b1 + type(selection_buffer), intent(inout) :: b2 + integer(bit_kind), pointer :: detmp(:,:,:) + double precision, pointer :: val(:) + integer :: i, i1, i2, k, nmwen + nmwen = min(b1%N, b1%cur+b2%cur) + allocate( val(size(b1%val)), detmp(N_int, 2, size(b1%det,3)) ) + i1=1 + i2=1 + do i=1,nmwen + if ( (i1 > b1%cur).and.(i2 > b2%cur) ) then + exit + else if (i1 > b1%cur) then + val(i) = b2%val(i2) + detmp(1:N_int,1,i) = b2%det(1:N_int,1,i2) + detmp(1:N_int,2,i) = b2%det(1:N_int,2,i2) + i2=i2+1 + else if (i2 > b2%cur) then + val(i) = b1%val(i1) + detmp(1:N_int,1,i) = b1%det(1:N_int,1,i1) + detmp(1:N_int,2,i) = b1%det(1:N_int,2,i1) + i1=i1+1 + else + if (b1%val(i1) <= b2%val(i2)) then + val(i) = b1%val(i1) + detmp(1:N_int,1,i) = b1%det(1:N_int,1,i1) + detmp(1:N_int,2,i) = b1%det(1:N_int,2,i1) + i1=i1+1 + else + val(i) = b2%val(i2) + detmp(1:N_int,1,i) = b2%det(1:N_int,1,i2) + detmp(1:N_int,2,i) = b2%det(1:N_int,2,i2) + i2=i2+1 + endif + endif + enddo + deallocate(b2%det, b2%val) + b2%det => detmp + b2%val => val + b2%mini = min(b2%mini,b2%val(b2%N)) + b2%cur = nmwen +end + subroutine sort_selection_buffer(b) use selection_types @@ -56,10 +117,9 @@ subroutine sort_selection_buffer(b) detmp(1:N_int,1,i) = b%det(1:N_int,1,iorder(i)) detmp(1:N_int,2,i) = b%det(1:N_int,2,iorder(i)) end do - deallocate(b%det) + deallocate(b%det,iorder) b%det => detmp b%mini = min(b%mini,b%val(b%N)) b%cur = nmwen - deallocate(iorder) end subroutine diff --git a/plugins/Full_CI_ZMQ/zmq_selection.irp.f b/plugins/Full_CI_ZMQ/zmq_selection.irp.f index 7ffb4a44..6b325828 100644 --- a/plugins/Full_CI_ZMQ/zmq_selection.irp.f +++ b/plugins/Full_CI_ZMQ/zmq_selection.irp.f @@ -45,7 +45,7 @@ subroutine ZMQ_selection(N_in, pt2) !$OMP PARALLEL DEFAULT(shared) SHARED(b, pt2) PRIVATE(i) NUM_THREADS(nproc+1) i = omp_get_thread_num() if (i==0) then - call selection_collector(b, pt2) + call selection_collector(b, N, pt2) else call selection_slave_inproc(i) endif @@ -59,6 +59,7 @@ subroutine ZMQ_selection(N_in, pt2) endif call save_wavefunction endif + call delete_selection_buffer(b) end subroutine @@ -70,7 +71,7 @@ subroutine selection_slave_inproc(i) call run_selection_slave(1,i,pt2_e0_denominator) end -subroutine selection_collector(b, pt2) +subroutine selection_collector(b, N, pt2) use f77_zmq use selection_types use bitmasks @@ -78,6 +79,7 @@ subroutine selection_collector(b, pt2) type(selection_buffer), intent(inout) :: b + integer, intent(in) :: N double precision, intent(out) :: pt2(N_states) double precision :: pt2_mwen(N_states) integer(ZMQ_PTR),external :: new_zmq_to_qp_run_socket @@ -87,25 +89,30 @@ subroutine selection_collector(b, pt2) integer(ZMQ_PTR) :: zmq_socket_pull integer :: msg_size, rc, more - integer :: acc, i, j, robin, N, ntask - double precision, allocatable :: val(:) - integer(bit_kind), allocatable :: det(:,:,:) + integer :: acc, i, j, robin, ntask + double precision, pointer :: val(:) + integer(bit_kind), pointer :: det(:,:,:) integer, allocatable :: task_id(:) integer :: done real :: time, time0 + type(selection_buffer) :: b2 + zmq_to_qp_run_socket = new_zmq_to_qp_run_socket() zmq_socket_pull = new_zmq_pull_socket() - allocate(val(b%N), det(N_int, 2, b%N), task_id(N_det_generators)) + call create_selection_buffer(N, N*2, b2) + allocate(task_id(N_det_generators)) done = 0 more = 1 pt2(:) = 0d0 call CPU_TIME(time0) do while (more == 1) - call pull_selection_results(zmq_socket_pull, pt2_mwen, val(1), det(1,1,1), N, task_id, ntask) + call pull_selection_results(zmq_socket_pull, pt2_mwen, b2%val(1), b2%det(1,1,1), b2%cur, task_id, ntask) + pt2 += pt2_mwen - do i=1, N - call add_to_selection_buffer(b, det(1,1,i), val(i)) - end do + call merge_selection_buffers(b2,b) +! do i=1, N +! call add_to_selection_buffer(b, det(1,1,i), val(i)) +! end do do i=1, ntask if(task_id(i) == 0) then @@ -119,6 +126,7 @@ subroutine selection_collector(b, pt2) end do + call delete_selection_buffer(b2) call end_zmq_to_qp_run_socket(zmq_to_qp_run_socket) call end_zmq_pull_socket(zmq_socket_pull) call sort_selection_buffer(b)