10
0
mirror of https://github.com/QuantumPackage/qp2.git synced 2025-01-09 12:44:05 +01:00

Merge branch 'dev' into features_pt2

This commit is contained in:
Anthony Scemama 2020-11-07 14:35:10 +01:00
commit b0f85476fe
10 changed files with 292 additions and 37 deletions

View File

@ -4,8 +4,8 @@ type units =
| Angstrom | Angstrom
;; ;;
let angstrom_to_bohr = 1. /. 0.52917721092 let angstrom_to_bohr = 1. /. 0.52917721067121
let bohr_to_angstrom = 0.52917721092 let bohr_to_angstrom = 0.52917721067121
;; ;;

View File

@ -436,7 +436,7 @@ BEGIN_PROVIDER [ double precision, ao_two_e_integral_schwartz,(ao_num,ao_num) ]
!$OMP SCHEDULE(dynamic) !$OMP SCHEDULE(dynamic)
do i=1,ao_num do i=1,ao_num
do k=1,i do k=1,i
ao_two_e_integral_schwartz(i,k) = dsqrt(ao_two_e_integral(i,k,i,k)) ao_two_e_integral_schwartz(i,k) = dsqrt(ao_two_e_integral(i,i,k,k))
ao_two_e_integral_schwartz(k,i) = ao_two_e_integral_schwartz(i,k) ao_two_e_integral_schwartz(k,i) = ao_two_e_integral_schwartz(i,k)
enddo enddo
enddo enddo

View File

@ -287,9 +287,9 @@ subroutine ZMQ_pt2(E, pt2_data, pt2_data_err, relative_error, N_in)
call omp_set_nested(.false.) call omp_set_nested(.false.)
print '(A)', '========== ================= =========== =============== =============== =================' print '(A)', '========== ======================= ===================== ===================== ==========='
print '(A)', ' Samples Energy Stat. Err Variance Norm^2 Seconds ' print '(A)', ' Samples Energy Variance Norm^2 Seconds'
print '(A)', '========== ================= =========== =============== =============== =================' print '(A)', '========== ======================= ===================== ===================== ==========='
PROVIDE global_selection_buffer PROVIDE global_selection_buffer
@ -312,26 +312,30 @@ subroutine ZMQ_pt2(E, pt2_data, pt2_data_err, relative_error, N_in)
!$OMP END PARALLEL !$OMP END PARALLEL
call end_parallel_job(zmq_to_qp_run_socket, zmq_socket_pull, 'pt2') call end_parallel_job(zmq_to_qp_run_socket, zmq_socket_pull, 'pt2')
print '(A)', '========== ================= =========== =============== =============== =================' print '(A)', '========== ======================= ===================== ===================== ==========='
do k=1,N_states do k=1,N_states
pt2_overlap(pt2_stoch_istate,k) = pt2_data % overlap(k,pt2_stoch_istate) pt2_overlap(pt2_stoch_istate,k) = pt2_data % overlap(k,pt2_stoch_istate)
enddo enddo
! ! The overlap is not exactly zero because of the guiding function.
! ! Remove the bias
! do k=1,pt2_stoch_istate-1
! pt2_overlap(k,pt2_stoch_istate) -= pt2_data % overlap(k,pt2_stoch_istate)
! enddo
print *, 'Overlap before orthogonalization'
print *, pt2_overlap(:,pt2_stoch_istate)
print *, 'Overlap after orthogonalization'
print *, pt2_overlap(pt2_stoch_istate,:)
print *, '-------'
SOFT_TOUCH pt2_overlap SOFT_TOUCH pt2_overlap
enddo enddo
FREE pt2_stoch_istate FREE pt2_stoch_istate
! Symmetrize overlap
do j=2,N_states
do i=1,j-1
pt2_overlap(i,j) = 0.5d0 * (pt2_overlap(i,j) + pt2_overlap(j,i))
pt2_overlap(j,i) = pt2_overlap(i,j)
enddo
enddo
print *, 'Overlap of perturbed states:'
do k=1,N_states
print *, pt2_overlap(k,:)
enddo
print *, '-------'
if (N_in > 0) then if (N_in > 0) then
b%cur = min(N_in,b%cur) b%cur = min(N_in,b%cur)
if (s2_eig) then if (s2_eig) then
@ -529,7 +533,14 @@ subroutine pt2_collector(zmq_socket_pull, E, relative_error, pt2_data, pt2_data_
if ((time - time1 > 1.d0) .or. (n==N_det_generators)) then if ((time - time1 > 1.d0) .or. (n==N_det_generators)) then
time1 = time time1 = time
print '(G10.3, 2X, F16.10, 2X, G10.3, 2X, G14.6, 2X, G14.6, 2X, F10.4)', c, avg+E, eqt, avg2, avg3(pt2_stoch_istate), time-time0 print '(I10, X, F12.6, X, G10.3, X, F10.6, X, G10.3, X, F10.6, X, G10.3, X, F10.4)', c, &
pt2_data % pt2(pt2_stoch_istate) +E, &
pt2_data_err % pt2(pt2_stoch_istate), &
pt2_data % variance(pt2_stoch_istate), &
pt2_data_err % variance(pt2_stoch_istate), &
pt2_data % overlap(pt2_stoch_istate,pt2_stoch_istate), &
pt2_data_err % overlap(pt2_stoch_istate,pt2_stoch_istate), &
time-time0
if (stop_now .or. ( & if (stop_now .or. ( &
(do_exit .and. (dabs(pt2_data_err % pt2(pt2_stoch_istate)) / & (do_exit .and. (dabs(pt2_data_err % pt2(pt2_stoch_istate)) / &
(1.d-20 + dabs(pt2_data % pt2(pt2_stoch_istate)) ) <= relative_error))) ) then (1.d-20 + dabs(pt2_data % pt2(pt2_stoch_istate)) ) <= relative_error))) ) then

View File

@ -1,4 +1,3 @@
use bitmasks use bitmasks
BEGIN_PROVIDER [ double precision, pt2_match_weight, (N_states) ] BEGIN_PROVIDER [ double precision, pt2_match_weight, (N_states) ]
@ -144,7 +143,6 @@ BEGIN_PROVIDER [ double precision, selection_weight, (N_states) ]
END_PROVIDER END_PROVIDER
subroutine get_mask_phase(det1, pm, Nint) subroutine get_mask_phase(det1, pm, Nint)
use bitmasks use bitmasks
implicit none implicit none
@ -288,6 +286,7 @@ subroutine select_singles_and_doubles(i_generator,hole_mask,particle_mask,fock_d
PROVIDE psi_bilinear_matrix_rows psi_det_sorted_order psi_bilinear_matrix_order PROVIDE psi_bilinear_matrix_rows psi_det_sorted_order psi_bilinear_matrix_order
PROVIDE psi_bilinear_matrix_transp_rows_loc psi_bilinear_matrix_transp_columns PROVIDE psi_bilinear_matrix_transp_rows_loc psi_bilinear_matrix_transp_columns
PROVIDE psi_bilinear_matrix_transp_order psi_selectors_coef_transp PROVIDE psi_bilinear_matrix_transp_order psi_selectors_coef_transp
PROVIDE banned_excitation
monoAdo = .true. monoAdo = .true.
monoBdo = .true. monoBdo = .true.
@ -607,7 +606,8 @@ subroutine select_singles_and_doubles(i_generator,hole_mask,particle_mask,fock_d
h2 = hole_list(i2,s2) h2 = hole_list(i2,s2)
call apply_hole(pmask, s2,h2, mask, ok, N_int) call apply_hole(pmask, s2,h2, mask, ok, N_int)
banned = .false. banned(:,:,1) = banned_excitation(:,:)
banned(:,:,2) = banned_excitation(:,:)
do j=1,mo_num do j=1,mo_num
bannedOrb(j, 1) = .true. bannedOrb(j, 1) = .true.
bannedOrb(j, 2) = .true. bannedOrb(j, 2) = .true.
@ -674,7 +674,8 @@ subroutine fill_buffer_double(i_generator, sp, h1, h2, bannedOrb, banned, fock_d
logical :: ok logical :: ok
integer :: s1, s2, p1, p2, ib, j, istate, jstate integer :: s1, s2, p1, p2, ib, j, istate, jstate
integer(bit_kind) :: mask(N_int, 2), det(N_int, 2) integer(bit_kind) :: mask(N_int, 2), det(N_int, 2)
double precision :: e_pert, delta_E, val, Hii, w, tmp, alpha_h_psi, coef(N_states) double precision :: e_pert(N_states), coef(N_states)
double precision :: delta_E, val, Hii, w, tmp, alpha_h_psi
double precision, external :: diag_H_mat_elem_fock double precision, external :: diag_H_mat_elem_fock
double precision :: E_shift double precision :: E_shift
@ -682,7 +683,12 @@ subroutine fill_buffer_double(i_generator, sp, h1, h2, bannedOrb, banned, fock_d
double precision, allocatable :: values(:) double precision, allocatable :: values(:)
integer, allocatable :: keys(:,:) integer, allocatable :: keys(:,:)
integer :: nkeys integer :: nkeys
double precision :: s_weight(N_states,N_states)
do jstate=1,N_states
do istate=1,N_states
s_weight(istate,jstate) = dsqrt(selection_weight(istate)*selection_weight(jstate))
enddo
enddo
if(sp == 3) then if(sp == 3) then
s1 = 1 s1 = 1
@ -763,23 +769,59 @@ subroutine fill_buffer_double(i_generator, sp, h1, h2, bannedOrb, banned, fock_d
! call occ_pattern_of_det(det,occ,N_int) ! call occ_pattern_of_det(det,occ,N_int)
! call occ_pattern_to_dets_size(occ,n,elec_alpha_num,N_int) ! call occ_pattern_to_dets_size(occ,n,elec_alpha_num,N_int)
e_pert = 0.d0
coef = 0.d0
logical :: do_diag
do_diag = .False.
do istate=1,N_states do istate=1,N_states
delta_E = E0(istate) - Hii + E_shift delta_E = E0(istate) - Hii + E_shift
alpha_h_psi = mat(istate, p1, p2) alpha_h_psi = mat(istate, p1, p2)
if (alpha_h_psi == 0.d0) cycle
val = alpha_h_psi + alpha_h_psi val = alpha_h_psi + alpha_h_psi
tmp = dsqrt(delta_E * delta_E + val * val) tmp = dsqrt(delta_E * delta_E + val * val)
if (delta_E < 0.d0) then if (delta_E < 0.d0) then
tmp = -tmp tmp = -tmp
endif endif
e_pert = 0.5d0 * (tmp - delta_E) e_pert(istate) = 0.5d0 * (tmp - delta_E)
if (dabs(alpha_h_psi) > 1.d-4) then if (dabs(alpha_h_psi) > 1.d-4) then
coef(istate) = e_pert / alpha_h_psi coef(istate) = e_pert(istate) / alpha_h_psi
else else
coef(istate) = alpha_h_psi / delta_E coef(istate) = alpha_h_psi / delta_E
endif endif
enddo enddo
do_diag = sum(dabs(coef)) > 0.001d0
double precision :: eigvalues(N_states+1)
double precision :: work(1+6*(N_states+1)+2*(N_states+1)**2)
integer :: iwork(3+5*(N_states+1)), info, k ,n
if (do_diag) then
double precision :: pt2_matrix(N_states+1,N_states+1)
pt2_matrix(N_states+1,N_states+1) = Hii+E_shift
do istate=1,N_states
pt2_matrix(:,istate) = 0.d0
pt2_matrix(istate,istate) = E0(istate)
pt2_matrix(istate,N_states+1) = mat(istate,p1,p2)
pt2_matrix(N_states+1,istate) = mat(istate,p1,p2)
enddo
call DSYEVD( 'V', 'U', N_states+1, pt2_matrix, N_states+1, eigvalues, &
work, size(work), iwork, size(iwork), info )
if (info /= 0) then
print *, 'error in '//irp_here
stop -1
endif
pt2_matrix = dabs(pt2_matrix)
iwork(1:N_states+1) = maxloc(pt2_matrix,DIM=1)
do k=1,N_states
e_pert(iwork(k)) = eigvalues(k) - E0(iwork(k))
enddo
endif
! ! Gram-Schmidt using input overlap matrix ! ! Gram-Schmidt using input overlap matrix
! do istate=1,N_states ! do istate=1,N_states
! do jstate=1,istate-1 ! do jstate=1,istate-1
@ -791,14 +833,13 @@ subroutine fill_buffer_double(i_generator, sp, h1, h2, bannedOrb, banned, fock_d
do istate=1, N_states do istate=1, N_states
alpha_h_psi = mat(istate, p1, p2) alpha_h_psi = mat(istate, p1, p2)
e_pert = coef(istate) * alpha_h_psi
do jstate=1,N_states do jstate=1,N_states
pt2_data % overlap(jstate,istate) += coef(jstate) * alpha_h_psi pt2_data % overlap(jstate,istate) += coef(jstate) * alpha_h_psi
enddo enddo
pt2_data % variance(istate) += alpha_h_psi * alpha_h_psi pt2_data % variance(istate) += alpha_h_psi * alpha_h_psi
pt2_data % pt2(istate) += e_pert pt2_data % pt2(istate) += e_pert(istate)
!!!DEBUG !!!DEBUG
! delta_E = E0(istate) - Hii + E_shift ! delta_E = E0(istate) - Hii + E_shift
@ -823,14 +864,29 @@ subroutine fill_buffer_double(i_generator, sp, h1, h2, bannedOrb, banned, fock_d
case(5) case(5)
! Variance selection ! Variance selection
w = w - alpha_h_psi * alpha_h_psi * selection_weight(istate) ! w = w - alpha_h_psi * alpha_h_psi * s_weight(istate,istate)
w = min(w, - alpha_h_psi * alpha_h_psi * s_weight(istate,istate))
! do jstate=1,N_states
! if (istate == jstate) cycle
! w = w + dabs(alpha_h_psi*mat(jstate,p1,p2)) * s_weight(istate,jstate)
! enddo
case(6) case(6)
w = w - coef(istate) * coef(istate) * selection_weight(istate) ! w = w - coef(istate) * coef(istate) * s_weight(istate,istate)
w = min(w,- coef(istate) * coef(istate) * s_weight(istate,istate))
! do jstate=1,N_states
! if (istate == jstate) cycle
! w = w + dabs(coef(istate)*coef(jstate)) * s_weight(istate,jstate)
! enddo
case default case default
! Energy selection ! Energy selection
w = w + e_pert * selection_weight(istate) ! w = w + e_pert(istate) * s_weight(istate,istate)
w = min(w, e_pert(istate) * s_weight(istate,istate))
! do jstate=1,N_states
! if (istate == jstate) cycle
! w = w + dabs(X(istate)*X(jstate)) * s_weight(istate,jstate)
! enddo
end select end select
end do end do

View File

@ -141,6 +141,12 @@ subroutine ZMQ_selection(N_in, pt2_data)
enddo enddo
pt2_overlap(:,:) = pt2_data % overlap(:,:) pt2_overlap(:,:) = pt2_data % overlap(:,:)
print *, 'Overlap of perturbed states:'
do l=1,N_states
print *, pt2_overlap(l,:)
enddo
print *, '-------'
SOFT_TOUCH pt2_overlap SOFT_TOUCH pt2_overlap
call update_pt2_and_variance_weights(pt2_data, N_states) call update_pt2_and_variance_weights(pt2_data, N_states)

View File

@ -8,7 +8,7 @@ default: 1.e-10
type: logical type: logical
doc: Thresholds of Davidson's algorithm is set to E(rPT2)*threshold_davidson_from_pt2 doc: Thresholds of Davidson's algorithm is set to E(rPT2)*threshold_davidson_from_pt2
interface: ezfio,provider,ocaml interface: ezfio,provider,ocaml
default: true default: false
[n_states_diag] [n_states_diag]
type: States_number type: States_number

View File

@ -44,7 +44,7 @@ default: 2
type: integer type: integer
doc: Weight used in the selection. 0: input state-average weight, 1: 1./(c_0^2), 2: rPT2 matching, 3: variance matching, 4: variance and rPT2 matching, 5: variance minimization and matching, 6: CI coefficients 7: input state-average multiplied by variance and rPT2 matching 8: input state-average multiplied by rPT2 matching 9: input state-average multiplied by variance matching doc: Weight used in the selection. 0: input state-average weight, 1: 1./(c_0^2), 2: rPT2 matching, 3: variance matching, 4: variance and rPT2 matching, 5: variance minimization and matching, 6: CI coefficients 7: input state-average multiplied by variance and rPT2 matching 8: input state-average multiplied by rPT2 matching 9: input state-average multiplied by variance matching
interface: ezfio,provider,ocaml interface: ezfio,provider,ocaml
default: 2 default: 1
[threshold_generators] [threshold_generators]
type: Threshold type: Threshold

View File

@ -99,6 +99,10 @@ double precision function get_two_e_integral(i,j,k,l,map)
type(map_type), intent(inout) :: map type(map_type), intent(inout) :: map
real(integral_kind) :: tmp real(integral_kind) :: tmp
PROVIDE mo_two_e_integrals_in_map mo_integrals_cache PROVIDE mo_two_e_integrals_in_map mo_integrals_cache
if (banned_excitation(i,k) .or. banned_excitation(j,l)) then
get_two_e_integral = 0.d0
return
endif
ii = l-mo_integrals_cache_min ii = l-mo_integrals_cache_min
ii = ior(ii, k-mo_integrals_cache_min) ii = ior(ii, k-mo_integrals_cache_min)
ii = ior(ii, j-mo_integrals_cache_min) ii = ior(ii, j-mo_integrals_cache_min)
@ -159,6 +163,11 @@ subroutine get_mo_two_e_integrals(j,k,l,sze,out_val,map)
! return ! return
!DEBUG !DEBUG
out_val(1:sze) = 0.d0
if (banned_excitation(j,l)) then
return
endif
ii0 = l-mo_integrals_cache_min ii0 = l-mo_integrals_cache_min
ii0 = ior(ii0, k-mo_integrals_cache_min) ii0 = ior(ii0, k-mo_integrals_cache_min)
ii0 = ior(ii0, j-mo_integrals_cache_min) ii0 = ior(ii0, j-mo_integrals_cache_min)
@ -172,6 +181,7 @@ subroutine get_mo_two_e_integrals(j,k,l,sze,out_val,map)
q = q+shiftr(s*s-s,1) q = q+shiftr(s*s-s,1)
do i=1,sze do i=1,sze
if (banned_excitation(i,k)) cycle
ii = ior(ii0, i-mo_integrals_cache_min) ii = ior(ii0, i-mo_integrals_cache_min)
if (iand(ii, -128) == 0) then if (iand(ii, -128) == 0) then
ii_8 = ior( shiftl(ii0_8,7), int(i,8)-mo_integrals_cache_min_8) ii_8 = ior( shiftl(ii0_8,7), int(i,8)-mo_integrals_cache_min_8)
@ -272,6 +282,29 @@ subroutine get_mo_two_e_integrals_exch_ii(k,l,sze,out_val,map)
end end
BEGIN_PROVIDER [ logical, banned_excitation, (mo_num,mo_num) ]
implicit none
use map_module
BEGIN_DOC
! If true, the excitation is banned in the selection. Useful with local MOs.
END_DOC
banned_excitation = .False.
integer :: i,j
integer(key_kind) :: idx
double precision :: tmp
! double precision :: buffer(mo_num)
do j=1,mo_num
do i=1,j-1
call two_e_integrals_index(i,j,j,i,idx)
!DIR$ FORCEINLINE
call map_get(mo_integrals_map,idx,tmp)
banned_excitation(i,j) = dabs(tmp) < 1.d-15
banned_excitation(j,i) = banned_excitation(i,j)
enddo
enddo
END_PROVIDER
integer*8 function get_mo_map_size() integer*8 function get_mo_map_size()
implicit none implicit none

View File

@ -17,7 +17,7 @@ program molden
write(i_unit_output,'(A)') '[Molden Format]' write(i_unit_output,'(A)') '[Molden Format]'
write(i_unit_output,'(A)') '[Atoms] ANGSTROM' write(i_unit_output,'(A)') '[Atoms] Angs'
do i = 1, nucl_num do i = 1, nucl_num
write(i_unit_output,'(A2,2X,I4,2X,I4,3(2X,F15.10))') & write(i_unit_output,'(A2,2X,I4,2X,I4,3(2X,F15.10))') &
trim(element_name(int(nucl_charge(i)))), & trim(element_name(int(nucl_charge(i)))), &

View File

@ -11,7 +11,7 @@ subroutine svd(A,LDA,U,LDU,D,Vt,LDVt,m,n)
integer, intent(in) :: LDA, LDU, LDVt, m, n integer, intent(in) :: LDA, LDU, LDVt, m, n
double precision, intent(in) :: A(LDA,n) double precision, intent(in) :: A(LDA,n)
double precision, intent(out) :: U(LDU,m) double precision, intent(out) :: U(LDU,min(m,n))
double precision,intent(out) :: Vt(LDVt,n) double precision,intent(out) :: Vt(LDVt,n)
double precision,intent(out) :: D(min(m,n)) double precision,intent(out) :: D(min(m,n))
double precision,allocatable :: work(:) double precision,allocatable :: work(:)
@ -19,19 +19,19 @@ subroutine svd(A,LDA,U,LDU,D,Vt,LDVt,m,n)
double precision,allocatable :: A_tmp(:,:) double precision,allocatable :: A_tmp(:,:)
allocate (A_tmp(LDA,n)) allocate (A_tmp(LDA,n))
A_tmp = A A_tmp(:,:) = A(:,:)
! Find optimal size for temp arrays ! Find optimal size for temp arrays
allocate(work(1)) allocate(work(1))
lwork = -1 lwork = -1
call dgesvd('A','A', m, n, A_tmp, LDA, & call dgesvd('S','S', m, n, A_tmp, LDA, &
D, U, LDU, Vt, LDVt, work, lwork, info) D, U, LDU, Vt, LDVt, work, lwork, info)
! /!\ int(WORK(1)) becomes negative when WORK(1) > 2147483648 ! /!\ int(WORK(1)) becomes negative when WORK(1) > 2147483648
lwork = max(int(work(1)), 5*MIN(M,N)) lwork = max(int(work(1)), 5*MIN(M,N))
deallocate(work) deallocate(work)
allocate(work(lwork)) allocate(work(lwork))
call dgesvd('A','A', m, n, A_tmp, LDA, & call dgesvd('S','S', m, n, A_tmp, LDA, &
D, U, LDU, Vt, LDVt, work, lwork, info) D, U, LDU, Vt, LDVt, work, lwork, info)
deallocate(work,A_tmp) deallocate(work,A_tmp)
@ -42,6 +42,128 @@ subroutine svd(A,LDA,U,LDU,D,Vt,LDVt,m,n)
end end
subroutine eigSVD(A,LDA,U,LDU,D,Vt,LDVt,m,n)
implicit none
BEGIN_DOC
! Algorithm 3 of https://arxiv.org/pdf/1810.06860.pdf
!
! A(m,n) = U(m,n) D(n) Vt(n,n) with m>n
END_DOC
integer, intent(in) :: LDA, LDU, LDVt, m, n
double precision, intent(in) :: A(LDA,n)
double precision, intent(out) :: U(LDU,n)
double precision,intent(out) :: Vt(LDVt,n)
double precision,intent(out) :: D(n)
integer :: i,j,k
if (m<n) then
stop -1
call svd(A,LDA,U,LDU,D,Vt,LDVt,m,n)
return
endif
double precision, allocatable :: B(:,:), V(:,:)
allocate(B(n,n))
! B = - At . A
call dgemm('T','N',n,n,m,-1.d0,A,size(A,1),A,size(A,1),0.d0,B,size(B,1))
! V, D = eig(B)
allocate(V(n,n))
call lapack_diagd(D,V,B,n,n)
deallocate(B)
do j=1,n
do i=1,n
Vt(i,j) = V(j,i)
enddo
enddo
! S = sqrt(-D)
! U = A.V.S^-1
! U = A.(S^-1.vt)t
do k=1,n
if (D(k) >= 0.d0) then
exit
endif
D(k) = dsqrt(-D(k))
call dscal(n, 1.d0/D(k), V(1,k), 1)
enddo
D(k:n) = 0.d0
k=k-1
call dgemm('N','N',m,n,k,1.d0,A,size(A,1),V,size(V,1),0.d0,U,size(U,1))
end
subroutine randomized_svd(A,LDA,U,LDU,D,Vt,LDVt,m,n,q,r)
implicit none
include 'constants.include.F'
BEGIN_DOC
! Randomized SVD: rank r, q power iterations
!
! 1. Sample column space of A with P: Z = A.P where P is random with r+p columns.
!
! 2. Power iterations : Z <- X . (Xt.Z)
!
! 3. Z = Q.R
!
! 4. Compute SVD on projected Qt.X = U' . S. Vt
!
! 5. U = Q U'
END_DOC
integer, intent(in) :: LDA, LDU, LDVt, m, n, q, r
double precision, intent(in) :: A(LDA,n)
double precision, intent(out) :: U(LDU,r)
double precision,intent(out) :: Vt(LDVt,r)
double precision,intent(out) :: D(r)
integer :: i, j, k
double precision,allocatable :: Z(:,:), P(:,:), Y(:,:), UY(:,:)
double precision :: r1,r2
allocate(P(n,r), Z(m,r))
! P is a normal random matrix (n,r)
do k=1,r
do i=1,n
call random_number(r1)
call random_number(r2)
r1 = dsqrt(-2.d0*dlog(r1))
r2 = dtwo_pi*r2
P(i,k) = r1*dcos(r2)
enddo
enddo
! Z(m,r) = A(m,n).P(n,r)
call dgemm('N','N',m,r,n,1.d0,A,size(A,1),P,size(P,1),0.d0,Z,size(Z,1))
! Power iterations
do k=1,q
! P(n,r) = At(n,m).Z(m,r)
call dgemm('T','N',n,r,m,1.d0,A,size(A,1),Z,size(Z,1),0.d0,P,size(P,1))
! Z(m,r) = A(m,n).P(n,r)
call dgemm('N','N',m,r,n,1.d0,A,size(A,1),P,size(P,1),0.d0,Z,size(Z,1))
enddo
deallocate(P)
! QR factorization of Z
call ortho_svd(Z,size(Z,1),m,r)
allocate(Y(r,n), UY(r,r))
! Y(r,n) = Zt(r,m).A(m,n)
call dgemm('T','N',r,n,m,1.d0,Z,size(Z,1),A,size(A,1),0.d0,Y,size(Y,1))
! SVD of Y
call svd(Y,size(Y,1),UY,size(UY,1),D,Vt,size(Vt,1),r,n)
deallocate(Y)
! U(m,r) = Z(m,r).UY(r,r)
call dgemm('N','N',m,r,r,1.d0,Z,size(Z,1),UY,size(UY,1),0.d0,U,size(U,1))
deallocate(UY,Z)
end
subroutine svd_complex(A,LDA,U,LDU,D,Vt,LDVt,m,n) subroutine svd_complex(A,LDA,U,LDU,D,Vt,LDVt,m,n)
implicit none implicit none
@ -807,6 +929,33 @@ subroutine ortho_canonical(overlap,LDA,N,C,LDC,m,cutoff)
end end
subroutine ortho_svd(A,LDA,m,n)
implicit none
BEGIN_DOC
! Orthogonalization via fast SVD
!
! A : matrix to orthogonalize
!
! LDA : leftmost dimension of A
!
! m : Number of rows of A
!
! n : Number of columns of A
!
END_DOC
integer, intent(in) :: m,n, LDA
double precision, intent(inout) :: A(LDA,n)
if (m < n) then
call ortho_qr(A,LDA,m,n)
endif
double precision, allocatable :: U(:,:), D(:), Vt(:,:)
allocate(U(m,n), D(n), Vt(n,n))
call SVD(A,LDA,U,size(U,1),D,Vt,size(Vt,1),m,n)
A(1:m,1:n) = U(1:m,1:n)
deallocate(U,D, Vt)
end
subroutine ortho_qr(A,LDA,m,n) subroutine ortho_qr(A,LDA,m,n)
implicit none implicit none
BEGIN_DOC BEGIN_DOC