From 727c9a84cd04a1aaa631406ff04945570843a1b4 Mon Sep 17 00:00:00 2001 From: Yann Garniron Date: Tue, 1 May 2018 17:43:46 +0200 Subject: [PATCH] improved synchronization --- plugins/dress_zmq/dress_stoch_routines.irp.f | 4 +- plugins/dress_zmq/run_dress_slave.irp.f | 94 ++++++++++++-------- 2 files changed, 60 insertions(+), 38 deletions(-) diff --git a/plugins/dress_zmq/dress_stoch_routines.irp.f b/plugins/dress_zmq/dress_stoch_routines.irp.f index d26a4d8c..25bec079 100644 --- a/plugins/dress_zmq/dress_stoch_routines.irp.f +++ b/plugins/dress_zmq/dress_stoch_routines.irp.f @@ -285,7 +285,7 @@ subroutine dress_collector(zmq_socket_pull, E, relative_error, delta, delta_s2, call wall_time(time) print '(2X, F16.7, 2X, G16.3, 2X, F16.4, A20)', avg+E(istate)+E0, eqt, time-time0, '' - if ((dabs(eqt) < relative_error .and. cps_N(cur_cp) >= 30) .or. cur_cp == N_cp) then + if ((dabs(eqt) < relative_error .and. cps_N(cur_cp) >= 30) .or. cur_cp == N_cp-4) then ! Termination print *, "TERMINATE" if (zmq_abort(zmq_to_qp_run_socket) == -1) then @@ -294,7 +294,7 @@ subroutine dress_collector(zmq_socket_pull, E, relative_error, delta, delta_s2, print *, irp_here, ': Error in sending abort signal (2)' endif endif - !exit pullLoop + exit pullLoop endif end if end do pullLoop diff --git a/plugins/dress_zmq/run_dress_slave.irp.f b/plugins/dress_zmq/run_dress_slave.irp.f index 84d9af6c..c38b2c90 100644 --- a/plugins/dress_zmq/run_dress_slave.irp.f +++ b/plugins/dress_zmq/run_dress_slave.irp.f @@ -11,6 +11,7 @@ END_PROVIDER subroutine run_dress_slave(thread,iproce,energy) use f77_zmq + use omp_lib implicit none double precision, intent(in) :: energy(N_states_diag) @@ -32,7 +33,6 @@ subroutine run_dress_slave(thread,iproce,energy) integer :: ind double precision,allocatable :: delta_ij_loc(:,:,:) - double precision :: div(N_states) integer :: h,p,n,i_state logical :: ok @@ -41,7 +41,7 @@ subroutine run_dress_slave(thread,iproce,energy) integer(bit_kind), allocatable :: det_buf(:,:,:) integer :: N_buf(3) logical :: last - integer, external :: omp_get_thread_num + !integer, external :: omp_get_thread_num double precision, allocatable :: delta_det(:,:,:,:), cp(:,:,:,:) integer :: toothMwen logical :: fracted @@ -60,24 +60,22 @@ subroutine run_dress_slave(thread,iproce,energy) task(:) = CHAR(0) - zmq_to_qp_run_socket = new_zmq_to_qp_run_socket() - zmq_socket_push = new_zmq_push_socket(thread) - call connect_to_taskserver(zmq_to_qp_run_socket,worker_id,thread) - if(worker_id == -1) then - print *, "WORKER -1" - call end_zmq_to_qp_run_socket(zmq_to_qp_run_socket) - call end_zmq_push_socket(zmq_socket_push,thread) - return - end if - do i=1,N_states - div(i) = psi_coef(dressed_column_idx(i), i) - end do - + integer :: iproc, cur_cp, done_for(0:N_cp) integer, allocatable :: tasks(:) integer :: lastCp(Nproc) integer :: lastSent, lastSendable logical :: send + integer(kind=OMP_LOCK_KIND) :: lck_det(0:comb_teeth+1) + integer(kind=OMP_LOCK_KIND) :: lck_sto(0:N_cp+1) + + do i=0,N_cp+1 + call omp_init_lock(lck_sto(i)) + end do + do i=0,comb_teeth+1 + call omp_init_lock(lck_det(i)) + end do + lastCp = 0 lastSent = 0 send = .false. @@ -85,17 +83,30 @@ subroutine run_dress_slave(thread,iproce,energy) !$OMP PARALLEL DEFAULT(SHARED) & !$OMP PRIVATE(int_buf, double_buf, det_buf, delta_ij_loc, task, task_id) & - !$OMP PRIVATE(toothMwen, fracted, fac) & - !$OMP PRIVATE(send, i_generator, subset, iproc, N_buf) + !$OMP PRIVATE(lastSendable, toothMwen, fracted, fac) & + !$OMP PRIVATE(i, cur_cp, send, i_generator, subset, iproc, N_buf) & + !$OMP PRIVATE(zmq_to_qp_run_socket, zmq_socket_push, worker_id) + + zmq_to_qp_run_socket = new_zmq_to_qp_run_socket() + zmq_socket_push = new_zmq_push_socket(thread) + call connect_to_taskserver(zmq_to_qp_run_socket,worker_id,thread) + if(worker_id == -1) then + print *, "WORKER -1" + call end_zmq_to_qp_run_socket(zmq_to_qp_run_socket) + call end_zmq_push_socket(zmq_socket_push,thread) + stop "WORKER -1" + end if + + iproc = omp_get_thread_num()+1 allocate(int_buf(N_dress_int_buffer)) allocate(double_buf(N_dress_double_buffer)) allocate(det_buf(N_int, 2, N_dress_det_buffer)) allocate(delta_ij_loc(N_states,N_det,2)) do - !$OMP CRITICAL (SENDAGE) + !!1$OMP CRITICAL (SENDAGE) call get_task_from_taskserver(zmq_to_qp_run_socket,worker_id, task_id, task) - !$OMP END CRITICAL (SENDAGE) + !!1$OMP END CRITICAL (SENDAGE) task = task//" 0" if(task_id == 0) then print *, "DONEDONE" @@ -109,7 +120,7 @@ subroutine run_dress_slave(thread,iproce,energy) send = .false. lastSendable = N_cp*2 do i=1,Nproc - lastSendable = min(lastCp(iproc), lastSendable) + lastSendable = min(lastCp(i), lastSendable) end do lastSendable -= 1 if(lastSendable > lastSent) then @@ -119,7 +130,7 @@ subroutine run_dress_slave(thread,iproce,energy) !$OMP END CRITICAL if(send) then - !$OMP CRITICAL + !!1$OMP CRITICAL N_buf = (/0,1,0/) delta_ij_loc = 0d0 @@ -131,12 +142,12 @@ subroutine run_dress_slave(thread,iproce,energy) delta_ij_loc(:,:,:) = delta_ij_loc(:,:,:) / cps_N(cur_cp) do i=cp_first_tooth(cur_cp)-1,0,-1 - delta_ij_loc(:,:,:) = delta_ij_loc(:,:,:) +delta_det(:,:,i,:) + delta_ij_loc(:,:,:) = delta_ij_loc(:,:,:) +delta_det(:,:,i,:) end do - !$OMP END CRITICAL - !$OMP CRITICAL (SENDAGE) + !!1$OMP END CRITICAL + !!1$OMP CRITICAL (SENDAGE) call push_dress_results(zmq_socket_push, done_for(cur_cp), cur_cp, delta_ij_loc, int_buf, double_buf, det_buf, N_buf, -1) - !$OMP END CRITICAL (SENDAGE) + !!1$OMP END CRITICAL (SENDAGE) end if @@ -148,13 +159,14 @@ subroutine run_dress_slave(thread,iproce,energy) call alpha_callback(delta_ij_loc, i_generator, subset, iproc) call generator_done(i_generator, int_buf, double_buf, det_buf, N_buf, iproc) - !if(.false.) then - !$OMP CRITICAL + !!1$OMP CRITICAL do i=1,N_cp fac = cps(i_generator, i) * dress_weight_inv(i_generator) * comb_step if(fac == 0d0) cycle + call omp_set_lock(lck_sto(i)) cp(:,:,i,1) += (delta_ij_loc(:,:,1) * fac) cp(:,:,i,2) += (delta_ij_loc(:,:,2) * fac) + call omp_unset_lock(lck_sto(i)) end do @@ -162,31 +174,41 @@ subroutine run_dress_slave(thread,iproce,energy) fracted = (toothMwen /= 0) if(fracted) fracted = (i_generator == first_det_of_teeth(toothMwen)) if(fracted) then + call omp_set_lock(lck_det(toothMwen)) + call omp_set_lock(lck_det(toothMwen-1)) delta_det(:,:,toothMwen-1, 1) += delta_ij_loc(:,:,1) * (1d0-fractage(toothMwen)) delta_det(:,:,toothMwen-1, 2) += delta_ij_loc(:,:,2) * (1d0-fractage(toothMwen)) delta_det(:,:,toothMwen , 1) += delta_ij_loc(:,:,1) * (fractage(toothMwen)) delta_det(:,:,toothMwen , 2) += delta_ij_loc(:,:,2) * (fractage(toothMwen)) + call omp_unset_lock(lck_det(toothMwen)) + call omp_unset_lock(lck_det(toothMwen-1)) else + call omp_set_lock(lck_det(toothMwen)) delta_det(:,:,toothMwen , 1) += delta_ij_loc(:,:,1) delta_det(:,:,toothMwen , 2) += delta_ij_loc(:,:,2) - end if + call omp_unset_lock(lck_det(toothMwen)) + end if + !!!&$OMP END CRITICAL - - !$OMP END CRITICAL - !end if - - !$OMP CRITICAL (SENDAGE) + !!1$OMP CRITICAL (SENDAGE) call push_dress_results(zmq_socket_push, i_generator, -1, delta_ij_loc, int_buf, double_buf, det_buf, N_buf, task_id) call task_done_to_taskserver(zmq_to_qp_run_socket,worker_id,task_id) - !$OMP END CRITICAL (SENDAGE) + !!1$OMP END CRITICAL (SENDAGE) lastCp(iproc) = done_cp_at_det(i_generator) end do - !$OMP END PARALLEL call sleep(10) call disconnect_from_taskserver(zmq_to_qp_run_socket,zmq_socket_push,worker_id) call end_zmq_to_qp_run_socket(zmq_to_qp_run_socket) call end_zmq_push_socket(zmq_socket_push,thread) + !$OMP END PARALLEL + + do i=0,N_cp+1 + call omp_destroy_lock(lck_sto(i)) + end do + do i=0,comb_teeth+1 + call omp_destroy_lock(lck_det(i)) + end do end subroutine @@ -233,7 +255,7 @@ subroutine push_dress_results(zmq_socket_push, ind, cur_cp, delta_loc, int_buf, if(rc /= 8*N_states) stop "push" N_buf = N_bufi - N_buf = (/0,1,0/) + !N_buf = (/0,1,0/) rc = f77_zmq_send( zmq_socket_push, N_buf, 4*3, ZMQ_SNDMORE) if(rc /= 4*3) stop "push5"