10
1
mirror of https://gitlab.com/scemama/QCaml.git synced 2024-11-07 06:33:39 +01:00

Working on MPI

This commit is contained in:
Anthony Scemama 2020-01-23 21:24:05 +01:00
parent b0135a48f4
commit 7808e28ae7
2 changed files with 68 additions and 30 deletions

View File

@ -60,15 +60,16 @@ module Node = struct
let name = Unix.gethostname () let name = Unix.gethostname ()
let comm = let comm =
Mpi.allgather (name, rank) Mpi.comm_world let _, color =
|> Array.to_list Mpi.allgather (name, rank) Mpi.comm_world
|> List.filter (fun (n, r) -> name = n) |> Array.to_list
|> List.map snd |> List.sort compare
|> Array.of_list |> List.find (fun (n, r) -> n = name)
|> Mpi.(group_incl (comm_group comm_world)) in
|> Mpi.(comm_create comm_world) Mpi.(comm_split comm_world color 0)
let rank = let rank =
Printf.printf "Node: %d %d\n%!" rank (Mpi.comm_rank comm);
Mpi.comm_rank comm Mpi.comm_rank comm
let master = rank = 0 let master = rank = 0
@ -85,44 +86,81 @@ module Node = struct
let broadcast x = broadcast_generic Mpi.broadcast x let broadcast x = broadcast_generic Mpi.broadcast x
let barrier () = Mpi.barrier comm let barrier () = Mpi.barrier comm
end end
module InterNode = struct module InterNode = struct
let comm = let comm =
let rec aux accu name = function
| [] -> List.rev accu let ranks =
| (newname, rank) :: rest when newname = name -> aux accu name rest let name = Unix.gethostname () in
| (newname, rank) :: rest -> aux (rank :: accu) newname rest
let rec aux accu old_name = function
| [] -> List.rev accu |> Array.of_list
| (new_name, r) :: rest when new_name <> old_name ->
aux (r::accu) new_name rest
| (new_name, r) :: rest -> aux accu new_name rest
in
Mpi.allgather (name, rank) Mpi.comm_world
|> Array.to_list
|> List.sort compare
|> aux [] ""
in in
let name = Unix.gethostname () in let world_group =
Mpi.allgather (name, rank) Mpi.comm_world Mpi.comm_group Mpi.comm_world
|> Array.to_list in
|> List.sort compare
|> aux [] "" let new_group =
|> Array.of_list Mpi.group_incl world_group ranks
|> Mpi.(group_incl (comm_group comm_world)) in
|> Mpi.(comm_create comm_world)
let result =
let g =
Mpi.comm_create Mpi.comm_world new_group
in
try
ignore @@ List.find (fun x -> x = rank) @@ Array.to_list ranks;
Some g
with Not_found -> None
in
result
let rank = let rank =
Mpi.comm_rank comm match comm with
| Some comm ->
Printf.printf "InterNode: %d %d\n%!" rank (Mpi.comm_rank comm);
Mpi.comm_rank comm
| None -> -1
let master = rank = 0 let master = rank = 0
let broadcast_generic broadcast x = let broadcast_generic broadcast x =
let x = match comm with
if master then Some (Lazy.force x) | Some comm ->
else None begin
in let x =
match broadcast x 0 comm with if master then Some (Lazy.force x)
| Some x -> x else None
| None -> assert false in
match broadcast x 0 comm with
| Some x -> x
| None -> assert false
end
| None -> Lazy.force x
let broadcast x = broadcast_generic Mpi.broadcast x let broadcast x = broadcast_generic Mpi.broadcast x
let barrier () = Mpi.barrier comm let barrier () =
match comm with
| Some comm -> Mpi.barrier comm
| None -> ()
end end

View File

@ -40,7 +40,7 @@ module Node : sig
val name : string val name : string
(** Name of the current host *) (** Name of the current host *)
val comm : Mpi.communicator val comm : Mpi.communicator
(** MPI Communicator containing the processes of the current node *) (** MPI Communicator containing the processes of the current node *)
val rank : Mpi.rank val rank : Mpi.rank
@ -60,7 +60,7 @@ end
(** {5 Inter-node operations} *) (** {5 Inter-node operations} *)
module InterNode : sig module InterNode : sig
val comm : Mpi.communicator val comm : Mpi.communicator option
(** MPI Communicator among the master processes of the each node *) (** MPI Communicator among the master processes of the each node *)
val rank : Mpi.rank val rank : Mpi.rank