10
0
mirror of https://github.com/QuantumPackage/qp2.git synced 2024-11-16 02:53:51 +01:00
QuantumPackage/ocaml/qp_tunnel.ml

422 lines
10 KiB
OCaml

open Qputils
open Qptypes
type ezfio_or_address = EZFIO of string | ADDRESS of string
type req_or_sub = REQ | SUB
let localport = 42379
let in_time_sum = ref 1.e-9
and in_size_sum = ref 0.
let () =
let open Command_line in
begin
"Creates an ssh tunnel for using slaves on another network. Launch a server on the front-end node of the cluster on which the master process runs. Then start a client ont the front-end node of the distant cluster."
|> set_footer_doc ;
[ { short='g' ; long="get-input" ; opt=Optional ;
doc="Downloads the EZFIO directory." ;
arg=Without_arg; } ;
anonymous
"(EZFIO_DIR|ADDRESS)"
Mandatory
"EZFIO directory or address.";
] |> set_specs
end;
let arg =
let x =
match Command_line.anon_args () with
| [x] -> x
| _ -> begin
Command_line.help () ;
failwith "EZFIO_FILE or ADDRESS is missing"
end
in
if Sys.file_exists x && Sys.is_directory x then
EZFIO x
else
ADDRESS x
in
let localhost =
Lazy.force TaskServer.ip_address
in
let long_address =
match arg with
| ADDRESS x -> x
| EZFIO x ->
let ic =
Filename.concat (Qpackage.ezfio_work x) "qp_run_address"
|> open_in
in
let result =
input_line ic
|> String.trim
in
close_in ic;
result
in
let protocol, address, port =
match String.split_on_char ':' long_address with
| t :: a :: p :: [] -> t, a, int_of_string p
| _ -> failwith @@
Printf.sprintf "%s : Malformed address" long_address
in
let zmq_context =
Zmq.Context.create ()
in
(** Check availability of the ports *)
let localport =
let dummy_socket =
Zmq.Socket.create zmq_context Zmq.Socket.rep
in
let rec try_new_port port_number =
try
List.iter (fun i ->
let address =
Printf.sprintf "tcp://%s:%d" localhost (port_number+i)
in
Zmq.Socket.bind dummy_socket address;
Zmq.Socket.unbind dummy_socket address
) [ 0;1;2;3;4;5;6;7;8;9 ] ;
port_number
with
| Unix.Unix_error _ -> try_new_port (port_number+100)
in
let result =
try_new_port localport
in
Zmq.Socket.close dummy_socket;
result
in
let create_socket sock_type bind_or_connect addr =
let socket =
Zmq.Socket.create zmq_context sock_type
in
let () =
try
bind_or_connect socket addr
with
| _ -> failwith @@
Printf.sprintf "Unable to establish connection to %s." addr
in
socket
in
(* Handle termination *)
let run_status = ref true in
let handler =
Sys.Signal_handle (fun signum ->
run_status := false;
Sys.set_signal signum Sys.Signal_default
)
in
Sys.set_signal Sys.sigusr1 handler;
Sys.set_signal Sys.sigint handler;
let new_thread req_or_sub addr_in addr_out =
let socket_in, socket_out =
match req_or_sub with
| REQ ->
create_socket Zmq.Socket.router Zmq.Socket.bind addr_in,
create_socket Zmq.Socket.dealer Zmq.Socket.connect addr_out
| SUB ->
create_socket Zmq.Socket.sub Zmq.Socket.connect addr_in,
create_socket Zmq.Socket.pub Zmq.Socket.bind addr_out
in
if req_or_sub = SUB then
Zmq.Socket.subscribe socket_in "";
let action_in =
match req_or_sub with
| REQ -> (fun () -> Zmq.Socket.recv_all socket_in |> Zmq.Socket.send_all socket_out)
| SUB -> (fun () -> Zmq.Socket.recv_all socket_in |> Zmq.Socket.send_all socket_out)
in
let action_out =
match req_or_sub with
| REQ -> (fun () -> Zmq.Socket.recv_all socket_out |> Zmq.Socket.send_all socket_in )
| SUB -> (fun () -> () )
in
let pollitem =
Zmq.Poll.mask_of
[| (socket_in, Zmq.Poll.In) ; (socket_out, Zmq.Poll.In) |]
in
while !run_status do
let polling =
Zmq.Poll.poll ~timeout:1000 pollitem
in
match polling with
| [| Some Zmq.Poll.In ; Some Zmq.Poll.In |] -> ( action_out () ; action_in () )
| [| _ ; Some Zmq.Poll.In |] -> action_out ()
| [| Some Zmq.Poll.In ; _ |] -> action_in ()
| _ -> ()
done;
Zmq.Socket.close socket_in;
Zmq.Socket.close socket_out;
in
let ocaml_thread =
let addr_out =
Printf.sprintf "tcp:%s:%d" address port
in
let addr_in =
Printf.sprintf "tcp://*:%d" localport
in
let f () =
new_thread REQ addr_in addr_out
in
(Thread.create f) ()
in
Printf.printf "Connect to:\ntcp://%s:%d\n%!" localhost localport;
let fortran_thread =
let addr_out =
Printf.sprintf "tcp:%s:%d" address (port+2)
in
let addr_in =
Printf.sprintf "tcp://*:%d" (localport+2)
in
let f () =
new_thread REQ addr_in addr_out
in
(Thread.create f) ()
in
let pub_thread =
let addr_in =
Printf.sprintf "tcp:%s:%d" address (port+1)
in
let addr_out =
Printf.sprintf "tcp://*:%d" (localport+1)
in
let f () =
new_thread SUB addr_in addr_out
in
(Thread.create f) ()
in
let input_thread =
let f () =
let addr_out =
match arg with
| EZFIO _ -> None
| ADDRESS _ -> Some (
Printf.sprintf "tcp:%s:%d" address (port+9) )
in
let addr_in =
Printf.sprintf "tcp://*:%d" (localport+9)
in
let socket_in =
create_socket Zmq.Socket.rep Zmq.Socket.bind addr_in
in
let socket_out =
match addr_out with
| Some addr_out -> Some (
create_socket Zmq.Socket.req Zmq.Socket.connect addr_out)
| None -> None
in
let temp_file =
Filename.temp_file "qp_tunnel" ".tar.gz"
in
let get_ezfio_filename () =
match arg with
| EZFIO x -> x
| ADDRESS _ ->
begin
match socket_out with
| None -> assert false
| Some socket_out -> (
Zmq.Socket.send socket_out "get_ezfio_filename" ;
Zmq.Socket.recv socket_out
)
end
in
let get_input () =
match arg with
| EZFIO x ->
begin
Printf.sprintf "tar --exclude=\"*.gz.*\" -zcf %s %s" temp_file x
|> Sys.command |> ignore;
let fd =
Unix.openfile temp_file [Unix.O_RDONLY] 0o640
in
let len =
Unix.lseek fd 0 Unix.SEEK_END
in
ignore @@ Unix.lseek fd 0 Unix.SEEK_SET ;
let bstr =
Unix.map_file fd Bigarray.char
Bigarray.c_layout false [| len |]
|> Bigarray.array1_of_genarray
in
let result =
String.init len (fun i -> bstr.{i}) ;
in
Unix.close fd;
Sys.remove temp_file;
result
end
| ADDRESS _ ->
begin
match socket_out with
| None -> assert false
| Some socket_out -> (
Zmq.Socket.send socket_out "get_input" ;
Zmq.Socket.recv socket_out
)
end
in
let () =
match socket_out with
| None -> ()
| Some socket_out ->
Zmq.Socket.send socket_out "test";
Printf.printf "Communication [ %s ]\n%!" (Zmq.Socket.recv socket_out);
in
(* Download input if asked *)
if Command_line.get_bool "get-input" then
begin
match arg with
| EZFIO _ -> ()
| ADDRESS _ ->
begin
Printf.printf "Getting input... %!";
let ezfio_filename =
get_ezfio_filename ()
in
Printf.printf "%s%!" ezfio_filename;
let oc =
open_out temp_file
in
get_input ()
|> output_string oc;
close_out oc;
Printf.sprintf "tar -zxf %s" temp_file
|> Sys.command |> ignore ;
let oc =
Filename.concat (Qpackage.ezfio_work ezfio_filename) "qp_run_address"
|> open_out
in
Printf.fprintf oc "tcp://%s:%d\n" localhost localport;
close_out oc;
Printf.printf " ...done\n%!"
end
end;
(* Main loop *)
let pollitem =
Zmq.Poll.mask_of [| (socket_in, Zmq.Poll.In) |]
in
let action () =
match Zmq.Socket.recv socket_in with
| "get_input" -> get_input ()
|> Zmq.Socket.send socket_in
| "get_ezfio_filename" -> get_ezfio_filename ()
|> Zmq.Socket.send socket_in
| "test" -> Zmq.Socket.send socket_in "OK"
| x -> Printf.sprintf "Message '%s' not understood" x
|> Zmq.Socket.send socket_in
in
Printf.printf "
On remote hosts, create ssh tunnel using:
ssh -L %d:%s:%d -L %d:%s:%d -L %d:%s:%d -L %d:%s:%d %s &
Or from this host connect to clients using:
ssh -R %d:localhost:%d -R %d:localhost:%d -R %d:localhost:%d -R %d:localhost:%d <host> &
%!"
(port ) localhost (localport )
(port+1) localhost (localport+1)
(port+2) localhost (localport+2)
(port+9) localhost (localport+9)
(Unix.gethostname ())
(port ) (localport )
(port+1) (localport+1)
(port+2) (localport+2)
(port+9) (localport+9);
Printf.printf "Ready\n%!";
while !run_status do
let polling =
Zmq.Poll.poll ~timeout:1000 pollitem
in
match polling.(0) with
| Some Zmq.Poll.In -> action ()
| None -> ()
| Some Zmq.Poll.In_out
| Some Zmq.Poll.Out -> ()
done;
let () =
match socket_out with
| Some socket_out -> Zmq.Socket.close socket_out
| None -> ()
in
Zmq.Socket.close socket_in
in
(Thread.create f) ()
in
(* Termination *)
Thread.join input_thread;
Thread.join fortran_thread;
Thread.join pub_thread;
Thread.join ocaml_thread;
Zmq.Context.terminate zmq_context;
Printf.printf "qp_tunnel exited properly.\n"