diff --git a/src/utils/linear_algebra.irp.f b/src/utils/linear_algebra.irp.f index 1e33c7dc..20599325 100644 --- a/src/utils/linear_algebra.irp.f +++ b/src/utils/linear_algebra.irp.f @@ -1690,6 +1690,19 @@ subroutine restore_symmetry(m,n,A,LDA,thresh) thresh2 = dsqrt(thresh) call nullify_small_elements(m,n,A,LDA,thresh) + ! Debug + !double precision, allocatable :: B(:,:) + !double precision :: max_diff, ti,tf + !allocate(B(m,n)) + !B = A + !call wall_time(ti) + !call restore_symmetry_fast(m,n,B,LDA,thresh) + !call wall_time(tf) + !print*,'' + !print*,'Restore_symmetry' + !print*,'Fast version:',tf-ti,'s' + !call wall_time(ti) + ! if (.not.restore_symm) then ! return ! endif @@ -1749,23 +1762,136 @@ subroutine restore_symmetry(m,n,A,LDA,thresh) enddo enddo + ! Debug + !call wall_time(tf) + !print*,'Old version:',tf-ti,'s' + + !max_diff = 0d0 + !do j = 1, n + ! do i = 1, n + ! if (dabs(A(i,j)-B(i,j)) > max_diff) then + ! max_diff = dabs(A(i,j)-B(i,j)) + ! endif + ! enddo + !enddo + !print*,'Max diff:', max_diff + !deallocate(B) + end +subroutine restore_symmetry_fast(m,n,A,LDA,thresh) + implicit none + BEGIN_DOC + ! Tries to find the matrix elements that are the same, and sets them + ! to the average value. + ! If restore_symm is False, only nullify small elements + END_DOC + integer, intent(in) :: m,n,LDA + double precision, intent(inout) :: A(LDA,n) + double precision, intent(in) :: thresh + double precision, allocatable :: copy(:), copy_sign(:) + integer, allocatable :: key(:) + integer :: sze, pi, pf, idx, i,j,k + double precision :: average, val, thresh2 + thresh2 = dsqrt(thresh) + call nullify_small_elements(m,n,A,LDA,thresh) + + sze = m * n + allocate(copy(sze),copy_sign(sze),key(sze)) + + ! Copy to 1D + !$OMP PARALLEL & + !$OMP SHARED(A,m,n,sze,copy_sign,copy,key) & + !$OMP PRIVATE(i,j,k) & + !$OMP DEFAULT(NONE) + !$OMP DO + do j = 1, n + do i = 1, m + copy(i+(j-1)*m) = A(i,j) + enddo + enddo + !$OMP END DO + ! Copy sign + !$OMP DO + do i = 1,sze + copy_sign(i) = sign(1d0,copy(i)) + copy(i) = dabs(copy(i)) + enddo + !$OMP END DO NOWAIT + ! Keys + !$OMP DO + do i = 1, sze + key(i) = i + enddo + !$OMP END DO + !$OMP END PARALLEL + ! Sort + call dsort(copy,key,sze) + !TODO + ! Parallelization with OMP + ! Jump all the elements below thresh + i = 1 + do while (copy(i) <= thresh) + i = i + 1 + enddo + ! Symmetrize + do while(i < sze) + pi = i + pf = i + val = 1d0/copy(i) + do while (dabs(val * copy(pf+1) - 1d0) < thresh2) + pf = pf + 1 + ! if pf == sze, copy(pf+1) will not be valid + if (pf == sze) then + exit + endif + enddo + ! if pi and pf are different do the average from pi to pf + if (pf - pi > 0) then + average = 0d0 + do j = pi, pf + average = average + copy(j) + enddo + average = average / (pf-pi+1) + do j = pi, pf + copy(j) = average + enddo + ! Update i + i = pf + endif + ! Update i + i = i + 1 + enddo + !$OMP PARALLEL & + !$OMP SHARED(m,sze,copy_sign,copy,key,A) & + !$OMP PRIVATE(i,j,k,idx) & + !$OMP DEFAULT(NONE) + ! copy -> A + !$OMP DO + do k = 1, sze + idx = key(k) + i = mod(idx-1,m) + 1 + j = (idx-1) / m + 1 + ! New value with the right sign + A(i,j) = sign(copy(k),copy_sign(idx)) + enddo + !$OMP END DO + !$OMP END PARALLEL + deallocate(copy,copy_sign,key) - - +end !subroutine svd_s(A, LDA, U, LDU, D, Vt, LDVt, m, n) ! implicit none