9
1
mirror of https://github.com/QuantumPackage/qp2.git synced 2024-11-07 05:53:37 +01:00

Merge pull request #251 from Ydrnan/restore_sym
Some checks failed
continuous-integration/drone/push Build is failing

more efficient restore symmetry
This commit is contained in:
Anthony Scemama 2023-02-25 01:50:22 +01:00 committed by GitHub
commit 60113ee34b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1574,79 +1574,110 @@ subroutine nullify_small_elements(m,n,A,LDA,thresh)
end end
subroutine restore_symmetry(m,n,A,LDA,thresh) subroutine restore_symmetry(m,n,A,LDA,thresh)
implicit none implicit none
BEGIN_DOC BEGIN_DOC
! Tries to find the matrix elements that are the same, and sets them ! Tries to find the matrix elements that are the same, and sets them
! to the average value. ! to the average value.
! If restore_symm is False, only nullify small elements ! If restore_symm is False, only nullify small elements
END_DOC END_DOC
integer, intent(in) :: m,n,LDA integer, intent(in) :: m,n,LDA
double precision, intent(inout) :: A(LDA,n) double precision, intent(inout) :: A(LDA,n)
double precision, intent(in) :: thresh double precision, intent(in) :: thresh
integer :: i,j,k,l
logical, allocatable :: done(:,:) double precision, allocatable :: copy(:), copy_sign(:)
double precision :: f, g, count, thresh2 integer, allocatable :: key(:), ii(:), jj(:)
integer :: sze, pi, pf, idx, i,j,k
double precision :: average, val, thresh2
thresh2 = dsqrt(thresh) thresh2 = dsqrt(thresh)
call nullify_small_elements(m,n,A,LDA,thresh)
! if (.not.restore_symm) then sze = m * n
! return
! endif
! TODO: Costs O(n^4), but can be improved to (2 n^2 * log(n)): allocate(copy(sze),copy_sign(sze),key(sze),ii(sze),jj(sze))
! - copy all values in a 1D array
! - sort 1D array
! - average nearby elements
! - for all elements, find matching value in the sorted 1D array
allocate(done(m,n)) ! Copy to 1D
!$OMP PARALLEL if (m>100) &
do j=1,n !$OMP SHARED(A,m,n,sze,copy_sign,copy,key,ii,jj) &
do i=1,m !$OMP PRIVATE(i,j,k) &
done(i,j) = A(i,j) == 0.d0 !$OMP DEFAULT(NONE)
!$OMP DO
do j = 1, n
do i = 1, m
k = i+(j-1)*m
copy(k) = A(i,j)
copy_sign(k) = sign(1.d0,copy(k))
copy(k) = -dabs(copy(k))
key(k) = k
ii(k) = i
jj(k) = j
enddo enddo
enddo enddo
!$OMP END DO
!$OMP END PARALLEL
do j=1,n ! Sort
do i=1,m call dsort(copy,key,sze)
if ( done(i,j) ) cycle call iset_order(ii,key,sze)
done(i,j) = .True. call iset_order(jj,key,sze)
count = 1.d0 call dset_order(copy_sign,key,sze)
f = 1.d0/A(i,j)
do l=1,n !TODO
do k=1,m ! Parallelization with OMP
if ( done(k,l) ) cycle
g = f * A(k,l) ! Symmetrize
if ( dabs(dabs(g) - 1.d0) < thresh2 ) then i = 1
count = count + 1.d0 do while (i < sze)
if (g>0.d0) then pi = i
A(i,j) = A(i,j) + A(k,l) pf = i
else
A(i,j) = A(i,j) - A(k,l) ! Exit if the remaining elements are below thresh
end if if (-copy(i) <= thresh) exit
endif
enddo val = 1d0/copy(i)
enddo do while (dabs(val * copy(pf+1) - 1d0) < thresh2)
if (count > 1.d0) then pf = pf + 1
A(i,j) = A(i,j) / count ! if pf == sze, copy(pf+1) will not be valid
do l=1,n if (pf == sze) then
do k=1,m exit
if ( done(k,l) ) cycle
g = f * A(k,l)
if ( dabs(dabs(g) - 1.d0) < thresh2 ) then
done(k,l) = .True.
if (g>0.d0) then
A(k,l) = A(i,j)
else
A(k,l) = -A(i,j)
end if
endif
enddo
enddo
endif endif
enddo enddo
! if pi and pf are different do the average from pi to pf
if (pf - pi > 0) then
average = 0d0
do j = pi, pf
average = average + copy(j)
enddo
average = average / (pf-pi+1.d0)
do j = pi, pf
copy(j) = average
enddo
! Update i
i = pf
endif
! Update i
i = i + 1
enddo enddo
copy(i:) = 0.d0
!$OMP PARALLEL if (sze>10000) &
!$OMP SHARED(m,sze,copy_sign,copy,key,A,ii,jj) &
!$OMP PRIVATE(i,j,k,idx) &
!$OMP DEFAULT(NONE)
! copy -> A
!$OMP DO
do k = 1, sze
i = ii(k)
j = jj(k)
A(i,j) = sign(copy(k),copy_sign(k))
enddo
!$OMP END DO
!$OMP END PARALLEL
deallocate(copy,copy_sign,key,ii,jj)
end end