9
1
mirror of https://github.com/QuantumPackage/qp2.git synced 2024-12-30 15:15:38 +01:00

Fixing Davidson

This commit is contained in:
Anthony Scemama 2019-01-26 19:10:39 +01:00
parent f1e14f0851
commit 318b1af239
7 changed files with 27 additions and 93 deletions

View File

@ -161,10 +161,6 @@ subroutine run_slave_main
call mpi_print('zmq_get_psi') call mpi_print('zmq_get_psi')
IRP_ENDIF IRP_ENDIF
if (zmq_get_psi(zmq_to_qp_run_socket,1) == -1) cycle if (zmq_get_psi(zmq_to_qp_run_socket,1) == -1) cycle
IRP_IF MPI_DEBUG
call mpi_print('zmq_get_dvector energy')
IRP_ENDIF
if (zmq_get_dvector(zmq_to_qp_run_socket,1,'energy',energy,N_states_diag) == -1) cycle
call wall_time(t1) call wall_time(t1)
call write_double(6,(t1-t0),'Broadcast time') call write_double(6,(t1-t0),'Broadcast time')

View File

@ -13,7 +13,7 @@ interface: ezfio,ocaml
[davidson_sze_max] [davidson_sze_max]
type: Strictly_positive_int type: Strictly_positive_int
doc: Number of micro-iterations before re-contracting doc: Number of micro-iterations before re-contracting
default: 8 default: 15
interface: ezfio,provider,ocaml interface: ezfio,provider,ocaml
[state_following] [state_following]

View File

@ -36,7 +36,6 @@ subroutine davidson_run_slave(thread,iproc)
integer(ZMQ_PTR) :: zmq_socket_push integer(ZMQ_PTR) :: zmq_socket_push
integer, external :: connect_to_taskserver integer, external :: connect_to_taskserver
integer, external :: zmq_get_N_states_diag_notouch
PROVIDE mpi_rank PROVIDE mpi_rank
zmq_to_qp_run_socket = new_zmq_to_qp_run_socket() zmq_to_qp_run_socket = new_zmq_to_qp_run_socket()
@ -47,14 +46,6 @@ subroutine davidson_run_slave(thread,iproc)
if (connect_to_taskserver(zmq_to_qp_run_socket,worker_id,thread) == -1) then if (connect_to_taskserver(zmq_to_qp_run_socket,worker_id,thread) == -1) then
return return
endif endif
if (zmq_get_N_states_diag_notouch(zmq_to_qp_run_socket,1) == -1) then
if (disconnect_from_taskserver(zmq_to_qp_run_socket,worker_id) == -1) then
continue
endif
return
endif
! SOFT_TOUCH N_states_diag
call davidson_slave_work(zmq_to_qp_run_socket, zmq_socket_push, N_states_diag, N_det, worker_id) call davidson_slave_work(zmq_to_qp_run_socket, zmq_socket_push, N_states_diag, N_det, worker_id)
@ -118,8 +109,9 @@ subroutine davidson_slave_work(zmq_to_qp_run_socket, zmq_socket_push, N_st, sze,
endif endif
do while (zmq_get_dmatrix(zmq_to_qp_run_socket, worker_id, 'u_t', u_t, ni, nj, size(u_t,kind=8)) == -1) do while (zmq_get_dmatrix(zmq_to_qp_run_socket, worker_id, 'u_t', u_t, ni, nj, size(u_t,kind=8)) == -1)
call sleep(1) print *, 'mpi_rank, N_states_diag, N_det'
print *, irp_here, ': waiting for u_t...' print *, mpi_rank, N_states_diag, N_det
stop 'u_t'
enddo enddo
IRP_IF MPI IRP_IF MPI
@ -324,9 +316,9 @@ subroutine H_S2_u_0_nstates_zmq(v_0,s_0,u_0,N_st,sze)
call new_parallel_job(zmq_to_qp_run_socket,zmq_socket_pull,'davidson') call new_parallel_job(zmq_to_qp_run_socket,zmq_socket_pull,'davidson')
integer :: N_states_diag_save ! integer :: N_states_diag_save
N_states_diag_save = N_states_diag ! N_states_diag_save = N_states_diag
N_states_diag = N_st ! N_states_diag = N_st
if (zmq_put_N_states_diag(zmq_to_qp_run_socket, 1) == -1) then if (zmq_put_N_states_diag(zmq_to_qp_run_socket, 1) == -1) then
stop 'Unable to put N_states_diag on ZMQ server' stop 'Unable to put N_states_diag on ZMQ server'
endif endif
@ -445,8 +437,8 @@ subroutine H_S2_u_0_nstates_zmq(v_0,s_0,u_0,N_st,sze)
!$OMP TASKWAIT !$OMP TASKWAIT
!$OMP END PARALLEL !$OMP END PARALLEL
N_states_diag = N_states_diag_save ! N_states_diag = N_states_diag_save
SOFT_TOUCH N_states_diag ! SOFT_TOUCH N_states_diag
end end
@ -560,62 +552,3 @@ integer function zmq_get_N_states_diag(zmq_to_qp_run_socket, worker_id)
IRP_ENDIF IRP_ENDIF
end end
integer function zmq_get_N_states_diag_notouch(zmq_to_qp_run_socket, worker_id)
use f77_zmq
implicit none
BEGIN_DOC
! Get N_states_diag from the qp_run scheduler
END_DOC
integer(ZMQ_PTR), intent(in) :: zmq_to_qp_run_socket
integer, intent(in) :: worker_id
integer :: rc
character*(256) :: msg
zmq_get_N_states_diag_notouch = 0
if (mpi_master) then
write(msg,'(A,1X,I8,1X,A200)') 'get_data '//trim(zmq_state), worker_id, 'N_states_diag'
rc = f77_zmq_send(zmq_to_qp_run_socket,trim(msg),len(trim(msg)),0)
if (rc /= len(trim(msg))) go to 10
rc = f77_zmq_recv(zmq_to_qp_run_socket,msg,len(msg),0)
if (msg(1:14) /= 'get_data_reply') go to 10
rc = f77_zmq_recv(zmq_to_qp_run_socket,N_states_diag,4,0)
if (rc /= 4) go to 10
endif
IRP_IF MPI_DEBUG
print *, irp_here, mpi_rank
call MPI_BARRIER(MPI_COMM_WORLD, ierr)
IRP_ENDIF
IRP_IF MPI
include 'mpif.h'
integer :: ierr
call MPI_BCAST (zmq_get_N_states_diag_notouch, 1, MPI_INTEGER, 0, MPI_COMM_WORLD, ierr)
if (ierr /= MPI_SUCCESS) then
print *, irp_here//': Unable to broadcast N_states'
stop -1
endif
if (zmq_get_N_states_diag_notouch == 0) then
call MPI_BCAST (N_states_diag, 1, MPI_INTEGER, 0, MPI_COMM_WORLD, ierr)
if (ierr /= MPI_SUCCESS) then
print *, irp_here//': Unable to broadcast N_states'
stop -1
endif
endif
IRP_ENDIF
return
! Exception
10 continue
zmq_get_N_states_diag_notouch = -1
IRP_IF MPI
call MPI_BCAST (zmq_get_N_states_diag_notouch, 1, MPI_INTEGER, 0, MPI_COMM_WORLD, ierr)
if (ierr /= MPI_SUCCESS) then
print *, irp_here//': Unable to broadcast N_states'
stop -1
endif
IRP_ENDIF
end

View File

@ -202,9 +202,7 @@ subroutine davidson_diag_hjj_sjj(dets_in,u_in,H_jj,s2_out,energies,dim_in,sze,N_
exit exit
endif endif
if (N_st_diag > 2*N_states) then if (itermax > 4) then
N_st_diag = N_st_diag-1
else if (itermax > 4) then
itermax = itermax - 1 itermax = itermax - 1
else if (m==1.and.disk_based_davidson) then else if (m==1.and.disk_based_davidson) then
m=0 m=0
@ -394,6 +392,8 @@ subroutine davidson_diag_hjj_sjj(dets_in,u_in,H_jj,s2_out,energies,dim_in,sze,N_
! call dgemm('T','N', shift2, shift2, sze, & ! call dgemm('T','N', shift2, shift2, sze, &
! 1.d0, U, size(U,1), S, size(S,1), & ! 1.d0, U, size(U,1), S, size(S,1), &
! 0.d0, s_, size(s_,1)) ! 0.d0, s_, size(s_,1))
!$OMP PARALLEL DO DEFAULT(SHARED) PRIVATE(i,j,k)
do j=1,shift2 do j=1,shift2
do i=1,shift2 do i=1,shift2
s_(i,j) = 0.d0 s_(i,j) = 0.d0
@ -402,6 +402,7 @@ subroutine davidson_diag_hjj_sjj(dets_in,u_in,H_jj,s2_out,energies,dim_in,sze,N_
enddo enddo
enddo enddo
enddo enddo
!$OMP END PARALLEL DO
! Compute h_kl = <u_k | W_l> = <u_k| H |u_l> ! Compute h_kl = <u_k | W_l> = <u_k| H |u_l>
! ------------------------------------------- ! -------------------------------------------

View File

@ -45,17 +45,17 @@ subroutine u_0_H_u_0(e_0,s_0,u_0,n,keys_tmp,Nint,N_st,sze)
if ((n > 100000).and.distributed_davidson) then if ((n > 100000).and.distributed_davidson) then
allocate (v_0(n,N_states_diag),s_vec(n,N_states_diag), u_1(n,N_states_diag)) allocate (v_0(n,N_states_diag),s_vec(n,N_states_diag), u_1(n,N_states_diag))
u_1(1:n,1:N_states) = u_0(1:n,1:N_states) u_1(:,:) = 0.d0
u_1(1:n,N_states+1:N_states_diag) = 0.d0 u_1(1:n,1:N_st) = u_0(1:n,1:N_st)
call H_S2_u_0_nstates_zmq(v_0,s_vec,u_1,N_st,n) call H_S2_u_0_nstates_zmq(v_0,s_vec,u_1,N_states_diag,n)
deallocate(u_1)
else else
allocate (v_0(n,N_st),s_vec(n,N_st),u_1(n,N_st)) allocate (v_0(n,N_st),s_vec(n,N_st),u_1(n,N_st))
u_1(1:n,:) = u_0(1:n,:) u_1(:,:) = 0.d0
u_1(1:n,1:N_st) = u_0(1:n,1:N_st)
call H_S2_u_0_nstates_openmp(v_0,s_vec,u_1,N_st,n) call H_S2_u_0_nstates_openmp(v_0,s_vec,u_1,N_st,n)
u_0(1:n,:) = u_1(1:n,:)
deallocate(u_1)
endif endif
u_0(1:n,1:N_st) = u_1(1:n,1:N_st)
deallocate(u_1)
double precision :: norm double precision :: norm
!$OMP PARALLEL DO PRIVATE(i,norm) DEFAULT(SHARED) !$OMP PARALLEL DO PRIVATE(i,norm) DEFAULT(SHARED)
do i=1,N_st do i=1,N_st

View File

@ -397,7 +397,9 @@ integer function zmq_get_dmatrix(zmq_to_qp_run_socket, worker_id, name, x, size_
if (rc /= len(trim(msg))) then if (rc /= len(trim(msg))) then
print *, trim(msg) print *, trim(msg)
zmq_get_dmatrix = -1 zmq_get_dmatrix = -1
print *, irp_here, 'rc /= len(trim(msg))', rc, len(trim(msg)) print *, irp_here, 'rc /= len(trim(msg))'
print *, irp_here, ' received : ', rc
print *, irp_here, ' expected : ', len(trim(msg))
go to 10 go to 10
endif endif
@ -411,7 +413,9 @@ integer function zmq_get_dmatrix(zmq_to_qp_run_socket, worker_id, name, x, size_
rc = f77_zmq_recv8(zmq_to_qp_run_socket,x(1,j),ni*8_8,0) rc = f77_zmq_recv8(zmq_to_qp_run_socket,x(1,j),ni*8_8,0)
if (rc /= ni*8_8) then if (rc /= ni*8_8) then
print *, irp_here, 'rc /= size_x1*8', rc, ni*8_8 print *, irp_here, 'rc /= size_x1*8 : ', trim(name)
print *, irp_here, ' received: ', rc
print *, irp_here, ' expected: ', ni*8_8
zmq_get_dmatrix = -1 zmq_get_dmatrix = -1
go to 10 go to 10
endif endif

View File

@ -681,7 +681,7 @@ integer function connect_to_taskserver(zmq_to_qp_run_socket,worker_id,thread)
return return
10 continue 10 continue
print *, irp_here//': '//trim(message) ! print *, irp_here//': '//trim(message)
connect_to_taskserver = -1 connect_to_taskserver = -1
end end