10
1
mirror of https://gitlab.com/scemama/QCaml.git synced 2024-06-26 15:12:05 +02:00

Split and join matrices

This commit is contained in:
Anthony Scemama 2019-04-02 13:54:16 +02:00
parent 3c3b2f14ab
commit f00a490b5e
8 changed files with 143 additions and 25 deletions

View File

@ -388,16 +388,25 @@ let make ?(n_states=1) ?(algo=`Direct) det_space =
Lazy.force m_H
in
let diagonal =
Vec.init (Matrix.dim1 m_H) (fun i -> Matrix.get m_H i i)
Parallel.broadcast (lazy (
Vec.init (Matrix.dim1 m_H) (fun i -> Matrix.get m_H i i)
))
in
let matrix_prod psi =
(*
Matrix.mm ~transa:`T m_H psi
*)
let result =
Matrix.parallel_mm ~transa:`T psi m_H
|> Matrix.transpose
in
Parallel.broadcast (lazy result)
in
let eigenvectors, eigenvalues =
let result = lazy (
let result =
Davidson.make ~threshold:1.e-6 ~n_states diagonal matrix_prod
) in
Parallel.broadcast result
in
Parallel.broadcast (lazy result)
in
let eigenvalues = Vec.map (fun x -> x +. e_shift) eigenvalues in
eigenvectors, eigenvalues

View File

@ -59,7 +59,7 @@ module Node = struct
let name = Unix.gethostname ()
let comm =
let comm = lazy (
Mpi.allgather (name, rank) Mpi.comm_world
|> Array.to_list
|> List.filter (fun (n, r) -> name = n)
@ -67,9 +67,10 @@ module Node = struct
|> Array.of_list
|> Mpi.(group_incl (comm_group comm_world))
|> Mpi.(comm_create comm_world)
)
let rank =
Mpi.comm_rank comm
Mpi.comm_rank (Lazy.force comm)
let master = rank = 0
@ -78,13 +79,13 @@ module Node = struct
if master then Some (Lazy.force x)
else None
in
match broadcast x 0 comm with
match broadcast x 0 (Lazy.force comm) with
| Some x -> x
| None -> assert false
let broadcast x = broadcast_generic Mpi.broadcast x
let barrier () = Mpi.barrier comm
let barrier () = Mpi.barrier (Lazy.force comm )
end

View File

@ -36,9 +36,6 @@ module Node : sig
val name : string
(** Name of the current host *)
val comm : Mpi.communicator
(** MPI Communicator on the current node *)
val rank : Mpi.rank
(** Rank of the current process in the node *)

View File

@ -9,8 +9,7 @@ let master = true
let barrier () = ()
let broadcast x =
Lazy.force x
let broadcast x = Lazy.force x
let broadcast_int x = x
@ -22,6 +21,19 @@ let broadcast_float_array x = x
let broadcast_vec x = x
module Node = struct
let name = Unix.gethostname ()
let rank = 0
let master = true
let broadcast x = Lazy.force x
let barrier () = ()
end
module Vec = struct

View File

@ -31,6 +31,26 @@ val broadcast_vec : Lacaml.D.vec -> Lacaml.D.vec
(** Broadcasts a Lacaml vector to all processes. *)
(** {5 Intra-node operations} *)
module Node : sig
val name : string
(** Name of the current host *)
val rank : int
(** Rank of the current process in the node *)
val master : bool
(** If true, master process of the node *)
val broadcast : 'a lazy_t -> 'a
(** Broadcasts data to all the processes of the current node. *)
val barrier : unit -> unit
(** Wait for all processes among the node to reach this point. *)
end
(** {5 Vector operations} *)
module Vec : sig

View File

@ -0,0 +1,3 @@
let create ?(temp_dir="/dev/shm") data_type size_array =
Bigarray.Genarray.create data_type Bigarray.fortran_layout size_array

View File

@ -323,25 +323,65 @@ let rec op2 dense_op sparse_op a b =
| (Dense _), (Sparse _) -> op2 dense_op sparse_op (sparse_of_dense a) b
| (Sparse _), (Dense _) -> op2 dense_op sparse_op a (sparse_of_dense b)
| (Sparse a), (Sparse b) -> Sparse
{ m=a.m ; n=a.n ;
v = Array.map2 sparse_op a.v b.v
}
{ m=a.m ; n=a.n ;
v = Array.map2 sparse_op a.v b.v
}
let add = op2 (fun a b -> Mat.add a b) (fun a b -> Vector.add a b)
let sub = op2 (fun a b -> Mat.sub a b) (fun a b -> Vector.sub a b)
let scale f = function
| Dense a -> let b = lacpy a in (Mat.scal f b ; Dense b)
| Sparse a -> Sparse
{ a with
v = if f = 1.0 then a.v
else Array.map (fun v -> Vector.scale f v) a.v }
| Dense a -> let b = lacpy a in (Mat.scal f b ; Dense b)
| Sparse a -> Sparse
{ a with
v = if f = 1.0 then a.v
else Array.map (fun v -> Vector.scale f v) a.v }
let frobenius_norm = function
| Dense a -> lange ~norm:`F a
| Sparse a ->
Array.fold_left (fun accu v -> accu +. Vector.dot v v) 0. a.v
|> sqrt
| Dense a -> lange ~norm:`F a
| Sparse a ->
Array.fold_left (fun accu v -> accu +. Vector.dot v v) 0. a.v
|> sqrt
let split_cols nrows = function
| Dense a ->
begin
Mat.to_col_vecs a
|> Array.to_list
|> Util.list_pack nrows
|> List.map (fun l ->
Dense (Mat.of_col_vecs @@ Array.of_list l) )
end
| Sparse a ->
begin
Array.to_list a.v
|> Util.list_pack nrows
|> List.map Array.of_list
|> List.map (fun v -> Sparse { m=a.m ; n= Array.length v ; v })
end
let join_cols l =
let rec aux_dense accu = function
| [] -> Dense (Mat.of_col_vecs_list (List.concat accu))
| (Dense a) :: rest -> aux_dense ((Mat.to_col_vecs_list a) :: accu) rest
| _ -> assert false
and aux_sparse m n accu = function
| [] -> Sparse { m ; n ; v=Array.of_list (List.concat accu) }
| (Sparse a) :: rest -> aux_sparse a.m (n+a.n) ((Array.to_list a.v)::accu) rest
| _ -> assert false
and aux = function
| [] -> Sparse { m=0 ; n=0 ; v=[| |] }
| (Dense a) :: rest -> aux_dense [] ((Dense a) :: rest)
| (Sparse a) :: rest -> aux_sparse 0 0 [] ((Sparse a) :: rest)
in aux (List.rev l)
@ -396,6 +436,19 @@ let rec ax_eq_b ?(trans=`N) a b =
}
(* ------- Parallel routines ---------- *)
let parallel_mm ?(transa=`N) ?(transb=`N) ?(threshold=epsilon) a b =
let n = 4 in
split_cols n b
|> Stream.of_list
|> Farm.run ~ordered:true ~f:(fun b ->
mm ~transa ~transb ~threshold a b
)
|> Util.stream_to_list
|> join_cols
(* ------------ Printers ------------ *)
@ -591,6 +644,16 @@ let test_case () =
Alcotest.(check (float 1.e-10)) "sparse sparse 4" 0. (norm_diff (mm a x) b);
in
let test_split_join () =
let m1_split = split_cols 7 m1 in
let m1_s_split = split_cols 7 m1_s in
let m2 = join_cols m1_split in
let m2_s = join_cols m1_s_split in
Alcotest.(check int) "length" 6 (List.length m1_split);
Alcotest.(check int) "length" 6 (List.length m1_s_split);
Alcotest.(check bool) "join" true (m1 = m2);
Alcotest.(check bool) "join" true (m1_s = m2_s);
in
[
"Conversion", `Quick, test_conversion;
"Dimensions", `Quick, test_dimensions;
@ -600,5 +663,6 @@ let test_case () =
"Matrix Vector", `Quick, test_mv;
"Matrix Matrix", `Quick, test_mm;
"Linear solve", `Quick, test_solve;
"split_join", `Quick, test_split_join;
]

View File

@ -48,6 +48,12 @@ val sparse_of_vector_array : Vector.t array -> t
val transpose : t -> t
(** Returns the transposed matrix. *)
val split_cols : int -> t -> t list
(** [split_cols n m_M] Split the matrix [m_M] by packs of [n] columns *)
val join_cols : t list -> t
(** [join_cols l] Joins the list of matrices into a single matrix (along columns). *)
(** {1 Operations} *)
@ -69,6 +75,12 @@ val add : t -> t -> t
val sub : t -> t -> t
(** Subtract two matrices *)
(** {1 Parallel routines } *)
val parallel_mm : ?transa:trans3 -> ?transb:trans3 -> ?threshold:float -> t -> t -> t
(** Matrix multiplication, parallelized by splitting b along the columns. *)
(** {1 Printers } *)
val pp_matrix : Format.formatter -> t -> unit