9
1
mirror of https://github.com/QuantumPackage/qp2.git synced 2024-12-21 19:13:29 +01:00

Added randomized SVD routines

This commit is contained in:
Anthony Scemama 2020-09-30 20:48:27 +02:00
parent e838868181
commit 88b798fe22

View File

@ -11,7 +11,7 @@ subroutine svd(A,LDA,U,LDU,D,Vt,LDVt,m,n)
integer, intent(in) :: LDA, LDU, LDVt, m, n integer, intent(in) :: LDA, LDU, LDVt, m, n
double precision, intent(in) :: A(LDA,n) double precision, intent(in) :: A(LDA,n)
double precision, intent(out) :: U(LDU,m) double precision, intent(out) :: U(LDU,min(m,n))
double precision,intent(out) :: Vt(LDVt,n) double precision,intent(out) :: Vt(LDVt,n)
double precision,intent(out) :: D(min(m,n)) double precision,intent(out) :: D(min(m,n))
double precision,allocatable :: work(:) double precision,allocatable :: work(:)
@ -24,14 +24,14 @@ subroutine svd(A,LDA,U,LDU,D,Vt,LDVt,m,n)
! Find optimal size for temp arrays ! Find optimal size for temp arrays
allocate(work(1)) allocate(work(1))
lwork = -1 lwork = -1
call dgesvd('A','A', m, n, A_tmp, LDA, & call dgesvd('S','S', m, n, A_tmp, LDA, &
D, U, LDU, Vt, LDVt, work, lwork, info) D, U, LDU, Vt, LDVt, work, lwork, info)
! /!\ int(WORK(1)) becomes negative when WORK(1) > 2147483648 ! /!\ int(WORK(1)) becomes negative when WORK(1) > 2147483648
lwork = max(int(work(1)), 5*MIN(M,N)) lwork = max(int(work(1)), 5*MIN(M,N))
deallocate(work) deallocate(work)
allocate(work(lwork)) allocate(work(lwork))
call dgesvd('A','A', m, n, A_tmp, LDA, & call dgesvd('S','S', m, n, A_tmp, LDA, &
D, U, LDU, Vt, LDVt, work, lwork, info) D, U, LDU, Vt, LDVt, work, lwork, info)
deallocate(work,A_tmp) deallocate(work,A_tmp)
@ -42,6 +42,128 @@ subroutine svd(A,LDA,U,LDU,D,Vt,LDVt,m,n)
end end
subroutine eigSVD(A,LDA,U,LDU,D,Vt,LDVt,m,n)
implicit none
BEGIN_DOC
! Algorithm 3 of https://arxiv.org/pdf/1810.06860.pdf
!
! A(m,n) = U(m,n) D(n) Vt(n,n) with m>n
END_DOC
integer, intent(in) :: LDA, LDU, LDVt, m, n
double precision, intent(in) :: A(LDA,n)
double precision, intent(out) :: U(LDU,n)
double precision,intent(out) :: Vt(LDVt,n)
double precision,intent(out) :: D(n)
integer :: i,j,k
if (m<n) then
stop -1
call svd(A,LDA,U,LDU,D,Vt,LDVt,m,n)
return
endif
double precision, allocatable :: B(:,:), V(:,:)
allocate(B(n,n))
! B = - At . A
call dgemm('T','N',n,n,m,-1.d0,A,size(A,1),A,size(A,1),0.d0,B,size(B,1))
! V, D = eig(B)
allocate(V(n,n))
call lapack_diagd(D,V,B,n,n)
deallocate(B)
do j=1,n
do i=1,n
Vt(i,j) = V(j,i)
enddo
enddo
! S = sqrt(-D)
! U = A.V.S^-1
! U = A.(S^-1.vt)t
do k=1,n
if (D(k) >= 0.d0) then
exit
endif
D(k) = dsqrt(-D(k))
call dscal(n, 1.d0/D(k), V(1,k), 1)
enddo
D(k:n) = 0.d0
k=k-1
call dgemm('N','N',m,n,k,1.d0,A,size(A,1),V,size(V,1),0.d0,U,size(U,1))
end
subroutine randomized_svd(A,LDA,U,LDU,D,Vt,LDVt,m,n,q,r)
implicit none
include 'constants.include.F'
BEGIN_DOC
! Randomized SVD: rank r, q power iterations
!
! 1. Sample column space of A with P: Z = A.P where P is random with r+p columns.
!
! 2. Power iterations : Z <- X . (Xt.Z)
!
! 3. Z = Q.R
!
! 4. Compute SVD on projected Qt.X = U' . S. Vt
!
! 5. U = Q U'
END_DOC
integer, intent(in) :: LDA, LDU, LDVt, m, n, q, r
double precision, intent(in) :: A(LDA,n)
double precision, intent(out) :: U(LDU,r)
double precision,intent(out) :: Vt(LDVt,r)
double precision,intent(out) :: D(r)
integer :: i, j, k
double precision,allocatable :: Z(:,:), P(:,:), Y(:,:), UY(:,:)
double precision :: r1,r2
allocate(P(n,r), Z(m,r))
! P is a normal random matrix (n,r)
do k=1,r
do i=1,n
call random_number(r1)
call random_number(r2)
r1 = dsqrt(-2.d0*dlog(r1))
r2 = dtwo_pi*r2
P(i,k) = r1*dcos(r2)
enddo
enddo
! Z(m,r) = A(m,n).P(n,r)
call dgemm('N','N',m,r,n,1.d0,A,size(A,1),P,size(P,1),0.d0,Z,size(Z,1))
! Power iterations
do k=1,q
! P(n,r) = At(n,m).Z(m,r)
call dgemm('T','N',n,r,m,1.d0,A,size(A,1),Z,size(Z,1),0.d0,P,size(P,1))
! Z(m,r) = A(m,n).P(n,r)
call dgemm('N','N',m,r,n,1.d0,A,size(A,1),P,size(P,1),0.d0,Z,size(Z,1))
enddo
deallocate(P)
! QR factorization of Z
call ortho_svd(Z,size(Z,1),m,r)
allocate(Y(r,n), UY(r,r))
! Y(r,n) = Zt(r,m).A(m,n)
call dgemm('T','N',r,n,m,1.d0,Z,size(Z,1),A,size(A,1),0.d0,Y,size(Y,1))
! SVD of Y
call svd(Y,size(Y,1),UY,size(UY,1),D,Vt,size(Vt,1),r,n)
deallocate(Y)
! U(m,r) = Z(m,r).UY(r,r)
call dgemm('N','N',m,r,r,1.d0,Z,size(Z,1),UY,size(UY,1),0.d0,U,size(U,1))
deallocate(UY,Z)
end
subroutine svd_complex(A,LDA,U,LDU,D,Vt,LDVt,m,n) subroutine svd_complex(A,LDA,U,LDU,D,Vt,LDVt,m,n)
implicit none implicit none
@ -807,6 +929,33 @@ subroutine ortho_canonical(overlap,LDA,N,C,LDC,m,cutoff)
end end
subroutine ortho_svd(A,LDA,m,n)
implicit none
BEGIN_DOC
! Orthogonalization via fast SVD
!
! A : matrix to orthogonalize
!
! LDA : leftmost dimension of A
!
! m : Number of rows of A
!
! n : Number of columns of A
!
END_DOC
integer, intent(in) :: m,n, LDA
double precision, intent(inout) :: A(LDA,n)
if (m < n) then
call ortho_qr(A,LDA,m,n)
endif
double precision, allocatable :: U(:,:), D(:), Vt(:,:)
allocate(U(m,n), D(n), Vt(n,n))
call SVD(A,LDA,U,size(U,1),D,Vt,size(Vt,1),m,n)
A(1:m,1:n) = U(1:m,1:n)
deallocate(U,D, Vt)
end
subroutine ortho_qr(A,LDA,m,n) subroutine ortho_qr(A,LDA,m,n)
implicit none implicit none
BEGIN_DOC BEGIN_DOC