From 2b57c0728228adae01ae341b7068356f47d488dc Mon Sep 17 00:00:00 2001 From: ydamour Date: Sat, 25 Feb 2023 01:37:41 +0100 Subject: [PATCH] more efficient restore symmetry --- src/utils/linear_algebra.irp.f | 145 ++++++++++++++++++++------------- 1 file changed, 88 insertions(+), 57 deletions(-) diff --git a/src/utils/linear_algebra.irp.f b/src/utils/linear_algebra.irp.f index 51df33c5..c02560e3 100644 --- a/src/utils/linear_algebra.irp.f +++ b/src/utils/linear_algebra.irp.f @@ -1574,79 +1574,110 @@ subroutine nullify_small_elements(m,n,A,LDA,thresh) end subroutine restore_symmetry(m,n,A,LDA,thresh) + implicit none + BEGIN_DOC -! Tries to find the matrix elements that are the same, and sets them -! to the average value. -! If restore_symm is False, only nullify small elements + ! Tries to find the matrix elements that are the same, and sets them + ! to the average value. + ! If restore_symm is False, only nullify small elements END_DOC + integer, intent(in) :: m,n,LDA double precision, intent(inout) :: A(LDA,n) double precision, intent(in) :: thresh - integer :: i,j,k,l - logical, allocatable :: done(:,:) - double precision :: f, g, count, thresh2 + + double precision, allocatable :: copy(:), copy_sign(:) + integer, allocatable :: key(:), ii(:), jj(:) + integer :: sze, pi, pf, idx, i,j,k + double precision :: average, val, thresh2 + thresh2 = dsqrt(thresh) - call nullify_small_elements(m,n,A,LDA,thresh) -! if (.not.restore_symm) then -! return -! endif + sze = m * n - ! TODO: Costs O(n^4), but can be improved to (2 n^2 * log(n)): - ! - copy all values in a 1D array - ! - sort 1D array - ! - average nearby elements - ! - for all elements, find matching value in the sorted 1D array + allocate(copy(sze),copy_sign(sze),key(sze),ii(sze),jj(sze)) - allocate(done(m,n)) - - do j=1,n - do i=1,m - done(i,j) = A(i,j) == 0.d0 + ! Copy to 1D + !$OMP PARALLEL if (m>100) & + !$OMP SHARED(A,m,n,sze,copy_sign,copy,key,ii,jj) & + !$OMP PRIVATE(i,j,k) & + !$OMP DEFAULT(NONE) + !$OMP DO + do j = 1, n + do i = 1, m + k = i+(j-1)*m + copy(k) = A(i,j) + copy_sign(k) = sign(1.d0,copy(k)) + copy(k) = -dabs(copy(k)) + key(k) = k + ii(k) = i + jj(k) = j enddo enddo + !$OMP END DO + !$OMP END PARALLEL - do j=1,n - do i=1,m - if ( done(i,j) ) cycle - done(i,j) = .True. - count = 1.d0 - f = 1.d0/A(i,j) - do l=1,n - do k=1,m - if ( done(k,l) ) cycle - g = f * A(k,l) - if ( dabs(dabs(g) - 1.d0) < thresh2 ) then - count = count + 1.d0 - if (g>0.d0) then - A(i,j) = A(i,j) + A(k,l) - else - A(i,j) = A(i,j) - A(k,l) - end if - endif - enddo - enddo - if (count > 1.d0) then - A(i,j) = A(i,j) / count - do l=1,n - do k=1,m - if ( done(k,l) ) cycle - g = f * A(k,l) - if ( dabs(dabs(g) - 1.d0) < thresh2 ) then - done(k,l) = .True. - if (g>0.d0) then - A(k,l) = A(i,j) - else - A(k,l) = -A(i,j) - end if - endif - enddo - enddo + ! Sort + call dsort(copy,key,sze) + call iset_order(ii,key,sze) + call iset_order(jj,key,sze) + call dset_order(copy_sign,key,sze) + + !TODO + ! Parallelization with OMP + + ! Symmetrize + i = 1 + do while (i < sze) + pi = i + pf = i + + ! Exit if the remaining elements are below thresh + if (-copy(i) <= thresh) exit + + val = 1d0/copy(i) + do while (dabs(val * copy(pf+1) - 1d0) < thresh2) + pf = pf + 1 + ! if pf == sze, copy(pf+1) will not be valid + if (pf == sze) then + exit endif - enddo + ! if pi and pf are different do the average from pi to pf + if (pf - pi > 0) then + average = 0d0 + do j = pi, pf + average = average + copy(j) + enddo + average = average / (pf-pi+1.d0) + do j = pi, pf + copy(j) = average + enddo + ! Update i + i = pf + endif + + ! Update i + i = i + 1 enddo + copy(i:) = 0.d0 + + !$OMP PARALLEL if (sze>10000) & + !$OMP SHARED(m,sze,copy_sign,copy,key,A,ii,jj) & + !$OMP PRIVATE(i,j,k,idx) & + !$OMP DEFAULT(NONE) + ! copy -> A + !$OMP DO + do k = 1, sze + i = ii(k) + j = jj(k) + A(i,j) = sign(copy(k),copy_sign(k)) + enddo + !$OMP END DO + !$OMP END PARALLEL + + deallocate(copy,copy_sign,key,ii,jj) end