From ad7398f9120e59270c939a02865d6311982f45b2 Mon Sep 17 00:00:00 2001 From: Anthony Scemama Date: Mon, 10 Sep 2018 15:15:19 +0200 Subject: [PATCH 1/2] Fixed Fragments --- plugins/Full_CI_ZMQ/pt2_stoch_routines.irp.f | 12 +++++++----- plugins/Full_CI_ZMQ/run_pt2_slave.irp.f | 18 +++++++++++------- plugins/Full_CI_ZMQ/selection.irp.f | 4 +++- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/plugins/Full_CI_ZMQ/pt2_stoch_routines.irp.f b/plugins/Full_CI_ZMQ/pt2_stoch_routines.irp.f index 3c8e797b..4f9138bc 100644 --- a/plugins/Full_CI_ZMQ/pt2_stoch_routines.irp.f +++ b/plugins/Full_CI_ZMQ/pt2_stoch_routines.irp.f @@ -6,7 +6,6 @@ BEGIN_PROVIDER [ integer, pt2_stoch_istate ] pt2_stoch_istate = 1 END_PROVIDER - BEGIN_PROVIDER [ integer, pt2_N_teeth ] &BEGIN_PROVIDER [ integer, pt2_minDetInFirstTeeth ] &BEGIN_PROVIDER [ integer, pt2_n_tasks_max ] @@ -14,11 +13,14 @@ END_PROVIDER implicit none logical, external :: testTeethBuilding integer :: i - pt2_F(:) = 1 - pt2_n_tasks_max = N_det_generators/100 + 1 + integer :: e + e = elec_num - n_core_orb * 2 + pt2_n_tasks_max = min(1+(e*(e-1))/2, int(dsqrt(dble(N_det_generators)))) do i=1,N_det_generators - if (maxval(dabs(psi_coef_sorted_gen(i,:))) > 0.005d0) then - pt2_F(i) = max(1,min( ((elec_alpha_num-n_core_orb)**2)/4, pt2_n_tasks_max)) + if (maxval(dabs(psi_coef_sorted_gen(i,1:N_states))) > 0.0001d0) then + pt2_F(i) = pt2_n_tasks_max + else + pt2_F(i) = 1 endif enddo diff --git a/plugins/Full_CI_ZMQ/run_pt2_slave.irp.f b/plugins/Full_CI_ZMQ/run_pt2_slave.irp.f index 6d8b6a8c..2be0bd88 100644 --- a/plugins/Full_CI_ZMQ/run_pt2_slave.irp.f +++ b/plugins/Full_CI_ZMQ/run_pt2_slave.irp.f @@ -43,10 +43,11 @@ subroutine run_pt2_slave(thread,iproc,energy) call create_selection_buffer(0, 0, buf) done = .False. + n_tasks = 1 do while (.not.done) - n_tasks = max(1,n_tasks) - n_tasks = min(n_tasks,pt2_n_tasks_max) +! n_tasks = max(1,n_tasks) +! n_tasks = min(pt2_n_tasks_max,n_tasks) integer, external :: get_tasks_from_taskserver if (get_tasks_from_taskserver(zmq_to_qp_run_socket,worker_id, task_id, task, n_tasks) == -1) then @@ -61,13 +62,17 @@ subroutine run_pt2_slave(thread,iproc,energy) enddo double precision :: time0, time1 - call wall_time(time0) +! call wall_time(time0) do k=1,n_tasks pt2(:,k) = 0.d0 buf%cur = 0 +!double precision :: time2 +!call wall_time(time2) call select_connected(i_generator(k),energy,pt2(1,k),buf,subset(k),pt2_F(i_generator(k))) +!call wall_time(time1) +!print *, i_generator(1), time1-time2, n_tasks, pt2_F(i_generator(1)) enddo - call wall_time(time1) +! call wall_time(time1) integer, external :: tasks_done_to_taskserver if (tasks_done_to_taskserver(zmq_to_qp_run_socket,worker_id,task_id,n_tasks) == -1) then @@ -75,9 +80,8 @@ subroutine run_pt2_slave(thread,iproc,energy) endif call push_pt2_results(zmq_socket_push, i_generator, pt2, task_id, n_tasks) - ! Try to adjust n_tasks around 1 second per job - n_tasks = min(n_tasks,int( 1.d0*dble(n_tasks) / (time1 - time0 + 1.d-9)))+1 -! n_tasks = n_tasks+1 + ! Try to adjust n_tasks around nproc seconds per job +! n_tasks = min(2*n_tasks,int( dble(n_tasks * nproc) / (time1 - time0 + 1.d0))) end do integer, external :: disconnect_from_taskserver diff --git a/plugins/Full_CI_ZMQ/selection.irp.f b/plugins/Full_CI_ZMQ/selection.irp.f index 047a0b26..81dea087 100644 --- a/plugins/Full_CI_ZMQ/selection.irp.f +++ b/plugins/Full_CI_ZMQ/selection.irp.f @@ -409,9 +409,11 @@ subroutine select_singles_and_doubles(i_generator,hole_mask,particle_mask,fock_d allocate(banned(mo_tot_num, mo_tot_num,2), bannedOrb(mo_tot_num, 2)) allocate (mat(N_states, mo_tot_num, mo_tot_num)) maskInd = -1 - integer :: nb_count + integer :: nb_count, maskInd_save + logical :: found do s1=1,2 do i1=N_holes(s1),1,-1 ! Generate low excitations first + h1 = hole_list(i1,s1) call apply_hole(psi_det_generators(1,1,i_generator), s1,h1, pmask, ok, N_int) From 06ffc784eb5e9cd8703a575b90a6ea8341ef177c Mon Sep 17 00:00:00 2001 From: Anthony Scemama Date: Fri, 14 Sep 2018 15:45:38 +0200 Subject: [PATCH 2/2] Fixed sBk --- plugins/Full_CI_ZMQ/pt2_stoch_routines.irp.f | 20 ++- plugins/Full_CI_ZMQ/selection.irp.f | 27 ++- plugins/dress_zmq/dress_slave.irp.f | 6 - plugins/dress_zmq/dress_stoch_routines.irp.f | 89 ++++++--- plugins/dress_zmq/run_dress_slave.irp.f | 75 ++++---- plugins/shiftedbk/shifted_bk_slave.irp.f | 179 ++++++++++++++++++- src/ZMQ/put_get.irp.f | 82 ++++++++- 7 files changed, 388 insertions(+), 90 deletions(-) diff --git a/plugins/Full_CI_ZMQ/pt2_stoch_routines.irp.f b/plugins/Full_CI_ZMQ/pt2_stoch_routines.irp.f index 4f9138bc..3c944cbc 100644 --- a/plugins/Full_CI_ZMQ/pt2_stoch_routines.irp.f +++ b/plugins/Full_CI_ZMQ/pt2_stoch_routines.irp.f @@ -15,9 +15,9 @@ END_PROVIDER integer :: i integer :: e e = elec_num - n_core_orb * 2 - pt2_n_tasks_max = min(1+(e*(e-1))/2, int(dsqrt(dble(N_det_generators)))) + pt2_n_tasks_max = 1+min((e*(e-1))/2, int(dsqrt(dble(N_det_generators)))/10) do i=1,N_det_generators - if (maxval(dabs(psi_coef_sorted_gen(i,1:N_states))) > 0.0001d0) then + if (maxval(dabs(psi_coef_sorted_gen(i,1:N_states))) > 0.001d0) then pt2_F(i) = pt2_n_tasks_max else pt2_F(i) = 1 @@ -158,9 +158,19 @@ subroutine ZMQ_pt2(E, pt2,relative_error, error) endif + integer, external :: add_task_to_taskserver - character(len=64000) :: task + character(len=64000000) :: task integer :: j,k,ipos + + ipos=0 + do i=1,N_det_generators + if (pt2_F(i) > 1) then + ipos += 1 + endif + enddo + call write_int(6,ipos,'Number of fragmented tasks') + ipos=1 task = ' ' @@ -168,7 +178,7 @@ subroutine ZMQ_pt2(E, pt2,relative_error, error) do j=1,pt2_F(pt2_J(i)) write(task(ipos:ipos+20),'(I9,1X,I9,''|'')') j, pt2_J(i) ipos += 20 - if (ipos > 63980) then + if (ipos > len(task)-20) then if (add_task_to_taskserver(zmq_to_qp_run_socket,trim(task(1:ipos))) == -1) then stop 'Unable to add task to task server' endif @@ -328,7 +338,7 @@ subroutine pt2_collector(zmq_socket_pull, E, relative_error, pt2, error) print '(G10.3, 2X, F16.10, 2X, G16.3, 2X, F16.4, A20)', c, avg+E, eqt, time-time0, '' if( dabs(error(pt2_stoch_istate) / pt2(pt2_stoch_istate)) < relative_error) then if (zmq_abort(zmq_to_qp_run_socket) == -1) then - call sleep(1) + call sleep(10) if (zmq_abort(zmq_to_qp_run_socket) == -1) then print *, irp_here, ': Error in sending abort signal (2)' endif diff --git a/plugins/Full_CI_ZMQ/selection.irp.f b/plugins/Full_CI_ZMQ/selection.irp.f index 81dea087..588790cc 100644 --- a/plugins/Full_CI_ZMQ/selection.irp.f +++ b/plugins/Full_CI_ZMQ/selection.irp.f @@ -357,6 +357,7 @@ subroutine select_singles_and_doubles(i_generator,hole_mask,particle_mask,fock_d endif enddo enddo + deallocate(exc_degree) nmax=k-1 @@ -404,16 +405,36 @@ subroutine select_singles_and_doubles(i_generator,hole_mask,particle_mask,fock_d end do deallocate(indices) - allocate(minilist(N_int, 2, N_det_selectors), fullminilist(N_int, 2, N_det)) allocate(banned(mo_tot_num, mo_tot_num,2), bannedOrb(mo_tot_num, 2)) allocate (mat(N_states, mo_tot_num, mo_tot_num)) maskInd = -1 - integer :: nb_count, maskInd_save + + integer :: nb_count, maskInd_save, monoBdo_save logical :: found + do s1=1,2 do i1=N_holes(s1),1,-1 ! Generate low excitations first + monoBdo_save = monoBdo + maskInd_save = maskInd + do s2=s1,2 + ib = 1 + if(s1 == s2) ib = i1+1 + do i2=N_holes(s2),ib,-1 + maskInd += 1 + if(mod(maskInd, csubset) == (subset-1)) then + found = .True. + end if + enddo + if(s1 /= s2) monoBdo = .false. + enddo + + if (.not.found) cycle + monoBdo = monoBdo_save + maskInd = maskInd_save + + h1 = hole_list(i1,s1) call apply_hole(psi_det_generators(1,1,i_generator), s1,h1, pmask, ok, N_int) @@ -526,8 +547,6 @@ subroutine select_singles_and_doubles(i_generator,hole_mask,particle_mask,fock_d enddo end if end do - - do s2=s1,2 sp = s1 diff --git a/plugins/dress_zmq/dress_slave.irp.f b/plugins/dress_zmq/dress_slave.irp.f index 33238df2..5e575901 100644 --- a/plugins/dress_zmq/dress_slave.irp.f +++ b/plugins/dress_zmq/dress_slave.irp.f @@ -64,11 +64,7 @@ subroutine run_wf PROVIDE psi_bilinear_matrix_rows psi_det_sorted_gen_order psi_bilinear_matrix_order PROVIDE psi_bilinear_matrix_transp_rows_loc psi_bilinear_matrix_transp_columns PROVIDE psi_bilinear_matrix_transp_order - !!$OMP PARALLEL PRIVATE(i) - !i = omp_get_thread_num() -! call dress_slave_tcp(i+1, energy) call dress_slave_tcp(0, energy) - !!$OMP END PARALLEL endif end do end @@ -77,8 +73,6 @@ subroutine dress_slave_tcp(i,energy) implicit none double precision, intent(in) :: energy(N_states_diag) integer, intent(in) :: i - logical :: lstop - lstop = .False. call run_dress_slave(0,i,energy) end diff --git a/plugins/dress_zmq/dress_stoch_routines.irp.f b/plugins/dress_zmq/dress_stoch_routines.irp.f index 3b9d128d..c07c3110 100644 --- a/plugins/dress_zmq/dress_stoch_routines.irp.f +++ b/plugins/dress_zmq/dress_stoch_routines.irp.f @@ -13,20 +13,24 @@ END_PROVIDER implicit none logical, external :: testTeethBuilding integer :: i - pt2_F(:) = 1 - pt2_n_tasks_max = 20 -! do i=1,N_det_generators -! if (maxval(dabs(psi_coef_sorted_gen(i,:))) > 0.001d0) then -! pt2_F(i) = max(1,min( (elec_alpha_num-n_core_orb)**2, pt2_n_tasks_max)) -! endif -! enddo + integer :: e + e = elec_num - n_core_orb * 2 + pt2_n_tasks_max = 1 + min((e*(e-1))/2, int(dsqrt(dble(N_det_generators)))/10) + do i=1,N_det_generators + if (maxval(dabs(psi_coef_sorted_gen(i,1:N_states))) > 0.001d0) then + pt2_F(i) = pt2_n_tasks_max + else + pt2_F(i) = 1 + endif + enddo + if(N_det_generators < 1024) then pt2_minDetInFirstTeeth = 1 pt2_N_teeth = 1 else pt2_minDetInFirstTeeth = min(5, N_det_generators) - do pt2_N_teeth=100,2,-1 + do pt2_N_teeth=20,2,-1 if(testTeethBuilding(pt2_minDetInFirstTeeth, pt2_N_teeth)) exit end do end if @@ -219,7 +223,7 @@ subroutine ZMQ_dress(E, dress, delta_out, delta_s2_out, relative_error) implicit none - character(len=64000) :: task + character(len=64000000) :: task integer(ZMQ_PTR) :: zmq_to_qp_run_socket, zmq_socket_pull integer, external :: omp_get_thread_num double precision, intent(in) :: E(N_states), relative_error @@ -232,8 +236,8 @@ subroutine ZMQ_dress(E, dress, delta_out, delta_s2_out, relative_error) integer :: i, j, k, Ncp - integer, external :: add_task_to_taskserver double precision :: state_average_weight_save(N_states) + PROVIDE Nproc task(:) = CHAR(0) allocate(delta(N_states,N_det), delta_s2(N_states, N_det)) state_average_weight_save(:) = state_average_weight(:) @@ -254,7 +258,7 @@ subroutine ZMQ_dress(E, dress, delta_out, delta_s2_out, relative_error) integer, external :: zmq_put_N_det_generators integer, external :: zmq_put_N_det_selectors integer, external :: zmq_put_dvector - integer, external :: zmq_set_running + integer, external :: zmq_put_int if (zmq_put_psi(zmq_to_qp_run_socket,1) == -1) then stop 'Unable to put psi on ZMQ server' @@ -271,25 +275,59 @@ subroutine ZMQ_dress(E, dress, delta_out, delta_s2_out, relative_error) if (zmq_put_dvector(zmq_to_qp_run_socket,1,"state_average_weight",state_average_weight,N_states) == -1) then stop 'Unable to put state_average_weight on ZMQ server' endif - if (zmq_put_dvector(zmq_to_qp_run_socket,1,"dress_stoch_istate",real(dress_stoch_istate,8),1) == -1) then + if (zmq_put_int(zmq_to_qp_run_socket,1,"dress_stoch_istate",dress_stoch_istate) == -1) then stop 'Unable to put dress_stoch_istate on ZMQ server' endif + if (zmq_put_dvector(zmq_to_qp_run_socket,1,'threshold_selectors',threshold_selectors,1) == -1) then + stop 'Unable to put threshold_selectors on ZMQ server' + endif + if (zmq_put_dvector(zmq_to_qp_run_socket,1,'threshold_generators',threshold_generators,1) == -1) then + stop 'Unable to put threshold_generators on ZMQ server' + endif - integer(ZMQ_PTR), external :: new_zmq_to_qp_run_socket + + call write_int(6,pt2_n_tasks_max,'Max number of task fragments') - do i=1,N_det_generators - do j=1,pt2_F(pt2_J(i)) - write(task(1:20),'(I9,1X,I9''|'')') j, pt2_J(i) - if (add_task_to_taskserver(zmq_to_qp_run_socket,trim(task(1:20))) == -1) then + integer, external :: add_task_to_taskserver + integer :: ipos + ipos=0 + do i=1,N_det_generators + if (pt2_F(i) > 1) then + ipos += 1 + endif + enddo + call write_int(6,ipos,'Number of fragmented tasks') + + + ipos=1 + task = ' ' + + do i= 1, N_det_generators + do j=1,pt2_F(pt2_J(i)) + write(task(ipos:ipos+20),'(I9,1X,I9,''|'')') j, pt2_J(i) + ipos += 20 + if (ipos > len(task)-20) then + if (add_task_to_taskserver(zmq_to_qp_run_socket,trim(task(1:ipos))) == -1) then + stop 'Unable to add task to task server' + endif + ipos=1 + endif + end do + enddo + if (ipos > 1) then + if (add_task_to_taskserver(zmq_to_qp_run_socket,trim(task(1:ipos))) == -1) then stop 'Unable to add task to task server' endif - end do - end do - if (zmq_set_running(zmq_to_qp_run_socket) == -1) then - print *, irp_here, ': Failed in zmq_set_running' - endif + endif + + integer, external :: zmq_set_running + if (zmq_set_running(zmq_to_qp_run_socket) == -1) then + print *, irp_here, ': Failed in zmq_set_running' + endif + + integer :: nproc_target nproc_target = nproc @@ -495,14 +533,14 @@ subroutine dress_collector(zmq_socket_pull, E, relative_error, delta, delta_s2, m += 1 if(dabs(error / avg) <= relative_error) then integer, external :: zmq_put_dvector - i= zmq_put_dvector(zmq_to_qp_run_socket, worker_id, "ending", dble(m-1), 1) + integer, external :: zmq_put_int + i= zmq_put_int(zmq_to_qp_run_socket, worker_id, "ending", (m-1)) found = .true. end if else do call pull_dress_results(zmq_socket_pull, m_task, f, edI_task, edI_index, breve_delta_m, task_id, n_tasks) if(time0 == -1d0) then - print *, "first pull", omp_get_wtime()-time time0 = omp_get_wtime() end if if(m_task == 0) then @@ -516,14 +554,13 @@ subroutine dress_collector(zmq_socket_pull, E, relative_error, delta, delta_s2, end if end do do i=1,n_tasks - if(edI(edI_index(i)) /= 0d0) stop "NIN M" edI(edI_index(i)) += edI_task(i) end do dot_f(m_task) -= f end if end do if (zmq_abort(zmq_to_qp_run_socket) == -1) then - call sleep(1) + call sleep(10) if (zmq_abort(zmq_to_qp_run_socket) == -1) then print *, irp_here, ': Error in sending abort signal (2)' endif diff --git a/plugins/dress_zmq/run_dress_slave.irp.f b/plugins/dress_zmq/run_dress_slave.irp.f index 0d3201bc..b9d73cb9 100644 --- a/plugins/dress_zmq/run_dress_slave.irp.f +++ b/plugins/dress_zmq/run_dress_slave.irp.f @@ -33,15 +33,14 @@ subroutine run_dress_slave(thread,iproce,energy) integer :: cp_max(Nproc) integer :: will_send, task_id, purge_task_id, ntask_buf integer, allocatable :: task_buf(:) - integer(kind=OMP_LOCK_KIND) :: lck_det(0:pt2_N_teeth+1) - integer(kind=OMP_LOCK_KIND) :: lck_sto(0:dress_N_cp+1), sending, getting_task +! integer(kind=OMP_LOCK_KIND) :: lck_det(0:pt2_N_teeth+1) +! integer(kind=OMP_LOCK_KIND) :: lck_sto(dress_N_cp) double precision :: fac - double precision :: ending(1) - integer, external :: zmq_get_dvector + integer :: ending + integer, external :: zmq_get_dvector, zmq_get_int ! double precision, external :: omp_get_wtime double precision :: time, time0 integer :: ntask_tbd, task_tbd(Nproc), i_gen_tbd(Nproc), subset_tbd(Nproc) -! if(iproce /= 0) stop "RUN DRESS SLAVE is OMP" allocate(delta_det(N_states, N_det, 0:pt2_N_teeth+1, 2)) allocate(cp(N_states, N_det, dress_N_cp, 2)) @@ -53,14 +52,12 @@ subroutine run_dress_slave(thread,iproce,energy) cp = 0d0 task = CHAR(0) - call omp_init_lock(sending) - call omp_init_lock(getting_task) - do i=0,dress_N_cp+1 - call omp_init_lock(lck_sto(i)) - end do - do i=0,pt2_N_teeth+1 - call omp_init_lock(lck_det(i)) - end do +! do i=1,dress_N_cp +! call omp_init_lock(lck_sto(i)) +! end do +! do i=0,pt2_N_teeth+1 +! call omp_init_lock(lck_det(i)) +! end do cp_done = 0 cp_sent = 0 @@ -69,7 +66,7 @@ subroutine run_dress_slave(thread,iproce,energy) double precision :: hij, sij, tmp purge_task_id = 0 provide psi_energy - ending(1) = dble(dress_N_cp+1) + ending = dress_N_cp+1 ntask_tbd = 0 !$OMP PARALLEL DEFAULT(SHARED) & !$OMP PRIVATE(breve_delta_m, task_id) & @@ -86,7 +83,7 @@ subroutine run_dress_slave(thread,iproce,energy) stop "WORKER -1" end if iproc = omp_get_thread_num()+1 - allocate(breve_delta_m(N_states,N_det,2)) + allocate(breve_delta_m(N_states,N_det,2)) allocate(task_buf(pt2_n_tasks_max)) ntask_buf = 0 @@ -94,8 +91,9 @@ subroutine run_dress_slave(thread,iproce,energy) call push_dress_results(zmq_socket_push, 0, 0, edI_task, edI_index, breve_delta_m, task_buf, ntask_buf) end if + cp_max(:) = 0 do while(cp_done > cp_sent .or. m /= dress_N_cp+1) - call omp_set_lock(getting_task) + !$OMP CRITICAL (send) if(ntask_tbd == 0) then ntask_tbd = size(task_tbd) call get_tasks_from_taskserver(zmq_to_qp_run_socket,worker_id, task_tbd, task, ntask_tbd) @@ -113,13 +111,13 @@ subroutine run_dress_slave(thread,iproce,energy) ntask_tbd -= 1 else m = dress_N_cp + 1 - i= zmq_get_dvector(zmq_to_qp_run_socket, worker_id, "ending", ending, 1) + i= zmq_get_int(zmq_to_qp_run_socket, worker_id, "ending", ending) end if - call omp_unset_lock(getting_task) will_send = 0 - !$OMP CRITICAL cp_max(iproc) = m +! print *, cp_max(:) +! print *, '' cp_done = minval(cp_max)-1 if(cp_done > cp_sent) then will_send = cp_sent + 1 @@ -132,10 +130,8 @@ subroutine run_dress_slave(thread,iproce,energy) ntask_buf += 1 task_buf(ntask_buf) = task_id end if - !$OMP END CRITICAL - if(will_send /= 0 .and. will_send <= int(ending(1))) then - call omp_set_lock(sending) + if(will_send /= 0 .and. will_send <= ending) then n_tasks = 0 sum_f = 0 do i=1,N_det_generators @@ -146,9 +142,10 @@ subroutine run_dress_slave(thread,iproce,energy) edI_index(n_tasks) = i end if end do - call push_dress_results(zmq_socket_push, will_send, sum_f, edI_task, edI_index, breve_delta_m, 0, n_tasks) - call omp_unset_lock(sending) + call push_dress_results(zmq_socket_push, will_send, sum_f, edI_task, edI_index, & + breve_delta_m, 0, n_tasks) end if + !$OMP END CRITICAL (send) if(m /= dress_N_cp+1) then !UPDATE i_generator @@ -158,29 +155,29 @@ subroutine run_dress_slave(thread,iproce,energy) time0 = omp_get_wtime() call alpha_callback(breve_delta_m, i_generator, subset, pt2_F(i_generator), iproc) time = omp_get_wtime() - !print '(I0.11, I4, A12, F12.3)', i_generator, subset, "GREPMETIME", time-time0 +!print '(I0.11, I4, A12, F12.3)', i_generator, subset, "GREPMETIME", time-time0 t = dress_T(i_generator) - call omp_set_lock(lck_det(t)) + !$OMP CRITICAL(t_crit) do j=1,N_det do i=1,N_states delta_det(i,j,t, 1) = delta_det(i,j,t, 1) + breve_delta_m(i,j,1) delta_det(i,j,t, 2) = delta_det(i,j,t, 2) + breve_delta_m(i,j,2) enddo enddo - call omp_unset_lock(lck_det(t)) + !$OMP END CRITICAL(t_crit) do p=1,dress_N_cp if(dress_e(i_generator, p) /= 0d0) then fac = dress_e(i_generator, p) - call omp_set_lock(lck_sto(p)) + !$OMP CRITICAL(p_crit) do j=1,N_det do i=1,N_states cp(i,j,p,1) = cp(i,j,p,1) + breve_delta_m(i,j,1) * fac cp(i,j,p,2) = cp(i,j,p,2) + breve_delta_m(i,j,2) * fac enddo enddo - call omp_unset_lock(lck_sto(p)) + !$OMP END CRITICAL(p_crit) end if end do @@ -198,7 +195,9 @@ subroutine run_dress_slave(thread,iproce,energy) ntask_buf = 0 end if end if + !$OMP FLUSH end do + !$OMP BARRIER if(ntask_buf /= 0) then call push_dress_results(zmq_socket_push, 0, 0, edI_task, edI_index, breve_delta_m, task_buf, ntask_buf) @@ -206,12 +205,12 @@ subroutine run_dress_slave(thread,iproce,energy) end if !$OMP SINGLE if(purge_task_id /= 0) then - do while(int(ending(1)) == dress_N_cp+1) + do while(ending == dress_N_cp+1) call sleep(1) - i= zmq_get_dvector(zmq_to_qp_run_socket, worker_id, "ending", ending, 1) + i= zmq_get_int(zmq_to_qp_run_socket, worker_id, "ending", ending) end do - will_send = int(ending(1)) + will_send = ending breve_delta_m = 0d0 do l=will_send, 1,-1 @@ -238,12 +237,12 @@ subroutine run_dress_slave(thread,iproce,energy) call end_zmq_to_qp_run_socket(zmq_to_qp_run_socket) call end_zmq_push_socket(zmq_socket_push,thread) !$OMP END PARALLEL - do i=0,dress_N_cp+1 - call omp_destroy_lock(lck_sto(i)) - end do - do i=0,pt2_N_teeth+1 - call omp_destroy_lock(lck_det(i)) - end do +! do i=0,dress_N_cp+1 +! call omp_destroy_lock(lck_sto(i)) +! end do +! do i=0,pt2_N_teeth+1 +! call omp_destroy_lock(lck_det(i)) +! end do end subroutine diff --git a/plugins/shiftedbk/shifted_bk_slave.irp.f b/plugins/shiftedbk/shifted_bk_slave.irp.f index 901940ed..83d95847 100644 --- a/plugins/shiftedbk/shifted_bk_slave.irp.f +++ b/plugins/shiftedbk/shifted_bk_slave.irp.f @@ -1,15 +1,176 @@ -program shifted_bk +program shifted_bk_slave implicit none BEGIN_DOC -! Helper subroutine to compute the dress in distributed mode. +! Helper program to compute the dress in distributed mode. END_DOC - - PROVIDE psi_bilinear_matrix_columns_loc psi_det_alpha_unique psi_det_beta_unique - PROVIDE psi_bilinear_matrix_rows psi_det_sorted_gen_order psi_bilinear_matrix_order - PROVIDE psi_bilinear_matrix_transp_rows_loc psi_bilinear_matrix_transp_columns - PROVIDE psi_bilinear_matrix_transp_order - !call diagonalize_CI() - call dress_slave() + read_wf = .False. + distributed_davidson = .False. + SOFT_TOUCH read_wf distributed_davidson + call provide_all + call switch_qp_run_to_master + call run_w end +subroutine provide_all + PROVIDE H_apply_buffer_allocated mo_bielec_integrals_in_map psi_det_generators psi_coef_generators psi_det_sorted_bit psi_selectors n_det_generators n_states generators_bitmask zmq_context N_states_diag + PROVIDE dress_e0_denominator mo_tot_num N_int ci_energy mpi_master zmq_state zmq_context + PROVIDE psi_det psi_coef threshold_generators threshold_selectors state_average_weight + PROVIDE N_det_selectors dress_stoch_istate N_det +end + +subroutine run_w + use f77_zmq + + implicit none + IRP_IF MPI + include 'mpif.h' + IRP_ENDIF + + integer(ZMQ_PTR), external :: new_zmq_to_qp_run_socket + integer(ZMQ_PTR) :: zmq_to_qp_run_socket + double precision :: energy(N_states) + character*(64) :: states(3) + character*(64) :: old_state + integer :: rc, i, ierr + double precision :: t0, t1 + + integer, external :: zmq_get_dvector, zmq_get_N_det_generators + integer, external :: zmq_get_ivector + integer, external :: zmq_get_psi, zmq_get_N_det_selectors, zmq_get_int + integer, external :: zmq_get_N_states_diag + + zmq_context = f77_zmq_ctx_new () + states(1) = 'selection' + states(2) = 'davidson' + states(3) = 'dress' + old_state = 'Waiting' + + zmq_to_qp_run_socket = new_zmq_to_qp_run_socket() + + PROVIDE psi_det psi_coef threshold_generators threshold_selectors state_average_weight mpi_master + PROVIDE zmq_state N_det_selectors dress_stoch_istate N_det dress_e0_denominator + PROVIDE N_det_generators N_states N_states_diag + call MPI_BARRIER(MPI_COMM_WORLD, ierr) + do + + if (mpi_master) then + call wait_for_states(states,zmq_state,size(states)) + if (zmq_state(1:64) == old_state(1:64)) then + call sleep(1) + cycle + else + old_state(1:64) = zmq_state(1:64) + endif + print *, trim(zmq_state) + endif + + IRP_IF MPI_DEBUG + print *, irp_here, mpi_rank + call MPI_BARRIER(MPI_COMM_WORLD, ierr) + IRP_ENDIF + IRP_IF MPI + call MPI_BCAST (zmq_state, 128, MPI_CHARACTER, 0, MPI_COMM_WORLD, ierr) + if (ierr /= MPI_SUCCESS) then + print *, irp_here, 'error in broadcast of zmq_state' + endif + IRP_ENDIF + + if(zmq_state(1:7) == 'Stopped') then + exit + endif + + + if (zmq_state(1:8) == 'davidson') then + + ! Davidson + ! -------- + + call wall_time(t0) + if (zmq_get_psi(zmq_to_qp_run_socket,1) == -1) cycle + if (zmq_get_N_states_diag(zmq_to_qp_run_socket,1) == -1) cycle + if (zmq_get_dvector(zmq_to_qp_run_socket,1,'energy',energy,N_states_diag) == -1) cycle + + call wall_time(t1) + if (mpi_master) then + call write_double(6,(t1-t0),'Broadcast time') + endif + + call omp_set_nested(.True.) + call davidson_slave_tcp(0) + call omp_set_nested(.False.) + print *, 'Davidson done' + IRP_IF MPI + call MPI_BARRIER(MPI_COMM_WORLD, ierr) + if (ierr /= MPI_SUCCESS) then + print *, irp_here, 'error in barrier' + endif + IRP_ENDIF + print *, 'All Davidson done' + + else if (zmq_state(1:5) == 'dress') then + + ! Dress + ! --- + + call wall_time(t0) + if (zmq_get_psi(zmq_to_qp_run_socket,1) == -1) cycle + print *, 'if (zmq_get_psi(zmq_to_qp_run_socket,1) == -1) cycle', mpi_rank + + if (zmq_get_N_det_generators (zmq_to_qp_run_socket, 1) == -1) cycle + print *, 'if (zmq_get_N_det_generators (zmq_to_qp_run_socket, 1) == -1) cycle', mpi_rank + + if (zmq_get_N_det_selectors(zmq_to_qp_run_socket, 1) == -1) cycle + print *, 'if (zmq_get_N_det_selectors(zmq_to_qp_run_socket, 1) == -1) cycle', mpi_rank + + if (zmq_get_dvector(zmq_to_qp_run_socket,1,'threshold_generators',threshold_generators,1) == -1) cycle + print *, 'if (zmq_get_dvector(zmq_to_qp_run_socket,1,threshold_generators,threshold_generators,1) == -1) cycle', mpi_rank + + if (zmq_get_dvector(zmq_to_qp_run_socket,1,'threshold_selectors',threshold_selectors,1) == -1) cycle + print *, 'if (zmq_get_dvector(zmq_to_qp_run_socket,1,threshold_selectors,threshold_selectors,1) == -1) cycle', mpi_rank + + if (zmq_get_dvector(zmq_to_qp_run_socket,1,'energy',energy,N_states) == -1) cycle + print *, 'if (zmq_get_dvector(zmq_to_qp_run_socket,1,energy,energy,N_states) == -1) cycle', mpi_rank + + if (zmq_get_int(zmq_to_qp_run_socket,1,'dress_stoch_istate',dress_stoch_istate) == -1) cycle + print *, 'if (zmq_get_int(zmq_to_qp_run_socket,1,dress_stoch_istate,dress_stoch_istate) == -1) cycle', mpi_rank + + if (zmq_get_dvector(zmq_to_qp_run_socket,1,'state_average_weight',state_average_weight,N_states) == -1) cycle + print *, 'if (zmq_get_dvector(zmq_to_qp_run_socket,1,state_average_weight,state_average_weight,N_states) == -1) cycle', mpi_rank + + psi_energy(1:N_states) = energy(1:N_states) + TOUCH psi_energy state_average_weight dress_stoch_istate threshold_selectors threshold_generators + if (mpi_master) then + print *, 'N_det', N_det + print *, 'N_det_generators', N_det_generators + print *, 'N_det_selectors', N_det_selectors + print *, 'psi_energy', psi_energy + print *, 'dress_stoch_istate', dress_stoch_istate + print *, 'state_average_weight', state_average_weight + endif + + call wall_time(t1) + call write_double(6,(t1-t0),'Broadcast time') + + call dress_slave_tcp(0, energy) + + + IRP_IF MPI + call MPI_BARRIER(MPI_COMM_WORLD, ierr) + if (ierr /= MPI_SUCCESS) then + print *, irp_here, 'error in barrier' + endif + IRP_ENDIF + print *, 'All dress done' + + endif + + end do + IRP_IF MPI + call MPI_finalize(ierr) + IRP_ENDIF +end + + + + diff --git a/src/ZMQ/put_get.irp.f b/src/ZMQ/put_get.irp.f index 207cb0ae..e86a6daf 100644 --- a/src/ZMQ/put_get.irp.f +++ b/src/ZMQ/put_get.irp.f @@ -8,7 +8,7 @@ integer function zmq_put_dvector(zmq_to_qp_run_socket, worker_id, name, x, size_ integer, intent(in) :: worker_id character*(*) :: name integer, intent(in) :: size_x - double precision, intent(out) :: x(size_x) + double precision, intent(in) :: x(size_x) integer :: rc character*(256) :: msg @@ -111,7 +111,7 @@ integer function zmq_put_ivector(zmq_to_qp_run_socket, worker_id, name, x, size_ integer, intent(in) :: worker_id character*(*) :: name integer, intent(in) :: size_x - integer, intent(out) :: x(size_x) + integer, intent(in) :: x(size_x) integer :: rc character*(256) :: msg @@ -201,3 +201,81 @@ end +integer function zmq_put_int(zmq_to_qp_run_socket, worker_id, name, x) + use f77_zmq + implicit none + BEGIN_DOC +! Put a vector of integers on the qp_run scheduler + END_DOC + integer(ZMQ_PTR), intent(in) :: zmq_to_qp_run_socket + integer, intent(in) :: worker_id + character*(*) :: name + integer, intent(in) :: x + integer :: rc + character*(256) :: msg + + zmq_put_int = 0 + + write(msg,'(A,1X,I8,1X,A200)') 'put_data '//trim(zmq_state), worker_id, name + rc = f77_zmq_send(zmq_to_qp_run_socket,trim(msg),len(trim(msg)),ZMQ_SNDMORE) + if (rc /= len(trim(msg))) then + zmq_put_int = -1 + return + endif + + rc = f77_zmq_send(zmq_to_qp_run_socket,x,4,0) + if (rc /= 4) then + zmq_put_int = -1 + return + endif + + rc = f77_zmq_recv(zmq_to_qp_run_socket,msg,len(msg),0) + if (msg(1:rc) /= 'put_data_reply ok') then + zmq_put_int = -1 + return + endif + +end + +integer function zmq_get_int(zmq_to_qp_run_socket, worker_id, name, x) + use f77_zmq + implicit none + BEGIN_DOC +! Get a vector of integers from the qp_run scheduler + END_DOC + integer(ZMQ_PTR), intent(in) :: zmq_to_qp_run_socket + integer, intent(in) :: worker_id + character*(*), intent(in) :: name + integer, intent(out) :: x + integer :: rc + character*(256) :: msg + + PROVIDE zmq_state + ! Success + zmq_get_int = 0 + + if (mpi_master) then + write(msg,'(A,1X,I8,1X,A200)') 'get_data '//trim(zmq_state), worker_id, name + rc = f77_zmq_send(zmq_to_qp_run_socket,trim(msg),len(trim(msg)),0) + if (rc /= len(trim(msg))) then + zmq_get_int = -1 + go to 10 + endif + + rc = f77_zmq_recv(zmq_to_qp_run_socket,msg,len(msg),0) + if (msg(1:14) /= 'get_data_reply') then + zmq_get_int = -1 + go to 10 + endif + + rc = f77_zmq_recv(zmq_to_qp_run_socket,x,4,0) + if (rc /= 4) then + zmq_get_int = -1 + go to 10 + endif + endif + + 10 continue + +end +