diff --git a/Parallel_mpi/Parallel.ml b/Parallel_mpi/Parallel.ml index ec9c76b..3c4ab70 100644 --- a/Parallel_mpi/Parallel.ml +++ b/Parallel_mpi/Parallel.ml @@ -55,6 +55,26 @@ let broadcast_vec x = Lacaml.D.Vec.of_array a +module Node = struct + + let name = Unix.gethostname () + + let comm_node = + Mpi.allgather (name, rank) Mpi.comm_world + |> Array.to_list + |> List.filter (fun (n, r) -> name = n) + |> List.map snd + |> Array.of_list + |> Mpi.(group_incl (comm_group comm_world)) + |> Mpi.(comm_create comm_world) + + let rank = + Mpi.comm_rank comm_node + + let master = rank = 0 + +end + module Vec = struct