diff --git a/ocaml/Message.ml b/ocaml/Message.ml index 72fb41b5..faf5ed69 100644 --- a/ocaml/Message.ml +++ b/ocaml/Message.ml @@ -324,33 +324,28 @@ end (** GetPsiReply_msg : Reply to the GetPsi message *) module GetPsiReply_msg : sig - type t = - { client_id : Id.Client.t ; - psi : Psi.t } - val create : client_id:Id.Client.t -> psi:Psi.t -> t - val to_string_list : t -> string list + type t = string list + val create : psi:Psi.t -> t val to_string : t -> string end = struct - type t = - { client_id : Id.Client.t ; - psi : Psi.t } - let create ~client_id ~psi = - { client_id ; psi } - let to_string x = + type t = string list + let create ~psi = let g, s = - match x.psi.Psi.n_det_generators, x.psi.Psi.n_det_selectors with + match psi.Psi.n_det_generators, psi.Psi.n_det_selectors with | Some g, Some s -> Strictly_positive_int.to_int g, Strictly_positive_int.to_int s | _ -> -1, -1 in - Printf.sprintf "get_psi_reply %d %d %d %d %d %d" - (Id.Client.to_int x.client_id) - (Strictly_positive_int.to_int x.psi.Psi.n_state) - (Strictly_positive_int.to_int x.psi.Psi.n_det) - (Strictly_positive_int.to_int x.psi.Psi.psi_det_size) - g s - let to_string_list x = - [ to_string x ; - x.psi.Psi.psi_det ; x.psi.Psi.psi_coef ; x.psi.Psi.energy ] + let head = + Printf.sprintf "get_psi_reply %d %d %d %d %d" + (Strictly_positive_int.to_int psi.Psi.n_state) + (Strictly_positive_int.to_int psi.Psi.n_det) + (Strictly_positive_int.to_int psi.Psi.psi_det_size) + g s + in + [ head ; psi.Psi.psi_det ; psi.Psi.psi_coef ; psi.Psi.energy ] + let to_string = function + | head :: _ :: _ :: _ :: [] -> head + | _ -> raise (Invalid_argument "Bad wave function message") end @@ -759,7 +754,6 @@ let to_string = function let to_string_list = function | PutPsi x -> PutPsi_msg.to_string_list x -| GetPsiReply x -> GetPsiReply_msg.to_string_list x | PutVector x -> PutVector_msg.to_string_list x | GetVectorReply x -> GetVectorReply_msg.to_string_list x | _ -> assert false diff --git a/ocaml/TaskServer.ml b/ocaml/TaskServer.ml index 91fbd231..1ed403f7 100644 --- a/ocaml/TaskServer.ml +++ b/ocaml/TaskServer.ml @@ -25,7 +25,7 @@ type t = state : Message.State.t option ; address_tcp : Address.Tcp.t option ; address_inproc : Address.Inproc.t option ; - psi : Message.Psi.t option; + psi : Message.GetPsiReply_msg.t option; vector : Message.Vector.t option; progress_bar : Progress_bar.t option ; running : bool; @@ -483,7 +483,7 @@ let put_psi msg rest_of_msg program_state rep_socket = in let new_program_state = { program_state with - psi = Some psi_local + psi = Some (Message.GetPsiReply_msg.create ~psi:psi_local) } and client_id = msg.Message.PutPsi_msg.client_id @@ -496,17 +496,12 @@ let put_psi msg rest_of_msg program_state rep_socket = let get_psi msg program_state rep_socket = - - let client_id = - msg.Message.GetPsi_msg.client_id - in - match program_state.psi with - | None -> failwith "No wave function saved in TaskServer" - | Some psi -> - Message.GetPsiReply (Message.GetPsiReply_msg.create ~client_id ~psi) - |> Message.to_string_list - |> ZMQ.Socket.send_all rep_socket; - program_state + begin + match program_state.psi with + | None -> failwith "No wave function saved in TaskServer" + | Some psi_message -> ZMQ.Socket.send_all rep_socket psi_message + end; + program_state diff --git a/ocaml/TaskServer.mli b/ocaml/TaskServer.mli index 7098b55a..4f93dc77 100644 --- a/ocaml/TaskServer.mli +++ b/ocaml/TaskServer.mli @@ -4,7 +4,7 @@ type t = state : Message.State.t option ; address_tcp : Address.Tcp.t option ; address_inproc : Address.Inproc.t option ; - psi : Message.Psi.t option; + psi : Message.GetPsiReply_msg.t option; vector : Message.Vector.t option ; progress_bar : Progress_bar.t option ; running : bool; diff --git a/plugins/Selectors_CASSD/zmq.irp.f b/plugins/Selectors_CASSD/zmq.irp.f index 4359a876..37dd29de 100644 --- a/plugins/Selectors_CASSD/zmq.irp.f +++ b/plugins/Selectors_CASSD/zmq.irp.f @@ -78,12 +78,8 @@ subroutine zmq_get_psi(zmq_to_qp_run_socket, worker_id, energy, size_energy) integer :: N_states_read, N_det_read, psi_det_size_read integer :: N_det_selectors_read, N_det_generators_read - read(msg(14:rc),*) rc, N_states_read, N_det_read, psi_det_size_read, & + read(msg(14:rc),*) N_states_read, N_det_read, psi_det_size_read, & N_det_generators_read, N_det_selectors_read - if (rc /= worker_id) then - print *, 'Wrong worker ID' - stop 'error' - endif N_states = N_states_read N_det = N_det_read diff --git a/plugins/Selectors_full/zmq.irp.f b/plugins/Selectors_full/zmq.irp.f index 59f40daf..f7f0c4b0 100644 --- a/plugins/Selectors_full/zmq.irp.f +++ b/plugins/Selectors_full/zmq.irp.f @@ -78,12 +78,8 @@ subroutine zmq_get_psi(zmq_to_qp_run_socket, worker_id, energy, size_energy) integer :: N_states_read, N_det_read, psi_det_size_read integer :: N_det_selectors_read, N_det_generators_read - read(msg(14:rc),*) rc, N_states_read, N_det_read, psi_det_size_read, & + read(msg(14:rc),*) N_states_read, N_det_read, psi_det_size_read, & N_det_generators_read, N_det_selectors_read - if (rc /= worker_id) then - print *, 'Wrong worker ID' - stop 'error' - endif N_states = N_states_read N_det = N_det_read diff --git a/src/Davidson/davidson_parallel.irp.f b/src/Davidson/davidson_parallel.irp.f index 76386c7b..9af78b4f 100644 --- a/src/Davidson/davidson_parallel.irp.f +++ b/src/Davidson/davidson_parallel.irp.f @@ -90,14 +90,9 @@ subroutine davidson_slave_work(zmq_to_qp_run_socket, zmq_socket_push, N_st, sze, stop 'error' endif - read(msg(14:rc),*) rc, N_states_read, N_det_read, psi_det_size_read, & + read(msg(14:rc),*) N_states_read, N_det_read, psi_det_size_read, & N_det_generators_read, N_det_selectors_read - if (rc /= worker_id) then - print *, 'Wrong worker ID' - stop 'error' - endif - if (N_states_read /= N_st) then print *, N_st stop 'error : N_st'