10
0
mirror of https://github.com/LCPQ/quantum_package synced 2025-01-05 11:00:10 +01:00
quantum_package/ocaml/TaskServer.ml

620 lines
16 KiB
OCaml

open Core.Std
open Qptypes
type t =
{
queue : Queuing_system.t ;
state : Message.State.t option ;
address_tcp : Address.Tcp.t option ;
address_inproc : Address.Inproc.t option ;
psi : Message.Psi.t option;
progress_bar : Progress_bar.t option ;
running : bool;
}
let debug_env =
match Sys.getenv "QP_TASK_DEBUG" with
| Some x -> x <> ""
| None -> false
let debug str =
if debug_env then
Printf.printf "TASK : %s%!" str
let zmq_context =
ZMQ.Context.create ()
let bind_socket ~socket_type ~socket ~address =
let rec loop = function
| 0 -> failwith @@ Printf.sprintf
"Unable to bind the %s socket : %s "
socket_type address
| -1 -> ()
| i ->
try
ZMQ.Socket.bind socket address;
loop (-1)
with
| Unix.Unix_error _ -> (Time.pause @@ Time.Span.of_float 1. ; loop (i-1) )
| other_exception -> raise other_exception
in loop 10
let hostname = lazy (
try
Unix.gethostname ()
with
| _ -> "localhost"
)
let ip_address = lazy (
match Sys.getenv "QP_NIC" with
| None ->
begin
try
Lazy.force hostname
|> Unix.Inet_addr.of_string_or_getbyname
|> Unix.Inet_addr.to_string
with
| Unix.Unix_error _ ->
failwith "Unable to find IP address from host name."
end
| Some interface ->
begin
try
ok_exn Linux_ext.get_ipv4_address_for_interface interface
with
| Unix.Unix_error _ ->
Lazy.force hostname
|> Unix.Inet_addr.of_string_or_getbyname
|> Unix.Inet_addr.to_string
end
)
let reply_ok rep_socket =
Message.Ok_msg.create ()
|> Message.Ok_msg.to_string
|> ZMQ.Socket.send rep_socket
let reply_wrong_state rep_socket =
Printf.printf "WRONG STATE\n%!";
Message.Error_msg.create "Wrong state"
|> Message.Error_msg.to_string
|> ZMQ.Socket.send rep_socket
let stop ~port =
debug "STOP";
let req_socket =
ZMQ.Socket.create zmq_context ZMQ.Socket.req
and address =
Printf.sprintf "tcp://%s:%d" (Lazy.force ip_address) port
in
ZMQ.Socket.set_linger_period req_socket 1_000_000;
ZMQ.Socket.connect req_socket address;
Message.Terminate (Message.Terminate_msg.create ())
|> Message.to_string
|> ZMQ.Socket.send req_socket ;
let msg =
ZMQ.Socket.recv req_socket
|> Message.of_string
in
let () =
match msg with
| Message.Ok _ -> ()
| _ -> failwith "Problem in termination"
in
ZMQ.Socket.set_linger_period req_socket 1_000;
ZMQ.Socket.close req_socket
let new_job msg program_state rep_socket =
let state =
msg.Message.Newjob_msg.state
in
let progress_bar =
Progress_bar.init
~start_value:0.
~end_value:1.
~bar_length:20
~title:(Message.State.to_string state)
in
let result =
{ program_state with
state = Some state ;
progress_bar = Some progress_bar ;
address_tcp = Some msg.Message.Newjob_msg.address_tcp;
address_inproc = Some msg.Message.Newjob_msg.address_inproc;
}
in
reply_ok rep_socket;
result
let end_job msg program_state rep_socket =
let failure () =
reply_wrong_state rep_socket;
program_state
and success state =
reply_ok rep_socket;
{ program_state with
state = None ;
progress_bar = None ;
}
in
match program_state.state with
| None -> failure ()
| Some state ->
begin
if (msg.Message.Endjob_msg.state = state) then
success state
else
failure ()
end
let connect msg program_state rep_socket =
let state =
match program_state.state with
| Some state -> state
| None -> assert false
in
let push_address =
match msg with
| Message.Connect_msg.Tcp ->
begin
match program_state.address_tcp with
| Some address -> Address.Tcp address
| None -> failwith "Error: No TCP address"
end
| Message.Connect_msg.Inproc ->
begin
match program_state.address_inproc with
| Some address -> Address.Inproc address
| None -> failwith "Error: No inproc address"
end
| Message.Connect_msg.Ipc -> assert false
in
let new_queue, client_id =
Queuing_system.add_client program_state.queue
in
Message.ConnectReply (Message.ConnectReply_msg.create
~state:state ~client_id ~push_address)
|> Message.to_string
|> ZMQ.Socket.send rep_socket ;
{ program_state with
queue = new_queue
}
let disconnect msg program_state rep_socket =
let state, client_id =
msg.Message.Disconnect_msg.state,
msg.Message.Disconnect_msg.client_id
in
let failure () =
reply_wrong_state rep_socket;
program_state
and success () =
let new_program_state =
{ program_state with
queue = Queuing_system.del_client ~client_id program_state.queue
}
in
Message.DisconnectReply (Message.DisconnectReply_msg.create ~state)
|> Message.to_string
|> ZMQ.Socket.send rep_socket ;
new_program_state
in
match program_state.state with
| None -> assert false
| Some state' ->
begin
if (state = state') then
success ()
else
failure ()
end
let del_task msg program_state rep_socket =
let state, task_id =
msg.Message.DelTask_msg.state,
msg.Message.DelTask_msg.task_id
in
let failure () =
reply_wrong_state rep_socket;
program_state
and success () =
let new_program_state =
{ program_state with
queue = Queuing_system.del_task ~task_id program_state.queue
}
in
let more =
(Queuing_system.number_of_queued new_program_state.queue +
Queuing_system.number_of_running new_program_state.queue) > 0
in
Message.DelTaskReply (Message.DelTaskReply_msg.create ~task_id ~more)
|> Message.to_string
|> ZMQ.Socket.send ~block:true rep_socket ; (** /!\ Has to be blocking *)
new_program_state
in
match program_state.state with
| None -> assert false
| Some state' ->
begin
if (state = state') then
success ()
else
failure ()
end
let add_task msg program_state rep_socket =
let state, task =
msg.Message.AddTask_msg.state,
msg.Message.AddTask_msg.task
in
let increment_progress_bar = function
| Some bar -> Some (Progress_bar.increment_end bar)
| None -> None
in
let rec add_task_triangle program_state imax = function
| 0 -> program_state
| i ->
let task =
Printf.sprintf "%d %d" i imax
in
let new_program_state =
{ program_state with
queue = Queuing_system.add_task ~task program_state.queue ;
progress_bar = increment_progress_bar program_state.progress_bar ;
}
in
add_task_triangle new_program_state imax (i-1)
in
let rec add_task_range program_state i = function
| j when (j < i) -> program_state
| j ->
let task =
Printf.sprintf "%d" j
in
let new_program_state =
{ program_state with
queue = Queuing_system.add_task ~task program_state.queue ;
progress_bar = increment_progress_bar program_state.progress_bar ;
}
in
add_task_range new_program_state i (j-1)
in
let new_program_state = function
| "triangle" :: i_str :: [] ->
let imax =
Int.of_string i_str
in
add_task_triangle program_state imax imax
| "range" :: i_str :: j_str :: [] ->
let i, j =
Int.of_string i_str,
Int.of_string j_str
in
add_task_range program_state i j
| _ ->
{ program_state with
queue = Queuing_system.add_task ~task program_state.queue ;
progress_bar = increment_progress_bar program_state.progress_bar ;
}
in
let result =
String.split ~on:' ' task
|> List.filter ~f:(fun x -> x <> "")
|> new_program_state
in
reply_ok rep_socket;
result
let get_task msg program_state rep_socket =
let state, client_id =
msg.Message.GetTask_msg.state,
msg.Message.GetTask_msg.client_id
in
let failure () =
reply_wrong_state rep_socket;
program_state
and success () =
let new_queue, task_id, task =
Queuing_system.pop_task ~client_id program_state.queue
in
let new_program_state =
{ program_state with
queue = new_queue
}
in
match (task, task_id) with
| Some task, Some task_id ->
begin
Message.GetTaskReply (Message.GetTaskReply_msg.create ~task ~task_id)
|> Message.to_string
|> ZMQ.Socket.send rep_socket ;
new_program_state
end
| _ ->
begin
Message.Terminate (Message.Terminate_msg.create ())
|> Message.to_string
|> ZMQ.Socket.send rep_socket ;
program_state
end
in
match program_state.state with
| None -> assert false
| Some state' ->
begin
if (state = state') then
success ()
else
failure ()
end
let task_done msg program_state rep_socket =
let state, client_id, task_id =
msg.Message.TaskDone_msg.state,
msg.Message.TaskDone_msg.client_id,
msg.Message.TaskDone_msg.task_id
in
let increment_progress_bar = function
| Some bar -> Some (Progress_bar.increment_cur bar)
| None -> None
in
let failure () =
reply_wrong_state rep_socket;
program_state
and success () =
let result =
{ program_state with
queue = Queuing_system.end_task ~task_id ~client_id program_state.queue ;
progress_bar = increment_progress_bar program_state.progress_bar ;
}
in
reply_ok rep_socket;
result
in
match program_state.state with
| None -> assert false
| Some state' ->
begin
if (state = state') then
success ()
else
failure ()
end
let put_psi msg rest_of_msg program_state rep_socket =
let psi_local =
match msg.Message.PutPsi_msg.psi with
| Some x -> x
| None ->
begin
let psi_det, psi_coef =
match rest_of_msg with
| [ x ; y ] -> x, y
| _ -> failwith "Badly formed put_psi message"
in
Message.Psi.create
~n_state:msg.Message.PutPsi_msg.n_state
~n_det:msg.Message.PutPsi_msg.n_det
~psi_det_size:msg.Message.PutPsi_msg.psi_det_size
~n_det_generators:msg.Message.PutPsi_msg.n_det_generators
~n_det_selectors:msg.Message.PutPsi_msg.n_det_selectors
~psi_det
~psi_coef
end
in
let new_program_state =
{ program_state with
psi = Some psi_local
}
and client_id =
msg.Message.PutPsi_msg.client_id
in
Message.PutPsiReply (Message.PutPsiReply_msg.create ~client_id)
|> Message.to_string
|> ZMQ.Socket.send rep_socket;
new_program_state
let get_psi msg program_state rep_socket =
let client_id =
msg.Message.GetPsi_msg.client_id
in
match program_state.psi with
| None -> failwith "No wave function saved in TaskServer"
| Some psi ->
Message.GetPsiReply (Message.GetPsiReply_msg.create ~client_id ~psi)
|> Message.to_string_list
|> ZMQ.Socket.send_all rep_socket;
program_state
let terminate program_state rep_socket =
reply_ok rep_socket;
{ program_state with
running = false
}
let error msg program_state rep_socket =
Printf.printf "%s\n%!" msg;
Message.Error (Message.Error_msg.create msg)
|> Message.to_string
|> ZMQ.Socket.send rep_socket ;
program_state
let run ~port =
(** Bind REP socket *)
let rep_socket =
ZMQ.Socket.create zmq_context ZMQ.Socket.rep
and address =
Printf.sprintf "tcp://%s:%d" (Lazy.force ip_address) port
in
bind_socket "REP" rep_socket address;
ZMQ.Socket.set_linger_period rep_socket 1_000_000;
let initial_program_state =
{ queue = Queuing_system.create () ;
running = true ;
psi = None;
state = None;
address_tcp = None;
address_inproc = None;
progress_bar = None ;
}
in
(** ZMR polling item *)
let pollitem =
ZMQ.Poll.mask_of
[| (rep_socket, ZMQ.Poll.In) |]
in
Printf.printf "Task server running : %s\n%!" address;
(** Main loop *)
let rec main_loop program_state = function
| false -> ()
| true ->
let polling =
ZMQ.Poll.poll ~timeout:1000 pollitem
in
if (polling.(0) <> Some ZMQ.Poll.In) then
main_loop program_state true
else
begin
let program_state =
match program_state.progress_bar with
| None -> program_state
| Some bar ->
if bar.Progress_bar.dirty then
{ program_state with
progress_bar = Some (Progress_bar.display bar)
}
else
program_state
in
(** Extract message *)
let raw_message, rest =
match ZMQ.Socket.recv_all rep_socket with
| x :: rest -> x, rest
| [] -> failwith "Badly formed message"
in
let message =
Message.of_string raw_message
in
(** Debug input *)
Printf.sprintf "%d %d : %s\n%!"
(Queuing_system.number_of_queued program_state.queue)
(Queuing_system.number_of_running program_state.queue)
(Message.to_string message)
|> debug;
let new_program_state =
try
match program_state.state, message with
| _ , Message.Terminate _ -> terminate program_state rep_socket
| _ , Message.PutPsi x -> put_psi x rest program_state rep_socket
| _ , Message.GetPsi x -> get_psi x program_state rep_socket
| None , Message.Newjob x -> new_job x program_state rep_socket
| _ , Message.Newjob _ -> error "A job is already running" program_state rep_socket
| Some _, Message.Endjob x -> end_job x program_state rep_socket
| None , _ -> error "No job is running" program_state rep_socket
| Some _, Message.Connect x -> connect x program_state rep_socket
| Some _, Message.Disconnect x -> disconnect x program_state rep_socket
| Some _, Message.AddTask x -> add_task x program_state rep_socket
| Some _, Message.DelTask x -> del_task x program_state rep_socket
| Some _, Message.GetTask x -> get_task x program_state rep_socket
| Some _, Message.TaskDone x -> task_done x program_state rep_socket
| _ , _ ->
error ("Invalid message : "^(Message.to_string message)) program_state rep_socket
with
| Failure f ->
error (f^" : "^raw_message) program_state rep_socket
| Assert_failure (f,i,j) ->
error (Printf.sprintf "%s:%d:%d : %s" f i j raw_message) program_state rep_socket
in
main_loop new_program_state new_program_state.running
end
in main_loop initial_program_state true;