10
0
mirror of https://github.com/LCPQ/quantum_package synced 2024-06-26 15:12:14 +02:00

Accelerated OCaml Psi messages

This commit is contained in:
Anthony Scemama 2017-05-18 18:53:55 +02:00
parent e8f35b59d4
commit 0ae7dfc224
6 changed files with 28 additions and 52 deletions

View File

@ -324,33 +324,28 @@ end
(** GetPsiReply_msg : Reply to the GetPsi message *)
module GetPsiReply_msg : sig
type t =
{ client_id : Id.Client.t ;
psi : Psi.t }
val create : client_id:Id.Client.t -> psi:Psi.t -> t
val to_string_list : t -> string list
type t = string list
val create : psi:Psi.t -> t
val to_string : t -> string
end = struct
type t =
{ client_id : Id.Client.t ;
psi : Psi.t }
let create ~client_id ~psi =
{ client_id ; psi }
let to_string x =
type t = string list
let create ~psi =
let g, s =
match x.psi.Psi.n_det_generators, x.psi.Psi.n_det_selectors with
match psi.Psi.n_det_generators, psi.Psi.n_det_selectors with
| Some g, Some s -> Strictly_positive_int.to_int g, Strictly_positive_int.to_int s
| _ -> -1, -1
in
Printf.sprintf "get_psi_reply %d %d %d %d %d %d"
(Id.Client.to_int x.client_id)
(Strictly_positive_int.to_int x.psi.Psi.n_state)
(Strictly_positive_int.to_int x.psi.Psi.n_det)
(Strictly_positive_int.to_int x.psi.Psi.psi_det_size)
g s
let to_string_list x =
[ to_string x ;
x.psi.Psi.psi_det ; x.psi.Psi.psi_coef ; x.psi.Psi.energy ]
let head =
Printf.sprintf "get_psi_reply %d %d %d %d %d"
(Strictly_positive_int.to_int psi.Psi.n_state)
(Strictly_positive_int.to_int psi.Psi.n_det)
(Strictly_positive_int.to_int psi.Psi.psi_det_size)
g s
in
[ head ; psi.Psi.psi_det ; psi.Psi.psi_coef ; psi.Psi.energy ]
let to_string = function
| head :: _ :: _ :: _ :: [] -> head
| _ -> raise (Invalid_argument "Bad wave function message")
end
@ -759,7 +754,6 @@ let to_string = function
let to_string_list = function
| PutPsi x -> PutPsi_msg.to_string_list x
| GetPsiReply x -> GetPsiReply_msg.to_string_list x
| PutVector x -> PutVector_msg.to_string_list x
| GetVectorReply x -> GetVectorReply_msg.to_string_list x
| _ -> assert false

View File

@ -25,7 +25,7 @@ type t =
state : Message.State.t option ;
address_tcp : Address.Tcp.t option ;
address_inproc : Address.Inproc.t option ;
psi : Message.Psi.t option;
psi : Message.GetPsiReply_msg.t option;
vector : Message.Vector.t option;
progress_bar : Progress_bar.t option ;
running : bool;
@ -483,7 +483,7 @@ let put_psi msg rest_of_msg program_state rep_socket =
in
let new_program_state =
{ program_state with
psi = Some psi_local
psi = Some (Message.GetPsiReply_msg.create ~psi:psi_local)
}
and client_id =
msg.Message.PutPsi_msg.client_id
@ -496,17 +496,12 @@ let put_psi msg rest_of_msg program_state rep_socket =
let get_psi msg program_state rep_socket =
let client_id =
msg.Message.GetPsi_msg.client_id
in
match program_state.psi with
| None -> failwith "No wave function saved in TaskServer"
| Some psi ->
Message.GetPsiReply (Message.GetPsiReply_msg.create ~client_id ~psi)
|> Message.to_string_list
|> ZMQ.Socket.send_all rep_socket;
program_state
begin
match program_state.psi with
| None -> failwith "No wave function saved in TaskServer"
| Some psi_message -> ZMQ.Socket.send_all rep_socket psi_message
end;
program_state

View File

@ -4,7 +4,7 @@ type t =
state : Message.State.t option ;
address_tcp : Address.Tcp.t option ;
address_inproc : Address.Inproc.t option ;
psi : Message.Psi.t option;
psi : Message.GetPsiReply_msg.t option;
vector : Message.Vector.t option ;
progress_bar : Progress_bar.t option ;
running : bool;

View File

@ -78,12 +78,8 @@ subroutine zmq_get_psi(zmq_to_qp_run_socket, worker_id, energy, size_energy)
integer :: N_states_read, N_det_read, psi_det_size_read
integer :: N_det_selectors_read, N_det_generators_read
read(msg(14:rc),*) rc, N_states_read, N_det_read, psi_det_size_read, &
read(msg(14:rc),*) N_states_read, N_det_read, psi_det_size_read, &
N_det_generators_read, N_det_selectors_read
if (rc /= worker_id) then
print *, 'Wrong worker ID'
stop 'error'
endif
N_states = N_states_read
N_det = N_det_read

View File

@ -78,12 +78,8 @@ subroutine zmq_get_psi(zmq_to_qp_run_socket, worker_id, energy, size_energy)
integer :: N_states_read, N_det_read, psi_det_size_read
integer :: N_det_selectors_read, N_det_generators_read
read(msg(14:rc),*) rc, N_states_read, N_det_read, psi_det_size_read, &
read(msg(14:rc),*) N_states_read, N_det_read, psi_det_size_read, &
N_det_generators_read, N_det_selectors_read
if (rc /= worker_id) then
print *, 'Wrong worker ID'
stop 'error'
endif
N_states = N_states_read
N_det = N_det_read

View File

@ -90,14 +90,9 @@ subroutine davidson_slave_work(zmq_to_qp_run_socket, zmq_socket_push, N_st, sze,
stop 'error'
endif
read(msg(14:rc),*) rc, N_states_read, N_det_read, psi_det_size_read, &
read(msg(14:rc),*) N_states_read, N_det_read, psi_det_size_read, &
N_det_generators_read, N_det_selectors_read
if (rc /= worker_id) then
print *, 'Wrong worker ID'
stop 'error'
endif
if (N_states_read /= N_st) then
print *, N_st
stop 'error : N_st'