mirror of
https://gitlab.com/scemama/QCaml.git
synced 2025-01-03 01:55:40 +01:00
Dot product OK
This commit is contained in:
parent
224a87d422
commit
0baf509285
@ -23,7 +23,7 @@ let barrier () =
|
|||||||
|
|
||||||
module Vec = struct
|
module Vec = struct
|
||||||
|
|
||||||
type t =
|
type t =
|
||||||
{
|
{
|
||||||
global_first : int ; (* Lower index in the global array *)
|
global_first : int ; (* Lower index in the global array *)
|
||||||
global_last : int ; (* Higher index in the global array *)
|
global_last : int ; (* Higher index in the global array *)
|
||||||
@ -32,6 +32,15 @@ module Vec = struct
|
|||||||
data : Lacaml.D.vec ; (* Lacaml vector containing the data *)
|
data : Lacaml.D.vec ; (* Lacaml vector containing the data *)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let dim vec =
|
||||||
|
vec.global_last - vec.global_first + 1
|
||||||
|
|
||||||
|
let local_first vec = vec.local_first
|
||||||
|
let local_last vec = vec.local_last
|
||||||
|
let global_first vec = vec.global_first
|
||||||
|
let global_last vec = vec.global_last
|
||||||
|
let data vec = vec.data
|
||||||
|
|
||||||
let pp ppf v =
|
let pp ppf v =
|
||||||
Format.fprintf ppf "@[<2>";
|
Format.fprintf ppf "@[<2>";
|
||||||
Format.fprintf ppf "@[ gf : %d@]@;" v.global_first;
|
Format.fprintf ppf "@[ gf : %d@]@;" v.global_first;
|
||||||
@ -93,7 +102,7 @@ module Vec = struct
|
|||||||
|
|
||||||
|
|
||||||
let to_array vec =
|
let to_array vec =
|
||||||
let final_size = vec.global_last - vec.global_first + 1 in
|
let final_size = dim vec in
|
||||||
let buffer_size = (Lacaml.D.Vec.dim vec.data) * size in
|
let buffer_size = (Lacaml.D.Vec.dim vec.data) * size in
|
||||||
let buffer = Array.make buffer_size 0. in
|
let buffer = Array.make buffer_size 0. in
|
||||||
let data = Lacaml.D.Vec.to_array vec.data in
|
let data = Lacaml.D.Vec.to_array vec.data in
|
||||||
@ -113,4 +122,17 @@ module Vec = struct
|
|||||||
to_array v
|
to_array v
|
||||||
|> Lacaml.D.Vec.of_array
|
|> Lacaml.D.Vec.of_array
|
||||||
|
|
||||||
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
let dot v1 v2 =
|
||||||
|
if Vec.dim v1 <> Vec.dim v2 then
|
||||||
|
invalid_arg "Incompatible dimensions";
|
||||||
|
let local_dot =
|
||||||
|
Lacaml.D.dot (Vec.data v1) (Vec.data v2)
|
||||||
|
in
|
||||||
|
Mpi.reduce_float local_dot Mpi.Float_sum 0 Mpi.comm_world
|
||||||
|
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ val barrier : unit -> unit
|
|||||||
(** {5 Vector operations} *)
|
(** {5 Vector operations} *)
|
||||||
module Vec : sig
|
module Vec : sig
|
||||||
|
|
||||||
type t =
|
type t = private
|
||||||
{
|
{
|
||||||
global_first : int ; (* Lower index in the global array *)
|
global_first : int ; (* Lower index in the global array *)
|
||||||
global_last : int ; (* Higher index in the global array *)
|
global_last : int ; (* Higher index in the global array *)
|
||||||
@ -26,7 +26,7 @@ module Vec : sig
|
|||||||
|
|
||||||
val pp : Format.formatter -> t -> unit
|
val pp : Format.formatter -> t -> unit
|
||||||
|
|
||||||
(** {6 Creation/conversion of vectors and dimension accessor} *)
|
(** {6 Creation/conversion of vectors} *)
|
||||||
|
|
||||||
val create : int -> t
|
val create : int -> t
|
||||||
(** [create n] @return a distributed vector with [n] rows (not initialized). *)
|
(** [create n] @return a distributed vector with [n] rows (not initialized). *)
|
||||||
@ -54,6 +54,27 @@ module Vec : sig
|
|||||||
val to_vec : t -> Lacaml.D.vec
|
val to_vec : t -> Lacaml.D.vec
|
||||||
(** [to_vec v] @return a Lacaml vector initialized from vector [v]. *)
|
(** [to_vec v] @return a Lacaml vector initialized from vector [v]. *)
|
||||||
|
|
||||||
|
|
||||||
|
(** {6 Accessors } *)
|
||||||
|
|
||||||
|
val dim : t -> int
|
||||||
|
(** [dim v] @return the dimension of the vector [v]. *)
|
||||||
|
|
||||||
|
val global_first : t -> int
|
||||||
|
(** [global_first v] @return the index of the first element of [v]. *)
|
||||||
|
|
||||||
|
val global_last : t -> int
|
||||||
|
(** [global_last v] @return the index of the last element of [v]. *)
|
||||||
|
|
||||||
|
val local_first : t -> int
|
||||||
|
(** [local_first v] @return the index of the first element of the local piece of [v]. *)
|
||||||
|
|
||||||
|
val global_last : t -> int
|
||||||
|
(** [local_last v] @return the index of the last element of the local piece of [v]. *)
|
||||||
|
|
||||||
|
val data : t -> Lacaml.D.vec
|
||||||
|
(** [data v] @return the local Lacaml vector in which the piece of the vector [v] is stored. *)
|
||||||
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
@ -77,10 +98,10 @@ end
|
|||||||
|
|
||||||
val gemm : Mat.t -> Mat.t -> Mat.t
|
val gemm : Mat.t -> Mat.t -> Mat.t
|
||||||
(* Distributed matrix-matrix product. The result is a distributed matrix. *)
|
(* Distributed matrix-matrix product. The result is a distributed matrix. *)
|
||||||
|
*)
|
||||||
|
|
||||||
val dot : Vec.t -> Vec.t-> float
|
val dot : Vec.t -> Vec.t-> float
|
||||||
(* Dot product between distributed vectors. *)
|
(* Dot product between distributed vectors. *)
|
||||||
*)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,11 +5,18 @@ let () =
|
|||||||
(*
|
(*
|
||||||
let v = Parallel.Vec.init 47 (fun i -> float_of_int i) in
|
let v = Parallel.Vec.init 47 (fun i -> float_of_int i) in
|
||||||
*)
|
*)
|
||||||
let a = Array.init 6 (fun i -> float_of_int (i+1)) |> Lacaml.D.Vec.of_array in
|
let a = Array.init 41 (fun i -> float_of_int (i+1)) |> Lacaml.D.Vec.of_array in
|
||||||
let v = Parallel.Vec.of_vec a in
|
let b = Array.init 41 (fun i -> float_of_int (3*i+1)) |> Lacaml.D.Vec.of_array in
|
||||||
Format.printf "%a" Parallel.Vec.pp v;
|
let v1 = Parallel.Vec.of_vec a in
|
||||||
|
let v2 = Parallel.Vec.of_vec b in
|
||||||
|
let d1 = Parallel.dot v1 v2 in
|
||||||
|
let d2 = Lacaml.D.dot a b in
|
||||||
|
(*
|
||||||
let w = Parallel.Vec.to_vec v in
|
let w = Parallel.Vec.to_vec v in
|
||||||
|
Format.printf "%a" Parallel.Vec.pp v;
|
||||||
if Parallel.master then
|
if Parallel.master then
|
||||||
Format.printf "@[%a@]@;" (Lacaml.Io.pp_lfvec ()) w;
|
Format.printf "@[%a@]@;" (Lacaml.Io.pp_lfvec ()) w;
|
||||||
|
*)
|
||||||
|
Printf.printf "%f %f\n" d1 d2;
|
||||||
print_newline ();
|
print_newline ();
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user