10
1
mirror of https://gitlab.com/scemama/QCaml.git synced 2024-07-25 20:27:28 +02:00
QCaml/Parallel/Parallel.ml

98 lines
2.3 KiB
OCaml
Raw Normal View History

2018-10-16 00:27:58 +02:00
(** Module for handling distributed parallelism *)
2018-10-16 19:09:00 +02:00
let size =
let result =
Mpi.comm_size Mpi.comm_world
in
assert (result > 0);
result
2018-10-16 00:27:58 +02:00
2018-10-16 19:09:00 +02:00
let rank =
let result =
Mpi.comm_rank Mpi.comm_world
in
assert (result >= 0);
result
let barrier () =
Mpi.barrier Mpi.comm_world
module Vec = struct
type t =
{
global_first : int ; (* Lower index in the global array *)
global_last : int ; (* Higher index in the global array *)
local_first : int ; (* Lower index in the local array *)
local_last : int ; (* Higher index in the local array *)
data : Lacaml.D.vec ; (* Lacaml vector containing the data *)
}
let pp ppf v =
Format.fprintf ppf "@[<2>";
Format.fprintf ppf "@[ gf : %d@]@;" v.global_first;
Format.fprintf ppf "@[ gl : %d@]@;" v.global_last;
Format.fprintf ppf "@[ lf : %d@]@;" v.local_first;
Format.fprintf ppf "@[ ll : %d@]@;" v.local_last;
Format.fprintf ppf "@[ data : %a@]@;" (Lacaml.Io.pp_lfvec ()) v.data;
Format.fprintf ppf "@]@.";
()
let create n =
2018-10-17 00:56:14 +02:00
let step = (n-1) / size + 1 in
2018-10-16 19:09:00 +02:00
let local_first = step * rank + 1 in
let local_last = min (local_first + step - 1) n in
{
global_first = 1 ;
global_last = n ;
local_first ;
local_last ;
2018-10-17 00:56:14 +02:00
data = Lacaml.D.Vec.create (max 0 (local_last - local_first + 1))
2018-10-16 19:09:00 +02:00
}
let make n x =
let result = create n in
{ result with data =
Lacaml.D.Vec.make
(Lacaml.D.Vec.dim result.data)
x
}
let make0 n =
make n 0.
let init n f =
let result = create n in
{ result with data =
Lacaml.D.Vec.init
(Lacaml.D.Vec.dim result.data)
(fun i -> f (i+result.local_first-1))
}
2018-10-17 00:56:14 +02:00
let of_array a =
let la = Array.length a in
let a =
let n = la mod size in
if n > 0 then
Array.concat [ a ; Array.make (size-n) 0. ]
else
a
in
let result = create la in
let () = Mpi.barrier Mpi.comm_world in
Printf.printf "%d\n%!" rank;
let a_local = Array.make ((Array.length a)/size) 0. in
Printf.printf "%d\n%!" (Array.length a_local);
let () = Mpi.scatter_float_array a a_local 0 Mpi.comm_world in
{ result with data =
Lacaml.D.Vec.of_array a_local
}
2018-10-16 19:09:00 +02:00
end