diff --git a/ocaml/Message.ml b/ocaml/Message.ml index b5de7e83..dc8369cf 100644 --- a/ocaml/Message.ml +++ b/ocaml/Message.ml @@ -245,7 +245,7 @@ end = struct (Id.Client.to_int x.client_id) end -(** GetTaskReply : Reply to the GetTask message *) +(** GetTaskReply : Reply to the GetTasks message *) module GetTaskReply_msg : sig type t val create : task_id:Id.Task.t option -> task:string option -> t @@ -265,6 +265,50 @@ end = struct end +(** GetTasks : get a new task to do *) +module GetTasks_msg : sig + type t = + { client_id: Id.Client.t ; + state: State.t ; + n_tasks: Strictly_positive_int.t ; + } + val create : state:string -> client_id:int -> n_tasks:int -> t + val to_string : t -> string +end = struct + type t = + { client_id: Id.Client.t ; + state: State.t ; + n_tasks: Strictly_positive_int.t; + } + let create ~state ~client_id ~n_tasks = + { client_id = Id.Client.of_int client_id ; state = State.of_string state ; + n_tasks = Strictly_positive_int.of_int n_tasks } + let to_string x = + Printf.sprintf "get_tasks %s %d %d" + (State.to_string x.state) + (Id.Client.to_int x.client_id) + (Strictly_positive_int.to_int x.n_tasks) +end + +(** GetTasksReply : Reply to the GetTasks message *) +module GetTasksReply_msg : sig + type t = (Id.Task.t * string) list + val create : t -> t + val to_string : t -> string + val to_string_list : t -> string list +end = struct + type t = (Id.Task.t * string) list + let create l = l + let to_string _ = + "get_tasks_reply ok" + let to_string_list x = + "get_tasks_reply ok" :: ( + List.map x ~f:(fun (task_id, task) -> Printf.sprintf "%d %s" (Id.Task.to_int task_id) task) + ) + +end + + (** PutData: put some data in the hash table *) module PutData_msg : sig type t = @@ -425,7 +469,9 @@ type t = | Disconnect of Disconnect_msg.t | DisconnectReply of DisconnectReply_msg.t | GetTask of GetTask_msg.t +| GetTasks of GetTasks_msg.t | GetTaskReply of GetTaskReply_msg.t +| GetTasksReply of GetTasksReply_msg.t | DelTask of DelTask_msg.t | DelTaskReply of DelTaskReply_msg.t | AddTask of AddTask_msg.t @@ -449,6 +495,8 @@ let of_string s = DelTask (DelTask_msg.create ~state ~task_ids) | GetTask_ { state ; client_id } -> GetTask (GetTask_msg.create ~state ~client_id) + | GetTasks_ { state ; client_id ; n_tasks } -> + GetTasks (GetTasks_msg.create ~state ~client_id ~n_tasks) | TaskDone_ { state ; task_ids ; client_id } -> TaskDone (TaskDone_msg.create ~state ~client_id ~task_ids) | Disconnect_ { state ; client_id } -> @@ -485,7 +533,9 @@ let to_string = function | Disconnect x -> Disconnect_msg.to_string x | DisconnectReply x -> DisconnectReply_msg.to_string x | GetTask x -> GetTask_msg.to_string x +| GetTasks x -> GetTasks_msg.to_string x | GetTaskReply x -> GetTaskReply_msg.to_string x +| GetTasksReply x -> GetTasksReply_msg.to_string x | DelTask x -> DelTask_msg.to_string x | DelTaskReply x -> DelTaskReply_msg.to_string x | AddTask x -> AddTask_msg.to_string x diff --git a/ocaml/Message_lexer.mll b/ocaml/Message_lexer.mll index 4d5bc702..ef245270 100644 --- a/ocaml/Message_lexer.mll +++ b/ocaml/Message_lexer.mll @@ -9,6 +9,7 @@ type kw_type = | ADD_TASK | DEL_TASK | GET_TASK + | GET_TASKS | TASK_DONE | DISCONNECT | CONNECT @@ -28,6 +29,7 @@ type state_tasks = { state : string ; tasks : string list type state_taskids = { state : string ; task_ids : int list ; } type state_taskids_clientid = { state : string ; task_ids : int list ; client_id : int ; } type state_clientid = { state : string ; client_id : int ; } +type state_clientid_ntasks = { state : string ; client_id : int ; n_tasks : int} type state_tcp_inproc = { state : string ; push_address_tcp : string ; push_address_inproc : string ; } type psi = { client_id: int ; n_state: int ; n_det: int ; psi_det_size: int ; n_det_generators: int option ; n_det_selectors: int option ; } @@ -37,6 +39,7 @@ type msg = | AddTask_ of state_tasks | DelTask_ of state_taskids | GetTask_ of state_clientid + | GetTasks_ of state_clientid_ntasks | TaskDone_ of state_taskids_clientid | Disconnect_ of state_clientid | Connect_ of string @@ -80,6 +83,7 @@ and kw = parse | "add_task" { ADD_TASK } | "del_task" { DEL_TASK } | "get_task" { GET_TASK } + | "get_tasks" { GET_TASKS } | "task_done" { TASK_DONE } | "disconnect" { DISCONNECT } | "connect" { CONNECT } @@ -155,6 +159,12 @@ and kw = parse let state = read_word lexbuf in let client_id = read_int lexbuf in GetTask_ { state ; client_id } + + | GET_TASKS -> + let state = read_word lexbuf in + let client_id = read_int lexbuf in + let n_tasks = read_int lexbuf in + GetTasks_ { state ; client_id ; n_tasks } | TASK_DONE -> let state = read_word lexbuf in @@ -218,6 +228,7 @@ and kw = parse "del_task state_pouet 12345" ; "del_task state_pouet 12345 | 6789 | 10 | 11" ; "get_task state_pouet 12" ; + "get_tasks state_pouet 12 23" ; "task_done state_pouet 12 12345"; "task_done state_pouet 12 12345 | 678 | 91011"; "connect tcp"; @@ -241,6 +252,7 @@ and kw = parse | AddTask_ { state ; tasks } -> Printf.sprintf "ADD_TASK state:\"%s\" tasks:{\"%s\"}" state (String.concat "\"}|{\"" tasks) | DelTask_ { state ; task_ids } -> Printf.sprintf "DEL_TASK state:\"%s\" task_ids:{%s}" state (String.concat "|" @@ List.map string_of_int task_ids) | GetTask_ { state ; client_id } -> Printf.sprintf "GET_TASK state:\"%s\" task_id:%d" state client_id + | GetTasks_ { state ; client_id ; n_tasks } -> Printf.sprintf "GET_TASKS state:\"%s\" task_id:%d n_tasks:%d" state client_id n_tasks | TaskDone_ { state ; task_ids ; client_id } -> Printf.sprintf "TASK_DONE state:\"%s\" task_ids:{%s} client_id:%d" state (String.concat "|" @@ List.map string_of_int task_ids) client_id | Disconnect_ { state ; client_id } -> Printf.sprintf "DISCONNECT state:\"%s\" client_id:%d" state client_id | Connect_ socket -> Printf.sprintf "CONNECT socket:\"%s\"" socket diff --git a/ocaml/TaskServer.ml b/ocaml/TaskServer.ml index 103265fd..bf07b9fd 100644 --- a/ocaml/TaskServer.ml +++ b/ocaml/TaskServer.ml @@ -412,6 +412,71 @@ let get_task msg program_state rep_socket pair_socket = +let get_tasks msg program_state rep_socket pair_socket = + + let state, client_id, n_tasks = + msg.Message.GetTasks_msg.state, + msg.Message.GetTasks_msg.client_id, + Strictly_positive_int.to_int msg.Message.GetTasks_msg.n_tasks + in + + let failure () = + reply_wrong_state rep_socket; + program_state + + and success () = + + let rec build_list accu queue = function + | 0 -> queue, accu + | n -> + let new_queue, task_id, task = + Queuing_system.pop_task ~client_id queue + in + match (task_id, task) with + | Some task_id, Some task -> + build_list ( (task_id, task)::accu ) new_queue (n-1) + | _ -> queue, ((Id.Task.of_int 0, "terminate")::accu) + in + + let new_queue, result = + build_list [] program_state.queue (n_tasks) + in + + let no_task = + Queuing_system.number_of_queued new_queue = 0 + in + + if no_task then + string_of_pub_state Waiting + |> ZMQ.Socket.send pair_socket + else + string_of_pub_state (Running (Message.State.to_string state)) + |> ZMQ.Socket.send pair_socket; + + let new_program_state = + { program_state with + queue = new_queue + } + in + + Message.GetTasksReply (Message.GetTasksReply_msg.create result) + |> Message.to_string_list + |> ZMQ.Socket.send_all rep_socket ; + new_program_state + in + + match program_state.state with + | None -> assert false + | Some state' -> + begin + if (state = state') then + success () + else + failure () + end + + + let task_done msg program_state rep_socket = let state, client_id, task_ids = @@ -703,6 +768,7 @@ let run ~port = | Some _, Message.AddTask x -> add_task x program_state rep_socket | Some _, Message.DelTask x -> del_task x program_state rep_socket | Some _, Message.GetTask x -> get_task x program_state rep_socket pair_socket + | Some _, Message.GetTasks x -> get_tasks x program_state rep_socket pair_socket | Some _, Message.TaskDone x -> task_done x program_state rep_socket | _ , _ -> error ("Invalid message : "^(Message.to_string message)) program_state rep_socket diff --git a/ocaml/qp_run.ml b/ocaml/qp_run.ml index 4761c3f4..57725895 100644 --- a/ocaml/qp_run.ml +++ b/ocaml/qp_run.ml @@ -4,7 +4,6 @@ open Qputils (* Environment variables : QP_PREFIX=gdb : to run gdb (or valgrind, or whatever) - QP_MPIRUN=mpirun: to run mpi slaves QP_TASK_DEBUG=1 : debug task server *) @@ -16,8 +15,7 @@ let print_list () = let () = Random.self_init () -let run slave mpi exe ezfio_file = - +let run slave exe ezfio_file = (** Check availability of the ports *) let port_number = @@ -31,7 +29,7 @@ let run slave mpi exe ezfio_file = try List.iter [ 0;1;2;3;4;5;6;7;8;9 ] ~f:(fun i -> let address = - Printf.sprintf "tcp://%s:%d" (Lazy.force TaskServer.ip_address) (port_number+i) + Printf.sprintf "tcp://*:%d" (port_number+i) in ZMQ.Socket.bind dummy_socket address; ZMQ.Socket.unbind dummy_socket address; @@ -47,10 +45,15 @@ let run slave mpi exe ezfio_file = ZMQ.Context.terminate zmq_context; result in + let time_start = Time.now () in + let address = + Printf.sprintf "tcp://%s:%d" (Lazy.force TaskServer.ip_address) port_number + in + if (not (Sys.file_exists_exn ezfio_file)) then failwith ("EZFIO directory "^ezfio_file^" not found"); @@ -100,9 +103,6 @@ let run slave mpi exe ezfio_file = in thread (); in - let address = - Printf.sprintf "tcp://%s:%d" (Lazy.force TaskServer.ip_address) port_number - in Unix.putenv ~key:"QP_RUN_ADDRESS" ~data:address; let () = if (not slave) then @@ -116,18 +116,13 @@ let run slave mpi exe ezfio_file = match Sys.getenv "QP_PREFIX" with | Some x -> x^" " | None -> "" - and mpirun = - match (mpi, Sys.getenv "QP_MPIRUN") with - | (true, None) -> "mpirun " - | (true, Some x) -> x^" " - | _ -> "" and exe = match (List.find ~f:(fun (x,_) -> x = exe) executables) with | Some (_,x) -> x^" " | None -> assert false in let exit_code = - match (Sys.command (mpirun^prefix^exe^ezfio_file)) with + match (Sys.command (prefix^exe^ezfio_file)) with | 0 -> 0 | i -> (Printf.printf "Program exited with code %d.\n%!" i; i) in @@ -148,8 +143,6 @@ let spec = empty +> flag "slave" no_arg ~doc:(" Required for slave tasks") - +> flag "mpi" no_arg - ~doc:(" Required for MPI slaves") +> anon ("executable" %: string) +> anon ("ezfio_file" %: string) ;; @@ -167,8 +160,8 @@ Executes a Quantum Package binary file among these:\n\n" ) ) spec - (fun slave mpi exe ezfio_file () -> - run slave mpi exe ezfio_file + (fun slave exe ezfio_file () -> + run slave exe ezfio_file ) |> Command.run ~version: Git.sha1 ~build_info: Git.message diff --git a/plugins/Full_CI_ZMQ/dump_fci_iterations_value.irp.f b/plugins/Full_CI_ZMQ/dump_fci_iterations_value.irp.f index e83d627f..08ae05ac 100644 --- a/plugins/Full_CI_ZMQ/dump_fci_iterations_value.irp.f +++ b/plugins/Full_CI_ZMQ/dump_fci_iterations_value.irp.f @@ -20,6 +20,9 @@ subroutine dump_fci_iterations_value(n_determinants,energy,pt2) !!! Check to ensure that we should save iterations (default is Append) ! saveMethod: 1==Append, 2==Overwrite, 3==NoSave + if (N_det < N_states) then + return + endif call ezfio_get_full_ci_zmq_iterative_save(saveMethod) !!! Check we are saving data diff --git a/plugins/Full_CI_ZMQ/selection_davidson_slave.irp.f b/plugins/Full_CI_ZMQ/selection_davidson_slave.irp.f index 17a54688..6af42f33 100644 --- a/plugins/Full_CI_ZMQ/selection_davidson_slave.irp.f +++ b/plugins/Full_CI_ZMQ/selection_davidson_slave.irp.f @@ -22,7 +22,9 @@ subroutine run_wf use f77_zmq implicit none - include 'mpif.h' + IRP_IF MPI + include 'mpif.h' + IRP_ENDIF integer(ZMQ_PTR), external :: new_zmq_to_qp_run_socket integer(ZMQ_PTR) :: zmq_to_qp_run_socket diff --git a/src/ZMQ/utils.irp.f b/src/ZMQ/utils.irp.f index 704ac645..4d093127 100644 --- a/src/ZMQ/utils.irp.f +++ b/src/ZMQ/utils.irp.f @@ -190,10 +190,10 @@ function new_zmq_pair_socket(bind) ! stop 'f77_zmq_setsockopt(new_zmq_pair_socket, ZMQ_RCVHWM, 2, 4)' ! endif ! -! rc = f77_zmq_setsockopt(new_zmq_pair_socket, ZMQ_IMMEDIATE, 1, 4) -! if (rc /= 0) then -! stop 'f77_zmq_setsockopt(new_zmq_pair_socket, ZMQ_IMMEDIATE, 1, 4)' -! endif + rc = f77_zmq_setsockopt(new_zmq_pair_socket, ZMQ_IMMEDIATE, 1, 4) + if (rc /= 0) then + stop 'f77_zmq_setsockopt(new_zmq_pair_socket, ZMQ_IMMEDIATE, 1, 4)' + endif ! ! rc = f77_zmq_setsockopt(new_zmq_pair_socket, ZMQ_LINGER, 600000, 4) ! if (rc /= 0) then @@ -849,6 +849,7 @@ subroutine get_task_from_taskserver(zmq_to_qp_run_socket,worker_id,task_id,task) character*(64) :: reply integer :: rc, sze +! call get_tasks_from_taskserver(zmq_to_qp_run_socket,worker_id,task_id,task,1) write(message,*) 'get_task '//trim(zmq_state), worker_id sze = len(trim(message)) @@ -888,6 +889,68 @@ subroutine get_task_from_taskserver(zmq_to_qp_run_socket,worker_id,task_id,task) end +subroutine get_tasks_from_taskserver(zmq_to_qp_run_socket,worker_id,task_id,task,n_tasks) + use f77_zmq + implicit none + BEGIN_DOC + ! Get multiple tasks from the task server + END_DOC + integer(ZMQ_PTR), intent(in) :: zmq_to_qp_run_socket + integer, intent(in) :: worker_id + integer, intent(in) :: n_tasks + integer, intent(out) :: task_id(n_tasks) + character*(512), intent(out) :: task(n_tasks) + + character*(1024) :: message + character*(64) :: reply + integer :: rc, sze, i + + write(message,*) 'get_tasks '//trim(zmq_state), worker_id, n_tasks + + sze = len(trim(message)) + rc = f77_zmq_send(zmq_to_qp_run_socket, message, sze, 0) + if (rc /= sze) then + print *, irp_here, ':f77_zmq_send(zmq_to_qp_run_socket, trim(message), sze, 0)' + stop 'error' + endif + + message = repeat(' ',512) + rc = f77_zmq_recv(zmq_to_qp_run_socket, message, 1024, 0) + rc = min(1024,rc) + read(message(1:rc),*) reply + if (trim(reply) == 'get_task_reply ok') then + continue + else if (trim(reply) == 'terminate') then + task_id(1) = 0 + task(1) = 'terminate' + else if (trim(message) == 'error No job is running') then + task_id(1) = 0 + task(1) = 'terminate' + else + print *, 'Unable to get the next task' + print *, trim(message) + stop -1 + endif + + do i=1,n_tasks + message = repeat(' ',512) + rc = f77_zmq_recv(zmq_to_qp_run_socket, message, 1024, 0) + rc = min(1024,rc) + read(message(1:rc),*) task_id(i) + rc = 1 + do while (message(rc:rc) == ' ') + rc += 1 + enddo + do while (message(rc:rc) /= ' ') + rc += 1 + enddo + rc += 1 + task(i) = message(rc:) + enddo + +end + + subroutine end_zmq_to_qp_run_socket(zmq_to_qp_run_socket) use f77_zmq implicit none