mirror of
https://github.com/LCPQ/quantum_package
synced 2024-12-22 20:35:19 +01:00
Accelerated OCaml Psi messages
This commit is contained in:
parent
e8f35b59d4
commit
0ae7dfc224
@ -324,33 +324,28 @@ end
|
|||||||
|
|
||||||
(** GetPsiReply_msg : Reply to the GetPsi message *)
|
(** GetPsiReply_msg : Reply to the GetPsi message *)
|
||||||
module GetPsiReply_msg : sig
|
module GetPsiReply_msg : sig
|
||||||
type t =
|
type t = string list
|
||||||
{ client_id : Id.Client.t ;
|
val create : psi:Psi.t -> t
|
||||||
psi : Psi.t }
|
|
||||||
val create : client_id:Id.Client.t -> psi:Psi.t -> t
|
|
||||||
val to_string_list : t -> string list
|
|
||||||
val to_string : t -> string
|
val to_string : t -> string
|
||||||
end = struct
|
end = struct
|
||||||
type t =
|
type t = string list
|
||||||
{ client_id : Id.Client.t ;
|
let create ~psi =
|
||||||
psi : Psi.t }
|
|
||||||
let create ~client_id ~psi =
|
|
||||||
{ client_id ; psi }
|
|
||||||
let to_string x =
|
|
||||||
let g, s =
|
let g, s =
|
||||||
match x.psi.Psi.n_det_generators, x.psi.Psi.n_det_selectors with
|
match psi.Psi.n_det_generators, psi.Psi.n_det_selectors with
|
||||||
| Some g, Some s -> Strictly_positive_int.to_int g, Strictly_positive_int.to_int s
|
| Some g, Some s -> Strictly_positive_int.to_int g, Strictly_positive_int.to_int s
|
||||||
| _ -> -1, -1
|
| _ -> -1, -1
|
||||||
in
|
in
|
||||||
Printf.sprintf "get_psi_reply %d %d %d %d %d %d"
|
let head =
|
||||||
(Id.Client.to_int x.client_id)
|
Printf.sprintf "get_psi_reply %d %d %d %d %d"
|
||||||
(Strictly_positive_int.to_int x.psi.Psi.n_state)
|
(Strictly_positive_int.to_int psi.Psi.n_state)
|
||||||
(Strictly_positive_int.to_int x.psi.Psi.n_det)
|
(Strictly_positive_int.to_int psi.Psi.n_det)
|
||||||
(Strictly_positive_int.to_int x.psi.Psi.psi_det_size)
|
(Strictly_positive_int.to_int psi.Psi.psi_det_size)
|
||||||
g s
|
g s
|
||||||
let to_string_list x =
|
in
|
||||||
[ to_string x ;
|
[ head ; psi.Psi.psi_det ; psi.Psi.psi_coef ; psi.Psi.energy ]
|
||||||
x.psi.Psi.psi_det ; x.psi.Psi.psi_coef ; x.psi.Psi.energy ]
|
let to_string = function
|
||||||
|
| head :: _ :: _ :: _ :: [] -> head
|
||||||
|
| _ -> raise (Invalid_argument "Bad wave function message")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
@ -759,7 +754,6 @@ let to_string = function
|
|||||||
|
|
||||||
let to_string_list = function
|
let to_string_list = function
|
||||||
| PutPsi x -> PutPsi_msg.to_string_list x
|
| PutPsi x -> PutPsi_msg.to_string_list x
|
||||||
| GetPsiReply x -> GetPsiReply_msg.to_string_list x
|
|
||||||
| PutVector x -> PutVector_msg.to_string_list x
|
| PutVector x -> PutVector_msg.to_string_list x
|
||||||
| GetVectorReply x -> GetVectorReply_msg.to_string_list x
|
| GetVectorReply x -> GetVectorReply_msg.to_string_list x
|
||||||
| _ -> assert false
|
| _ -> assert false
|
||||||
|
@ -25,7 +25,7 @@ type t =
|
|||||||
state : Message.State.t option ;
|
state : Message.State.t option ;
|
||||||
address_tcp : Address.Tcp.t option ;
|
address_tcp : Address.Tcp.t option ;
|
||||||
address_inproc : Address.Inproc.t option ;
|
address_inproc : Address.Inproc.t option ;
|
||||||
psi : Message.Psi.t option;
|
psi : Message.GetPsiReply_msg.t option;
|
||||||
vector : Message.Vector.t option;
|
vector : Message.Vector.t option;
|
||||||
progress_bar : Progress_bar.t option ;
|
progress_bar : Progress_bar.t option ;
|
||||||
running : bool;
|
running : bool;
|
||||||
@ -483,7 +483,7 @@ let put_psi msg rest_of_msg program_state rep_socket =
|
|||||||
in
|
in
|
||||||
let new_program_state =
|
let new_program_state =
|
||||||
{ program_state with
|
{ program_state with
|
||||||
psi = Some psi_local
|
psi = Some (Message.GetPsiReply_msg.create ~psi:psi_local)
|
||||||
}
|
}
|
||||||
and client_id =
|
and client_id =
|
||||||
msg.Message.PutPsi_msg.client_id
|
msg.Message.PutPsi_msg.client_id
|
||||||
@ -496,17 +496,12 @@ let put_psi msg rest_of_msg program_state rep_socket =
|
|||||||
|
|
||||||
|
|
||||||
let get_psi msg program_state rep_socket =
|
let get_psi msg program_state rep_socket =
|
||||||
|
begin
|
||||||
let client_id =
|
match program_state.psi with
|
||||||
msg.Message.GetPsi_msg.client_id
|
| None -> failwith "No wave function saved in TaskServer"
|
||||||
in
|
| Some psi_message -> ZMQ.Socket.send_all rep_socket psi_message
|
||||||
match program_state.psi with
|
end;
|
||||||
| None -> failwith "No wave function saved in TaskServer"
|
program_state
|
||||||
| Some psi ->
|
|
||||||
Message.GetPsiReply (Message.GetPsiReply_msg.create ~client_id ~psi)
|
|
||||||
|> Message.to_string_list
|
|
||||||
|> ZMQ.Socket.send_all rep_socket;
|
|
||||||
program_state
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ type t =
|
|||||||
state : Message.State.t option ;
|
state : Message.State.t option ;
|
||||||
address_tcp : Address.Tcp.t option ;
|
address_tcp : Address.Tcp.t option ;
|
||||||
address_inproc : Address.Inproc.t option ;
|
address_inproc : Address.Inproc.t option ;
|
||||||
psi : Message.Psi.t option;
|
psi : Message.GetPsiReply_msg.t option;
|
||||||
vector : Message.Vector.t option ;
|
vector : Message.Vector.t option ;
|
||||||
progress_bar : Progress_bar.t option ;
|
progress_bar : Progress_bar.t option ;
|
||||||
running : bool;
|
running : bool;
|
||||||
|
@ -78,12 +78,8 @@ subroutine zmq_get_psi(zmq_to_qp_run_socket, worker_id, energy, size_energy)
|
|||||||
|
|
||||||
integer :: N_states_read, N_det_read, psi_det_size_read
|
integer :: N_states_read, N_det_read, psi_det_size_read
|
||||||
integer :: N_det_selectors_read, N_det_generators_read
|
integer :: N_det_selectors_read, N_det_generators_read
|
||||||
read(msg(14:rc),*) rc, N_states_read, N_det_read, psi_det_size_read, &
|
read(msg(14:rc),*) N_states_read, N_det_read, psi_det_size_read, &
|
||||||
N_det_generators_read, N_det_selectors_read
|
N_det_generators_read, N_det_selectors_read
|
||||||
if (rc /= worker_id) then
|
|
||||||
print *, 'Wrong worker ID'
|
|
||||||
stop 'error'
|
|
||||||
endif
|
|
||||||
|
|
||||||
N_states = N_states_read
|
N_states = N_states_read
|
||||||
N_det = N_det_read
|
N_det = N_det_read
|
||||||
|
@ -78,12 +78,8 @@ subroutine zmq_get_psi(zmq_to_qp_run_socket, worker_id, energy, size_energy)
|
|||||||
|
|
||||||
integer :: N_states_read, N_det_read, psi_det_size_read
|
integer :: N_states_read, N_det_read, psi_det_size_read
|
||||||
integer :: N_det_selectors_read, N_det_generators_read
|
integer :: N_det_selectors_read, N_det_generators_read
|
||||||
read(msg(14:rc),*) rc, N_states_read, N_det_read, psi_det_size_read, &
|
read(msg(14:rc),*) N_states_read, N_det_read, psi_det_size_read, &
|
||||||
N_det_generators_read, N_det_selectors_read
|
N_det_generators_read, N_det_selectors_read
|
||||||
if (rc /= worker_id) then
|
|
||||||
print *, 'Wrong worker ID'
|
|
||||||
stop 'error'
|
|
||||||
endif
|
|
||||||
|
|
||||||
N_states = N_states_read
|
N_states = N_states_read
|
||||||
N_det = N_det_read
|
N_det = N_det_read
|
||||||
|
@ -90,14 +90,9 @@ subroutine davidson_slave_work(zmq_to_qp_run_socket, zmq_socket_push, N_st, sze,
|
|||||||
stop 'error'
|
stop 'error'
|
||||||
endif
|
endif
|
||||||
|
|
||||||
read(msg(14:rc),*) rc, N_states_read, N_det_read, psi_det_size_read, &
|
read(msg(14:rc),*) N_states_read, N_det_read, psi_det_size_read, &
|
||||||
N_det_generators_read, N_det_selectors_read
|
N_det_generators_read, N_det_selectors_read
|
||||||
|
|
||||||
if (rc /= worker_id) then
|
|
||||||
print *, 'Wrong worker ID'
|
|
||||||
stop 'error'
|
|
||||||
endif
|
|
||||||
|
|
||||||
if (N_states_read /= N_st) then
|
if (N_states_read /= N_st) then
|
||||||
print *, N_st
|
print *, N_st
|
||||||
stop 'error : N_st'
|
stop 'error : N_st'
|
||||||
|
Loading…
Reference in New Issue
Block a user