From 1a4677dc1944b90207e86480083cebf4b719b753 Mon Sep 17 00:00:00 2001 From: Anthony Scemama Date: Thu, 4 Apr 2019 09:14:15 +0200 Subject: [PATCH] 4-idx node-only --- Parallel_mpi/Farm.ml | 38 ++++++++++++++++++------------------ Parallel_mpi/Farm.mli | 4 +++- Parallel_mpi/Parallel.ml | 9 ++++----- Parallel_mpi/Parallel.mli | 3 +++ Parallel_serial/Farm.ml | 2 +- Parallel_serial/Farm.mli | 4 +++- Parallel_serial/Parallel.ml | 2 ++ Parallel_serial/Parallel.mli | 3 +++ Utils/FourIdxStorage.ml | 2 +- 9 files changed, 39 insertions(+), 28 deletions(-) diff --git a/Parallel_mpi/Farm.ml b/Parallel_mpi/Farm.ml index 3a8b7f3..5c404e2 100644 --- a/Parallel_mpi/Farm.ml +++ b/Parallel_mpi/Farm.ml @@ -32,13 +32,13 @@ type status = | Running | Done -let run_parallel_server ~ordered stream = +let run_parallel_server ~comm ~ordered stream = (* [status.(rank)] is [Initializing] if rank has not yet obtained a task, [Running] if rank is running a task and [Done] if [rank] is waiting at the barrier. *) - let status = Array.make (Mpi.comm_size Mpi.comm_world) Initializing in + let status = Array.make (Mpi.comm_size comm) Initializing in status.(0) <- Done; @@ -50,8 +50,8 @@ let run_parallel_server ~ordered stream = debug "Before receive_status"; (* Avoid busy receive *) let rec wait_and_receive () = - match Mpi.iprobe Mpi.any_source Mpi.any_tag Mpi.comm_world with - | Some _ -> Mpi.receive_status Mpi.any_source Mpi.any_tag Mpi.comm_world + match Mpi.iprobe Mpi.any_source Mpi.any_tag comm with + | Some _ -> Mpi.receive_status Mpi.any_source Mpi.any_tag comm | None -> (Unix.sleepf 0.001 ; wait_and_receive ()) in wait_and_receive () @@ -74,7 +74,7 @@ let run_parallel_server ~ordered stream = with Stream.Failure -> None in debug @@ Printf.sprintf "Sending to %d\n" client_rank; - Mpi.send task client_rank 0 Mpi.comm_world; + Mpi.send task client_rank 0 comm; debug @@ Printf.sprintf "Sent to %d : %s\n" client_rank (if task = None then "None" else "Some"); if task <> None then @@ -103,7 +103,7 @@ let run_parallel_server ~ordered stream = if all_done () then begin debug "Before barrier"; - Mpi.barrier Mpi.comm_world; + Mpi.barrier comm; debug "After barrier"; None end @@ -172,11 +172,11 @@ let run_parallel_server ~ordered stream = (* Client side *) (********************************************************************) -let run_parallel_client f = +let run_parallel_client ~comm f = (** Send a first message containing [None] to request a task *) debug "Before send None"; - Mpi.send None 0 0 Mpi.comm_world; + Mpi.send None 0 0 comm; debug "After send None"; (** Main loop. @@ -189,20 +189,20 @@ let run_parallel_client f = let message = debug "Before receive"; - Mpi.receive 0 0 Mpi.comm_world + Mpi.receive 0 0 comm in debug "After receive" ; match message with | None -> ( debug "Before barrier"; - Mpi.barrier Mpi.comm_world; + Mpi.barrier comm; debug "After barrier";) | Some (task_id, task) -> let result = f task in begin debug @@ Printf.sprintf "Before send task_id %d" task_id ; - Mpi.send (Some (task_id, result)) 0 0 Mpi.comm_world; + Mpi.send (Some (task_id, result)) 0 0 comm; debug @@ Printf.sprintf "After send task_id %d" task_id ; run () end @@ -217,28 +217,28 @@ let run_parallel_client f = -let run_parallel ~ordered f stream = - match Mpi.comm_rank Mpi.comm_world with - | 0 -> run_parallel_server ~ordered stream - | _ -> run_parallel_client f +let run_parallel ~comm ~ordered f stream = + match Mpi.comm_rank comm with + | 0 -> run_parallel_server ~comm ~ordered stream + | _ -> run_parallel_client ~comm f let nested = ref false -let run ?(ordered=true) ~f stream = +let run ?(ordered=true) ?(comm=Mpi.comm_world) ~f stream = if !nested then begin let message = "Nested parallel regions are not supported by Farm.ml" in Printf.eprintf "%s\n%!" message ; - exit 1 + failwith message end; nested := true; let result = - match Mpi.comm_size Mpi.comm_world with + match Mpi.comm_size comm with | 1 -> run_sequential f stream - | _ -> run_parallel ~ordered f stream + | _ -> run_parallel ~comm ~ordered f stream in nested := false; result diff --git a/Parallel_mpi/Farm.mli b/Parallel_mpi/Farm.mli index 9ac9298..5dc93c6 100644 --- a/Parallel_mpi/Farm.mli +++ b/Parallel_mpi/Farm.mli @@ -4,11 +4,13 @@ The input is a stream of input data, and the output is a stream of data. *) -val run : ?ordered:bool -> f:('a -> 'b) -> 'a Stream.t -> 'b Stream.t +val run : ?ordered:bool -> ?comm:Mpi.communicator -> + f:('a -> 'b) -> 'a Stream.t -> 'b Stream.t (** Run the [f] function on every process by popping elements from the input stream, and putting the results on the output stream. If [ordered] (the default is [ordered = true], then the order of the output is kept consistent with the order of the input. + [comm], within MPI is a communicator. It describes a subgroup of processes. *) diff --git a/Parallel_mpi/Parallel.ml b/Parallel_mpi/Parallel.ml index f84835d..fdef772 100644 --- a/Parallel_mpi/Parallel.ml +++ b/Parallel_mpi/Parallel.ml @@ -59,7 +59,7 @@ module Node = struct let name = Unix.gethostname () - let comm = lazy ( + let comm = Mpi.allgather (name, rank) Mpi.comm_world |> Array.to_list |> List.filter (fun (n, r) -> name = n) @@ -67,10 +67,9 @@ module Node = struct |> Array.of_list |> Mpi.(group_incl (comm_group comm_world)) |> Mpi.(comm_create comm_world) - ) let rank = - Mpi.comm_rank (Lazy.force comm) + Mpi.comm_rank comm let master = rank = 0 @@ -79,13 +78,13 @@ module Node = struct if master then Some (Lazy.force x) else None in - match broadcast x 0 (Lazy.force comm) with + match broadcast x 0 comm with | Some x -> x | None -> assert false let broadcast x = broadcast_generic Mpi.broadcast x - let barrier () = Mpi.barrier (Lazy.force comm ) + let barrier () = Mpi.barrier comm end diff --git a/Parallel_mpi/Parallel.mli b/Parallel_mpi/Parallel.mli index e2d3e9b..1634b6e 100644 --- a/Parallel_mpi/Parallel.mli +++ b/Parallel_mpi/Parallel.mli @@ -36,6 +36,9 @@ module Node : sig val name : string (** Name of the current host *) + val comm : Mpi.communicator + (** MPI Communicator containing the processes of the current node *) + val rank : Mpi.rank (** Rank of the current process in the node *) diff --git a/Parallel_serial/Farm.ml b/Parallel_serial/Farm.ml index 6e3bee7..4f88e8b 100644 --- a/Parallel_serial/Farm.ml +++ b/Parallel_serial/Farm.ml @@ -8,6 +8,6 @@ let run_sequential f stream = with Stream.Failure -> None in Stream.from next -let run ?(ordered=true) ~f stream = +let run ?(ordered=true) ?(comm) ~f stream = run_sequential f stream diff --git a/Parallel_serial/Farm.mli b/Parallel_serial/Farm.mli index 9ac9298..52523f6 100644 --- a/Parallel_serial/Farm.mli +++ b/Parallel_serial/Farm.mli @@ -4,11 +4,13 @@ The input is a stream of input data, and the output is a stream of data. *) -val run : ?ordered:bool -> f:('a -> 'b) -> 'a Stream.t -> 'b Stream.t +val run : ?ordered:bool -> ?comm:'c -> + f:('a -> 'b) -> 'a Stream.t -> 'b Stream.t (** Run the [f] function on every process by popping elements from the input stream, and putting the results on the output stream. If [ordered] (the default is [ordered = true], then the order of the output is kept consistent with the order of the input. + In the non-parallel mode, the [comm] argument is unused. *) diff --git a/Parallel_serial/Parallel.ml b/Parallel_serial/Parallel.ml index bb1270c..937d658 100644 --- a/Parallel_serial/Parallel.ml +++ b/Parallel_serial/Parallel.ml @@ -25,6 +25,8 @@ module Node = struct let name = Unix.gethostname () + let comm = None + let rank = 0 let master = true diff --git a/Parallel_serial/Parallel.mli b/Parallel_serial/Parallel.mli index b0f6c8a..cde3ac9 100644 --- a/Parallel_serial/Parallel.mli +++ b/Parallel_serial/Parallel.mli @@ -36,6 +36,9 @@ module Node : sig val name : string (** Name of the current host *) + val comm : 'a option + (** Always [None] *) + val rank : int (** Rank of the current process in the node *) diff --git a/Utils/FourIdxStorage.ml b/Utils/FourIdxStorage.ml index 078ffca..9ba2996 100644 --- a/Utils/FourIdxStorage.ml +++ b/Utils/FourIdxStorage.ml @@ -386,7 +386,7 @@ let four_index_transform coef source = let n = ref 0 in Stream.of_list range_mo - |> Farm.run ~f:task ~ordered:false + |> Farm.run ~f:task ~ordered:false ~comm:Parallel.Node.comm |> Stream.iter (fun l -> if Parallel.master then (incr n ; Printf.eprintf "\r%d / %d%!" !n mo_num); Array.iter (fun (alpha, beta, gamma, delta, x) ->