diff --git a/Parallel_mpi/Parallel.ml b/Parallel_mpi/Parallel.ml index df0fa7a..21ed192 100644 --- a/Parallel_mpi/Parallel.ml +++ b/Parallel_mpi/Parallel.ml @@ -60,15 +60,16 @@ module Node = struct let name = Unix.gethostname () let comm = - Mpi.allgather (name, rank) Mpi.comm_world - |> Array.to_list - |> List.filter (fun (n, r) -> name = n) - |> List.map snd - |> Array.of_list - |> Mpi.(group_incl (comm_group comm_world)) - |> Mpi.(comm_create comm_world) + let _, color = + Mpi.allgather (name, rank) Mpi.comm_world + |> Array.to_list + |> List.sort compare + |> List.find (fun (n, r) -> n = name) + in + Mpi.(comm_split comm_world color 0) let rank = + Printf.printf "Node: %d %d\n%!" rank (Mpi.comm_rank comm); Mpi.comm_rank comm let master = rank = 0 @@ -85,44 +86,81 @@ module Node = struct let broadcast x = broadcast_generic Mpi.broadcast x let barrier () = Mpi.barrier comm + + end module InterNode = struct let comm = - let rec aux accu name = function - | [] -> List.rev accu - | (newname, rank) :: rest when newname = name -> aux accu name rest - | (newname, rank) :: rest -> aux (rank :: accu) newname rest + + let ranks = + let name = Unix.gethostname () in + + let rec aux accu old_name = function + | [] -> List.rev accu |> Array.of_list + | (new_name, r) :: rest when new_name <> old_name -> + aux (r::accu) new_name rest + | (new_name, r) :: rest -> aux accu new_name rest + in + + Mpi.allgather (name, rank) Mpi.comm_world + |> Array.to_list + |> List.sort compare + |> aux [] "" in - let name = Unix.gethostname () in - Mpi.allgather (name, rank) Mpi.comm_world - |> Array.to_list - |> List.sort compare - |> aux [] "" - |> Array.of_list - |> Mpi.(group_incl (comm_group comm_world)) - |> Mpi.(comm_create comm_world) + let world_group = + Mpi.comm_group Mpi.comm_world + in + + let new_group = + Mpi.group_incl world_group ranks + in + + let result = + let g = + Mpi.comm_create Mpi.comm_world new_group + in + try + ignore @@ List.find (fun x -> x = rank) @@ Array.to_list ranks; + Some g + with Not_found -> None + in + result + let rank = - Mpi.comm_rank comm + match comm with + | Some comm -> + Printf.printf "InterNode: %d %d\n%!" rank (Mpi.comm_rank comm); + Mpi.comm_rank comm + | None -> -1 let master = rank = 0 let broadcast_generic broadcast x = - let x = - if master then Some (Lazy.force x) - else None - in - match broadcast x 0 comm with - | Some x -> x - | None -> assert false + match comm with + | Some comm -> + begin + let x = + if master then Some (Lazy.force x) + else None + in + match broadcast x 0 comm with + | Some x -> x + | None -> assert false + end + | None -> Lazy.force x let broadcast x = broadcast_generic Mpi.broadcast x - let barrier () = Mpi.barrier comm + let barrier () = + match comm with + | Some comm -> Mpi.barrier comm + | None -> () + end diff --git a/Parallel_mpi/Parallel.mli b/Parallel_mpi/Parallel.mli index 3ffc9aa..381656e 100644 --- a/Parallel_mpi/Parallel.mli +++ b/Parallel_mpi/Parallel.mli @@ -40,7 +40,7 @@ module Node : sig val name : string (** Name of the current host *) - val comm : Mpi.communicator + val comm : Mpi.communicator (** MPI Communicator containing the processes of the current node *) val rank : Mpi.rank @@ -60,7 +60,7 @@ end (** {5 Inter-node operations} *) module InterNode : sig - val comm : Mpi.communicator + val comm : Mpi.communicator option (** MPI Communicator among the master processes of the each node *) val rank : Mpi.rank