10
0
mirror of https://github.com/LCPQ/quantum_package synced 2024-12-22 12:23:48 +01:00

Killed sort bottleneck in selection

This commit is contained in:
Anthony Scemama 2017-05-05 11:32:17 +02:00
parent 7f32fab829
commit 4a043229b7
4 changed files with 83 additions and 20 deletions

View File

@ -105,6 +105,7 @@ subroutine ZMQ_pt2(E, pt2,relative_error)
call pt2_slave_inproc(i)
endif
!$OMP END PARALLEL
call delete_selection_buffer(b)
call end_parallel_job(zmq_to_qp_run_socket, 'pt2')
else

View File

@ -58,13 +58,9 @@ subroutine run_selection_slave(thread,iproc,energy)
call task_done_to_taskserver(zmq_to_qp_run_socket,worker_id,task_id(i))
end do
if(ctask > 0) then
call sort_selection_buffer(buf)
call push_selection_results(zmq_socket_push, pt2, buf, task_id(1), ctask)
do i=1,buf%cur
call add_to_selection_buffer(buf2, buf%det(1,1,i), buf%val(i))
if (buf2%cur == buf2%N) then
call sort_selection_buffer(buf2)
endif
enddo
call merge_selection_buffers(buf,buf2)
buf%mini = buf2%mini
pt2 = 0d0
buf%cur = 0
@ -92,8 +88,6 @@ subroutine push_selection_results(zmq_socket_push, pt2, b, task_id, ntask)
integer, intent(in) :: ntask, task_id(*)
integer :: rc
call sort_selection_buffer(b)
rc = f77_zmq_send( zmq_socket_push, b%cur, 4, ZMQ_SNDMORE)
if(rc /= 4) stop "push"
rc = f77_zmq_send( zmq_socket_push, pt2, 8*N_states, ZMQ_SNDMORE)

View File

@ -15,6 +15,18 @@ subroutine create_selection_buffer(N, siz, res)
res%cur = 0
end subroutine
subroutine delete_selection_buffer(b)
use selection_types
implicit none
type(selection_buffer), intent(inout) :: b
if (associated(b%det)) then
deallocate(b%det)
endif
if (associated(b%val)) then
deallocate(b%val)
endif
end
subroutine add_to_selection_buffer(b, det, val)
use selection_types
@ -35,6 +47,55 @@ subroutine add_to_selection_buffer(b, det, val)
end if
end subroutine
subroutine merge_selection_buffers(b1, b2)
use selection_types
implicit none
BEGIN_DOC
! Merges the selection buffers b1 and b2 into b2
END_DOC
type(selection_buffer), intent(in) :: b1
type(selection_buffer), intent(inout) :: b2
integer(bit_kind), pointer :: detmp(:,:,:)
double precision, pointer :: val(:)
integer :: i, i1, i2, k, nmwen
nmwen = min(b1%N, b1%cur+b2%cur)
allocate( val(size(b1%val)), detmp(N_int, 2, size(b1%det,3)) )
i1=1
i2=1
do i=1,nmwen
if ( (i1 > b1%cur).and.(i2 > b2%cur) ) then
exit
else if (i1 > b1%cur) then
val(i) = b2%val(i2)
detmp(1:N_int,1,i) = b2%det(1:N_int,1,i2)
detmp(1:N_int,2,i) = b2%det(1:N_int,2,i2)
i2=i2+1
else if (i2 > b2%cur) then
val(i) = b1%val(i1)
detmp(1:N_int,1,i) = b1%det(1:N_int,1,i1)
detmp(1:N_int,2,i) = b1%det(1:N_int,2,i1)
i1=i1+1
else
if (b1%val(i1) <= b2%val(i2)) then
val(i) = b1%val(i1)
detmp(1:N_int,1,i) = b1%det(1:N_int,1,i1)
detmp(1:N_int,2,i) = b1%det(1:N_int,2,i1)
i1=i1+1
else
val(i) = b2%val(i2)
detmp(1:N_int,1,i) = b2%det(1:N_int,1,i2)
detmp(1:N_int,2,i) = b2%det(1:N_int,2,i2)
i2=i2+1
endif
endif
enddo
deallocate(b2%det, b2%val)
b2%det => detmp
b2%val => val
b2%mini = min(b2%mini,b2%val(b2%N))
b2%cur = nmwen
end
subroutine sort_selection_buffer(b)
use selection_types
@ -56,10 +117,9 @@ subroutine sort_selection_buffer(b)
detmp(1:N_int,1,i) = b%det(1:N_int,1,iorder(i))
detmp(1:N_int,2,i) = b%det(1:N_int,2,iorder(i))
end do
deallocate(b%det)
deallocate(b%det,iorder)
b%det => detmp
b%mini = min(b%mini,b%val(b%N))
b%cur = nmwen
deallocate(iorder)
end subroutine

View File

@ -45,7 +45,7 @@ subroutine ZMQ_selection(N_in, pt2)
!$OMP PARALLEL DEFAULT(shared) SHARED(b, pt2) PRIVATE(i) NUM_THREADS(nproc+1)
i = omp_get_thread_num()
if (i==0) then
call selection_collector(b, pt2)
call selection_collector(b, N, pt2)
else
call selection_slave_inproc(i)
endif
@ -59,6 +59,7 @@ subroutine ZMQ_selection(N_in, pt2)
endif
call save_wavefunction
endif
call delete_selection_buffer(b)
end subroutine
@ -70,7 +71,7 @@ subroutine selection_slave_inproc(i)
call run_selection_slave(1,i,pt2_e0_denominator)
end
subroutine selection_collector(b, pt2)
subroutine selection_collector(b, N, pt2)
use f77_zmq
use selection_types
use bitmasks
@ -78,6 +79,7 @@ subroutine selection_collector(b, pt2)
type(selection_buffer), intent(inout) :: b
integer, intent(in) :: N
double precision, intent(out) :: pt2(N_states)
double precision :: pt2_mwen(N_states)
integer(ZMQ_PTR),external :: new_zmq_to_qp_run_socket
@ -87,25 +89,30 @@ subroutine selection_collector(b, pt2)
integer(ZMQ_PTR) :: zmq_socket_pull
integer :: msg_size, rc, more
integer :: acc, i, j, robin, N, ntask
double precision, allocatable :: val(:)
integer(bit_kind), allocatable :: det(:,:,:)
integer :: acc, i, j, robin, ntask
double precision, pointer :: val(:)
integer(bit_kind), pointer :: det(:,:,:)
integer, allocatable :: task_id(:)
integer :: done
real :: time, time0
type(selection_buffer) :: b2
zmq_to_qp_run_socket = new_zmq_to_qp_run_socket()
zmq_socket_pull = new_zmq_pull_socket()
allocate(val(b%N), det(N_int, 2, b%N), task_id(N_det_generators))
call create_selection_buffer(N, N*2, b2)
allocate(task_id(N_det_generators))
done = 0
more = 1
pt2(:) = 0d0
call CPU_TIME(time0)
do while (more == 1)
call pull_selection_results(zmq_socket_pull, pt2_mwen, val(1), det(1,1,1), N, task_id, ntask)
call pull_selection_results(zmq_socket_pull, pt2_mwen, b2%val(1), b2%det(1,1,1), b2%cur, task_id, ntask)
pt2 += pt2_mwen
do i=1, N
call add_to_selection_buffer(b, det(1,1,i), val(i))
end do
call merge_selection_buffers(b2,b)
! do i=1, N
! call add_to_selection_buffer(b, det(1,1,i), val(i))
! end do
do i=1, ntask
if(task_id(i) == 0) then
@ -119,6 +126,7 @@ subroutine selection_collector(b, pt2)
end do
call delete_selection_buffer(b2)
call end_zmq_to_qp_run_socket(zmq_to_qp_run_socket)
call end_zmq_pull_socket(zmq_socket_pull)
call sort_selection_buffer(b)