10
0
mirror of https://github.com/LCPQ/quantum_package synced 2024-12-25 05:43:47 +01:00

Check for duplicates in parallel

This commit is contained in:
Anthony Scemama 2018-10-29 16:12:37 +01:00
parent 9a70059a11
commit 8274044a7c
5 changed files with 46 additions and 27 deletions

View File

@ -192,8 +192,8 @@ subroutine copy_H_apply_buffer_to_wf
call normalize(psi_coef,N_det)
SOFT_TOUCH N_det psi_det psi_coef
! logical :: found_duplicates
! call remove_duplicates_in_psi_det(found_duplicates)
logical :: found_duplicates
call remove_duplicates_in_psi_det(found_duplicates)
end
subroutine remove_duplicates_in_psi_det(found_duplicates)
@ -205,16 +205,24 @@ subroutine remove_duplicates_in_psi_det(found_duplicates)
integer :: i,j,k
integer(bit_kind), allocatable :: bit_tmp(:)
logical,allocatable :: duplicate(:)
logical :: dup
allocate (duplicate(N_det), bit_tmp(N_det))
found_duplicates = .False.
!$OMP PARALLEL DEFAULT(SHARED) PRIVATE(i,j,k,dup)
!$OMP DO
do i=1,N_det
integer, external :: det_search_key
!$DIR FORCEINLINE
bit_tmp(i) = det_search_key(psi_det_sorted_bit(1,1,i),N_int)
duplicate(i) = .False.
enddo
!$OMP END DO
!$OMP DO
do i=1,N_det-1
if (duplicate(i)) then
cycle
@ -229,28 +237,26 @@ subroutine remove_duplicates_in_psi_det(found_duplicates)
cycle
endif
endif
duplicate(j) = .True.
dup = .True.
do k=1,N_int
if ( (psi_det_sorted_bit(k,1,i) /= psi_det_sorted_bit(k,1,j) ) &
.or. (psi_det_sorted_bit(k,2,i) /= psi_det_sorted_bit(k,2,j) ) ) then
duplicate(j) = .False.
dup = .False.
exit
endif
enddo
if (dup) then
duplicate(j) = .True.
found_duplicates = .True.
endif
j += 1
if (j > N_det) then
exit
endif
enddo
enddo
found_duplicates = .False.
do i=1,N_det
if (duplicate(i)) then
found_duplicates = .True.
exit
endif
enddo
!$OMP END DO
!$OMP END PARALLEL
if (found_duplicates) then
k=0
@ -259,14 +265,16 @@ subroutine remove_duplicates_in_psi_det(found_duplicates)
k += 1
psi_det(:,:,k) = psi_det_sorted_bit (:,:,i)
psi_coef(k,:) = psi_coef_sorted_bit(i,:)
else
call debug_det(psi_det_sorted_bit(1,1,i),N_int)
stop 'duplicates in psi_det'
! else
! call debug_det(psi_det_sorted_bit(1,1,i),N_int)
! stop 'duplicates in psi_det'
endif
enddo
N_det = k
call write_bool(6,found_duplicates,'Found duplicate determinants')
SOFT_TOUCH N_det psi_det psi_coef
psi_det_sorted_bit(:,:,1:N_det) = psi_det(:,:,1:N_det)
psi_coef_sorted_bit(1:N_det,:) = psi_coef(1:N_det,:)
SOFT_TOUCH N_det psi_det psi_coef psi_det_sorted_bit psi_coef_sorted_bit
endif
deallocate (duplicate,bit_tmp)
end

View File

@ -167,17 +167,22 @@ end
integer*8, external :: occ_pattern_search_key
integer(bit_kind), allocatable :: tmp_array(:,:,:)
logical,allocatable :: duplicate(:)
logical :: dup
allocate ( iorder(N_det), duplicate(N_det), bit_tmp(N_det), tmp_array(N_int,2,N_det) )
do i=1,N_det
iorder(i) = i
!$DIR FORCEINLINE
bit_tmp(i) = occ_pattern_search_key(psi_occ_pattern(1,1,i),N_int)
enddo
call i8sort(bit_tmp,iorder,N_det)
!DIR$ IVDEP
!$OMP PARALLEL DEFAULT(shared) PRIVATE(i,j,k,dup)
!$OMP DO
do i=1,N_det
do k=1,N_int
tmp_array(k,1,i) = psi_occ_pattern(k,1,iorder(i))
@ -185,8 +190,10 @@ end
enddo
duplicate(i) = .False.
enddo
!$OMP END DO
! Find duplicates
!$OMP DO
do i=1,N_det-1
if (duplicate(i)) then
cycle
@ -200,20 +207,25 @@ end
endif
cycle
endif
duplicate(j) = .True.
dup = .True.
do k=1,N_int
if ( (tmp_array(k,1,i) /= tmp_array(k,1,j)) &
.or. (tmp_array(k,2,i) /= tmp_array(k,2,j)) ) then
duplicate(j) = .False.
dup = .False.
exit
endif
enddo
if (dup) then
duplicate(j) = .True.
endif
j+=1
if (j>N_det) then
exit
endif
enddo
enddo
!$OMP END DO
!$OMP END PARALLEL
! Copy filtered result
N_occ_pattern=0
@ -229,6 +241,7 @@ end
enddo
!- Check
! print *, 'Checking for duplicates in occ pattern'
! do i=1,N_occ_pattern
! do j=i+1,N_occ_pattern
! duplicate(1) = .True.
@ -249,6 +262,7 @@ end
! endif
! enddo
! enddo
! print *, 'No duplicates'
!-
deallocate(iorder,duplicate,bit_tmp,tmp_array)
@ -354,7 +368,7 @@ subroutine make_s2_eigenfunction
!$OMP END PARALLEL
call copy_H_apply_buffer_to_wf
SOFT_TOUCH N_det psi_coef psi_det
SOFT_TOUCH N_det psi_coef psi_det psi_occ_pattern N_occ_pattern
print *, 'Added determinants for S^2'
call write_time(6)

View File

@ -696,12 +696,10 @@ subroutine fill_buffer_double(i_generator, sp, h1, h2, bannedOrb, banned, fock_d
e_pert = 0.5d0 * (tmp - delta_E)
coef = alpha_h_psi / delta_E
pt2(istate) = pt2(istate) + e_pert
variance(istate) = variance(istate) + alpha_h_psi * alpha_h_psi
norm(istate) = norm(istate) + coef * coef
sum_e_pert = sum_e_pert + e_pert * state_average_weight(istate)
variance(istate) = variance(istate) + alpha_h_psi * alpha_h_psi * state_average_weight(istate)
norm(istate) = norm(istate) + coef * coef * state_average_weight(istate)
end do
if(sum_e_pert <= buf%mini) then

View File

@ -133,7 +133,6 @@ subroutine ZMQ_selection(N_in, pt2, variance, norm)
variance(k) = variance(k) * f(k)
norm(k) = norm(k) * f(k)
enddo
! variance = variance - pt2*pt2
end subroutine

View File

@ -56,7 +56,7 @@ subroutine print_summary(e_,pt2_,error_,variance_,norm_)
do k=1, N_states_p
print*,'State ',k
print *, 'Variance = ', variance_(k)
print *, 'PT norm = ', norm_(k)
print *, 'PT norm = ', dsqrt(norm_(k))
print *, 'PT2 = ', pt2_(k)
print *, 'E = ', e_(k)
print *, 'E+PT2'//pt2_string//' = ', e_(k)+pt2_(k), ' +/- ', error_(k)