From 8018440410fac858f9a5ed2fb9f2c4ec4963c4b3 Mon Sep 17 00:00:00 2001 From: AbdAmmar Date: Thu, 25 Jan 2024 22:13:13 +0100 Subject: [PATCH] OPENMP & DGEMM in pseudo_inv --- src/utils/linear_algebra.irp.f | 57 +++++++++++++++++++++++----------- 1 file changed, 39 insertions(+), 18 deletions(-) diff --git a/src/utils/linear_algebra.irp.f b/src/utils/linear_algebra.irp.f index 314ad4f6..a67a219c 100644 --- a/src/utils/linear_algebra.irp.f +++ b/src/utils/linear_algebra.irp.f @@ -1321,19 +1321,22 @@ subroutine get_inverse(A,LDA,m,C,LDC) deallocate(ipiv,work) end -subroutine get_pseudo_inverse(A,LDA,m,n,C,LDC,cutoff) - implicit none +subroutine get_pseudo_inverse(A, LDA, m, n, C, LDC, cutoff) + BEGIN_DOC ! Find C = A^-1 END_DOC - integer, intent(in) :: m,n, LDA, LDC - double precision, intent(in) :: A(LDA,n) - double precision, intent(in) :: cutoff - double precision, intent(out) :: C(LDC,m) - double precision, allocatable :: U(:,:), D(:), Vt(:,:), work(:), A_tmp(:,:) - integer :: info, lwork - integer :: i,j,k + implicit none + integer, intent(in) :: m, n, LDA, LDC + double precision, intent(in) :: A(LDA,n) + double precision, intent(in) :: cutoff + double precision, intent(out) :: C(LDC,m) + + integer :: info, lwork + integer :: i, j, k, n_svd + double precision, allocatable :: U(:,:), D(:), Vt(:,:), work(:), A_tmp(:,:) + allocate (D(n),U(m,n),Vt(n,n),work(1),A_tmp(m,n)) do j=1,n do i=1,m @@ -1355,22 +1358,40 @@ subroutine get_pseudo_inverse(A,LDA,m,n,C,LDC,cutoff) stop 1 endif - do i=1,n - if (D(i)/D(1) > cutoff) then - D(i) = 1.d0/D(i) + n_svd = 0 + do i = 1, n + if(D(i)/D(1) > cutoff) then + D(i) = 1.d0 / D(i) + n_svd = n_svd + 1 else D(i) = 0.d0 endif enddo + print*, ' n_svd = ', n_svd - C = 0.d0 - do i=1,m - do j=1,n - do k=1,n - C(j,i) = C(j,i) + U(i,k) * D(k) * Vt(k,j) - enddo + !$OMP PARALLEL & + !$OMP DEFAULT (NONE) & + !$OMP PRIVATE (i, j) & + !$OMP SHARED (n, n_svd, D, Vt) + !$OMP DO + do j = 1, n + do i = 1, n_svd + Vt(i,j) = D(i) * Vt(i,j) enddo enddo + !$OMP END DO + !$OMP END PARALLEL + + call dgemm("N", "N", m, n, n_svd, 1.d0, U, m, Vt, n, 0.d0, C, LDC) + + !C = 0.d0 + !do i=1,m + ! do j=1,n + ! do k=1,n + ! C(j,i) = C(j,i) + U(i,k) * D(k) * Vt(k,j) + ! enddo + ! enddo + !enddo deallocate(U,D,Vt,work,A_tmp)