10
0
mirror of https://github.com/LCPQ/quantum_package synced 2024-06-26 07:02:14 +02:00

Corrected N_states N_states_diag i parallel davidson

This commit is contained in:
Anthony Scemama 2016-10-06 14:35:43 +02:00
parent e0e1b22a51
commit 311df93d20
4 changed files with 143 additions and 55 deletions

View File

@ -12,8 +12,8 @@ subroutine davidson_process(blockb, blocke, N, idx, vt, st)
integer , intent(in) :: blockb, blocke
integer , intent(inout) :: N
integer , intent(inout) :: idx(dav_size)
double precision , intent(inout) :: vt(N_states, dav_size)
double precision , intent(inout) :: st(N_states, dav_size)
double precision , intent(inout) :: vt(N_states_diag, dav_size)
double precision , intent(inout) :: st(N_states_diag, dav_size)
integer :: i, j, sh, sh2, exa, ext, org_i, org_j, istate, ni, endi
integer(bit_kind) :: sorted_i(N_int)
@ -63,7 +63,7 @@ subroutine davidson_process(blockb, blocke, N, idx, vt, st)
vt (:,org_j) = 0d0
st (:,org_j) = 0d0
end if
do istate=1,N_states
do istate=1,N_states_diag
vt (istate,org_i) += hij*dav_ut(istate,org_j)
st (istate,org_i) += s2*dav_ut(istate,org_j)
vt (istate,org_j) += hij*dav_ut(istate,org_i)
@ -79,7 +79,7 @@ subroutine davidson_process(blockb, blocke, N, idx, vt, st)
do i=1, dav_size
if(wrotten(i)) then
N = N+1
do istate=1,N_states
do istate=1,N_states_diag
vt (istate,N) = vt (istate,i)
st (istate,N) = st (istate,i)
idx(N) = i
@ -98,10 +98,10 @@ subroutine davidson_collect(blockb, blocke, N, idx, vt, st , v0, s0)
integer , intent(in) :: blockb, blocke
integer , intent(in) :: N
integer , intent(in) :: idx(N)
double precision , intent(in) :: vt(N_states, N)
double precision , intent(in) :: st(N_states, N)
double precision , intent(inout) :: v0(dav_size, N_states)
double precision , intent(inout) :: s0(dav_size, N_states)
double precision , intent(in) :: vt(N_states_diag, N)
double precision , intent(in) :: st(N_states_diag, N)
double precision , intent(inout) :: v0(dav_size, N_states_diag)
double precision , intent(inout) :: s0(dav_size, N_states_diag)
integer :: i
@ -210,8 +210,8 @@ subroutine davidson_slave_work(zmq_to_qp_run_socket, zmq_socket_push, worker_id)
allocate(idx(dav_size))
allocate(vt(N_states, dav_size))
allocate(st(N_states, dav_size))
allocate(vt(N_states_diag, dav_size))
allocate(st(N_states_diag, dav_size))
do
@ -239,8 +239,8 @@ subroutine davidson_push_results(zmq_socket_push, blockb, blocke, N, idx, vt, st
integer ,intent(in) :: blockb, blocke
integer ,intent(in) :: N
integer ,intent(in) :: idx(N)
double precision ,intent(in) :: vt(N_states, N)
double precision ,intent(in) :: st(N_states, N)
double precision ,intent(in) :: vt(N_states_diag, N)
double precision ,intent(in) :: st(N_states_diag, N)
integer :: rc
rc = f77_zmq_send( zmq_socket_push, blockb, 4, ZMQ_SNDMORE)
@ -255,11 +255,11 @@ subroutine davidson_push_results(zmq_socket_push, blockb, blocke, N, idx, vt, st
rc = f77_zmq_send( zmq_socket_push, idx, 4*N, ZMQ_SNDMORE)
if(rc /= 4*N) stop "davidson_push_results failed to push idx"
rc = f77_zmq_send( zmq_socket_push, vt, 8*N_states* N, ZMQ_SNDMORE)
if(rc /= 8*N_states* N) stop "davidson_push_results failed to push vt"
rc = f77_zmq_send( zmq_socket_push, vt, 8*N_states_diag* N, ZMQ_SNDMORE)
if(rc /= 8*N_states_diag* N) stop "davidson_push_results failed to push vt"
rc = f77_zmq_send( zmq_socket_push, st, 8*N_states* N, ZMQ_SNDMORE)
if(rc /= 8*N_states* N) stop "davidson_push_results failed to push st"
rc = f77_zmq_send( zmq_socket_push, st, 8*N_states_diag* N, ZMQ_SNDMORE)
if(rc /= 8*N_states_diag* N) stop "davidson_push_results failed to push st"
rc = f77_zmq_send( zmq_socket_push, task_id, 4, 0)
if(rc /= 4) stop "davidson_push_results failed to push task_id"
@ -276,8 +276,8 @@ subroutine davidson_pull_results(zmq_socket_pull, blockb, blocke, N, idx, vt, st
integer ,intent(out) :: blockb, blocke
integer ,intent(out) :: N
integer ,intent(out) :: idx(dav_size)
double precision ,intent(out) :: vt(N_states, dav_size)
double precision ,intent(out) :: st(N_states, dav_size)
double precision ,intent(out) :: vt(N_states_diag, dav_size)
double precision ,intent(out) :: st(N_states_diag, dav_size)
integer :: rc
@ -293,11 +293,11 @@ subroutine davidson_pull_results(zmq_socket_pull, blockb, blocke, N, idx, vt, st
rc = f77_zmq_recv( zmq_socket_pull, idx, 4*N, 0)
if(rc /= 4*N) stop "davidson_push_results failed to pull idx"
rc = f77_zmq_recv( zmq_socket_pull, vt, 8*N_states* N, 0)
if(rc /= 8*N_states* N) stop "davidson_push_results failed to pull vt"
rc = f77_zmq_recv( zmq_socket_pull, vt, 8*N_states_diag* N, 0)
if(rc /= 8*N_states_diag* N) stop "davidson_push_results failed to pull vt"
rc = f77_zmq_recv( zmq_socket_pull, st, 8*N_states* N, 0)
if(rc /= 8*N_states* N) stop "davidson_push_results failed to pull st"
rc = f77_zmq_recv( zmq_socket_pull, st, 8*N_states_diag* N, 0)
if(rc /= 8*N_states_diag* N) stop "davidson_push_results failed to pull st"
rc = f77_zmq_recv( zmq_socket_pull, task_id, 4, 0)
if(rc /= 4) stop "davidson_pull_results failed to pull task_id"
@ -312,8 +312,8 @@ subroutine davidson_collector(zmq_to_qp_run_socket, zmq_socket_pull , v0, s0)
integer(ZMQ_PTR), intent(in) :: zmq_to_qp_run_socket
integer(ZMQ_PTR), intent(in) :: zmq_socket_pull
double precision ,intent(inout) :: v0(dav_size, N_states)
double precision ,intent(inout) :: s0(dav_size, N_states)
double precision ,intent(inout) :: v0(dav_size, N_states_diag)
double precision ,intent(inout) :: s0(dav_size, N_states_diag)
integer :: more, task_id
@ -330,8 +330,8 @@ subroutine davidson_collector(zmq_to_qp_run_socket, zmq_socket_pull , v0, s0)
done = .false.
allocate(idx(dav_size))
allocate(vt(N_states, dav_size))
allocate(st(N_states, dav_size))
allocate(vt(N_states_diag, dav_size))
allocate(st(N_states_diag, dav_size))
more = 1
@ -360,8 +360,8 @@ subroutine davidson_run(zmq_to_qp_run_socket , v0, s0)
integer :: i
integer, external :: omp_get_thread_num
double precision , intent(inout) :: v0(dav_size, N_states)
double precision , intent(inout) :: s0(dav_size, N_states)
double precision , intent(inout) :: v0(dav_size, N_states_diag)
double precision , intent(inout) :: s0(dav_size, N_states_diag)
call zmq_set_running(zmq_to_qp_run_socket)
@ -411,7 +411,7 @@ subroutine davidson_miniserver_run()
if (buffer(1:rc) /= 'end') then
rc = f77_zmq_send (responder, dav_size, 4, ZMQ_SNDMORE)
rc = f77_zmq_send (responder, dav_det, 16*N_int*dav_size, ZMQ_SNDMORE)
rc = f77_zmq_send (responder, dav_ut, 8*dav_size*N_states, 0)
rc = f77_zmq_send (responder, dav_ut, 8*dav_size*N_states_diag, 0)
else
rc = f77_zmq_send (responder, "end", 3, 0)
exit
@ -465,7 +465,7 @@ subroutine davidson_miniserver_get()
rc = f77_zmq_recv(requester, dav_size, 4, 0)
TOUCH dav_size
rc = f77_zmq_recv(requester, dav_det, 16*N_int*dav_size, 0)
rc = f77_zmq_recv(requester, dav_ut, 8*dav_size*N_states, 0)
rc = f77_zmq_recv(requester, dav_ut, 8*dav_size*N_states_diag, 0)
TOUCH dav_det dav_ut
rc = f77_zmq_close(requester)
@ -480,7 +480,7 @@ BEGIN_PROVIDER [ integer(bit_kind), dav_det, (N_int, 2, dav_size) ]
END_PROVIDER
BEGIN_PROVIDER [ double precision, dav_ut, (N_states, dav_size) ]
BEGIN_PROVIDER [ double precision, dav_ut, (N_states_diag, dav_size) ]
END_PROVIDER

View File

@ -36,5 +36,5 @@ program davidson_slave
end
! subroutine provide_everything
! PROVIDE mo_bielec_integrals_in_map psi_det_sorted_bit N_states zmq_context
! PROVIDE mo_bielec_integrals_in_map psi_det_sorted_bit N_states_diag zmq_context
! end subroutine

View File

@ -58,8 +58,8 @@ subroutine H_u_0_nstates(v_0,u_0,H_jj,n,keys_tmp,Nint,N_st,sze_8)
integer, external :: align_double
!!!DIR$ ATTRIBUTES ALIGN : $IRP_ALIGN :: vt, ut
if(N_st /= N_states) stop "H_u_0_nstates N_st /= N_states"
N_st_8 = N_states ! align_double(N_st)
if(N_st /= N_states_diag) stop "H_u_0_nstates N_st /= N_states_diag"
N_st_8 = N_states_diag ! align_double(N_st)
ASSERT (Nint > 0)
ASSERT (Nint == N_int)
@ -214,7 +214,7 @@ subroutine H_S2_u_0_nstates(v_0,s_0,u_0,H_jj,S2_jj,n,keys_tmp,Nint,N_st,sze_8)
integer(ZMQ_PTR) :: handler
if(N_st /= N_states .or. sze_8 < N_det) stop "assert fail in H_S2_u_0_nstates"
if(N_st /= N_states_diag .or. sze_8 < N_det) stop "assert fail in H_S2_u_0_nstates"
N_st_8 = N_st !! align_double(N_st)
ASSERT (Nint > 0)

View File

@ -109,6 +109,42 @@ subroutine bielec_integrals_index_reverse(i,j,k,l,i1)
end
BEGIN_PROVIDER [ integer, ao_integrals_cache_min ]
&BEGIN_PROVIDER [ integer, ao_integrals_cache_max ]
implicit none
BEGIN_DOC
! Min and max values of the AOs for which the integrals are in the cache
END_DOC
ao_integrals_cache_min = max(1,ao_num - 63)
ao_integrals_cache_max = ao_num
END_PROVIDER
BEGIN_PROVIDER [ double precision, ao_integrals_cache, (ao_integrals_cache_min:ao_integrals_cache_max,ao_integrals_cache_min:ao_integrals_cache_max,ao_integrals_cache_min:ao_integrals_cache_max,ao_integrals_cache_min:ao_integrals_cache_max) ]
implicit none
BEGIN_DOC
! Cache of AO integrals for fast access
END_DOC
PROVIDE ao_bielec_integrals_in_map
integer :: i,j,k,l
integer(key_kind) :: idx
!$OMP PARALLEL DO PRIVATE (i,j,k,l,idx)
do l=ao_integrals_cache_min,ao_integrals_cache_max
do k=ao_integrals_cache_min,ao_integrals_cache_max
do j=ao_integrals_cache_min,ao_integrals_cache_max
do i=ao_integrals_cache_min,ao_integrals_cache_max
!DIR$ FORCEINLINE
call bielec_integrals_index(i,j,k,l,idx)
!DIR$ FORCEINLINE
call map_get(ao_integrals_map,idx,ao_integrals_cache(i,j,k,l))
enddo
enddo
enddo
enddo
!$OMP END PARALLEL DO
END_PROVIDER
double precision function get_ao_bielec_integral(i,j,k,l,map)
use map_module
@ -127,8 +163,20 @@ double precision function get_ao_bielec_integral(i,j,k,l,map)
else if (ao_bielec_integral_schwartz(i,k)*ao_bielec_integral_schwartz(j,l) < ao_integrals_threshold) then
tmp = 0.d0
else
call bielec_integrals_index(i,j,k,l,idx)
call map_get(map,idx,tmp)
if ( (i >= ao_integrals_cache_min) .and. &
(j >= ao_integrals_cache_min) .and. &
(k >= ao_integrals_cache_min) .and. &
(l >= ao_integrals_cache_min) .and. &
(i <= ao_integrals_cache_max) .and. &
(j <= ao_integrals_cache_max) .and. &
(k <= ao_integrals_cache_max) .and. &
(l <= ao_integrals_cache_max) ) then
tmp = ao_integrals_cache(i,j,k,l)
else
!DIR$ FORCEINLINE
call bielec_integrals_index(i,j,k,l,idx)
call map_get(map,idx,tmp)
endif
endif
get_ao_bielec_integral = tmp
end
@ -155,16 +203,9 @@ subroutine get_ao_bielec_integrals(j,k,l,sze,out_val)
return
endif
double precision :: get_ao_bielec_integral
do i=1,sze
if (ao_overlap_abs(i,k)*ao_overlap_abs(j,l) < thresh ) then
out_val(i) = 0.d0
else if (ao_bielec_integral_schwartz(i,k)*ao_bielec_integral_schwartz(j,l) < thresh) then
out_val(i)=0.d0
else
!DIR$ FORCEINLINE
call bielec_integrals_index(i,j,k,l,hash)
call map_get(ao_integrals_map, hash, out_val(i))
endif
out_val(i) = get_ao_bielec_integral(i,j,k,l,ao_integrals_map)
enddo
end
@ -276,6 +317,43 @@ subroutine insert_into_mo_integrals_map(n_integrals, &
call map_update(mo_integrals_map, buffer_i, buffer_values, n_integrals, thr)
end
BEGIN_PROVIDER [ integer, mo_integrals_cache_min ]
&BEGIN_PROVIDER [ integer, mo_integrals_cache_max ]
implicit none
BEGIN_DOC
! Min and max values of the MOs for which the integrals are in the cache
END_DOC
mo_integrals_cache_min = max(1,elec_alpha_num - 31)
mo_integrals_cache_max = min(mo_tot_num,elec_alpha_num + 32)
END_PROVIDER
BEGIN_PROVIDER [ double precision, mo_integrals_cache, (mo_integrals_cache_min:mo_integrals_cache_max,mo_integrals_cache_min:mo_integrals_cache_max,mo_integrals_cache_min:mo_integrals_cache_max,mo_integrals_cache_min:mo_integrals_cache_max) ]
implicit none
BEGIN_DOC
! Cache of MO integrals for fast access
END_DOC
PROVIDE mo_bielec_integrals_in_map
integer :: i,j,k,l
integer(key_kind) :: idx
FREE ao_integrals_cache
!$OMP PARALLEL DO PRIVATE (i,j,k,l,idx)
do l=mo_integrals_cache_min,mo_integrals_cache_max
do k=mo_integrals_cache_min,mo_integrals_cache_max
do j=mo_integrals_cache_min,mo_integrals_cache_max
do i=mo_integrals_cache_min,mo_integrals_cache_max
!DIR$ FORCEINLINE
call bielec_integrals_index(i,j,k,l,idx)
!DIR$ FORCEINLINE
call map_get(mo_integrals_map,idx,mo_integrals_cache(i,j,k,l))
enddo
enddo
enddo
enddo
!$OMP END PARALLEL DO
END_PROVIDER
double precision function get_mo_bielec_integral(i,j,k,l,map)
use map_module
implicit none
@ -287,11 +365,22 @@ double precision function get_mo_bielec_integral(i,j,k,l,map)
type(map_type), intent(inout) :: map
real(integral_kind) :: tmp
PROVIDE mo_bielec_integrals_in_map
!DIR$ FORCEINLINE
call bielec_integrals_index(i,j,k,l,idx)
!DIR$ FORCEINLINE
call map_get(map,idx,tmp)
get_mo_bielec_integral = dble(tmp)
if ( (i >= mo_integrals_cache_min) .and. &
(j >= mo_integrals_cache_min) .and. &
(k >= mo_integrals_cache_min) .and. &
(l >= mo_integrals_cache_min) .and. &
(i <= mo_integrals_cache_max) .and. &
(j <= mo_integrals_cache_max) .and. &
(k <= mo_integrals_cache_max) .and. &
(l <= mo_integrals_cache_max) ) then
get_mo_bielec_integral = mo_integrals_cache(i,j,k,l)
else
!DIR$ FORCEINLINE
call bielec_integrals_index(i,j,k,l,idx)
!DIR$ FORCEINLINE
call map_get(map,idx,tmp)
get_mo_bielec_integral = dble(tmp)
endif
end
double precision function get_mo_bielec_integral_schwartz(i,j,k,l,map)
@ -306,14 +395,12 @@ double precision function get_mo_bielec_integral_schwartz(i,j,k,l,map)
real(integral_kind) :: tmp
PROVIDE mo_bielec_integrals_in_map
if (mo_bielec_integral_schwartz(i,k)*mo_bielec_integral_schwartz(j,l) > mo_integrals_threshold) then
double precision, external :: get_mo_bielec_integral
!DIR$ FORCEINLINE
call bielec_integrals_index(i,j,k,l,idx)
!DIR$ FORCEINLINE
call map_get(map,idx,tmp)
get_mo_bielec_integral_schwartz = get_mo_bielec_integral(i,j,k,l,map)
else
tmp = 0.d0
endif
get_mo_bielec_integral_schwartz = dble(tmp)
end
@ -325,6 +412,7 @@ double precision function mo_bielec_integral(i,j,k,l)
integer, intent(in) :: i,j,k,l
double precision :: get_mo_bielec_integral
PROVIDE mo_bielec_integrals_in_map
!DIR$ FORCEINLINE
mo_bielec_integral = get_mo_bielec_integral(i,j,k,l,mo_integrals_map)
return
end