From ce10c5052c5dd653f49ae116d51a84e4bb8b5718 Mon Sep 17 00:00:00 2001 From: Anthony Scemama Date: Wed, 3 May 2017 21:15:54 +0200 Subject: [PATCH] Travis bug --- ocaml/Message.ml | 139 ++++++++++++++++++++++++++- ocaml/Message_lexer.mll | 21 +++- ocaml/TaskServer.ml | 51 ++++++++++ ocaml/TaskServer.mli | 1 + plugins/CAS_SD_ZMQ/cassd_zmq.irp.f | 4 +- src/Davidson/davidson_parallel.irp.f | 9 +- 6 files changed, 214 insertions(+), 11 deletions(-) diff --git a/ocaml/Message.ml b/ocaml/Message.ml index 2ed38864..7a1d1712 100644 --- a/ocaml/Message.ml +++ b/ocaml/Message.ml @@ -455,6 +455,122 @@ end = struct end +(** GetVector : get the current vector (Davidson) *) +module GetVector_msg : sig + type t = + { client_id: Id.Client.t ; + } + val create : client_id:int -> t + val to_string : t -> string +end = struct + type t = + { client_id: Id.Client.t ; + } + let create ~client_id = + { client_id = Id.Client.of_int client_id } + let to_string x = + Printf.sprintf "get_vector %d" + (Id.Client.to_int x.client_id) +end + +module Vector : sig + type t = + { + size : Strictly_positive_int.t; + data : string; + } + val create : size:Strictly_positive_int.t -> data:string -> t +end = struct + type t = + { + size : Strictly_positive_int.t; + data : string; + } + let create ~size ~data = + { size ; data } +end + +(** GetVectorReply_msg : Reply to the GetVector message *) +module GetVectorReply_msg : sig + type t = + { client_id : Id.Client.t ; + vector : Vector.t } + val create : client_id:Id.Client.t -> vector:Vector.t -> t + val to_string : t -> string + val to_string_list : t -> string list +end = struct + type t = + { client_id : Id.Client.t ; + vector : Vector.t } + let create ~client_id ~vector = + { client_id ; vector } + let to_string x = + Printf.sprintf "get_vector_reply %d %d" + (Id.Client.to_int x.client_id) + (Strictly_positive_int.to_int x.vector.Vector.size) + let to_string_list x = + [ to_string x ; x.vector.Vector.data ] +end + +(** PutVector : put the current variational wave function *) +module PutVector_msg : sig + type t = + { client_id : Id.Client.t ; + size : Strictly_positive_int.t ; + vector : Vector.t option; + } + val create : + client_id:int -> size:int -> data:string option -> t + val to_string_list : t -> string list + val to_string : t -> string +end = struct + type t = + { client_id : Id.Client.t ; + size : Strictly_positive_int.t ; + vector : Vector.t option; + } + let create ~client_id ~size ~data = + let size = + Strictly_positive_int.of_int size + in + let vector = + match data with + | None -> None + | Some s -> Some (Vector.create ~size ~data:s) + in + { client_id = Id.Client.of_int client_id ; + vector ; size + } + + let to_string x = + Printf.sprintf "put_vector %d %d" + (Id.Client.to_int x.client_id) + (Strictly_positive_int.to_int x.size) + + let to_string_list x = + match x.vector with + | Some v -> [ to_string x ; v.Vector.data ] + | None -> failwith "Empty vector" +end + +(** PutVectorReply_msg : Reply to the PutVector message *) +module PutVectorReply_msg : sig + type t + val create : client_id:Id.Client.t -> t + val to_string : t -> string +end = struct + type t = + { client_id : Id.Client.t ; + } + let create ~client_id = + { client_id; } + let to_string x = + Printf.sprintf "put_vector_reply %d" + (Id.Client.to_int x.client_id) +end + + + (** TaskDone : Inform the server that a task is finished *) module TaskDone_msg : sig type t = @@ -526,6 +642,10 @@ type t = | PutPsi of PutPsi_msg.t | GetPsiReply of GetPsiReply_msg.t | PutPsiReply of PutPsiReply_msg.t +| GetVector of GetVector_msg.t +| PutVector of PutVector_msg.t +| GetVectorReply of GetVectorReply_msg.t +| PutVectorReply of PutVectorReply_msg.t | Newjob of Newjob_msg.t | Endjob of Endjob_msg.t | Connect of Connect_msg.t @@ -580,6 +700,10 @@ let of_string s = ~n_det_generators:None ~n_det_selectors:None ~psi_det:None ~psi_coef:None ~energy:None ) end + | GetVector_ client_id -> + GetVector (GetVector_msg.create ~client_id) + | PutVector_ { client_id ; size } -> + PutVector (PutVector_msg.create ~client_id ~size ~data:None ) | Terminate_ -> Terminate (Terminate_msg.create ) | SetWaiting_ -> SetWaiting | SetStopped_ -> SetStopped @@ -592,6 +716,8 @@ let of_string s = let to_string = function | GetPsi x -> GetPsi_msg.to_string x | PutPsiReply x -> PutPsiReply_msg.to_string x +| GetVector x -> GetVector_msg.to_string x +| PutVectorReply x -> PutVectorReply_msg.to_string x | Newjob x -> Newjob_msg.to_string x | Endjob x -> Endjob_msg.to_string x | Connect x -> Connect_msg.to_string x @@ -600,8 +726,8 @@ let to_string = function | DisconnectReply x -> DisconnectReply_msg.to_string x | GetTask x -> GetTask_msg.to_string x | GetTaskReply x -> GetTaskReply_msg.to_string x -| DelTask x -> DelTask_msg.to_string x -| DelTaskReply x -> DelTaskReply_msg.to_string x +| DelTask x -> DelTask_msg.to_string x +| DelTaskReply x -> DelTaskReply_msg.to_string x | AddTask x -> AddTask_msg.to_string x | AddTaskReply x -> AddTaskReply_msg.to_string x | TaskDone x -> TaskDone_msg.to_string x @@ -610,12 +736,17 @@ let to_string = function | Error x -> Error_msg.to_string x | PutPsi x -> PutPsi_msg.to_string x | GetPsiReply x -> GetPsiReply_msg.to_string x +| PutVector x -> PutVector_msg.to_string x +| GetVectorReply x -> GetVectorReply_msg.to_string x | SetStopped -> "set_stopped" | SetRunning -> "set_running" | SetWaiting -> "set_waiting" let to_string_list = function -| PutPsi x -> PutPsi_msg.to_string_list x -| GetPsiReply x -> GetPsiReply_msg.to_string_list x +| PutPsi x -> PutPsi_msg.to_string_list x +| GetPsiReply x -> GetPsiReply_msg.to_string_list x +| PutVector x -> PutVector_msg.to_string_list x +| GetVectorReply x -> GetVectorReply_msg.to_string_list x | _ -> assert false + diff --git a/ocaml/Message_lexer.mll b/ocaml/Message_lexer.mll index c67f4528..b85baecf 100644 --- a/ocaml/Message_lexer.mll +++ b/ocaml/Message_lexer.mll @@ -17,6 +17,8 @@ type kw_type = | TERMINATE | GET_PSI | PUT_PSI + | GET_VECTOR + | PUT_VECTOR | OK | ERROR | SET_STOPPED @@ -29,7 +31,8 @@ type state_taskids_clientid = { state : string ; task_ids : int list ; type state_clientid = { state : string ; client_id : int ; } type state_tcp_inproc = { state : string ; push_address_tcp : string ; push_address_inproc : string ; } type psi = { client_id: int ; n_state: int ; n_det: int ; psi_det_size: int ; - n_det_generators: int option ; n_det_selectors: int option } + n_det_generators: int option ; n_det_selectors: int option ; } +type vector = { client_id: int ; size: int } type msg = | AddTask_ of state_tasks @@ -43,6 +46,8 @@ type msg = | Terminate_ | GetPsi_ of int | PutPsi_ of psi + | GetVector_ of int + | PutVector_ of vector | Ok_ | Error_ of string | SetStopped_ @@ -85,6 +90,8 @@ and kw = parse | "terminate" { TERMINATE } | "get_psi" { GET_PSI } | "put_psi" { PUT_PSI } + | "get_vector" { GET_PSI } + | "put_vector" { PUT_PSI } | "ok" { OK } | "error" { ERROR } | "set_stopped" { SET_STOPPED } @@ -179,6 +186,15 @@ and kw = parse in PutPsi_ { client_id ; n_state ; n_det ; psi_det_size ; n_det_generators ; n_det_selectors } + | GET_VECTOR -> + let client_id = read_int lexbuf in + GetVector_ client_id + + | PUT_VECTOR -> + let client_id = read_int lexbuf in + let size = read_int lexbuf in + PutVector_ { client_id ; size } + | CONNECT -> let socket = read_word lexbuf in Connect_ socket @@ -253,6 +269,9 @@ and kw = parse | Some s, Some g -> Printf.sprintf "PUT_PSI client_id:%d n_state:%d n_det:%d psi_det_size:%d n_det_generators:%d n_det_selectors:%d" client_id n_state n_det psi_det_size g s | _ -> Printf.sprintf "PUT_PSI client_id:%d n_state:%d n_det:%d psi_det_size:%d" client_id n_state n_det psi_det_size end + | GetVector_ client_id -> Printf.sprintf "GET_VECTOR client_id:%d" client_id + | PutVector_ { client_id ; size } -> + Printf.sprintf "PUT_VECTOR client_id:%d size:%d" client_id size | Terminate_ -> "TERMINATE" | SetWaiting_ -> "SET_WAITING" | SetStopped_ -> "SET_STOPPED" diff --git a/ocaml/TaskServer.ml b/ocaml/TaskServer.ml index 6537f579..887c7482 100644 --- a/ocaml/TaskServer.ml +++ b/ocaml/TaskServer.ml @@ -26,6 +26,7 @@ type t = address_tcp : Address.Tcp.t option ; address_inproc : Address.Inproc.t option ; psi : Message.Psi.t option; + vector : Message.Vector.t option; progress_bar : Progress_bar.t option ; running : bool; } @@ -523,10 +524,57 @@ let get_psi msg program_state rep_socket = +let put_vector msg rest_of_msg program_state rep_socket = + + let vector_local = + match msg.Message.PutVector_msg.vector with + | Some x -> x + | None -> + begin + let data = + match rest_of_msg with + | [ x ] -> x + | _ -> failwith "Badly formed put_vector message" + in + Message.Vector.create + ~size:msg.Message.PutVector_msg.size + ~data + end + in + let new_program_state = + { program_state with + vector = Some vector_local + } + and client_id = + msg.Message.PutVector_msg.client_id + in + Message.PutVectorReply (Message.PutVectorReply_msg.create ~client_id) + |> Message.to_string + |> ZMQ.Socket.send rep_socket; + + new_program_state + + +let get_vector msg program_state rep_socket = + + let client_id = + msg.Message.GetVector_msg.client_id + in + match program_state.vector with + | None -> failwith "No wave function saved in TaskServer" + | Some vector -> + Message.GetVectorReply (Message.GetVectorReply_msg.create ~client_id ~vector) + |> Message.to_string_list + |> ZMQ.Socket.send_all rep_socket; + program_state + + + let terminate program_state rep_socket = reply_ok rep_socket; { program_state with psi = None; + vector = None; address_tcp = None; address_inproc = None; running = false @@ -610,6 +658,7 @@ let run ~port = { queue = Queuing_system.create () ; running = true ; psi = None; + vector = None; state = None; address_tcp = None; address_inproc = None; @@ -679,6 +728,8 @@ let run ~port = try match program_state.state, message with | _ , Message.Terminate _ -> terminate program_state rep_socket + | _ , Message.PutVector x -> put_vector x rest program_state rep_socket + | _ , Message.GetVector x -> get_vector x program_state rep_socket | _ , Message.PutPsi x -> put_psi x rest program_state rep_socket | _ , Message.GetPsi x -> get_psi x program_state rep_socket | None , Message.Newjob x -> new_job x program_state rep_socket pair_socket diff --git a/ocaml/TaskServer.mli b/ocaml/TaskServer.mli index e1baab12..7098b55a 100644 --- a/ocaml/TaskServer.mli +++ b/ocaml/TaskServer.mli @@ -5,6 +5,7 @@ type t = address_tcp : Address.Tcp.t option ; address_inproc : Address.Inproc.t option ; psi : Message.Psi.t option; + vector : Message.Vector.t option ; progress_bar : Progress_bar.t option ; running : bool; } diff --git a/plugins/CAS_SD_ZMQ/cassd_zmq.irp.f b/plugins/CAS_SD_ZMQ/cassd_zmq.irp.f index ffacdd8a..f8ee7ba2 100644 --- a/plugins/CAS_SD_ZMQ/cassd_zmq.irp.f +++ b/plugins/CAS_SD_ZMQ/cassd_zmq.irp.f @@ -91,7 +91,7 @@ program cassd_zmq enddo endif E_CI_before(1:N_states) = CI_energy(1:N_states) - call ezfio_set_full_ci_zmq_energy(CI_energy(1)) + call ezfio_set_cas_sd_zmq_energy(CI_energy(1)) n_det_before = N_det to_select = N_det @@ -116,7 +116,7 @@ program cassd_zmq threshold_davidson = threshold_davidson_in call diagonalize_CI call save_wavefunction - call ezfio_set_full_ci_zmq_energy(CI_energy(1)) + call ezfio_set_cas_sd_zmq_energy(CI_energy(1)) endif integer :: exc_max, degree_min diff --git a/src/Davidson/davidson_parallel.irp.f b/src/Davidson/davidson_parallel.irp.f index 68db35da..f4114adb 100644 --- a/src/Davidson/davidson_parallel.irp.f +++ b/src/Davidson/davidson_parallel.irp.f @@ -70,6 +70,11 @@ subroutine davidson_slave_work(zmq_to_qp_run_socket, zmq_socket_push, N_st, sze, ! ----------------------- integer :: rc + integer :: N_states_read, N_det_read, psi_det_size_read + integer :: N_det_selectors_read, N_det_generators_read + double precision :: energy(N_st) + + write(msg, *) 'get_psi ', worker_id rc = f77_zmq_send(zmq_to_qp_run_socket,trim(msg),len(trim(msg)),0) if (rc /= len(trim(msg))) then @@ -84,10 +89,6 @@ subroutine davidson_slave_work(zmq_to_qp_run_socket, zmq_socket_push, N_st, sze, stop 'error' endif - integer :: N_states_read, N_det_read, psi_det_size_read - integer :: N_det_selectors_read, N_det_generators_read - double precision :: energy(N_st) - read(msg(14:rc),*) rc, N_states_read, N_det_read, psi_det_size_read, & N_det_generators_read, N_det_selectors_read