Fixing Davidson

This commit is contained in:
Anthony Scemama 2019-01-26 19:10:39 +01:00
parent f1e14f0851
commit 318b1af239
7 changed files with 27 additions and 93 deletions

View File

@ -161,10 +161,6 @@ subroutine run_slave_main
call mpi_print('zmq_get_psi')
IRP_ENDIF
if (zmq_get_psi(zmq_to_qp_run_socket,1) == -1) cycle
IRP_IF MPI_DEBUG
call mpi_print('zmq_get_dvector energy')
IRP_ENDIF
if (zmq_get_dvector(zmq_to_qp_run_socket,1,'energy',energy,N_states_diag) == -1) cycle
call wall_time(t1)
call write_double(6,(t1-t0),'Broadcast time')

View File

@ -13,7 +13,7 @@ interface: ezfio,ocaml
[davidson_sze_max]
type: Strictly_positive_int
doc: Number of micro-iterations before re-contracting
default: 8
default: 15
interface: ezfio,provider,ocaml
[state_following]

View File

@ -36,7 +36,6 @@ subroutine davidson_run_slave(thread,iproc)
integer(ZMQ_PTR) :: zmq_socket_push
integer, external :: connect_to_taskserver
integer, external :: zmq_get_N_states_diag_notouch
PROVIDE mpi_rank
zmq_to_qp_run_socket = new_zmq_to_qp_run_socket()
@ -47,14 +46,6 @@ subroutine davidson_run_slave(thread,iproc)
if (connect_to_taskserver(zmq_to_qp_run_socket,worker_id,thread) == -1) then
return
endif
if (zmq_get_N_states_diag_notouch(zmq_to_qp_run_socket,1) == -1) then
if (disconnect_from_taskserver(zmq_to_qp_run_socket,worker_id) == -1) then
continue
endif
return
endif
! SOFT_TOUCH N_states_diag
call davidson_slave_work(zmq_to_qp_run_socket, zmq_socket_push, N_states_diag, N_det, worker_id)
@ -118,8 +109,9 @@ subroutine davidson_slave_work(zmq_to_qp_run_socket, zmq_socket_push, N_st, sze,
endif
do while (zmq_get_dmatrix(zmq_to_qp_run_socket, worker_id, 'u_t', u_t, ni, nj, size(u_t,kind=8)) == -1)
call sleep(1)
print *, irp_here, ': waiting for u_t...'
print *, 'mpi_rank, N_states_diag, N_det'
print *, mpi_rank, N_states_diag, N_det
stop 'u_t'
enddo
IRP_IF MPI
@ -324,9 +316,9 @@ subroutine H_S2_u_0_nstates_zmq(v_0,s_0,u_0,N_st,sze)
call new_parallel_job(zmq_to_qp_run_socket,zmq_socket_pull,'davidson')
integer :: N_states_diag_save
N_states_diag_save = N_states_diag
N_states_diag = N_st
! integer :: N_states_diag_save
! N_states_diag_save = N_states_diag
! N_states_diag = N_st
if (zmq_put_N_states_diag(zmq_to_qp_run_socket, 1) == -1) then
stop 'Unable to put N_states_diag on ZMQ server'
endif
@ -445,8 +437,8 @@ subroutine H_S2_u_0_nstates_zmq(v_0,s_0,u_0,N_st,sze)
!$OMP TASKWAIT
!$OMP END PARALLEL
N_states_diag = N_states_diag_save
SOFT_TOUCH N_states_diag
! N_states_diag = N_states_diag_save
! SOFT_TOUCH N_states_diag
end
@ -560,62 +552,3 @@ integer function zmq_get_N_states_diag(zmq_to_qp_run_socket, worker_id)
IRP_ENDIF
end
integer function zmq_get_N_states_diag_notouch(zmq_to_qp_run_socket, worker_id)
use f77_zmq
implicit none
BEGIN_DOC
! Get N_states_diag from the qp_run scheduler
END_DOC
integer(ZMQ_PTR), intent(in) :: zmq_to_qp_run_socket
integer, intent(in) :: worker_id
integer :: rc
character*(256) :: msg
zmq_get_N_states_diag_notouch = 0
if (mpi_master) then
write(msg,'(A,1X,I8,1X,A200)') 'get_data '//trim(zmq_state), worker_id, 'N_states_diag'
rc = f77_zmq_send(zmq_to_qp_run_socket,trim(msg),len(trim(msg)),0)
if (rc /= len(trim(msg))) go to 10
rc = f77_zmq_recv(zmq_to_qp_run_socket,msg,len(msg),0)
if (msg(1:14) /= 'get_data_reply') go to 10
rc = f77_zmq_recv(zmq_to_qp_run_socket,N_states_diag,4,0)
if (rc /= 4) go to 10
endif
IRP_IF MPI_DEBUG
print *, irp_here, mpi_rank
call MPI_BARRIER(MPI_COMM_WORLD, ierr)
IRP_ENDIF
IRP_IF MPI
include 'mpif.h'
integer :: ierr
call MPI_BCAST (zmq_get_N_states_diag_notouch, 1, MPI_INTEGER, 0, MPI_COMM_WORLD, ierr)
if (ierr /= MPI_SUCCESS) then
print *, irp_here//': Unable to broadcast N_states'
stop -1
endif
if (zmq_get_N_states_diag_notouch == 0) then
call MPI_BCAST (N_states_diag, 1, MPI_INTEGER, 0, MPI_COMM_WORLD, ierr)
if (ierr /= MPI_SUCCESS) then
print *, irp_here//': Unable to broadcast N_states'
stop -1
endif
endif
IRP_ENDIF
return
! Exception
10 continue
zmq_get_N_states_diag_notouch = -1
IRP_IF MPI
call MPI_BCAST (zmq_get_N_states_diag_notouch, 1, MPI_INTEGER, 0, MPI_COMM_WORLD, ierr)
if (ierr /= MPI_SUCCESS) then
print *, irp_here//': Unable to broadcast N_states'
stop -1
endif
IRP_ENDIF
end

View File

@ -202,9 +202,7 @@ subroutine davidson_diag_hjj_sjj(dets_in,u_in,H_jj,s2_out,energies,dim_in,sze,N_
exit
endif
if (N_st_diag > 2*N_states) then
N_st_diag = N_st_diag-1
else if (itermax > 4) then
if (itermax > 4) then
itermax = itermax - 1
else if (m==1.and.disk_based_davidson) then
m=0
@ -394,6 +392,8 @@ subroutine davidson_diag_hjj_sjj(dets_in,u_in,H_jj,s2_out,energies,dim_in,sze,N_
! call dgemm('T','N', shift2, shift2, sze, &
! 1.d0, U, size(U,1), S, size(S,1), &
! 0.d0, s_, size(s_,1))
!$OMP PARALLEL DO DEFAULT(SHARED) PRIVATE(i,j,k)
do j=1,shift2
do i=1,shift2
s_(i,j) = 0.d0
@ -402,6 +402,7 @@ subroutine davidson_diag_hjj_sjj(dets_in,u_in,H_jj,s2_out,energies,dim_in,sze,N_
enddo
enddo
enddo
!$OMP END PARALLEL DO
! Compute h_kl = <u_k | W_l> = <u_k| H |u_l>
! -------------------------------------------

View File

@ -45,17 +45,17 @@ subroutine u_0_H_u_0(e_0,s_0,u_0,n,keys_tmp,Nint,N_st,sze)
if ((n > 100000).and.distributed_davidson) then
allocate (v_0(n,N_states_diag),s_vec(n,N_states_diag), u_1(n,N_states_diag))
u_1(1:n,1:N_states) = u_0(1:n,1:N_states)
u_1(1:n,N_states+1:N_states_diag) = 0.d0
call H_S2_u_0_nstates_zmq(v_0,s_vec,u_1,N_st,n)
deallocate(u_1)
u_1(:,:) = 0.d0
u_1(1:n,1:N_st) = u_0(1:n,1:N_st)
call H_S2_u_0_nstates_zmq(v_0,s_vec,u_1,N_states_diag,n)
else
allocate (v_0(n,N_st),s_vec(n,N_st),u_1(n,N_st))
u_1(1:n,:) = u_0(1:n,:)
u_1(:,:) = 0.d0
u_1(1:n,1:N_st) = u_0(1:n,1:N_st)
call H_S2_u_0_nstates_openmp(v_0,s_vec,u_1,N_st,n)
u_0(1:n,:) = u_1(1:n,:)
deallocate(u_1)
endif
u_0(1:n,1:N_st) = u_1(1:n,1:N_st)
deallocate(u_1)
double precision :: norm
!$OMP PARALLEL DO PRIVATE(i,norm) DEFAULT(SHARED)
do i=1,N_st

View File

@ -397,7 +397,9 @@ integer function zmq_get_dmatrix(zmq_to_qp_run_socket, worker_id, name, x, size_
if (rc /= len(trim(msg))) then
print *, trim(msg)
zmq_get_dmatrix = -1
print *, irp_here, 'rc /= len(trim(msg))', rc, len(trim(msg))
print *, irp_here, 'rc /= len(trim(msg))'
print *, irp_here, ' received : ', rc
print *, irp_here, ' expected : ', len(trim(msg))
go to 10
endif
@ -411,7 +413,9 @@ integer function zmq_get_dmatrix(zmq_to_qp_run_socket, worker_id, name, x, size_
rc = f77_zmq_recv8(zmq_to_qp_run_socket,x(1,j),ni*8_8,0)
if (rc /= ni*8_8) then
print *, irp_here, 'rc /= size_x1*8', rc, ni*8_8
print *, irp_here, 'rc /= size_x1*8 : ', trim(name)
print *, irp_here, ' received: ', rc
print *, irp_here, ' expected: ', ni*8_8
zmq_get_dmatrix = -1
go to 10
endif

View File

@ -681,7 +681,7 @@ integer function connect_to_taskserver(zmq_to_qp_run_socket,worker_id,thread)
return
10 continue
print *, irp_here//': '//trim(message)
! print *, irp_here//': '//trim(message)
connect_to_taskserver = -1
end