10
1
mirror of https://gitlab.com/scemama/QCaml.git synced 2024-11-19 04:22:21 +01:00
QCaml/Parallel_mpi/Parallel.ml

282 lines
6.0 KiB
OCaml
Raw Normal View History

2018-10-16 00:27:58 +02:00
(** Module for handling distributed parallelism *)
2018-10-16 19:09:00 +02:00
let size =
let result =
Mpi.comm_size Mpi.comm_world
in
assert (result > 0);
result
2018-10-16 00:27:58 +02:00
2018-10-16 19:09:00 +02:00
let rank =
let result =
Mpi.comm_rank Mpi.comm_world
in
assert (result >= 0);
result
2018-10-17 11:58:13 +02:00
2018-10-17 10:38:07 +02:00
let master = rank = 0
2018-10-16 19:09:00 +02:00
2018-10-17 11:58:13 +02:00
2018-10-16 19:09:00 +02:00
let barrier () =
Mpi.barrier Mpi.comm_world
2019-03-29 17:38:19 +01:00
let broadcast_generic broadcast x =
2018-10-23 13:39:06 +02:00
let x =
if master then Some (Lazy.force x)
else None
in
2019-03-28 23:38:15 +01:00
match broadcast x 0 Mpi.comm_world with
2018-10-23 13:39:06 +02:00
| Some x -> x
| None -> assert false
2018-10-17 11:58:13 +02:00
2019-03-29 17:38:19 +01:00
let broadcast x = broadcast_generic Mpi.broadcast x
2018-10-17 11:58:13 +02:00
let broadcast_int x =
Mpi.broadcast_int x 0 Mpi.comm_world
let broadcast_int_array x =
2018-10-17 18:48:39 +02:00
Mpi.broadcast_int_array x 0 Mpi.comm_world;
x
2018-10-17 11:58:13 +02:00
let broadcast_float x =
Mpi.broadcast_float x 0 Mpi.comm_world
let broadcast_float_array x =
2018-10-17 18:48:39 +02:00
Mpi.broadcast_float_array x 0 Mpi.comm_world;
x
2018-10-17 11:58:13 +02:00
let broadcast_vec x =
let a = Lacaml.D.Vec.to_array x in
let a = broadcast_float_array a in
Lacaml.D.Vec.of_array a
2019-03-29 23:15:57 +01:00
module Node = struct
let name = Unix.gethostname ()
2019-04-04 09:14:15 +02:00
let comm =
2020-01-23 21:24:05 +01:00
let _, color =
Mpi.allgather (name, rank) Mpi.comm_world
|> Array.to_list
|> List.sort compare
|> List.find (fun (n, r) -> n = name)
in
Mpi.(comm_split comm_world color 0)
2019-03-29 23:15:57 +01:00
let rank =
2020-01-23 21:24:05 +01:00
Printf.printf "Node: %d %d\n%!" rank (Mpi.comm_rank comm);
2019-04-04 09:14:15 +02:00
Mpi.comm_rank comm
2019-03-29 23:15:57 +01:00
let master = rank = 0
2019-03-29 23:39:47 +01:00
let broadcast_generic broadcast x =
let x =
if master then Some (Lazy.force x)
else None
in
2019-04-04 09:14:15 +02:00
match broadcast x 0 comm with
2020-01-23 17:30:01 +01:00
| Some x -> x
| None -> assert false
let broadcast x = broadcast_generic Mpi.broadcast x
let barrier () = Mpi.barrier comm
2020-01-23 21:24:05 +01:00
2020-01-23 17:30:01 +01:00
end
module InterNode = struct
let comm =
2020-01-23 21:24:05 +01:00
let ranks =
let name = Unix.gethostname () in
let rec aux accu old_name = function
| [] -> List.rev accu |> Array.of_list
| (new_name, r) :: rest when new_name <> old_name ->
aux (r::accu) new_name rest
| (new_name, r) :: rest -> aux accu new_name rest
in
Mpi.allgather (name, rank) Mpi.comm_world
|> Array.to_list
|> List.sort compare
|> aux [] ""
in
let world_group =
Mpi.comm_group Mpi.comm_world
in
let new_group =
Mpi.group_incl world_group ranks
in
let result =
let g =
Mpi.comm_create Mpi.comm_world new_group
in
try
ignore @@ List.find (fun x -> x = rank) @@ Array.to_list ranks;
Some g
with Not_found -> None
2020-01-23 17:30:01 +01:00
in
2020-01-23 21:24:05 +01:00
result
2020-01-23 17:30:01 +01:00
let rank =
2020-01-23 21:24:05 +01:00
match comm with
| Some comm ->
Printf.printf "InterNode: %d %d\n%!" rank (Mpi.comm_rank comm);
Mpi.comm_rank comm
| None -> -1
2020-01-23 17:30:01 +01:00
let master = rank = 0
let broadcast_generic broadcast x =
2020-01-23 21:24:05 +01:00
match comm with
| Some comm ->
begin
let x =
if master then Some (Lazy.force x)
else None
in
match broadcast x 0 comm with
| Some x -> x
| None -> assert false
end
| None -> Lazy.force x
2019-03-29 23:39:47 +01:00
let broadcast x = broadcast_generic Mpi.broadcast x
2019-04-01 15:20:17 +02:00
2020-01-23 21:24:05 +01:00
let barrier () =
match comm with
| Some comm -> Mpi.barrier comm
| None -> ()
2019-03-29 23:15:57 +01:00
end
2018-10-17 11:58:13 +02:00
2018-10-16 19:09:00 +02:00
module Vec = struct
2018-10-17 11:44:28 +02:00
type t =
2018-10-16 19:09:00 +02:00
{
global_first : int ; (* Lower index in the global array *)
global_last : int ; (* Higher index in the global array *)
local_first : int ; (* Lower index in the local array *)
local_last : int ; (* Higher index in the local array *)
data : Lacaml.D.vec ; (* Lacaml vector containing the data *)
}
2018-10-17 11:44:28 +02:00
let dim vec =
vec.global_last - vec.global_first + 1
let local_first vec = vec.local_first
let local_last vec = vec.local_last
let global_first vec = vec.global_first
let global_last vec = vec.global_last
let data vec = vec.data
2018-10-16 19:09:00 +02:00
let pp ppf v =
Format.fprintf ppf "@[<2>";
Format.fprintf ppf "@[ gf : %d@]@;" v.global_first;
Format.fprintf ppf "@[ gl : %d@]@;" v.global_last;
Format.fprintf ppf "@[ lf : %d@]@;" v.local_first;
Format.fprintf ppf "@[ ll : %d@]@;" v.local_last;
Format.fprintf ppf "@[ data : %a@]@;" (Lacaml.Io.pp_lfvec ()) v.data;
Format.fprintf ppf "@]@.";
()
let create n =
2018-10-17 00:56:14 +02:00
let step = (n-1) / size + 1 in
2018-10-16 19:09:00 +02:00
let local_first = step * rank + 1 in
let local_last = min (local_first + step - 1) n in
{
global_first = 1 ;
global_last = n ;
local_first ;
local_last ;
2018-10-17 00:56:14 +02:00
data = Lacaml.D.Vec.create (max 0 (local_last - local_first + 1))
2018-10-16 19:09:00 +02:00
}
let make n x =
let result = create n in
{ result with data =
Lacaml.D.Vec.make
(Lacaml.D.Vec.dim result.data)
x
}
let make0 n =
make n 0.
2018-10-17 10:15:02 +02:00
2018-10-16 19:09:00 +02:00
let init n f =
let result = create n in
{ result with data =
Lacaml.D.Vec.init
(Lacaml.D.Vec.dim result.data)
(fun i -> f (i+result.local_first-1))
}
2018-10-17 00:56:14 +02:00
let of_array a =
2018-10-17 10:15:02 +02:00
let length_a = Array.length a in
2018-10-17 00:56:14 +02:00
let a =
2018-10-17 10:15:02 +02:00
let n = length_a mod size in
2018-10-17 00:56:14 +02:00
if n > 0 then
Array.concat [ a ; Array.make (size-n) 0. ]
else
a
in
2018-10-17 10:15:02 +02:00
let result = create length_a in
2018-10-17 00:56:14 +02:00
let a_local = Array.make ((Array.length a)/size) 0. in
let () = Mpi.scatter_float_array a a_local 0 Mpi.comm_world in
2018-10-17 10:15:02 +02:00
{ result with data = Lacaml.D.Vec.of_array a_local }
2018-10-17 10:38:07 +02:00
let to_array vec =
2018-10-17 11:44:28 +02:00
let final_size = dim vec in
2018-10-17 10:38:07 +02:00
let buffer_size = (Lacaml.D.Vec.dim vec.data) * size in
let buffer = Array.make buffer_size 0. in
let data = Lacaml.D.Vec.to_array vec.data in
let () = Mpi.gather_float_array data buffer 0 Mpi.comm_world in
if final_size = buffer_size then
buffer
else
Array.init final_size (fun i -> buffer.(i))
let of_vec a =
Lacaml.D.Vec.to_array a
|> of_array
let to_vec v =
to_array v
|> Lacaml.D.Vec.of_array
2018-10-17 00:56:14 +02:00
2018-10-17 11:44:28 +02:00
2018-10-16 19:09:00 +02:00
end
2018-10-17 11:44:28 +02:00
let dot v1 v2 =
if Vec.dim v1 <> Vec.dim v2 then
invalid_arg "Incompatible dimensions";
let local_dot =
Lacaml.D.dot (Vec.data v1) (Vec.data v2)
in
Mpi.reduce_float local_dot Mpi.Float_sum 0 Mpi.comm_world