program svdwf
  implicit none
  BEGIN_DOC
! Make the SVD of the alpha-beta wave function and print singular values.
  END_DOC
  read_wf = .True.
  TOUCH read_wf
  call run()
end

subroutine run
  implicit none
  include 'constants.include.F'
  double precision, allocatable :: U(:,:), V(:,:), D(:), A(:,:)
  integer :: i, j, k, l, q, r, m, n, iter
  double precision,allocatable    :: Z(:,:), P(:,:), Yt(:,:), UYt(:,:), r1(:,:)

  m = n_det_alpha_unique
  n = n_det_beta_unique

  r = min(1000,n)

  allocate(Z(m,r))

  ! Z(m,r) = A(m,n).P(n,r)
  do j=1,r
    do i=1,m
      Z(i,j) = 0.d0
    enddo
  enddo
  allocate(P(n,r))

  !$OMP PARALLEL DEFAULT(SHARED) PRIVATE(i,j,k,l,r1)
  allocate(r1(N_det,2))
  !$OMP DO
  do l=1,r
   call random_number(r1)
   r1(:,1) = dsqrt(-2.d0*dlog(r1(:,1)))
   r1(:,1) = r1(:,1) * dcos(dtwo_pi*r1(:,2))
   do k=1,N_det
     i = psi_bilinear_matrix_rows(k)
     j = psi_bilinear_matrix_columns(k)
     Z(i,l) = Z(i,l) + psi_bilinear_matrix_values(k,1) * r1(k,1)
   enddo
  enddo
  !$OMP END DO
  deallocate(r1)
  !$OMP END PARALLEL

  ! Power iterations
  do iter=1,20
   print *, 'Power iteration ', iter, '/', 20

   ! P(n,r) = At(n,m).Z(m,r)
   !$OMP PARALLEL DEFAULT(SHARED) PRIVATE(i,j,k,l)
   !$OMP DO
   do l=1,r
    P(:,l) = 0.d0
    do k=1,N_det
      i = psi_bilinear_matrix_rows(k)
      j = psi_bilinear_matrix_columns(k)
      P(j,l) = P(j,l) + psi_bilinear_matrix_values(k,1) * Z(i,l)
    enddo
   enddo
   !$OMP END DO

   !$OMP BARRIER

   !$OMP DO
   do l=1,r
    Z(:,l) = 0.d0
    do k=1,N_det
      i = psi_bilinear_matrix_rows(k)
      j = psi_bilinear_matrix_columns(k)
      Z(i,l) = Z(i,l) + psi_bilinear_matrix_values(k,1) * P(j,l)
    enddo
   enddo
   !$OMP END DO

   !$OMP END PARALLEL

   ! Compute QR
   call ortho_qr(Z,size(Z,1),m,r)
  enddo

  ! Y(r,n) = Zt(r,m).A(m,n)
  allocate(Yt(n,r))
  !$OMP PARALLEL DEFAULT(SHARED) PRIVATE(i,j,k,l)
  !$OMP DO
  do l=1,r
    do k=1,n
      Yt(k,l) = 0.d0
    enddo
    do k=1,N_det
     i = psi_bilinear_matrix_rows(k)
     j = psi_bilinear_matrix_columns(k)
     Yt(j,l) = Yt(j,l) + Z(i,l) * psi_bilinear_matrix_values(k,1)
   enddo
  enddo
  !$OMP END DO
  !$OMP END PARALLEL

  allocate(D(r),V(n,r), UYt(r,r))
  call svd(Yt,size(Yt,1),V,size(V,1),D,UYt,size(UYt,1),n,r)
  deallocate(Yt)

  ! U(m,r) = Z(m,r).UY(r,r)
  allocate(U(m,r))
  call dgemm('N','T',m,r,r,1.d0,Z,size(Z,1),UYt,size(UYt,1),0.d0,U,size(U,1))
  deallocate(UYt,Z)

  do i=1,r
    print *, i, real(D(i)), real(D(i)**2), real(sum(D(1:i)**2))
    if (D(i) < 1.d-15) then
      k = i
      exit
    endif
  enddo
  print *, 'threshold: ', 2.858 * D(k/2)
  do i=1,m
    print '(I6,4(X,F12.8))', i, U(i,1:4)
  enddo
  print *, ''
  do i=1,n
    print '(I6,4(X,F12.8))', i, V(i,1:4)
  enddo

  deallocate(U,D,V)
end