10
0
mirror of https://github.com/QuantumPackage/qp2.git synced 2024-06-19 19:52:20 +02:00

Don't recompute 1st Davidson iteration

This commit is contained in:
Anthony Scemama 2019-11-18 13:21:51 +01:00
parent 328672f6be
commit 0f8ea82d68
7 changed files with 104 additions and 36 deletions

View File

@ -1132,7 +1132,7 @@ subroutine get_d1(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs)
call get_mo_two_e_integrals(hfix,pfix,p1,mo_num,hij_cache(1,1),mo_integrals_map)
call get_mo_two_e_integrals(hfix,pfix,p2,mo_num,hij_cache(1,2),mo_integrals_map)
putj = p1
do puti=1,mo_num
do puti=1,mo_num !HOT
if(lbanned(puti,mi)) cycle
!p1 fixed
putj = p1

View File

@ -6,7 +6,7 @@ default: 1.e-10
[n_states_diag]
type: States_number
doc: Number of states to consider during the Davdison diagonalization
doc: Controls the number of states to consider during the Davdison diagonalization. The number of states is n_states * n_states_diag
default: 4
interface: ezfio,ocaml

View File

@ -428,7 +428,7 @@ subroutine H_S2_u_0_nstates_zmq(v_0,s_0,u_0,N_st,sze)
integer :: istep, imin, imax, ishift, ipos
integer, external :: add_task_to_taskserver
integer, parameter :: tasksize=40000
integer, parameter :: tasksize=10000
character*(100000) :: task
istep=1
ishift=0

View File

@ -161,7 +161,7 @@ subroutine davidson_diag_hjj_sjj(dets_in,u_in,H_jj,s2_out,energies,dim_in,sze,N_
stop -1
endif
itermax = max(3,min(davidson_sze_max, sze/N_st_diag))
itermax = max(2,min(davidson_sze_max, sze/N_st_diag))+1
itertot = 0
if (state_following) then
@ -219,7 +219,7 @@ subroutine davidson_diag_hjj_sjj(dets_in,u_in,H_jj,s2_out,energies,dim_in,sze,N_
exit
endif
if (itermax > 4) then
if (itermax > 3) then
itermax = itermax - 1
else if (m==1.and.disk_based_davidson) then
m=0
@ -322,6 +322,12 @@ subroutine davidson_diag_hjj_sjj(dets_in,u_in,H_jj,s2_out,energies,dim_in,sze,N_
call normalize(u_in(1,k),sze)
enddo
do k=1,N_st_diag
do i=1,sze
U(i,k) = u_in(i,k)
enddo
enddo
do while (.not.converged)
itertot = itertot+1
@ -329,30 +335,33 @@ subroutine davidson_diag_hjj_sjj(dets_in,u_in,H_jj,s2_out,energies,dim_in,sze,N_
exit
endif
do k=1,N_st_diag
do i=1,sze
U(i,k) = u_in(i,k)
enddo
enddo
do iter=1,itermax-1
shift = N_st_diag*(iter-1)
shift2 = N_st_diag*iter
call ortho_qr(U,size(U,1),sze,shift2)
call ortho_qr(U,size(U,1),sze,shift2)
if ((iter > 1).or.(itertot == 1)) then
! Compute |W_k> = \sum_i |i><i|H|u_k>
! -----------------------------------
! Compute |W_k> = \sum_i |i><i|H|u_k>
! -----------------------------------------
if (disk_based) then
call ortho_qr_unblocked(U,size(U,1),sze,shift2)
call ortho_qr_unblocked(U,size(U,1),sze,shift2)
else
call ortho_qr(U,size(U,1),sze,shift2)
call ortho_qr(U,size(U,1),sze,shift2)
endif
if ((sze > 100000).and.distributed_davidson) then
call H_S2_u_0_nstates_zmq (W(1,shift+1),S_d,U(1,shift+1),N_st_diag,sze)
if ((sze > 100000).and.distributed_davidson) then
call H_S2_u_0_nstates_zmq (W(1,shift+1),S_d,U(1,shift+1),N_st_diag,sze)
else
call H_S2_u_0_nstates_openmp(W(1,shift+1),S_d,U(1,shift+1),N_st_diag,sze)
endif
S(1:sze,shift+1:shift+N_st_diag) = real(S_d(1:sze,1:N_st_diag))
else
call H_S2_u_0_nstates_openmp(W(1,shift+1),S_d,U(1,shift+1),N_st_diag,sze)
! Already computed in update below
continue
endif
S(1:sze,shift+1:shift+N_st_diag) = real(S_d(1:sze,1:N_st_diag))
if (dressing_state > 0) then
@ -579,7 +588,12 @@ subroutine davidson_diag_hjj_sjj(dets_in,u_in,H_jj,s2_out,energies,dim_in,sze,N_
enddo
write(*,'(1X,I3,1X,100(1X,F16.10,1X,F11.6,1X,E11.3))') iter, to_print(1:3,1:N_st)
if ((itertot>1).and.(iter == 1)) then
!don't print
continue
else
write(*,'(1X,I3,1X,100(1X,F16.10,1X,F11.6,1X,E11.3))') iter-1, to_print(1:3,1:N_st)
endif
call davidson_converged(lambda,residual_norm,wall,iter,cpu,N_st,converged)
do k=1,N_st
if (residual_norm(k) > 1.e8) then
@ -600,11 +614,56 @@ subroutine davidson_diag_hjj_sjj(dets_in,u_in,H_jj,s2_out,energies,dim_in,sze,N_
enddo
! Re-contract to u_in
! -----------
! Re-contract U and update S and W
! --------------------------------
call sgemm('N','N', sze, N_st_diag, shift2, 1., &
S, size(S,1), y_s, size(y_s,1), 0., S(1,shift2+1), size(S,1))
do k=1,N_st_diag
do i=1,sze
S(i,k) = S(i,shift2+k)
enddo
enddo
call dgemm('N','N', sze, N_st_diag, shift2, 1.d0, &
W, size(W,1), y, size(y,1), 0.d0, u_in, size(u_in,1))
do k=1,N_st_diag
do i=1,sze
W(i,k) = u_in(i,k)
enddo
enddo
call dgemm('N','N', sze, N_st_diag, shift2, 1.d0, &
U, size(U,1), y, size(y,1), 0.d0, u_in, size(u_in,1))
do k=1,N_st_diag
do i=1,sze
U(i,k) = u_in(i,k)
enddo
enddo
if (disk_based) then
call ortho_qr_unblocked(U,size(U,1),sze,N_st_diag)
call ortho_qr_unblocked(U,size(U,1),sze,N_st_diag)
else
call ortho_qr(U,size(U,1),sze,N_st_diag)
call ortho_qr(U,size(U,1),sze,N_st_diag)
endif
do j=1,N_st_diag
k=1
do while ((k<sze).and.(U(k,j) == 0.d0))
k = k+1
enddo
if (U(k,j) * u_in(k,j) < 0.d0) then
do i=1,sze
W(i,j) = -W(i,j)
S(i,j) = -S(i,j)
enddo
endif
enddo
do j=1,N_st_diag
do i=1,sze
S_d(i,j) = dble(S(i,j))
enddo
enddo
enddo
@ -626,7 +685,7 @@ subroutine davidson_diag_hjj_sjj(dets_in,u_in,H_jj,s2_out,energies,dim_in,sze,N_
call munmap( (/int(sze,8),int(N_st_diag*itermax,8)/), 8, fd_w, ptr_w )
fd_w = getUnitAndOpen(trim(ezfio_work_dir)//'davidson_w','r')
close(fd_w,status='delete')
call munmap( (/int(sze,8),int(N_st_diag*itermax,8)/), 8, fd_s, ptr_s )
call munmap( (/int(sze,8),int(N_st_diag*itermax,8)/), 4, fd_s, ptr_s )
fd_s = getUnitAndOpen(trim(ezfio_work_dir)//'davidson_s','r')
close(fd_s,status='delete')
else

View File

@ -15,7 +15,7 @@ BEGIN_PROVIDER [ integer, n_states_diag ]
print *, 'davidson/n_states_diag not found in EZFIO file'
stop 1
endif
n_states_diag = max(N_states, N_states_diag)
n_states_diag = max(2,N_states * N_states_diag)
endif
IRP_IF MPI_DEBUG
print *, irp_here, mpi_rank

View File

@ -138,18 +138,27 @@ subroutine ortho_qr(A,LDA,m,n)
double precision, intent(inout) :: A(LDA,n)
integer :: lwork, info
integer, allocatable :: jpvt(:)
double precision, allocatable :: tau(:), work(:)
double precision, allocatable :: TAU(:), WORK(:)
allocate (TAU(n), WORK(1))
allocate (jpvt(n), tau(n), work(1))
LWORK=-1
call dgeqrf( m, n, A, LDA, TAU, WORK, LWORK, INFO )
LWORK=2*int(WORK(1))
LWORK=int(WORK(1))
deallocate(WORK)
allocate(WORK(LWORK))
call dgeqrf(m, n, A, LDA, TAU, WORK, LWORK, INFO )
call dorgqr(m, n, n, A, LDA, tau, WORK, LWORK, INFO)
deallocate(WORK,jpvt,tau)
LWORK=-1
call dorgqr(m, n, n, A, LDA, TAU, WORK, LWORK, INFO)
LWORK=int(WORK(1))
deallocate(WORK)
allocate(WORK(LWORK))
call dorgqr(m, n, n, A, LDA, TAU, WORK, LWORK, INFO)
deallocate(WORK,TAU)
end
subroutine ortho_qr_unblocked(A,LDA,m,n)
@ -170,13 +179,12 @@ subroutine ortho_qr_unblocked(A,LDA,m,n)
double precision, intent(inout) :: A(LDA,n)
integer :: info
integer, allocatable :: jpvt(:)
double precision, allocatable :: tau(:), work(:)
double precision, allocatable :: TAU(:), WORK(:)
allocate (jpvt(n), tau(n), work(n))
allocate (TAU(n), WORK(n))
call dgeqr2( m, n, A, LDA, TAU, WORK, INFO )
call dorg2r(m, n, n, A, LDA, tau, WORK, INFO)
deallocate(WORK,jpvt,tau)
call dorg2r(m, n, n, A, LDA, TAU, WORK, INFO)
deallocate(WORK,TAU)
end
subroutine ortho_lowdin(overlap,LDA,N,C,LDC,m)

View File

@ -99,6 +99,7 @@ subroutine check_mem(rss_in,routine)
rss += rss_in
if (int(rss)+1 > qp_max_mem) then
print *, 'Not enough memory: aborting in ', routine
print *, int(rss)+1, ' GB required'
stop -1
endif
!$OMP END CRITICAL