10
0
mirror of https://github.com/QuantumPackage/qp2.git synced 2024-06-19 19:52:20 +02:00

Don't recompute 1st Davidson iteration

This commit is contained in:
Anthony Scemama 2019-11-18 13:21:51 +01:00
parent 328672f6be
commit 0f8ea82d68
7 changed files with 104 additions and 36 deletions

View File

@ -1132,7 +1132,7 @@ subroutine get_d1(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs)
call get_mo_two_e_integrals(hfix,pfix,p1,mo_num,hij_cache(1,1),mo_integrals_map) call get_mo_two_e_integrals(hfix,pfix,p1,mo_num,hij_cache(1,1),mo_integrals_map)
call get_mo_two_e_integrals(hfix,pfix,p2,mo_num,hij_cache(1,2),mo_integrals_map) call get_mo_two_e_integrals(hfix,pfix,p2,mo_num,hij_cache(1,2),mo_integrals_map)
putj = p1 putj = p1
do puti=1,mo_num do puti=1,mo_num !HOT
if(lbanned(puti,mi)) cycle if(lbanned(puti,mi)) cycle
!p1 fixed !p1 fixed
putj = p1 putj = p1

View File

@ -6,7 +6,7 @@ default: 1.e-10
[n_states_diag] [n_states_diag]
type: States_number type: States_number
doc: Number of states to consider during the Davdison diagonalization doc: Controls the number of states to consider during the Davdison diagonalization. The number of states is n_states * n_states_diag
default: 4 default: 4
interface: ezfio,ocaml interface: ezfio,ocaml

View File

@ -428,7 +428,7 @@ subroutine H_S2_u_0_nstates_zmq(v_0,s_0,u_0,N_st,sze)
integer :: istep, imin, imax, ishift, ipos integer :: istep, imin, imax, ishift, ipos
integer, external :: add_task_to_taskserver integer, external :: add_task_to_taskserver
integer, parameter :: tasksize=40000 integer, parameter :: tasksize=10000
character*(100000) :: task character*(100000) :: task
istep=1 istep=1
ishift=0 ishift=0

View File

@ -161,7 +161,7 @@ subroutine davidson_diag_hjj_sjj(dets_in,u_in,H_jj,s2_out,energies,dim_in,sze,N_
stop -1 stop -1
endif endif
itermax = max(3,min(davidson_sze_max, sze/N_st_diag)) itermax = max(2,min(davidson_sze_max, sze/N_st_diag))+1
itertot = 0 itertot = 0
if (state_following) then if (state_following) then
@ -219,7 +219,7 @@ subroutine davidson_diag_hjj_sjj(dets_in,u_in,H_jj,s2_out,energies,dim_in,sze,N_
exit exit
endif endif
if (itermax > 4) then if (itermax > 3) then
itermax = itermax - 1 itermax = itermax - 1
else if (m==1.and.disk_based_davidson) then else if (m==1.and.disk_based_davidson) then
m=0 m=0
@ -322,6 +322,12 @@ subroutine davidson_diag_hjj_sjj(dets_in,u_in,H_jj,s2_out,energies,dim_in,sze,N_
call normalize(u_in(1,k),sze) call normalize(u_in(1,k),sze)
enddo enddo
do k=1,N_st_diag
do i=1,sze
U(i,k) = u_in(i,k)
enddo
enddo
do while (.not.converged) do while (.not.converged)
itertot = itertot+1 itertot = itertot+1
@ -329,30 +335,33 @@ subroutine davidson_diag_hjj_sjj(dets_in,u_in,H_jj,s2_out,energies,dim_in,sze,N_
exit exit
endif endif
do k=1,N_st_diag
do i=1,sze
U(i,k) = u_in(i,k)
enddo
enddo
do iter=1,itermax-1 do iter=1,itermax-1
shift = N_st_diag*(iter-1) shift = N_st_diag*(iter-1)
shift2 = N_st_diag*iter shift2 = N_st_diag*iter
call ortho_qr(U,size(U,1),sze,shift2) if ((iter > 1).or.(itertot == 1)) then
call ortho_qr(U,size(U,1),sze,shift2) ! Compute |W_k> = \sum_i |i><i|H|u_k>
! -----------------------------------
! Compute |W_k> = \sum_i |i><i|H|u_k> if (disk_based) then
! ----------------------------------------- call ortho_qr_unblocked(U,size(U,1),sze,shift2)
call ortho_qr_unblocked(U,size(U,1),sze,shift2)
else
call ortho_qr(U,size(U,1),sze,shift2)
call ortho_qr(U,size(U,1),sze,shift2)
endif
if ((sze > 100000).and.distributed_davidson) then
if ((sze > 100000).and.distributed_davidson) then call H_S2_u_0_nstates_zmq (W(1,shift+1),S_d,U(1,shift+1),N_st_diag,sze)
call H_S2_u_0_nstates_zmq (W(1,shift+1),S_d,U(1,shift+1),N_st_diag,sze) else
call H_S2_u_0_nstates_openmp(W(1,shift+1),S_d,U(1,shift+1),N_st_diag,sze)
endif
S(1:sze,shift+1:shift+N_st_diag) = real(S_d(1:sze,1:N_st_diag))
else else
call H_S2_u_0_nstates_openmp(W(1,shift+1),S_d,U(1,shift+1),N_st_diag,sze) ! Already computed in update below
continue
endif endif
S(1:sze,shift+1:shift+N_st_diag) = real(S_d(1:sze,1:N_st_diag))
if (dressing_state > 0) then if (dressing_state > 0) then
@ -579,7 +588,12 @@ subroutine davidson_diag_hjj_sjj(dets_in,u_in,H_jj,s2_out,energies,dim_in,sze,N_
enddo enddo
write(*,'(1X,I3,1X,100(1X,F16.10,1X,F11.6,1X,E11.3))') iter, to_print(1:3,1:N_st) if ((itertot>1).and.(iter == 1)) then
!don't print
continue
else
write(*,'(1X,I3,1X,100(1X,F16.10,1X,F11.6,1X,E11.3))') iter-1, to_print(1:3,1:N_st)
endif
call davidson_converged(lambda,residual_norm,wall,iter,cpu,N_st,converged) call davidson_converged(lambda,residual_norm,wall,iter,cpu,N_st,converged)
do k=1,N_st do k=1,N_st
if (residual_norm(k) > 1.e8) then if (residual_norm(k) > 1.e8) then
@ -600,11 +614,56 @@ subroutine davidson_diag_hjj_sjj(dets_in,u_in,H_jj,s2_out,energies,dim_in,sze,N_
enddo enddo
! Re-contract to u_in ! Re-contract U and update S and W
! ----------- ! --------------------------------
call sgemm('N','N', sze, N_st_diag, shift2, 1., &
S, size(S,1), y_s, size(y_s,1), 0., S(1,shift2+1), size(S,1))
do k=1,N_st_diag
do i=1,sze
S(i,k) = S(i,shift2+k)
enddo
enddo
call dgemm('N','N', sze, N_st_diag, shift2, 1.d0, &
W, size(W,1), y, size(y,1), 0.d0, u_in, size(u_in,1))
do k=1,N_st_diag
do i=1,sze
W(i,k) = u_in(i,k)
enddo
enddo
call dgemm('N','N', sze, N_st_diag, shift2, 1.d0, & call dgemm('N','N', sze, N_st_diag, shift2, 1.d0, &
U, size(U,1), y, size(y,1), 0.d0, u_in, size(u_in,1)) U, size(U,1), y, size(y,1), 0.d0, u_in, size(u_in,1))
do k=1,N_st_diag
do i=1,sze
U(i,k) = u_in(i,k)
enddo
enddo
if (disk_based) then
call ortho_qr_unblocked(U,size(U,1),sze,N_st_diag)
call ortho_qr_unblocked(U,size(U,1),sze,N_st_diag)
else
call ortho_qr(U,size(U,1),sze,N_st_diag)
call ortho_qr(U,size(U,1),sze,N_st_diag)
endif
do j=1,N_st_diag
k=1
do while ((k<sze).and.(U(k,j) == 0.d0))
k = k+1
enddo
if (U(k,j) * u_in(k,j) < 0.d0) then
do i=1,sze
W(i,j) = -W(i,j)
S(i,j) = -S(i,j)
enddo
endif
enddo
do j=1,N_st_diag
do i=1,sze
S_d(i,j) = dble(S(i,j))
enddo
enddo
enddo enddo
@ -626,7 +685,7 @@ subroutine davidson_diag_hjj_sjj(dets_in,u_in,H_jj,s2_out,energies,dim_in,sze,N_
call munmap( (/int(sze,8),int(N_st_diag*itermax,8)/), 8, fd_w, ptr_w ) call munmap( (/int(sze,8),int(N_st_diag*itermax,8)/), 8, fd_w, ptr_w )
fd_w = getUnitAndOpen(trim(ezfio_work_dir)//'davidson_w','r') fd_w = getUnitAndOpen(trim(ezfio_work_dir)//'davidson_w','r')
close(fd_w,status='delete') close(fd_w,status='delete')
call munmap( (/int(sze,8),int(N_st_diag*itermax,8)/), 8, fd_s, ptr_s ) call munmap( (/int(sze,8),int(N_st_diag*itermax,8)/), 4, fd_s, ptr_s )
fd_s = getUnitAndOpen(trim(ezfio_work_dir)//'davidson_s','r') fd_s = getUnitAndOpen(trim(ezfio_work_dir)//'davidson_s','r')
close(fd_s,status='delete') close(fd_s,status='delete')
else else

View File

@ -15,7 +15,7 @@ BEGIN_PROVIDER [ integer, n_states_diag ]
print *, 'davidson/n_states_diag not found in EZFIO file' print *, 'davidson/n_states_diag not found in EZFIO file'
stop 1 stop 1
endif endif
n_states_diag = max(N_states, N_states_diag) n_states_diag = max(2,N_states * N_states_diag)
endif endif
IRP_IF MPI_DEBUG IRP_IF MPI_DEBUG
print *, irp_here, mpi_rank print *, irp_here, mpi_rank

View File

@ -138,18 +138,27 @@ subroutine ortho_qr(A,LDA,m,n)
double precision, intent(inout) :: A(LDA,n) double precision, intent(inout) :: A(LDA,n)
integer :: lwork, info integer :: lwork, info
integer, allocatable :: jpvt(:) double precision, allocatable :: TAU(:), WORK(:)
double precision, allocatable :: tau(:), work(:)
allocate (TAU(n), WORK(1))
allocate (jpvt(n), tau(n), work(1))
LWORK=-1 LWORK=-1
call dgeqrf( m, n, A, LDA, TAU, WORK, LWORK, INFO ) call dgeqrf( m, n, A, LDA, TAU, WORK, LWORK, INFO )
LWORK=2*int(WORK(1)) LWORK=int(WORK(1))
deallocate(WORK) deallocate(WORK)
allocate(WORK(LWORK)) allocate(WORK(LWORK))
call dgeqrf(m, n, A, LDA, TAU, WORK, LWORK, INFO ) call dgeqrf(m, n, A, LDA, TAU, WORK, LWORK, INFO )
call dorgqr(m, n, n, A, LDA, tau, WORK, LWORK, INFO)
deallocate(WORK,jpvt,tau) LWORK=-1
call dorgqr(m, n, n, A, LDA, TAU, WORK, LWORK, INFO)
LWORK=int(WORK(1))
deallocate(WORK)
allocate(WORK(LWORK))
call dorgqr(m, n, n, A, LDA, TAU, WORK, LWORK, INFO)
deallocate(WORK,TAU)
end end
subroutine ortho_qr_unblocked(A,LDA,m,n) subroutine ortho_qr_unblocked(A,LDA,m,n)
@ -170,13 +179,12 @@ subroutine ortho_qr_unblocked(A,LDA,m,n)
double precision, intent(inout) :: A(LDA,n) double precision, intent(inout) :: A(LDA,n)
integer :: info integer :: info
integer, allocatable :: jpvt(:) double precision, allocatable :: TAU(:), WORK(:)
double precision, allocatable :: tau(:), work(:)
allocate (jpvt(n), tau(n), work(n)) allocate (TAU(n), WORK(n))
call dgeqr2( m, n, A, LDA, TAU, WORK, INFO ) call dgeqr2( m, n, A, LDA, TAU, WORK, INFO )
call dorg2r(m, n, n, A, LDA, tau, WORK, INFO) call dorg2r(m, n, n, A, LDA, TAU, WORK, INFO)
deallocate(WORK,jpvt,tau) deallocate(WORK,TAU)
end end
subroutine ortho_lowdin(overlap,LDA,N,C,LDC,m) subroutine ortho_lowdin(overlap,LDA,N,C,LDC,m)

View File

@ -99,6 +99,7 @@ subroutine check_mem(rss_in,routine)
rss += rss_in rss += rss_in
if (int(rss)+1 > qp_max_mem) then if (int(rss)+1 > qp_max_mem) then
print *, 'Not enough memory: aborting in ', routine print *, 'Not enough memory: aborting in ', routine
print *, int(rss)+1, ' GB required'
stop -1 stop -1
endif endif
!$OMP END CRITICAL !$OMP END CRITICAL