From 75c59840dd166632823254bb18783237ffdbcb1a Mon Sep 17 00:00:00 2001 From: Anthony Scemama Date: Mon, 6 Feb 2023 16:32:23 +0100 Subject: [PATCH] OCaml code modernization --- ocaml/Command_line.ml | 19 +++--- ocaml/Command_line.mli | 2 + ocaml/Makefile | 2 +- ocaml/Message.ml | 8 +-- ocaml/Molecule.ml | 4 +- ocaml/Progress_bar.ml | 2 +- ocaml/Qputils.ml | 4 ++ ocaml/TaskServer.ml | 4 +- ocaml/qp_create_ezfio.ml | 11 +++- ocaml/qp_run.ml | 9 +-- ocaml/qp_tunnel.ml | 139 +++++++++++++++++++++++---------------- 11 files changed, 122 insertions(+), 82 deletions(-) diff --git a/ocaml/Command_line.ml b/ocaml/Command_line.ml index 1dd57892..602315c6 100644 --- a/ocaml/Command_line.ml +++ b/ocaml/Command_line.ml @@ -1,3 +1,5 @@ +exception Error of string + type short_opt = char type long_opt = string type optional = Mandatory | Optional @@ -181,15 +183,16 @@ let set_specs specs_in = Getopt.parse_cmdline cmd_specs (fun x -> anon_args := !anon_args @ [x]); if show_help () then - (help () ; exit 0); + help () + else + (* Check that all mandatory arguments are set *) + List.filter (fun x -> x.short <> ' ' && x.opt = Mandatory) !specs + |> List.iter (fun x -> + match get x.long with + | Some _ -> () + | None -> raise (Error ("--"^x.long^" option is missing.")) + ) - (* Check that all mandatory arguments are set *) - List.filter (fun x -> x.short <> ' ' && x.opt = Mandatory) !specs - |> List.iter (fun x -> - match get x.long with - | Some _ -> () - | None -> failwith ("Error: --"^x.long^" option is missing.") - ) ;; diff --git a/ocaml/Command_line.mli b/ocaml/Command_line.mli index 9f6e7022..5ad4ee08 100644 --- a/ocaml/Command_line.mli +++ b/ocaml/Command_line.mli @@ -59,6 +59,8 @@ let () = *) +exception Error of string + type short_opt = char type long_opt = string diff --git a/ocaml/Makefile b/ocaml/Makefile index 40d292fe..8853a991 100644 --- a/ocaml/Makefile +++ b/ocaml/Makefile @@ -29,7 +29,7 @@ tests: $(ALL_TESTS) .gitignore: $(MLFILES) $(MLIFILES) @for i in .gitignore ezfio.ml element_create_db Qptypes.ml Git.ml *.byte *.native _build $(ALL_EXE) $(ALL_TESTS) \ $(patsubst %.ml,%,$(wildcard test_*.ml)) $(patsubst %.ml,%,$(wildcard qp_*.ml)) \ - $(shell grep Input Input_auto_generated.ml | awk '{print $$2 ".ml"}') \ + Input_*.ml \ qp_edit.ml qp_edit qp_edit.native Input_auto_generated.ml;\ do \ echo $$i ; \ diff --git a/ocaml/Message.ml b/ocaml/Message.ml index b7d77430..049203d7 100644 --- a/ocaml/Message.ml +++ b/ocaml/Message.ml @@ -63,11 +63,11 @@ end module Connect_msg : sig type t = Tcp | Inproc | Ipc - val create : typ:string -> t + val create : string -> t val to_string : t -> string end = struct type t = Tcp | Inproc | Ipc - let create ~typ = + let create typ = match typ with | "tcp" -> Tcp | "inproc" -> Inproc @@ -515,9 +515,9 @@ let of_string s = | Connect_ socket -> Connect (Connect_msg.create socket) | NewJob_ { state ; push_address_tcp ; push_address_inproc } -> - Newjob (Newjob_msg.create push_address_tcp push_address_inproc state) + Newjob (Newjob_msg.create ~address_tcp:push_address_tcp ~address_inproc:push_address_inproc ~state) | EndJob_ state -> - Endjob (Endjob_msg.create state) + Endjob (Endjob_msg.create ~state) | GetData_ { state ; client_id ; key } -> GetData (GetData_msg.create ~client_id ~state ~key) | PutData_ { state ; client_id ; key } -> diff --git a/ocaml/Molecule.ml b/ocaml/Molecule.ml index 9b01ac3a..603244c8 100644 --- a/ocaml/Molecule.ml +++ b/ocaml/Molecule.ml @@ -101,7 +101,7 @@ let to_string_general ~f m = |> String.concat "\n" let to_string = - to_string_general ~f:(fun x -> Atom.to_string Units.Angstrom x) + to_string_general ~f:(fun x -> Atom.to_string ~units:Units.Angstrom x) let to_xyz = to_string_general ~f:Atom.to_xyz @@ -113,7 +113,7 @@ let of_xyz_string s = let l = String_ext.split s ~on:'\n' |> List.filter (fun x -> x <> "") - |> list_map (fun x -> Atom.of_string units x) + |> list_map (fun x -> Atom.of_string ~units x) in let ne = ( get_charge { nuclei=l ; diff --git a/ocaml/Progress_bar.ml b/ocaml/Progress_bar.ml index bc720b95..2cd3d19e 100644 --- a/ocaml/Progress_bar.ml +++ b/ocaml/Progress_bar.ml @@ -10,7 +10,7 @@ type t = next : float; } -let init ?(bar_length=20) ?(start_value=0.) ?(end_value=1.) ~title = +let init ?(bar_length=20) ?(start_value=0.) ?(end_value=1.) title = { title ; start_value ; end_value ; bar_length ; cur_value=start_value ; init_time= Unix.time () ; dirty = false ; next = Unix.time () } diff --git a/ocaml/Qputils.ml b/ocaml/Qputils.ml index 270e069f..752a65a0 100644 --- a/ocaml/Qputils.ml +++ b/ocaml/Qputils.ml @@ -56,3 +56,7 @@ let string_of_string s = s let list_map f l = List.rev_map f l |> List.rev + +let socket_convert socket = + ((Obj.magic (Obj.repr socket)) : [ `Xsub ] Zmq.Socket.t ) + diff --git a/ocaml/TaskServer.ml b/ocaml/TaskServer.ml index 92a6f5ca..5d9d2416 100644 --- a/ocaml/TaskServer.ml +++ b/ocaml/TaskServer.ml @@ -155,7 +155,7 @@ let new_job msg program_state rep_socket pair_socket = ~start_value:0. ~end_value:1. ~bar_length:20 - ~title:(Message.State.to_string state) + (Message.State.to_string state) in let result = @@ -776,7 +776,7 @@ let run ~port = Zmq.Socket.create zmq_context Zmq.Socket.rep in Zmq.Socket.set_linger_period rep_socket 1_000_000; - bind_socket "REP" rep_socket port; + bind_socket ~socket_type:"REP" ~socket:rep_socket ~port; let initial_program_state = { queue = Queuing_system.create () ; diff --git a/ocaml/qp_create_ezfio.ml b/ocaml/qp_create_ezfio.ml index be6c305b..4583b118 100644 --- a/ocaml/qp_create_ezfio.ml +++ b/ocaml/qp_create_ezfio.ml @@ -677,6 +677,7 @@ let run ?o b au c d m p cart xyz_file = let () = + try ( let open Command_line in begin @@ -734,7 +735,7 @@ If a file with the same name as the basis set exists, this file will be read. O let basis = match Command_line.get "basis" with - | None -> assert false + | None -> "" | Some x -> x in @@ -773,10 +774,14 @@ If a file with the same name as the basis set exists, this file will be read. O let xyz_filename = match Command_line.anon_args () with - | [x] -> x - | _ -> (Command_line.help () ; failwith "input file is missing") + | [] -> failwith "input file is missing" + | x::_ -> x in run ?o:output basis au charge dummy multiplicity pseudo cart xyz_filename + ) + with + | Failure txt -> Printf.eprintf "Fatal error: %s\n%!" txt + | Command_line.Error txt -> Printf.eprintf "Command line error: %s\n%!" txt diff --git a/ocaml/qp_run.ml b/ocaml/qp_run.ml index d096b15b..b9d14efe 100644 --- a/ocaml/qp_run.ml +++ b/ocaml/qp_run.ml @@ -6,7 +6,7 @@ open Qputils *) - + let print_list () = Lazy.force Qpackage.executables |> List.iter (fun (x,_) -> Printf.printf " * %s\n" x) @@ -110,7 +110,7 @@ let run slave ?prefix exe ezfio_file = let task_thread = let thread = Thread.create ( fun () -> - TaskServer.run port_number ) + TaskServer.run ~port:port_number ) in thread (); in @@ -151,10 +151,11 @@ let run slave ?prefix exe ezfio_file = let duration = Unix.time () -. time_start |> Unix.gmtime in let open Unix in let d, h, m, s = - duration.tm_yday, duration.tm_hour, duration.tm_min, duration.tm_sec + duration.tm_yday, duration.tm_hour, duration.tm_min, duration.tm_sec in Printf.printf "Wall time: %d:%2.2d:%2.2d" (d*24+h) m s ; Printf.printf "\n\n"; + Unix.sleep 1; if (exit_code <> 0) then exit exit_code @@ -187,7 +188,7 @@ let () = end; (* Handle options *) - let slave = Command_line.get_bool "slave" + let slave = Command_line.get_bool "slave" and prefix = Command_line.get "prefix" in diff --git a/ocaml/qp_tunnel.ml b/ocaml/qp_tunnel.ml index 84e50eb5..95aadabf 100644 --- a/ocaml/qp_tunnel.ml +++ b/ocaml/qp_tunnel.ml @@ -2,7 +2,7 @@ open Qputils open Qptypes type ezfio_or_address = EZFIO of string | ADDRESS of string -type req_or_sub = REQ | SUB +type req_or_sub = REQ | SUB let localport = 42379 @@ -29,7 +29,7 @@ let () = end; let arg = - let x = + let x = match Command_line.anon_args () with | [x] -> x | _ -> begin @@ -44,7 +44,7 @@ let () = in - let localhost = + let localhost = Lazy.force TaskServer.ip_address in @@ -52,28 +52,28 @@ let () = let long_address = match arg with | ADDRESS x -> x - | EZFIO x -> - let ic = + | EZFIO x -> + let ic = Filename.concat (Qpackage.ezfio_work x) "qp_run_address" |> open_in in - let result = + let result = input_line ic |> String.trim in close_in ic; result in - + let protocol, address, port = match String.split_on_char ':' long_address with | t :: a :: p :: [] -> t, a, int_of_string p - | _ -> failwith @@ + | _ -> failwith @@ Printf.sprintf "%s : Malformed address" long_address in - let zmq_context = + let zmq_context = Zmq.Context.create () in @@ -105,10 +105,10 @@ let () = let create_socket sock_type bind_or_connect addr = - let socket = + let socket = Zmq.Socket.create zmq_context sock_type in - let () = + let () = try bind_or_connect socket addr with @@ -131,37 +131,64 @@ let () = Sys.set_signal Sys.sigint handler; - let new_thread req_or_sub addr_in addr_out = + let new_thread_req addr_in addr_out = let socket_in, socket_out = - match req_or_sub with - | REQ -> create_socket Zmq.Socket.router Zmq.Socket.bind addr_in, create_socket Zmq.Socket.dealer Zmq.Socket.connect addr_out - | SUB -> - create_socket Zmq.Socket.sub Zmq.Socket.connect addr_in, - create_socket Zmq.Socket.pub Zmq.Socket.bind addr_out in - if req_or_sub = SUB then - Zmq.Socket.subscribe socket_in ""; - - - let action_in = - match req_or_sub with - | REQ -> (fun () -> Zmq.Socket.recv_all socket_in |> Zmq.Socket.send_all socket_out) - | SUB -> (fun () -> Zmq.Socket.recv_all socket_in |> Zmq.Socket.send_all socket_out) + let action_in = + fun () -> Zmq.Socket.recv_all socket_in |> Zmq.Socket.send_all socket_out in - let action_out = - match req_or_sub with - | REQ -> (fun () -> Zmq.Socket.recv_all socket_out |> Zmq.Socket.send_all socket_in ) - | SUB -> (fun () -> () ) + let action_out = + fun () -> Zmq.Socket.recv_all socket_out |> Zmq.Socket.send_all socket_in in let pollitem = Zmq.Poll.mask_of - [| (socket_in, Zmq.Poll.In) ; (socket_out, Zmq.Poll.In) |] + [| (socket_convert socket_in, Zmq.Poll.In) ; (socket_convert socket_out, Zmq.Poll.In) |] + in + + while !run_status do + + let polling = + Zmq.Poll.poll ~timeout:1000 pollitem + in + + match polling with + | [| Some Zmq.Poll.In ; Some Zmq.Poll.In |] -> ( action_out () ; action_in () ) + | [| _ ; Some Zmq.Poll.In |] -> action_out () + | [| Some Zmq.Poll.In ; _ |] -> action_in () + | _ -> () + done; + + Zmq.Socket.close socket_in; + Zmq.Socket.close socket_out; + in + + let new_thread_sub addr_in addr_out = + let socket_in, socket_out = + create_socket Zmq.Socket.sub Zmq.Socket.connect addr_in, + create_socket Zmq.Socket.pub Zmq.Socket.bind addr_out + in + + Zmq.Socket.subscribe socket_in ""; + + + + let action_in = + fun () -> Zmq.Socket.recv_all socket_in |> Zmq.Socket.send_all socket_out + in + + let action_out = + fun () -> () + in + + let pollitem = + Zmq.Poll.mask_of + [| (socket_convert socket_in, Zmq.Poll.In) ; (socket_convert socket_out, Zmq.Poll.In) |] in @@ -173,8 +200,8 @@ let () = match polling with | [| Some Zmq.Poll.In ; Some Zmq.Poll.In |] -> ( action_out () ; action_in () ) - | [| _ ; Some Zmq.Poll.In |] -> action_out () - | [| Some Zmq.Poll.In ; _ |] -> action_in () + | [| _ ; Some Zmq.Poll.In |] -> action_out () + | [| Some Zmq.Poll.In ; _ |] -> action_in () | _ -> () done; @@ -193,8 +220,8 @@ let () = Printf.sprintf "tcp://*:%d" localport in - let f () = - new_thread REQ addr_in addr_out + let f () = + new_thread_req addr_in addr_out in (Thread.create f) () @@ -211,8 +238,8 @@ let () = Printf.sprintf "tcp://*:%d" (localport+2) in - let f () = - new_thread REQ addr_in addr_out + let f () = + new_thread_req addr_in addr_out in (Thread.create f) () in @@ -227,8 +254,8 @@ let () = Printf.sprintf "tcp://*:%d" (localport+1) in - let f () = - new_thread SUB addr_in addr_out + let f () = + new_thread_sub addr_in addr_out in (Thread.create f) () in @@ -236,7 +263,7 @@ let () = let input_thread = - let f () = + let f () = let addr_out = match arg with | EZFIO _ -> None @@ -248,22 +275,22 @@ let () = Printf.sprintf "tcp://*:%d" (localport+9) in - let socket_in = + let socket_in = create_socket Zmq.Socket.rep Zmq.Socket.bind addr_in in let socket_out = - match addr_out with + match addr_out with | Some addr_out -> Some ( create_socket Zmq.Socket.req Zmq.Socket.connect addr_out) | None -> None in - let temp_file = + let temp_file = Filename.temp_file "qp_tunnel" ".tar.gz" in - let get_ezfio_filename () = + let get_ezfio_filename () = match arg with | EZFIO x -> x | ADDRESS _ -> @@ -277,9 +304,9 @@ let () = end in - let get_input () = + let get_input () = match arg with - | EZFIO x -> + | EZFIO x -> begin Printf.sprintf "tar --exclude=\"*.gz.*\" -zcf %s %s" temp_file x |> Sys.command |> ignore; @@ -291,11 +318,11 @@ let () = in ignore @@ Unix.lseek fd 0 Unix.SEEK_SET ; let bstr = - Unix.map_file fd Bigarray.char + Unix.map_file fd Bigarray.char Bigarray.c_layout false [| len |] |> Bigarray.array1_of_genarray in - let result = + let result = String.init len (fun i -> bstr.{i}) ; in Unix.close fd; @@ -313,7 +340,7 @@ let () = end in - let () = + let () = match socket_out with | None -> () | Some socket_out -> @@ -329,7 +356,7 @@ let () = | ADDRESS _ -> begin Printf.printf "Getting input... %!"; - let ezfio_filename = + let ezfio_filename = get_ezfio_filename () in Printf.printf "%s%!" ezfio_filename; @@ -343,7 +370,7 @@ let () = |> Sys.command |> ignore ; let oc = Filename.concat (Qpackage.ezfio_work ezfio_filename) "qp_run_address" - |> open_out + |> open_out in Printf.fprintf oc "tcp://%s:%d\n" localhost localport; close_out oc; @@ -359,9 +386,9 @@ let () = let action () = match Zmq.Socket.recv socket_in with | "get_input" -> get_input () - |> Zmq.Socket.send socket_in + |> Zmq.Socket.send socket_in | "get_ezfio_filename" -> get_ezfio_filename () - |> Zmq.Socket.send socket_in + |> Zmq.Socket.send socket_in | "test" -> Zmq.Socket.send socket_in "OK" | x -> Printf.sprintf "Message '%s' not understood" x |> Zmq.Socket.send socket_in @@ -372,7 +399,7 @@ On remote hosts, create ssh tunnel using: ssh -L %d:%s:%d -L %d:%s:%d -L %d:%s:%d -L %d:%s:%d %s & Or from this host connect to clients using: ssh -R %d:localhost:%d -R %d:localhost:%d -R %d:localhost:%d -R %d:localhost:%d & -%!" +%!" (port ) localhost (localport ) (port+1) localhost (localport+1) (port+2) localhost (localport+2) @@ -392,12 +419,12 @@ Or from this host connect to clients using: match polling.(0) with | Some Zmq.Poll.In -> action () | None -> () - | Some Zmq.Poll.In_out + | Some Zmq.Poll.In_out | Some Zmq.Poll.Out -> () done; - let () = + let () = match socket_out with | Some socket_out -> Zmq.Socket.close socket_out | None -> () @@ -415,7 +442,5 @@ Or from this host connect to clients using: Thread.join ocaml_thread; Zmq.Context.terminate zmq_context; Printf.printf "qp_tunnel exited properly.\n" - -