10
0
mirror of https://github.com/LCPQ/quantum_package synced 2024-12-22 20:35:19 +01:00

Killed sort bottleneck in selection

This commit is contained in:
Anthony Scemama 2017-05-05 11:32:17 +02:00
parent 7f32fab829
commit 4a043229b7
4 changed files with 83 additions and 20 deletions

View File

@ -105,6 +105,7 @@ subroutine ZMQ_pt2(E, pt2,relative_error)
call pt2_slave_inproc(i) call pt2_slave_inproc(i)
endif endif
!$OMP END PARALLEL !$OMP END PARALLEL
call delete_selection_buffer(b)
call end_parallel_job(zmq_to_qp_run_socket, 'pt2') call end_parallel_job(zmq_to_qp_run_socket, 'pt2')
else else

View File

@ -58,13 +58,9 @@ subroutine run_selection_slave(thread,iproc,energy)
call task_done_to_taskserver(zmq_to_qp_run_socket,worker_id,task_id(i)) call task_done_to_taskserver(zmq_to_qp_run_socket,worker_id,task_id(i))
end do end do
if(ctask > 0) then if(ctask > 0) then
call sort_selection_buffer(buf)
call push_selection_results(zmq_socket_push, pt2, buf, task_id(1), ctask) call push_selection_results(zmq_socket_push, pt2, buf, task_id(1), ctask)
do i=1,buf%cur call merge_selection_buffers(buf,buf2)
call add_to_selection_buffer(buf2, buf%det(1,1,i), buf%val(i))
if (buf2%cur == buf2%N) then
call sort_selection_buffer(buf2)
endif
enddo
buf%mini = buf2%mini buf%mini = buf2%mini
pt2 = 0d0 pt2 = 0d0
buf%cur = 0 buf%cur = 0
@ -92,8 +88,6 @@ subroutine push_selection_results(zmq_socket_push, pt2, b, task_id, ntask)
integer, intent(in) :: ntask, task_id(*) integer, intent(in) :: ntask, task_id(*)
integer :: rc integer :: rc
call sort_selection_buffer(b)
rc = f77_zmq_send( zmq_socket_push, b%cur, 4, ZMQ_SNDMORE) rc = f77_zmq_send( zmq_socket_push, b%cur, 4, ZMQ_SNDMORE)
if(rc /= 4) stop "push" if(rc /= 4) stop "push"
rc = f77_zmq_send( zmq_socket_push, pt2, 8*N_states, ZMQ_SNDMORE) rc = f77_zmq_send( zmq_socket_push, pt2, 8*N_states, ZMQ_SNDMORE)

View File

@ -15,6 +15,18 @@ subroutine create_selection_buffer(N, siz, res)
res%cur = 0 res%cur = 0
end subroutine end subroutine
subroutine delete_selection_buffer(b)
use selection_types
implicit none
type(selection_buffer), intent(inout) :: b
if (associated(b%det)) then
deallocate(b%det)
endif
if (associated(b%val)) then
deallocate(b%val)
endif
end
subroutine add_to_selection_buffer(b, det, val) subroutine add_to_selection_buffer(b, det, val)
use selection_types use selection_types
@ -35,6 +47,55 @@ subroutine add_to_selection_buffer(b, det, val)
end if end if
end subroutine end subroutine
subroutine merge_selection_buffers(b1, b2)
use selection_types
implicit none
BEGIN_DOC
! Merges the selection buffers b1 and b2 into b2
END_DOC
type(selection_buffer), intent(in) :: b1
type(selection_buffer), intent(inout) :: b2
integer(bit_kind), pointer :: detmp(:,:,:)
double precision, pointer :: val(:)
integer :: i, i1, i2, k, nmwen
nmwen = min(b1%N, b1%cur+b2%cur)
allocate( val(size(b1%val)), detmp(N_int, 2, size(b1%det,3)) )
i1=1
i2=1
do i=1,nmwen
if ( (i1 > b1%cur).and.(i2 > b2%cur) ) then
exit
else if (i1 > b1%cur) then
val(i) = b2%val(i2)
detmp(1:N_int,1,i) = b2%det(1:N_int,1,i2)
detmp(1:N_int,2,i) = b2%det(1:N_int,2,i2)
i2=i2+1
else if (i2 > b2%cur) then
val(i) = b1%val(i1)
detmp(1:N_int,1,i) = b1%det(1:N_int,1,i1)
detmp(1:N_int,2,i) = b1%det(1:N_int,2,i1)
i1=i1+1
else
if (b1%val(i1) <= b2%val(i2)) then
val(i) = b1%val(i1)
detmp(1:N_int,1,i) = b1%det(1:N_int,1,i1)
detmp(1:N_int,2,i) = b1%det(1:N_int,2,i1)
i1=i1+1
else
val(i) = b2%val(i2)
detmp(1:N_int,1,i) = b2%det(1:N_int,1,i2)
detmp(1:N_int,2,i) = b2%det(1:N_int,2,i2)
i2=i2+1
endif
endif
enddo
deallocate(b2%det, b2%val)
b2%det => detmp
b2%val => val
b2%mini = min(b2%mini,b2%val(b2%N))
b2%cur = nmwen
end
subroutine sort_selection_buffer(b) subroutine sort_selection_buffer(b)
use selection_types use selection_types
@ -56,10 +117,9 @@ subroutine sort_selection_buffer(b)
detmp(1:N_int,1,i) = b%det(1:N_int,1,iorder(i)) detmp(1:N_int,1,i) = b%det(1:N_int,1,iorder(i))
detmp(1:N_int,2,i) = b%det(1:N_int,2,iorder(i)) detmp(1:N_int,2,i) = b%det(1:N_int,2,iorder(i))
end do end do
deallocate(b%det) deallocate(b%det,iorder)
b%det => detmp b%det => detmp
b%mini = min(b%mini,b%val(b%N)) b%mini = min(b%mini,b%val(b%N))
b%cur = nmwen b%cur = nmwen
deallocate(iorder)
end subroutine end subroutine

View File

@ -45,7 +45,7 @@ subroutine ZMQ_selection(N_in, pt2)
!$OMP PARALLEL DEFAULT(shared) SHARED(b, pt2) PRIVATE(i) NUM_THREADS(nproc+1) !$OMP PARALLEL DEFAULT(shared) SHARED(b, pt2) PRIVATE(i) NUM_THREADS(nproc+1)
i = omp_get_thread_num() i = omp_get_thread_num()
if (i==0) then if (i==0) then
call selection_collector(b, pt2) call selection_collector(b, N, pt2)
else else
call selection_slave_inproc(i) call selection_slave_inproc(i)
endif endif
@ -59,6 +59,7 @@ subroutine ZMQ_selection(N_in, pt2)
endif endif
call save_wavefunction call save_wavefunction
endif endif
call delete_selection_buffer(b)
end subroutine end subroutine
@ -70,7 +71,7 @@ subroutine selection_slave_inproc(i)
call run_selection_slave(1,i,pt2_e0_denominator) call run_selection_slave(1,i,pt2_e0_denominator)
end end
subroutine selection_collector(b, pt2) subroutine selection_collector(b, N, pt2)
use f77_zmq use f77_zmq
use selection_types use selection_types
use bitmasks use bitmasks
@ -78,6 +79,7 @@ subroutine selection_collector(b, pt2)
type(selection_buffer), intent(inout) :: b type(selection_buffer), intent(inout) :: b
integer, intent(in) :: N
double precision, intent(out) :: pt2(N_states) double precision, intent(out) :: pt2(N_states)
double precision :: pt2_mwen(N_states) double precision :: pt2_mwen(N_states)
integer(ZMQ_PTR),external :: new_zmq_to_qp_run_socket integer(ZMQ_PTR),external :: new_zmq_to_qp_run_socket
@ -87,25 +89,30 @@ subroutine selection_collector(b, pt2)
integer(ZMQ_PTR) :: zmq_socket_pull integer(ZMQ_PTR) :: zmq_socket_pull
integer :: msg_size, rc, more integer :: msg_size, rc, more
integer :: acc, i, j, robin, N, ntask integer :: acc, i, j, robin, ntask
double precision, allocatable :: val(:) double precision, pointer :: val(:)
integer(bit_kind), allocatable :: det(:,:,:) integer(bit_kind), pointer :: det(:,:,:)
integer, allocatable :: task_id(:) integer, allocatable :: task_id(:)
integer :: done integer :: done
real :: time, time0 real :: time, time0
type(selection_buffer) :: b2
zmq_to_qp_run_socket = new_zmq_to_qp_run_socket() zmq_to_qp_run_socket = new_zmq_to_qp_run_socket()
zmq_socket_pull = new_zmq_pull_socket() zmq_socket_pull = new_zmq_pull_socket()
allocate(val(b%N), det(N_int, 2, b%N), task_id(N_det_generators)) call create_selection_buffer(N, N*2, b2)
allocate(task_id(N_det_generators))
done = 0 done = 0
more = 1 more = 1
pt2(:) = 0d0 pt2(:) = 0d0
call CPU_TIME(time0) call CPU_TIME(time0)
do while (more == 1) do while (more == 1)
call pull_selection_results(zmq_socket_pull, pt2_mwen, val(1), det(1,1,1), N, task_id, ntask) call pull_selection_results(zmq_socket_pull, pt2_mwen, b2%val(1), b2%det(1,1,1), b2%cur, task_id, ntask)
pt2 += pt2_mwen pt2 += pt2_mwen
do i=1, N call merge_selection_buffers(b2,b)
call add_to_selection_buffer(b, det(1,1,i), val(i)) ! do i=1, N
end do ! call add_to_selection_buffer(b, det(1,1,i), val(i))
! end do
do i=1, ntask do i=1, ntask
if(task_id(i) == 0) then if(task_id(i) == 0) then
@ -119,6 +126,7 @@ subroutine selection_collector(b, pt2)
end do end do
call delete_selection_buffer(b2)
call end_zmq_to_qp_run_socket(zmq_to_qp_run_socket) call end_zmq_to_qp_run_socket(zmq_to_qp_run_socket)
call end_zmq_pull_socket(zmq_socket_pull) call end_zmq_pull_socket(zmq_socket_pull)
call sort_selection_buffer(b) call sort_selection_buffer(b)