10
0
mirror of https://github.com/LCPQ/quantum_package synced 2024-12-22 20:35:19 +01:00

Accelerated OCaml Psi messages

This commit is contained in:
Anthony Scemama 2017-05-18 18:53:55 +02:00
parent e8f35b59d4
commit 0ae7dfc224
6 changed files with 28 additions and 52 deletions

View File

@ -324,33 +324,28 @@ end
(** GetPsiReply_msg : Reply to the GetPsi message *) (** GetPsiReply_msg : Reply to the GetPsi message *)
module GetPsiReply_msg : sig module GetPsiReply_msg : sig
type t = type t = string list
{ client_id : Id.Client.t ; val create : psi:Psi.t -> t
psi : Psi.t }
val create : client_id:Id.Client.t -> psi:Psi.t -> t
val to_string_list : t -> string list
val to_string : t -> string val to_string : t -> string
end = struct end = struct
type t = type t = string list
{ client_id : Id.Client.t ; let create ~psi =
psi : Psi.t }
let create ~client_id ~psi =
{ client_id ; psi }
let to_string x =
let g, s = let g, s =
match x.psi.Psi.n_det_generators, x.psi.Psi.n_det_selectors with match psi.Psi.n_det_generators, psi.Psi.n_det_selectors with
| Some g, Some s -> Strictly_positive_int.to_int g, Strictly_positive_int.to_int s | Some g, Some s -> Strictly_positive_int.to_int g, Strictly_positive_int.to_int s
| _ -> -1, -1 | _ -> -1, -1
in in
Printf.sprintf "get_psi_reply %d %d %d %d %d %d" let head =
(Id.Client.to_int x.client_id) Printf.sprintf "get_psi_reply %d %d %d %d %d"
(Strictly_positive_int.to_int x.psi.Psi.n_state) (Strictly_positive_int.to_int psi.Psi.n_state)
(Strictly_positive_int.to_int x.psi.Psi.n_det) (Strictly_positive_int.to_int psi.Psi.n_det)
(Strictly_positive_int.to_int x.psi.Psi.psi_det_size) (Strictly_positive_int.to_int psi.Psi.psi_det_size)
g s g s
let to_string_list x = in
[ to_string x ; [ head ; psi.Psi.psi_det ; psi.Psi.psi_coef ; psi.Psi.energy ]
x.psi.Psi.psi_det ; x.psi.Psi.psi_coef ; x.psi.Psi.energy ] let to_string = function
| head :: _ :: _ :: _ :: [] -> head
| _ -> raise (Invalid_argument "Bad wave function message")
end end
@ -759,7 +754,6 @@ let to_string = function
let to_string_list = function let to_string_list = function
| PutPsi x -> PutPsi_msg.to_string_list x | PutPsi x -> PutPsi_msg.to_string_list x
| GetPsiReply x -> GetPsiReply_msg.to_string_list x
| PutVector x -> PutVector_msg.to_string_list x | PutVector x -> PutVector_msg.to_string_list x
| GetVectorReply x -> GetVectorReply_msg.to_string_list x | GetVectorReply x -> GetVectorReply_msg.to_string_list x
| _ -> assert false | _ -> assert false

View File

@ -25,7 +25,7 @@ type t =
state : Message.State.t option ; state : Message.State.t option ;
address_tcp : Address.Tcp.t option ; address_tcp : Address.Tcp.t option ;
address_inproc : Address.Inproc.t option ; address_inproc : Address.Inproc.t option ;
psi : Message.Psi.t option; psi : Message.GetPsiReply_msg.t option;
vector : Message.Vector.t option; vector : Message.Vector.t option;
progress_bar : Progress_bar.t option ; progress_bar : Progress_bar.t option ;
running : bool; running : bool;
@ -483,7 +483,7 @@ let put_psi msg rest_of_msg program_state rep_socket =
in in
let new_program_state = let new_program_state =
{ program_state with { program_state with
psi = Some psi_local psi = Some (Message.GetPsiReply_msg.create ~psi:psi_local)
} }
and client_id = and client_id =
msg.Message.PutPsi_msg.client_id msg.Message.PutPsi_msg.client_id
@ -496,16 +496,11 @@ let put_psi msg rest_of_msg program_state rep_socket =
let get_psi msg program_state rep_socket = let get_psi msg program_state rep_socket =
begin
let client_id =
msg.Message.GetPsi_msg.client_id
in
match program_state.psi with match program_state.psi with
| None -> failwith "No wave function saved in TaskServer" | None -> failwith "No wave function saved in TaskServer"
| Some psi -> | Some psi_message -> ZMQ.Socket.send_all rep_socket psi_message
Message.GetPsiReply (Message.GetPsiReply_msg.create ~client_id ~psi) end;
|> Message.to_string_list
|> ZMQ.Socket.send_all rep_socket;
program_state program_state

View File

@ -4,7 +4,7 @@ type t =
state : Message.State.t option ; state : Message.State.t option ;
address_tcp : Address.Tcp.t option ; address_tcp : Address.Tcp.t option ;
address_inproc : Address.Inproc.t option ; address_inproc : Address.Inproc.t option ;
psi : Message.Psi.t option; psi : Message.GetPsiReply_msg.t option;
vector : Message.Vector.t option ; vector : Message.Vector.t option ;
progress_bar : Progress_bar.t option ; progress_bar : Progress_bar.t option ;
running : bool; running : bool;

View File

@ -78,12 +78,8 @@ subroutine zmq_get_psi(zmq_to_qp_run_socket, worker_id, energy, size_energy)
integer :: N_states_read, N_det_read, psi_det_size_read integer :: N_states_read, N_det_read, psi_det_size_read
integer :: N_det_selectors_read, N_det_generators_read integer :: N_det_selectors_read, N_det_generators_read
read(msg(14:rc),*) rc, N_states_read, N_det_read, psi_det_size_read, & read(msg(14:rc),*) N_states_read, N_det_read, psi_det_size_read, &
N_det_generators_read, N_det_selectors_read N_det_generators_read, N_det_selectors_read
if (rc /= worker_id) then
print *, 'Wrong worker ID'
stop 'error'
endif
N_states = N_states_read N_states = N_states_read
N_det = N_det_read N_det = N_det_read

View File

@ -78,12 +78,8 @@ subroutine zmq_get_psi(zmq_to_qp_run_socket, worker_id, energy, size_energy)
integer :: N_states_read, N_det_read, psi_det_size_read integer :: N_states_read, N_det_read, psi_det_size_read
integer :: N_det_selectors_read, N_det_generators_read integer :: N_det_selectors_read, N_det_generators_read
read(msg(14:rc),*) rc, N_states_read, N_det_read, psi_det_size_read, & read(msg(14:rc),*) N_states_read, N_det_read, psi_det_size_read, &
N_det_generators_read, N_det_selectors_read N_det_generators_read, N_det_selectors_read
if (rc /= worker_id) then
print *, 'Wrong worker ID'
stop 'error'
endif
N_states = N_states_read N_states = N_states_read
N_det = N_det_read N_det = N_det_read

View File

@ -90,14 +90,9 @@ subroutine davidson_slave_work(zmq_to_qp_run_socket, zmq_socket_push, N_st, sze,
stop 'error' stop 'error'
endif endif
read(msg(14:rc),*) rc, N_states_read, N_det_read, psi_det_size_read, & read(msg(14:rc),*) N_states_read, N_det_read, psi_det_size_read, &
N_det_generators_read, N_det_selectors_read N_det_generators_read, N_det_selectors_read
if (rc /= worker_id) then
print *, 'Wrong worker ID'
stop 'error'
endif
if (N_states_read /= N_st) then if (N_states_read /= N_st) then
print *, N_st print *, N_st
stop 'error : N_st' stop 'error : N_st'