10
0
mirror of https://github.com/LCPQ/quantum_package synced 2024-07-03 18:05:59 +02:00

Accelerated integrals

This commit is contained in:
Anthony Scemama 2018-06-07 16:15:26 +02:00
parent c46635c5ed
commit 261a8f6dfd
7 changed files with 123 additions and 56 deletions

View File

@ -52,6 +52,11 @@ program fci_zmq
double precision :: error(N_states)
correlation_energy_ratio = 0.d0
if (do_pt2) then
pt2_string = ' '
else
pt2_string = '(approx)'
endif
if (.True.) then ! Avoid pre-calculation of CI_energy
do while ( &
@ -63,7 +68,6 @@ program fci_zmq
if (do_pt2) then
pt2_string = ' '
pt2 = 0.d0
threshold_selectors = 1.d0
threshold_generators = 1d0
@ -72,8 +76,6 @@ program fci_zmq
threshold_selectors = threshold_selectors_save
threshold_generators = threshold_generators_save
SOFT_TOUCH threshold_selectors threshold_generators
else
pt2_string = '(approx)'
endif

View File

@ -139,6 +139,7 @@ subroutine ZMQ_pt2(E, pt2,relative_error, absolute_error, error)
endif
call omp_set_nested(.true.)
!$OMP PARALLEL DEFAULT(shared) NUM_THREADS(nproc+1) &
!$OMP PRIVATE(i)
i = omp_get_thread_num()
@ -149,6 +150,7 @@ subroutine ZMQ_pt2(E, pt2,relative_error, absolute_error, error)
call pt2_slave_inproc(i)
endif
!$OMP END PARALLEL
call omp_set_nested(.false.)
call end_parallel_job(zmq_to_qp_run_socket, zmq_socket_pull, 'pt2')
call delete_selection_buffer(b)

View File

@ -82,8 +82,8 @@ subroutine run_pt2_slave(thread,iproc,energy)
endif
call push_pt2_results(zmq_socket_push, i_generator, pt2, task_id, n_tasks)
! Try to adjust n_tasks around 1 second per job
n_tasks = min(n_tasks,int( dble(n_tasks) / (time1 - time0 + 1.d-9)))+1
! Try to adjust n_tasks around 5 second per job
n_tasks = min(n_tasks,int( 5.d0*dble(n_tasks) / (time1 - time0 + 1.d-9)))+1
end do
integer, external :: disconnect_from_taskserver

View File

@ -97,8 +97,8 @@ subroutine run_selection_slave_new(thread,iproc,energy)
pt2(:,:) = 0d0
buf%cur = 0
! Try to adjust n_tasks around 1 second per job
n_tasks = min(n_tasks,int( dble(n_tasks) / (time1 - time0 + 1.d-9)))+1
! Try to adjust n_tasks around 5 second per job
n_tasks = min(n_tasks,int( 5.d0 * dble(n_tasks) / (time1 - time0 + 1.d-9)))+1
end do
integer, external :: disconnect_from_taskserver

View File

@ -191,17 +191,26 @@ subroutine get_m1(gen, phasemask, bannedOrb, vect, mask, h, p, sp, coefs)
p2 = p(2, sp)
lbanned(p2) = .true.
double precision :: hij_cache(mo_tot_num,2)
call get_mo_bielec_integrals(hole,p1,p2,mo_tot_num,hij_cache(1,1),mo_integrals_map)
call get_mo_bielec_integrals(hole,p2,p1,mo_tot_num,hij_cache(1,2),mo_integrals_map)
do i=1,hole-1
if(lbanned(i)) cycle
hij = (mo_bielec_integral(i, hole, p1, p2) - mo_bielec_integral(i, hole, p2, p1))
hij *= get_phase_bi(phasemask, sp, sp, i, p1, hole, p2)
vect(1:N_states,i) += hij * coefs(1:N_states)
hij = hij_cache(i,1)-hij_cache(i,2)
if (hij /= 0.d0) then
hij *= get_phase_bi(phasemask, sp, sp, i, p1, hole, p2)
vect(1:N_states,i) += hij * coefs(1:N_states)
endif
end do
do i=hole+1,mo_tot_num
if(lbanned(i)) cycle
hij = (mo_bielec_integral(hole, i, p1, p2) - mo_bielec_integral(hole, i, p2, p1))
hij *= get_phase_bi(phasemask, sp, sp, hole, p1, i, p2)
vect(1:N_states,i) += hij * coefs(1:N_states)
hij = hij_cache(i,2)-hij_cache(i,1)
if (hij /= 0.d0) then
hij *= get_phase_bi(phasemask, sp, sp, hole, p1, i, p2)
vect(1:N_states,i) += hij * coefs(1:N_states)
endif
end do
call apply_particle(mask, sp, p2, det, ok, N_int)
@ -209,11 +218,14 @@ subroutine get_m1(gen, phasemask, bannedOrb, vect, mask, h, p, sp, coefs)
vect(1:N_states, p2) += hij * coefs(1:N_states)
else
p2 = p(1, sh)
call get_mo_bielec_integrals(hole,p1,p2,mo_tot_num,hij_cache(1,1),mo_integrals_map)
do i=1,mo_tot_num
if(lbanned(i)) cycle
hij = mo_bielec_integral(i, hole, p1, p2)
hij *= get_phase_bi(phasemask, sp, sh, i, p1, hole, p2)
vect(1:N_states,i) += hij * coefs(1:N_states)
hij = hij_cache(i,1)
if (hij /= 0.d0) then
hij *= get_phase_bi(phasemask, sp, sh, i, p1, hole, p2)
vect(1:N_states,i) += hij * coefs(1:N_states)
endif
end do
end if
deallocate(lbanned)
@ -707,17 +719,17 @@ subroutine splash_pq(mask, sp, det, i_gen, N_sel, bannedOrb, banned, mat, intere
call bitstring_to_list_in_selection(mobMask(1,1), p(1,1), p(0,1), N_int)
call bitstring_to_list_in_selection(mobMask(1,2), p(1,2), p(0,2), N_int)
perMask(1,1) = iand(mask(1,1), not(det(1,1,i)))
perMask(1,2) = iand(mask(1,2), not(det(1,2,i)))
do j=2,N_int
perMask(j,1) = iand(mask(j,1), not(det(j,1,i)))
perMask(j,2) = iand(mask(j,2), not(det(j,2,i)))
end do
call bitstring_to_list_in_selection(perMask(1,1), h(1,1), h(0,1), N_int)
call bitstring_to_list_in_selection(perMask(1,2), h(1,2), h(0,2), N_int)
if (interesting(i) >= i_gen) then
perMask(1,1) = iand(mask(1,1), not(det(1,1,i)))
perMask(1,2) = iand(mask(1,2), not(det(1,2,i)))
do j=2,N_int
perMask(j,1) = iand(mask(j,1), not(det(j,1,i)))
perMask(j,2) = iand(mask(j,2), not(det(j,2,i)))
end do
call bitstring_to_list_in_selection(perMask(1,1), h(1,1), h(0,1), N_int)
call bitstring_to_list_in_selection(perMask(1,2), h(1,2), h(0,2), N_int)
call get_mask_phase(psi_det_sorted(1,1,interesting(i)), phasemask)
if(nt == 4) then
call get_d2(det(1,1,i), phasemask, bannedOrb, banned, mat, mask, h, p, sp, psi_selectors_coef_transp(1, interesting(i)))
@ -731,6 +743,7 @@ subroutine splash_pq(mask, sp, det, i_gen, N_sel, bannedOrb, banned, mat, intere
if(nt == 3) call past_d1(bannedOrb, p)
end if
end do
end
@ -881,9 +894,11 @@ subroutine get_d1(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs)
integer, parameter :: turn3(2,3) = reshape((/2,3, 1,3, 1,2/), (/2,3/))
integer :: bant
double precision, allocatable :: hij_cache(:,:)
PROVIDE mo_integrals_map
allocate (lbanned(mo_tot_num, 2))
allocate (hij_cache(mo_tot_num,2))
lbanned = bannedOrb
do i=1, p(0,1)
@ -907,11 +922,13 @@ subroutine get_d1(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs)
p1 = p(1,ma)
p2 = p(2,ma)
if(.not. bannedOrb(puti, mi)) then
call get_mo_bielec_integrals(hfix,p1,p2,mo_tot_num,hij_cache(1,1),mo_integrals_map)
call get_mo_bielec_integrals(hfix,p2,p1,mo_tot_num,hij_cache(1,2),mo_integrals_map)
tmp_row = 0d0
do putj=1, hfix-1
if(lbanned(putj, ma)) cycle
if(banned(putj, puti,bant)) cycle
hij = mo_bielec_integral(putj, hfix, p1, p2)-mo_bielec_integral(putj,hfix,p2,p1)
hij = hij_cache(putj,1) - hij_cache(putj,2)
if (hij /= 0.d0) then
hij = hij * get_phase_bi(phasemask, ma, ma, putj, p1, hfix, p2)
tmp_row(1:N_states,putj) += hij * coefs(1:N_states)
@ -920,7 +937,7 @@ subroutine get_d1(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs)
do putj=hfix+1, mo_tot_num
if(lbanned(putj, ma)) cycle
if(banned(putj, puti,bant)) cycle
hij = mo_bielec_integral(putj,hfix,p2, p1)-mo_bielec_integral(putj,hfix,p1,p2)
hij = hij_cache(putj,2) - hij_cache(putj,1)
if (hij /= 0.d0) then
hij = hij * get_phase_bi(phasemask, ma, ma, hfix, p1, putj, p2)
tmp_row(1:N_states,putj) += hij * coefs(1:N_states)
@ -938,12 +955,14 @@ subroutine get_d1(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs)
pfix = p(1,mi)
tmp_row = 0d0
tmp_row2 = 0d0
call get_mo_bielec_integrals(hfix,pfix,p1,mo_tot_num,hij_cache(1,1),mo_integrals_map)
call get_mo_bielec_integrals(hfix,pfix,p2,mo_tot_num,hij_cache(1,2),mo_integrals_map)
do puti=1,mo_tot_num
if(lbanned(puti,mi)) cycle
!p1 fixed
putj = p1
if(.not. banned(putj,puti,bant)) then
hij = mo_bielec_integral(puti,hfix,pfix,p2)
hij = hij_cache(puti,2)
if (hij /= 0.d0) then
hij = hij * get_phase_bi(phasemask, ma, mi, hfix, p2, puti, pfix)
tmp_row(:,puti) += hij * coefs(:)
@ -952,7 +971,7 @@ subroutine get_d1(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs)
putj = p2
if(.not. banned(putj,puti,bant)) then
hij = mo_bielec_integral(puti,hfix,pfix,p1)
hij = hij_cache(puti,1)
if (hij /= 0.d0) then
hij = hij * get_phase_bi(phasemask, ma, mi, hfix, p1, puti, pfix)
tmp_row2(:,puti) += hij * coefs(:)
@ -976,12 +995,13 @@ subroutine get_d1(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs)
puti = p(i, ma)
p1 = p(turn3(1,i), ma)
p2 = p(turn3(2,i), ma)
call get_mo_bielec_integrals(hfix,p1,p2,mo_tot_num,hij_cache(1,1),mo_integrals_map)
call get_mo_bielec_integrals(hfix,p2,p1,mo_tot_num,hij_cache(1,2),mo_integrals_map)
tmp_row = 0d0
do putj=1,hfix-1
if(lbanned(putj,ma)) cycle
! if(banned(puti,putj,1)) cycle
if(banned(putj,puti,1)) cycle
hij = mo_bielec_integral(putj, hfix, p1, p2)-mo_bielec_integral(putj,hfix,p2,p1)
hij = hij_cache(putj,1) - hij_cache(putj,2)
if (hij /= 0.d0) then
hij = hij * get_phase_bi(phasemask, ma, ma, putj, p1, hfix, p2)
tmp_row(:,putj) += hij * coefs(:)
@ -989,9 +1009,8 @@ subroutine get_d1(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs)
end do
do putj=hfix+1,mo_tot_num
if(lbanned(putj,ma)) cycle
! if(banned(puti,putj,1)) cycle
if(banned(putj,puti,1)) cycle
hij = mo_bielec_integral(putj, hfix, p2, p1)-mo_bielec_integral(putj,hfix,p1,p2)
hij = hij_cache(putj,2) - hij_cache(putj,1)
if (hij /= 0.d0) then
hij = hij * get_phase_bi(phasemask, ma, ma, hfix, p1, putj, p2)
tmp_row(:,putj) += hij * coefs(:)
@ -1008,11 +1027,13 @@ subroutine get_d1(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs)
p2 = p(2,ma)
tmp_row = 0d0
tmp_row2 = 0d0
call get_mo_bielec_integrals(hfix,p1,pfix,mo_tot_num,hij_cache(1,1),mo_integrals_map)
call get_mo_bielec_integrals(hfix,p2,pfix,mo_tot_num,hij_cache(1,2),mo_integrals_map)
do puti=1,mo_tot_num
if(lbanned(puti,ma)) cycle
putj = p2
if(.not. banned(puti,putj,1)) then
hij = mo_bielec_integral(puti,hfix, p1,pfix)
hij = hij_cache(puti,1)
if (hij /= 0.d0) then
hij = hij * get_phase_bi(phasemask, mi, ma, hfix, pfix, puti, p1)
tmp_row(:,puti) += hij * coefs(:)
@ -1021,7 +1042,7 @@ subroutine get_d1(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs)
putj = p1
if(.not. banned(puti,putj,1)) then
hij = mo_bielec_integral(puti, hfix, p2, pfix)
hij = hij_cache(puti,2)
if (hij /= 0.d0) then
hij = hij * get_phase_bi(phasemask, mi, ma, hfix, pfix, puti, p2)
tmp_row2(:,puti) += hij * coefs(:)
@ -1034,7 +1055,7 @@ subroutine get_d1(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs)
mat(:,p1,p1:) += tmp_row2(:,p1:)
end if
end if
deallocate(lbanned)
deallocate(lbanned,hij_cache)
!! MONO
if(sp == 3) then
@ -1081,6 +1102,8 @@ subroutine get_d0(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs)
logical :: ok
integer :: bant
double precision, allocatable :: hij_cache(:,:)
allocate (hij_cache(mo_tot_num,2))
bant = 1
@ -1089,6 +1112,7 @@ subroutine get_d0(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs)
h2 = p(1,2)
do p1=1, mo_tot_num
if(bannedOrb(p1, 1)) cycle
call get_mo_bielec_integrals(p1,h2,h1,mo_tot_num,hij_cache(1,1),mo_integrals_map)
do p2=1, mo_tot_num
if(bannedOrb(p2,2)) cycle
if(banned(p1, p2, bant)) cycle ! rentable?
@ -1097,7 +1121,7 @@ subroutine get_d0(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs)
call i_h_j(gen, det, N_int, hij)
else
phase = get_phase_bi(phasemask, 1, 2, h1, p1, h2, p2)
hij = mo_bielec_integral(p2, p1, h2, h1) * phase
hij = hij_cache(p2,1) * phase
end if
mat(:, p1, p2) += coefs(:) * hij
end do
@ -1107,19 +1131,28 @@ subroutine get_d0(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs)
p2 = p(2,sp)
do puti=1, mo_tot_num
if(bannedOrb(puti, sp)) cycle
call get_mo_bielec_integrals(puti,p2,p1,mo_tot_num,hij_cache(1,1),mo_integrals_map)
call get_mo_bielec_integrals(puti,p1,p2,mo_tot_num,hij_cache(1,2),mo_integrals_map)
do putj=puti+1, mo_tot_num
if(bannedOrb(putj, sp)) cycle
if(banned(puti, putj, bant)) cycle ! rentable?
if(puti == p1 .or. putj == p2 .or. puti == p2 .or. putj == p1) then
call apply_particles(mask, sp,puti,sp,putj, det, ok, N_int)
call i_h_j(gen, det, N_int, hij)
if (hij /= 0.d0) then
mat(:, puti, putj) += coefs(:) * hij
endif
else
hij = (mo_bielec_integral(putj, puti, p2, p1) - mo_bielec_integral(putj, puti, p1, p2))* get_phase_bi(phasemask, sp, sp, puti, p1 , putj, p2)
hij = hij_cache(putj,1) - hij_cache(putj,2)
if (hij /= 0.d0) then
hij *= get_phase_bi(phasemask, sp, sp, puti, p1 , putj, p2)
mat(:, puti, putj) += coefs(:) * hij
endif
end if
mat(:, puti, putj) += coefs(:) * hij
end do
end do
end if
deallocate(hij_cache)
end

View File

@ -30,6 +30,7 @@ subroutine run_wf
integer(ZMQ_PTR) :: zmq_to_qp_run_socket
double precision :: energy(N_states)
character*(64) :: states(3)
character*(64) :: old_state
integer :: rc, i, ierr
double precision :: t0, t1
@ -44,14 +45,21 @@ subroutine run_wf
states(1) = 'selection'
states(2) = 'davidson'
states(3) = 'pt2'
old_state = 'Waiting'
zmq_to_qp_run_socket = new_zmq_to_qp_run_socket()
do
call wait_for_states(states,zmq_state,size(states))
if (zmq_state == old_state) then
cycle
else
old_state = zmq_state
endif
print *, trim(zmq_state)
if(zmq_state(1:7) == 'Stopped') then
exit

View File

@ -438,24 +438,46 @@ subroutine get_mo_bielec_integrals(j,k,l,sze,out_val,map)
double precision, intent(out) :: out_val(sze)
type(map_type), intent(inout) :: map
integer :: i
integer(key_kind) :: hash(sze)
real(integral_kind) :: tmp_val(sze)
PROVIDE mo_bielec_integrals_in_map
double precision, external :: get_mo_bielec_integral
PROVIDE mo_bielec_integrals_in_map mo_integrals_cache
integer :: ii, ii0
integer*8 :: ii_8, ii0_8
real(integral_kind) :: tmp
integer(key_kind) :: i1, idx
integer(key_kind) :: p,q,r,s,i2
PROVIDE mo_bielec_integrals_in_map mo_integrals_cache
ii0 = l-mo_integrals_cache_min
ii0 = ior(ii0, k-mo_integrals_cache_min)
ii0 = ior(ii0, j-mo_integrals_cache_min)
ii0_8 = int(l,8)-mo_integrals_cache_min_8
ii0_8 = ior( ishft(ii0_8,7), int(k,8)-mo_integrals_cache_min_8)
ii0_8 = ior( ishft(ii0_8,7), int(j,8)-mo_integrals_cache_min_8)
q = min(j,l)
s = max(j,l)
q = q+ishft(s*s-s,-1)
do i=1,sze
!DIR$ FORCEINLINE
call bielec_integrals_index(i,j,k,l,hash(i))
ii = ior(ii0, i-mo_integrals_cache_min)
if (iand(ii, -128) == 0) then
ii_8 = ior( ishft(ii0_8,7), int(i,8)-mo_integrals_cache_min_8)
out_val(i) = mo_integrals_cache(ii_8)
else
p = min(i,k)
r = max(i,k)
p = p+ishft(r*r-r,-1)
i1 = min(p,q)
i2 = max(p,q)
idx = i1+ishft(i2*i2-i2,-1)
!DIR$ FORCEINLINE
call map_get(map,idx,tmp)
out_val(i) = dble(tmp)
endif
enddo
if (key_kind == 8) then
call map_get_many(map, hash, out_val, sze)
else
call map_get_many(map, hash, tmp_val, sze)
! Conversion to double precision
do i=1,sze
out_val(i) = dble(tmp_val(i))
enddo
endif
end
subroutine get_mo_bielec_integrals_ij(k,l,sze,out_array,map)