mirror of
https://github.com/LCPQ/quantum_package
synced 2024-12-22 20:35:19 +01:00
Accelerated OCaml Psi messages
This commit is contained in:
parent
e8f35b59d4
commit
0ae7dfc224
@ -324,33 +324,28 @@ end
|
||||
|
||||
(** GetPsiReply_msg : Reply to the GetPsi message *)
|
||||
module GetPsiReply_msg : sig
|
||||
type t =
|
||||
{ client_id : Id.Client.t ;
|
||||
psi : Psi.t }
|
||||
val create : client_id:Id.Client.t -> psi:Psi.t -> t
|
||||
val to_string_list : t -> string list
|
||||
type t = string list
|
||||
val create : psi:Psi.t -> t
|
||||
val to_string : t -> string
|
||||
end = struct
|
||||
type t =
|
||||
{ client_id : Id.Client.t ;
|
||||
psi : Psi.t }
|
||||
let create ~client_id ~psi =
|
||||
{ client_id ; psi }
|
||||
let to_string x =
|
||||
type t = string list
|
||||
let create ~psi =
|
||||
let g, s =
|
||||
match x.psi.Psi.n_det_generators, x.psi.Psi.n_det_selectors with
|
||||
match psi.Psi.n_det_generators, psi.Psi.n_det_selectors with
|
||||
| Some g, Some s -> Strictly_positive_int.to_int g, Strictly_positive_int.to_int s
|
||||
| _ -> -1, -1
|
||||
in
|
||||
Printf.sprintf "get_psi_reply %d %d %d %d %d %d"
|
||||
(Id.Client.to_int x.client_id)
|
||||
(Strictly_positive_int.to_int x.psi.Psi.n_state)
|
||||
(Strictly_positive_int.to_int x.psi.Psi.n_det)
|
||||
(Strictly_positive_int.to_int x.psi.Psi.psi_det_size)
|
||||
g s
|
||||
let to_string_list x =
|
||||
[ to_string x ;
|
||||
x.psi.Psi.psi_det ; x.psi.Psi.psi_coef ; x.psi.Psi.energy ]
|
||||
let head =
|
||||
Printf.sprintf "get_psi_reply %d %d %d %d %d"
|
||||
(Strictly_positive_int.to_int psi.Psi.n_state)
|
||||
(Strictly_positive_int.to_int psi.Psi.n_det)
|
||||
(Strictly_positive_int.to_int psi.Psi.psi_det_size)
|
||||
g s
|
||||
in
|
||||
[ head ; psi.Psi.psi_det ; psi.Psi.psi_coef ; psi.Psi.energy ]
|
||||
let to_string = function
|
||||
| head :: _ :: _ :: _ :: [] -> head
|
||||
| _ -> raise (Invalid_argument "Bad wave function message")
|
||||
end
|
||||
|
||||
|
||||
@ -759,7 +754,6 @@ let to_string = function
|
||||
|
||||
let to_string_list = function
|
||||
| PutPsi x -> PutPsi_msg.to_string_list x
|
||||
| GetPsiReply x -> GetPsiReply_msg.to_string_list x
|
||||
| PutVector x -> PutVector_msg.to_string_list x
|
||||
| GetVectorReply x -> GetVectorReply_msg.to_string_list x
|
||||
| _ -> assert false
|
||||
|
@ -25,7 +25,7 @@ type t =
|
||||
state : Message.State.t option ;
|
||||
address_tcp : Address.Tcp.t option ;
|
||||
address_inproc : Address.Inproc.t option ;
|
||||
psi : Message.Psi.t option;
|
||||
psi : Message.GetPsiReply_msg.t option;
|
||||
vector : Message.Vector.t option;
|
||||
progress_bar : Progress_bar.t option ;
|
||||
running : bool;
|
||||
@ -483,7 +483,7 @@ let put_psi msg rest_of_msg program_state rep_socket =
|
||||
in
|
||||
let new_program_state =
|
||||
{ program_state with
|
||||
psi = Some psi_local
|
||||
psi = Some (Message.GetPsiReply_msg.create ~psi:psi_local)
|
||||
}
|
||||
and client_id =
|
||||
msg.Message.PutPsi_msg.client_id
|
||||
@ -496,17 +496,12 @@ let put_psi msg rest_of_msg program_state rep_socket =
|
||||
|
||||
|
||||
let get_psi msg program_state rep_socket =
|
||||
|
||||
let client_id =
|
||||
msg.Message.GetPsi_msg.client_id
|
||||
in
|
||||
match program_state.psi with
|
||||
| None -> failwith "No wave function saved in TaskServer"
|
||||
| Some psi ->
|
||||
Message.GetPsiReply (Message.GetPsiReply_msg.create ~client_id ~psi)
|
||||
|> Message.to_string_list
|
||||
|> ZMQ.Socket.send_all rep_socket;
|
||||
program_state
|
||||
begin
|
||||
match program_state.psi with
|
||||
| None -> failwith "No wave function saved in TaskServer"
|
||||
| Some psi_message -> ZMQ.Socket.send_all rep_socket psi_message
|
||||
end;
|
||||
program_state
|
||||
|
||||
|
||||
|
||||
|
@ -4,7 +4,7 @@ type t =
|
||||
state : Message.State.t option ;
|
||||
address_tcp : Address.Tcp.t option ;
|
||||
address_inproc : Address.Inproc.t option ;
|
||||
psi : Message.Psi.t option;
|
||||
psi : Message.GetPsiReply_msg.t option;
|
||||
vector : Message.Vector.t option ;
|
||||
progress_bar : Progress_bar.t option ;
|
||||
running : bool;
|
||||
|
@ -78,12 +78,8 @@ subroutine zmq_get_psi(zmq_to_qp_run_socket, worker_id, energy, size_energy)
|
||||
|
||||
integer :: N_states_read, N_det_read, psi_det_size_read
|
||||
integer :: N_det_selectors_read, N_det_generators_read
|
||||
read(msg(14:rc),*) rc, N_states_read, N_det_read, psi_det_size_read, &
|
||||
read(msg(14:rc),*) N_states_read, N_det_read, psi_det_size_read, &
|
||||
N_det_generators_read, N_det_selectors_read
|
||||
if (rc /= worker_id) then
|
||||
print *, 'Wrong worker ID'
|
||||
stop 'error'
|
||||
endif
|
||||
|
||||
N_states = N_states_read
|
||||
N_det = N_det_read
|
||||
|
@ -78,12 +78,8 @@ subroutine zmq_get_psi(zmq_to_qp_run_socket, worker_id, energy, size_energy)
|
||||
|
||||
integer :: N_states_read, N_det_read, psi_det_size_read
|
||||
integer :: N_det_selectors_read, N_det_generators_read
|
||||
read(msg(14:rc),*) rc, N_states_read, N_det_read, psi_det_size_read, &
|
||||
read(msg(14:rc),*) N_states_read, N_det_read, psi_det_size_read, &
|
||||
N_det_generators_read, N_det_selectors_read
|
||||
if (rc /= worker_id) then
|
||||
print *, 'Wrong worker ID'
|
||||
stop 'error'
|
||||
endif
|
||||
|
||||
N_states = N_states_read
|
||||
N_det = N_det_read
|
||||
|
@ -90,14 +90,9 @@ subroutine davidson_slave_work(zmq_to_qp_run_socket, zmq_socket_push, N_st, sze,
|
||||
stop 'error'
|
||||
endif
|
||||
|
||||
read(msg(14:rc),*) rc, N_states_read, N_det_read, psi_det_size_read, &
|
||||
read(msg(14:rc),*) N_states_read, N_det_read, psi_det_size_read, &
|
||||
N_det_generators_read, N_det_selectors_read
|
||||
|
||||
if (rc /= worker_id) then
|
||||
print *, 'Wrong worker ID'
|
||||
stop 'error'
|
||||
endif
|
||||
|
||||
if (N_states_read /= N_st) then
|
||||
print *, N_st
|
||||
stop 'error : N_st'
|
||||
|
Loading…
Reference in New Issue
Block a user