From 88b798fe22d0c3d5c718686d11772a281c31308c Mon Sep 17 00:00:00 2001 From: Anthony Scemama Date: Wed, 30 Sep 2020 20:48:27 +0200 Subject: [PATCH] Added randomized SVD routines --- src/utils/linear_algebra.irp.f | 155 ++++++++++++++++++++++++++++++++- 1 file changed, 152 insertions(+), 3 deletions(-) diff --git a/src/utils/linear_algebra.irp.f b/src/utils/linear_algebra.irp.f index cf43edd5..deb4f4ca 100644 --- a/src/utils/linear_algebra.irp.f +++ b/src/utils/linear_algebra.irp.f @@ -11,7 +11,7 @@ subroutine svd(A,LDA,U,LDU,D,Vt,LDVt,m,n) integer, intent(in) :: LDA, LDU, LDVt, m, n double precision, intent(in) :: A(LDA,n) - double precision, intent(out) :: U(LDU,m) + double precision, intent(out) :: U(LDU,min(m,n)) double precision,intent(out) :: Vt(LDVt,n) double precision,intent(out) :: D(min(m,n)) double precision,allocatable :: work(:) @@ -24,14 +24,14 @@ subroutine svd(A,LDA,U,LDU,D,Vt,LDVt,m,n) ! Find optimal size for temp arrays allocate(work(1)) lwork = -1 - call dgesvd('A','A', m, n, A_tmp, LDA, & + call dgesvd('S','S', m, n, A_tmp, LDA, & D, U, LDU, Vt, LDVt, work, lwork, info) ! /!\ int(WORK(1)) becomes negative when WORK(1) > 2147483648 lwork = max(int(work(1)), 5*MIN(M,N)) deallocate(work) allocate(work(lwork)) - call dgesvd('A','A', m, n, A_tmp, LDA, & + call dgesvd('S','S', m, n, A_tmp, LDA, & D, U, LDU, Vt, LDVt, work, lwork, info) deallocate(work,A_tmp) @@ -42,6 +42,128 @@ subroutine svd(A,LDA,U,LDU,D,Vt,LDVt,m,n) end +subroutine eigSVD(A,LDA,U,LDU,D,Vt,LDVt,m,n) + implicit none + BEGIN_DOC +! Algorithm 3 of https://arxiv.org/pdf/1810.06860.pdf +! +! A(m,n) = U(m,n) D(n) Vt(n,n) with m>n + END_DOC + integer, intent(in) :: LDA, LDU, LDVt, m, n + double precision, intent(in) :: A(LDA,n) + double precision, intent(out) :: U(LDU,n) + double precision,intent(out) :: Vt(LDVt,n) + double precision,intent(out) :: D(n) + + integer :: i,j,k + + if (m= 0.d0) then + exit + endif + D(k) = dsqrt(-D(k)) + call dscal(n, 1.d0/D(k), V(1,k), 1) + enddo + D(k:n) = 0.d0 + k=k-1 + call dgemm('N','N',m,n,k,1.d0,A,size(A,1),V,size(V,1),0.d0,U,size(U,1)) + +end + + +subroutine randomized_svd(A,LDA,U,LDU,D,Vt,LDVt,m,n,q,r) + implicit none + include 'constants.include.F' + BEGIN_DOC +! Randomized SVD: rank r, q power iterations +! +! 1. Sample column space of A with P: Z = A.P where P is random with r+p columns. +! +! 2. Power iterations : Z <- X . (Xt.Z) +! +! 3. Z = Q.R +! +! 4. Compute SVD on projected Qt.X = U' . S. Vt +! +! 5. U = Q U' + END_DOC + + integer, intent(in) :: LDA, LDU, LDVt, m, n, q, r + double precision, intent(in) :: A(LDA,n) + double precision, intent(out) :: U(LDU,r) + double precision,intent(out) :: Vt(LDVt,r) + double precision,intent(out) :: D(r) + integer :: i, j, k + + double precision,allocatable :: Z(:,:), P(:,:), Y(:,:), UY(:,:) + double precision :: r1,r2 + allocate(P(n,r), Z(m,r)) + + ! P is a normal random matrix (n,r) + do k=1,r + do i=1,n + call random_number(r1) + call random_number(r2) + r1 = dsqrt(-2.d0*dlog(r1)) + r2 = dtwo_pi*r2 + P(i,k) = r1*dcos(r2) + enddo + enddo + + ! Z(m,r) = A(m,n).P(n,r) + call dgemm('N','N',m,r,n,1.d0,A,size(A,1),P,size(P,1),0.d0,Z,size(Z,1)) + + ! Power iterations + do k=1,q + ! P(n,r) = At(n,m).Z(m,r) + call dgemm('T','N',n,r,m,1.d0,A,size(A,1),Z,size(Z,1),0.d0,P,size(P,1)) + ! Z(m,r) = A(m,n).P(n,r) + call dgemm('N','N',m,r,n,1.d0,A,size(A,1),P,size(P,1),0.d0,Z,size(Z,1)) + enddo + + deallocate(P) + + ! QR factorization of Z + call ortho_svd(Z,size(Z,1),m,r) + + allocate(Y(r,n), UY(r,r)) + ! Y(r,n) = Zt(r,m).A(m,n) + call dgemm('T','N',r,n,m,1.d0,Z,size(Z,1),A,size(A,1),0.d0,Y,size(Y,1)) + + ! SVD of Y + call svd(Y,size(Y,1),UY,size(UY,1),D,Vt,size(Vt,1),r,n) + deallocate(Y) + + ! U(m,r) = Z(m,r).UY(r,r) + call dgemm('N','N',m,r,r,1.d0,Z,size(Z,1),UY,size(UY,1),0.d0,U,size(U,1)) + deallocate(UY,Z) + +end subroutine svd_complex(A,LDA,U,LDU,D,Vt,LDVt,m,n) implicit none @@ -807,6 +929,33 @@ subroutine ortho_canonical(overlap,LDA,N,C,LDC,m,cutoff) end +subroutine ortho_svd(A,LDA,m,n) + implicit none + BEGIN_DOC + ! Orthogonalization via fast SVD + ! + ! A : matrix to orthogonalize + ! + ! LDA : leftmost dimension of A + ! + ! m : Number of rows of A + ! + ! n : Number of columns of A + ! + END_DOC + integer, intent(in) :: m,n, LDA + double precision, intent(inout) :: A(LDA,n) + if (m < n) then + call ortho_qr(A,LDA,m,n) + endif + double precision, allocatable :: U(:,:), D(:), Vt(:,:) + allocate(U(m,n), D(n), Vt(n,n)) + call SVD(A,LDA,U,size(U,1),D,Vt,size(Vt,1),m,n) + A(1:m,1:n) = U(1:m,1:n) + deallocate(U,D, Vt) + +end + subroutine ortho_qr(A,LDA,m,n) implicit none BEGIN_DOC