(** Module for handling distributed parallelism *) let size = let result = Mpi.comm_size Mpi.comm_world in assert (result > 0); result let rank = let result = Mpi.comm_rank Mpi.comm_world in assert (result >= 0); result let master = rank = 0 let barrier () = Mpi.barrier Mpi.comm_world let broadcast x = let x = if master then Some (Lazy.force x) else None in match Mpi.broadcast x 0 Mpi.comm_world with | Some x -> x | None -> assert false let broadcast_int x = Mpi.broadcast_int x 0 Mpi.comm_world let broadcast_int_array x = Mpi.broadcast_int_array x 0 Mpi.comm_world; x let broadcast_float x = Mpi.broadcast_float x 0 Mpi.comm_world let broadcast_float_array x = Mpi.broadcast_float_array x 0 Mpi.comm_world; x let broadcast_vec x = let a = Lacaml.D.Vec.to_array x in let a = broadcast_float_array a in Lacaml.D.Vec.of_array a module Vec = struct type t = { global_first : int ; (* Lower index in the global array *) global_last : int ; (* Higher index in the global array *) local_first : int ; (* Lower index in the local array *) local_last : int ; (* Higher index in the local array *) data : Lacaml.D.vec ; (* Lacaml vector containing the data *) } let dim vec = vec.global_last - vec.global_first + 1 let local_first vec = vec.local_first let local_last vec = vec.local_last let global_first vec = vec.global_first let global_last vec = vec.global_last let data vec = vec.data let pp ppf v = Format.fprintf ppf "@[<2>"; Format.fprintf ppf "@[ gf : %d@]@;" v.global_first; Format.fprintf ppf "@[ gl : %d@]@;" v.global_last; Format.fprintf ppf "@[ lf : %d@]@;" v.local_first; Format.fprintf ppf "@[ ll : %d@]@;" v.local_last; Format.fprintf ppf "@[ data : %a@]@;" (Lacaml.Io.pp_lfvec ()) v.data; Format.fprintf ppf "@]@."; () let create n = let step = (n-1) / size + 1 in let local_first = step * rank + 1 in let local_last = min (local_first + step - 1) n in { global_first = 1 ; global_last = n ; local_first ; local_last ; data = Lacaml.D.Vec.create (max 0 (local_last - local_first + 1)) } let make n x = let result = create n in { result with data = Lacaml.D.Vec.make (Lacaml.D.Vec.dim result.data) x } let make0 n = make n 0. let init n f = let result = create n in { result with data = Lacaml.D.Vec.init (Lacaml.D.Vec.dim result.data) (fun i -> f (i+result.local_first-1)) } let of_array a = let length_a = Array.length a in let a = let n = length_a mod size in if n > 0 then Array.concat [ a ; Array.make (size-n) 0. ] else a in let result = create length_a in let a_local = Array.make ((Array.length a)/size) 0. in let () = Mpi.scatter_float_array a a_local 0 Mpi.comm_world in { result with data = Lacaml.D.Vec.of_array a_local } let to_array vec = let final_size = dim vec in let buffer_size = (Lacaml.D.Vec.dim vec.data) * size in let buffer = Array.make buffer_size 0. in let data = Lacaml.D.Vec.to_array vec.data in let () = Mpi.gather_float_array data buffer 0 Mpi.comm_world in if final_size = buffer_size then buffer else Array.init final_size (fun i -> buffer.(i)) let of_vec a = Lacaml.D.Vec.to_array a |> of_array let to_vec v = to_array v |> Lacaml.D.Vec.of_array end let dot v1 v2 = if Vec.dim v1 <> Vec.dim v2 then invalid_arg "Incompatible dimensions"; let local_dot = Lacaml.D.dot (Vec.data v1) (Vec.data v2) in Mpi.reduce_float local_dot Mpi.Float_sum 0 Mpi.comm_world