mirror of
https://gitlab.com/scemama/QCaml.git
synced 2025-01-03 01:55:40 +01:00
Dot product OK
This commit is contained in:
parent
224a87d422
commit
0baf509285
@ -32,6 +32,15 @@ module Vec = struct
|
||||
data : Lacaml.D.vec ; (* Lacaml vector containing the data *)
|
||||
}
|
||||
|
||||
let dim vec =
|
||||
vec.global_last - vec.global_first + 1
|
||||
|
||||
let local_first vec = vec.local_first
|
||||
let local_last vec = vec.local_last
|
||||
let global_first vec = vec.global_first
|
||||
let global_last vec = vec.global_last
|
||||
let data vec = vec.data
|
||||
|
||||
let pp ppf v =
|
||||
Format.fprintf ppf "@[<2>";
|
||||
Format.fprintf ppf "@[ gf : %d@]@;" v.global_first;
|
||||
@ -93,7 +102,7 @@ module Vec = struct
|
||||
|
||||
|
||||
let to_array vec =
|
||||
let final_size = vec.global_last - vec.global_first + 1 in
|
||||
let final_size = dim vec in
|
||||
let buffer_size = (Lacaml.D.Vec.dim vec.data) * size in
|
||||
let buffer = Array.make buffer_size 0. in
|
||||
let data = Lacaml.D.Vec.to_array vec.data in
|
||||
@ -113,4 +122,17 @@ module Vec = struct
|
||||
to_array v
|
||||
|> Lacaml.D.Vec.of_array
|
||||
|
||||
|
||||
end
|
||||
|
||||
|
||||
|
||||
let dot v1 v2 =
|
||||
if Vec.dim v1 <> Vec.dim v2 then
|
||||
invalid_arg "Incompatible dimensions";
|
||||
let local_dot =
|
||||
Lacaml.D.dot (Vec.data v1) (Vec.data v2)
|
||||
in
|
||||
Mpi.reduce_float local_dot Mpi.Float_sum 0 Mpi.comm_world
|
||||
|
||||
|
||||
|
@ -15,7 +15,7 @@ val barrier : unit -> unit
|
||||
(** {5 Vector operations} *)
|
||||
module Vec : sig
|
||||
|
||||
type t =
|
||||
type t = private
|
||||
{
|
||||
global_first : int ; (* Lower index in the global array *)
|
||||
global_last : int ; (* Higher index in the global array *)
|
||||
@ -26,7 +26,7 @@ module Vec : sig
|
||||
|
||||
val pp : Format.formatter -> t -> unit
|
||||
|
||||
(** {6 Creation/conversion of vectors and dimension accessor} *)
|
||||
(** {6 Creation/conversion of vectors} *)
|
||||
|
||||
val create : int -> t
|
||||
(** [create n] @return a distributed vector with [n] rows (not initialized). *)
|
||||
@ -54,6 +54,27 @@ module Vec : sig
|
||||
val to_vec : t -> Lacaml.D.vec
|
||||
(** [to_vec v] @return a Lacaml vector initialized from vector [v]. *)
|
||||
|
||||
|
||||
(** {6 Accessors } *)
|
||||
|
||||
val dim : t -> int
|
||||
(** [dim v] @return the dimension of the vector [v]. *)
|
||||
|
||||
val global_first : t -> int
|
||||
(** [global_first v] @return the index of the first element of [v]. *)
|
||||
|
||||
val global_last : t -> int
|
||||
(** [global_last v] @return the index of the last element of [v]. *)
|
||||
|
||||
val local_first : t -> int
|
||||
(** [local_first v] @return the index of the first element of the local piece of [v]. *)
|
||||
|
||||
val global_last : t -> int
|
||||
(** [local_last v] @return the index of the last element of the local piece of [v]. *)
|
||||
|
||||
val data : t -> Lacaml.D.vec
|
||||
(** [data v] @return the local Lacaml vector in which the piece of the vector [v] is stored. *)
|
||||
|
||||
end
|
||||
|
||||
|
||||
@ -77,10 +98,10 @@ end
|
||||
|
||||
val gemm : Mat.t -> Mat.t -> Mat.t
|
||||
(* Distributed matrix-matrix product. The result is a distributed matrix. *)
|
||||
*)
|
||||
|
||||
val dot : Vec.t -> Vec.t-> float
|
||||
(* Dot product between distributed vectors. *)
|
||||
*)
|
||||
|
||||
|
||||
|
||||
|
@ -5,11 +5,18 @@ let () =
|
||||
(*
|
||||
let v = Parallel.Vec.init 47 (fun i -> float_of_int i) in
|
||||
*)
|
||||
let a = Array.init 6 (fun i -> float_of_int (i+1)) |> Lacaml.D.Vec.of_array in
|
||||
let v = Parallel.Vec.of_vec a in
|
||||
Format.printf "%a" Parallel.Vec.pp v;
|
||||
let a = Array.init 41 (fun i -> float_of_int (i+1)) |> Lacaml.D.Vec.of_array in
|
||||
let b = Array.init 41 (fun i -> float_of_int (3*i+1)) |> Lacaml.D.Vec.of_array in
|
||||
let v1 = Parallel.Vec.of_vec a in
|
||||
let v2 = Parallel.Vec.of_vec b in
|
||||
let d1 = Parallel.dot v1 v2 in
|
||||
let d2 = Lacaml.D.dot a b in
|
||||
(*
|
||||
let w = Parallel.Vec.to_vec v in
|
||||
Format.printf "%a" Parallel.Vec.pp v;
|
||||
if Parallel.master then
|
||||
Format.printf "@[%a@]@;" (Lacaml.Io.pp_lfvec ()) w;
|
||||
*)
|
||||
Printf.printf "%f %f\n" d1 d2;
|
||||
print_newline ();
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user