From 2fd46841aed16b4750d674d7b2e507ece04d6e72 Mon Sep 17 00:00:00 2001 From: Anthony Scemama Date: Tue, 23 May 2017 19:30:51 +0200 Subject: [PATCH] Fixed selection CASSD slave --- .../CAS_SD_ZMQ/selection_cassd_slave.irp.f | 68 +++++++++---------- plugins/Generators_CAS/generators.irp.f | 12 ++-- plugins/Selectors_CASSD/selectors.irp.f | 5 +- plugins/Selectors_CASSD/zmq.irp.f | 9 +-- 4 files changed, 49 insertions(+), 45 deletions(-) diff --git a/plugins/CAS_SD_ZMQ/selection_cassd_slave.irp.f b/plugins/CAS_SD_ZMQ/selection_cassd_slave.irp.f index 657ad63c..b5f2371d 100644 --- a/plugins/CAS_SD_ZMQ/selection_cassd_slave.irp.f +++ b/plugins/CAS_SD_ZMQ/selection_cassd_slave.irp.f @@ -1,11 +1,12 @@ -program selection_slave +program prog_selection_slave implicit none BEGIN_DOC ! Helper program to compute the PT2 in distributed mode. END_DOC read_wf = .False. - SOFT_TOUCH read_wf + distributed_davidson = .False. + SOFT_TOUCH read_wf distributed_davidson call provide_everything call switch_qp_run_to_master call run_wf @@ -23,19 +24,21 @@ subroutine run_wf integer(ZMQ_PTR), external :: new_zmq_to_qp_run_socket integer(ZMQ_PTR) :: zmq_to_qp_run_socket double precision :: energy(N_states) - character*(64) :: states(1) + character*(64) :: states(4) integer :: rc, i call provide_everything zmq_context = f77_zmq_ctx_new () states(1) = 'selection' + states(2) = 'davidson' + states(3) = 'pt2' zmq_to_qp_run_socket = new_zmq_to_qp_run_socket() do - call wait_for_states(states,zmq_state,1) + call wait_for_states(states,zmq_state,3) if(trim(zmq_state) == 'Stopped') then @@ -51,43 +54,40 @@ subroutine run_wf !$OMP PARALLEL PRIVATE(i) i = omp_get_thread_num() - call selection_slave_tcp(i, energy) + call run_selection_slave(0, i, energy) !$OMP END PARALLEL print *, 'Selection done' + else if (trim(zmq_state) == 'davidson') then + + ! Davidson + ! -------- + + print *, 'Davidson' + call zmq_get_psi(zmq_to_qp_run_socket,1,energy,N_states) + call omp_set_nested(.True.) + call davidson_slave_tcp(0) + call omp_set_nested(.False.) + print *, 'Davidson done' + + else if (trim(zmq_state) == 'pt2') then + + ! PT2 + ! --- + + print *, 'PT2' + call zmq_get_psi(zmq_to_qp_run_socket,1,energy,N_states) + + !$OMP PARALLEL PRIVATE(i) + i = omp_get_thread_num() + call run_selection_slave(0, i, energy) + !$OMP END PARALLEL + print *, 'PT2 done' + endif end do end -subroutine update_energy(energy) - implicit none - double precision, intent(in) :: energy(N_states) - BEGIN_DOC -! Update energy when it is received from ZMQ - END_DOC - integer :: j,k - do j=1,N_states - do k=1,N_det - CI_eigenvectors(k,j) = psi_coef(k,j) - enddo - enddo - call u_0_S2_u_0(CI_eigenvectors_s2,CI_eigenvectors,N_det,psi_det,N_int) - if (.True.) then - do k=1,N_states - ci_electronic_energy(k) = energy(k) - enddo - TOUCH ci_electronic_energy CI_eigenvectors_s2 CI_eigenvectors - endif - call write_double(6,ci_energy,'Energy') -end - -subroutine selection_slave_tcp(i,energy) - implicit none - double precision, intent(in) :: energy(N_states) - integer, intent(in) :: i - - call run_selection_slave(0,i,energy) -end diff --git a/plugins/Generators_CAS/generators.irp.f b/plugins/Generators_CAS/generators.irp.f index f47341de..4e2fcd58 100644 --- a/plugins/Generators_CAS/generators.irp.f +++ b/plugins/Generators_CAS/generators.irp.f @@ -14,9 +14,9 @@ BEGIN_PROVIDER [ integer, N_det_generators ] good = .True. do k=1,N_int good = good .and. ( & - iand(not(cas_bitmask(k,1,l)), psi_det(k,1,i)) == & + iand(not(cas_bitmask(k,1,l)), psi_det_sorted(k,1,i)) == & iand(not(cas_bitmask(k,1,l)), HF_bitmask(k,1)) ) .and. ( & - iand(not(cas_bitmask(k,2,l)), psi_det(k,2,i)) == & + iand(not(cas_bitmask(k,2,l)), psi_det_sorted(k,2,i)) == & iand(not(cas_bitmask(k,2,l)), HF_bitmask(k,2)) ) enddo if (good) then @@ -46,9 +46,9 @@ END_PROVIDER good = .True. do k=1,N_int good = good .and. ( & - iand(not(cas_bitmask(k,1,l)), psi_det(k,1,i)) == & + iand(not(cas_bitmask(k,1,l)), psi_det_sorted(k,1,i)) == & iand(not(cas_bitmask(k,1,l)), HF_bitmask(k,1)) .and. ( & - iand(not(cas_bitmask(k,2,l)), psi_det(k,2,i)) == & + iand(not(cas_bitmask(k,2,l)), psi_det_sorted(k,2,i)) == & iand(not(cas_bitmask(k,2,l)), HF_bitmask(k,2) )) ) enddo if (good) then @@ -58,8 +58,8 @@ END_PROVIDER if (good) then m = m+1 do k=1,N_int - psi_det_generators(k,1,m) = psi_det(k,1,i) - psi_det_generators(k,2,m) = psi_det(k,2,i) + psi_det_generators(k,1,m) = psi_det_sorted(k,1,i) + psi_det_generators(k,2,m) = psi_det_sorted(k,2,i) enddo psi_coef_generators(m,:) = psi_coef(m,:) endif diff --git a/plugins/Selectors_CASSD/selectors.irp.f b/plugins/Selectors_CASSD/selectors.irp.f index ab36527d..3c0e2b5f 100644 --- a/plugins/Selectors_CASSD/selectors.irp.f +++ b/plugins/Selectors_CASSD/selectors.irp.f @@ -61,7 +61,10 @@ END_PROVIDER endif enddo if (N_det /= m) then - print *, N_det, m + print *, 'N_det = ', N_det + print *, 'm = ', m + print *, 'N_det_generators = ', N_det_generators + print *, 'psi_det_size = ', psi_det_size stop 'N_det /= m' endif END_PROVIDER diff --git a/plugins/Selectors_CASSD/zmq.irp.f b/plugins/Selectors_CASSD/zmq.irp.f index 37dd29de..f7f0c4b0 100644 --- a/plugins/Selectors_CASSD/zmq.irp.f +++ b/plugins/Selectors_CASSD/zmq.irp.f @@ -84,23 +84,24 @@ subroutine zmq_get_psi(zmq_to_qp_run_socket, worker_id, energy, size_energy) N_states = N_states_read N_det = N_det_read psi_det_size = psi_det_size_read + TOUCH psi_det_size N_det N_states - rc = f77_zmq_recv(zmq_to_qp_run_socket,psi_det,N_int*2*N_det*bit_kind,ZMQ_SNDMORE) + rc = f77_zmq_recv(zmq_to_qp_run_socket,psi_det,N_int*2*N_det*bit_kind,0) if (rc /= N_int*2*N_det*bit_kind) then print *, 'f77_zmq_recv(zmq_to_qp_run_socket,psi_det,N_int*2*N_det*bit_kind,ZMQ_SNDMORE)' stop 'error' endif - rc = f77_zmq_recv(zmq_to_qp_run_socket,psi_coef,psi_det_size*N_states*8,ZMQ_SNDMORE) + rc = f77_zmq_recv(zmq_to_qp_run_socket,psi_coef,psi_det_size*N_states*8,0) if (rc /= psi_det_size*N_states*8) then print *, '77_zmq_recv(zmq_to_qp_run_socket,psi_coef,psi_det_size*N_states*8,ZMQ_SNDMORE)' stop 'error' endif - TOUCH psi_det_size N_det N_states psi_det psi_coef + TOUCH psi_det psi_coef rc = f77_zmq_recv(zmq_to_qp_run_socket,energy,size_energy*8,0) if (rc /= size_energy*8) then - print *, 'f77_zmq_recv(zmq_to_qp_run_socket,energy,size_energy*8,0)' + print *, '77_zmq_recv(zmq_to_qp_run_socket,energy,size_energy*8,0)' stop 'error' endif