mirror of
https://gitlab.com/scemama/QCaml.git
synced 2024-08-30 00:03:42 +02:00
Split and join matrices
This commit is contained in:
parent
3c3b2f14ab
commit
f00a490b5e
17
CI/CI.ml
17
CI/CI.ml
@ -388,16 +388,25 @@ let make ?(n_states=1) ?(algo=`Direct) det_space =
|
|||||||
Lazy.force m_H
|
Lazy.force m_H
|
||||||
in
|
in
|
||||||
let diagonal =
|
let diagonal =
|
||||||
Vec.init (Matrix.dim1 m_H) (fun i -> Matrix.get m_H i i)
|
Parallel.broadcast (lazy (
|
||||||
|
Vec.init (Matrix.dim1 m_H) (fun i -> Matrix.get m_H i i)
|
||||||
|
))
|
||||||
in
|
in
|
||||||
let matrix_prod psi =
|
let matrix_prod psi =
|
||||||
|
(*
|
||||||
Matrix.mm ~transa:`T m_H psi
|
Matrix.mm ~transa:`T m_H psi
|
||||||
|
*)
|
||||||
|
let result =
|
||||||
|
Matrix.parallel_mm ~transa:`T psi m_H
|
||||||
|
|> Matrix.transpose
|
||||||
|
in
|
||||||
|
Parallel.broadcast (lazy result)
|
||||||
in
|
in
|
||||||
let eigenvectors, eigenvalues =
|
let eigenvectors, eigenvalues =
|
||||||
let result = lazy (
|
let result =
|
||||||
Davidson.make ~threshold:1.e-6 ~n_states diagonal matrix_prod
|
Davidson.make ~threshold:1.e-6 ~n_states diagonal matrix_prod
|
||||||
) in
|
in
|
||||||
Parallel.broadcast result
|
Parallel.broadcast (lazy result)
|
||||||
in
|
in
|
||||||
let eigenvalues = Vec.map (fun x -> x +. e_shift) eigenvalues in
|
let eigenvalues = Vec.map (fun x -> x +. e_shift) eigenvalues in
|
||||||
eigenvectors, eigenvalues
|
eigenvectors, eigenvalues
|
||||||
|
@ -59,7 +59,7 @@ module Node = struct
|
|||||||
|
|
||||||
let name = Unix.gethostname ()
|
let name = Unix.gethostname ()
|
||||||
|
|
||||||
let comm =
|
let comm = lazy (
|
||||||
Mpi.allgather (name, rank) Mpi.comm_world
|
Mpi.allgather (name, rank) Mpi.comm_world
|
||||||
|> Array.to_list
|
|> Array.to_list
|
||||||
|> List.filter (fun (n, r) -> name = n)
|
|> List.filter (fun (n, r) -> name = n)
|
||||||
@ -67,9 +67,10 @@ module Node = struct
|
|||||||
|> Array.of_list
|
|> Array.of_list
|
||||||
|> Mpi.(group_incl (comm_group comm_world))
|
|> Mpi.(group_incl (comm_group comm_world))
|
||||||
|> Mpi.(comm_create comm_world)
|
|> Mpi.(comm_create comm_world)
|
||||||
|
)
|
||||||
|
|
||||||
let rank =
|
let rank =
|
||||||
Mpi.comm_rank comm
|
Mpi.comm_rank (Lazy.force comm)
|
||||||
|
|
||||||
let master = rank = 0
|
let master = rank = 0
|
||||||
|
|
||||||
@ -78,13 +79,13 @@ module Node = struct
|
|||||||
if master then Some (Lazy.force x)
|
if master then Some (Lazy.force x)
|
||||||
else None
|
else None
|
||||||
in
|
in
|
||||||
match broadcast x 0 comm with
|
match broadcast x 0 (Lazy.force comm) with
|
||||||
| Some x -> x
|
| Some x -> x
|
||||||
| None -> assert false
|
| None -> assert false
|
||||||
|
|
||||||
let broadcast x = broadcast_generic Mpi.broadcast x
|
let broadcast x = broadcast_generic Mpi.broadcast x
|
||||||
|
|
||||||
let barrier () = Mpi.barrier comm
|
let barrier () = Mpi.barrier (Lazy.force comm )
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
@ -36,9 +36,6 @@ module Node : sig
|
|||||||
val name : string
|
val name : string
|
||||||
(** Name of the current host *)
|
(** Name of the current host *)
|
||||||
|
|
||||||
val comm : Mpi.communicator
|
|
||||||
(** MPI Communicator on the current node *)
|
|
||||||
|
|
||||||
val rank : Mpi.rank
|
val rank : Mpi.rank
|
||||||
(** Rank of the current process in the node *)
|
(** Rank of the current process in the node *)
|
||||||
|
|
||||||
|
@ -9,8 +9,7 @@ let master = true
|
|||||||
|
|
||||||
let barrier () = ()
|
let barrier () = ()
|
||||||
|
|
||||||
let broadcast x =
|
let broadcast x = Lazy.force x
|
||||||
Lazy.force x
|
|
||||||
|
|
||||||
let broadcast_int x = x
|
let broadcast_int x = x
|
||||||
|
|
||||||
@ -22,6 +21,19 @@ let broadcast_float_array x = x
|
|||||||
|
|
||||||
let broadcast_vec x = x
|
let broadcast_vec x = x
|
||||||
|
|
||||||
|
module Node = struct
|
||||||
|
|
||||||
|
let name = Unix.gethostname ()
|
||||||
|
|
||||||
|
let rank = 0
|
||||||
|
|
||||||
|
let master = true
|
||||||
|
|
||||||
|
let broadcast x = Lazy.force x
|
||||||
|
|
||||||
|
let barrier () = ()
|
||||||
|
|
||||||
|
end
|
||||||
|
|
||||||
module Vec = struct
|
module Vec = struct
|
||||||
|
|
||||||
|
@ -31,6 +31,26 @@ val broadcast_vec : Lacaml.D.vec -> Lacaml.D.vec
|
|||||||
(** Broadcasts a Lacaml vector to all processes. *)
|
(** Broadcasts a Lacaml vector to all processes. *)
|
||||||
|
|
||||||
|
|
||||||
|
(** {5 Intra-node operations} *)
|
||||||
|
module Node : sig
|
||||||
|
val name : string
|
||||||
|
(** Name of the current host *)
|
||||||
|
|
||||||
|
val rank : int
|
||||||
|
(** Rank of the current process in the node *)
|
||||||
|
|
||||||
|
val master : bool
|
||||||
|
(** If true, master process of the node *)
|
||||||
|
|
||||||
|
val broadcast : 'a lazy_t -> 'a
|
||||||
|
(** Broadcasts data to all the processes of the current node. *)
|
||||||
|
|
||||||
|
val barrier : unit -> unit
|
||||||
|
(** Wait for all processes among the node to reach this point. *)
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
(** {5 Vector operations} *)
|
(** {5 Vector operations} *)
|
||||||
module Vec : sig
|
module Vec : sig
|
||||||
|
|
||||||
|
3
Parallel_serial/SharedMemory.ml
Normal file
3
Parallel_serial/SharedMemory.ml
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
let create ?(temp_dir="/dev/shm") data_type size_array =
|
||||||
|
Bigarray.Genarray.create data_type Bigarray.fortran_layout size_array
|
||||||
|
|
@ -323,25 +323,65 @@ let rec op2 dense_op sparse_op a b =
|
|||||||
| (Dense _), (Sparse _) -> op2 dense_op sparse_op (sparse_of_dense a) b
|
| (Dense _), (Sparse _) -> op2 dense_op sparse_op (sparse_of_dense a) b
|
||||||
| (Sparse _), (Dense _) -> op2 dense_op sparse_op a (sparse_of_dense b)
|
| (Sparse _), (Dense _) -> op2 dense_op sparse_op a (sparse_of_dense b)
|
||||||
| (Sparse a), (Sparse b) -> Sparse
|
| (Sparse a), (Sparse b) -> Sparse
|
||||||
{ m=a.m ; n=a.n ;
|
{ m=a.m ; n=a.n ;
|
||||||
v = Array.map2 sparse_op a.v b.v
|
v = Array.map2 sparse_op a.v b.v
|
||||||
}
|
}
|
||||||
|
|
||||||
let add = op2 (fun a b -> Mat.add a b) (fun a b -> Vector.add a b)
|
let add = op2 (fun a b -> Mat.add a b) (fun a b -> Vector.add a b)
|
||||||
let sub = op2 (fun a b -> Mat.sub a b) (fun a b -> Vector.sub a b)
|
let sub = op2 (fun a b -> Mat.sub a b) (fun a b -> Vector.sub a b)
|
||||||
|
|
||||||
let scale f = function
|
let scale f = function
|
||||||
| Dense a -> let b = lacpy a in (Mat.scal f b ; Dense b)
|
| Dense a -> let b = lacpy a in (Mat.scal f b ; Dense b)
|
||||||
| Sparse a -> Sparse
|
| Sparse a -> Sparse
|
||||||
{ a with
|
{ a with
|
||||||
v = if f = 1.0 then a.v
|
v = if f = 1.0 then a.v
|
||||||
else Array.map (fun v -> Vector.scale f v) a.v }
|
else Array.map (fun v -> Vector.scale f v) a.v }
|
||||||
|
|
||||||
let frobenius_norm = function
|
let frobenius_norm = function
|
||||||
| Dense a -> lange ~norm:`F a
|
| Dense a -> lange ~norm:`F a
|
||||||
| Sparse a ->
|
| Sparse a ->
|
||||||
Array.fold_left (fun accu v -> accu +. Vector.dot v v) 0. a.v
|
Array.fold_left (fun accu v -> accu +. Vector.dot v v) 0. a.v
|
||||||
|> sqrt
|
|> sqrt
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
let split_cols nrows = function
|
||||||
|
| Dense a ->
|
||||||
|
begin
|
||||||
|
Mat.to_col_vecs a
|
||||||
|
|> Array.to_list
|
||||||
|
|> Util.list_pack nrows
|
||||||
|
|> List.map (fun l ->
|
||||||
|
Dense (Mat.of_col_vecs @@ Array.of_list l) )
|
||||||
|
end
|
||||||
|
| Sparse a ->
|
||||||
|
begin
|
||||||
|
Array.to_list a.v
|
||||||
|
|> Util.list_pack nrows
|
||||||
|
|> List.map Array.of_list
|
||||||
|
|> List.map (fun v -> Sparse { m=a.m ; n= Array.length v ; v })
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
let join_cols l =
|
||||||
|
let rec aux_dense accu = function
|
||||||
|
| [] -> Dense (Mat.of_col_vecs_list (List.concat accu))
|
||||||
|
| (Dense a) :: rest -> aux_dense ((Mat.to_col_vecs_list a) :: accu) rest
|
||||||
|
| _ -> assert false
|
||||||
|
|
||||||
|
and aux_sparse m n accu = function
|
||||||
|
| [] -> Sparse { m ; n ; v=Array.of_list (List.concat accu) }
|
||||||
|
| (Sparse a) :: rest -> aux_sparse a.m (n+a.n) ((Array.to_list a.v)::accu) rest
|
||||||
|
| _ -> assert false
|
||||||
|
|
||||||
|
and aux = function
|
||||||
|
| [] -> Sparse { m=0 ; n=0 ; v=[| |] }
|
||||||
|
| (Dense a) :: rest -> aux_dense [] ((Dense a) :: rest)
|
||||||
|
| (Sparse a) :: rest -> aux_sparse 0 0 [] ((Sparse a) :: rest)
|
||||||
|
|
||||||
|
in aux (List.rev l)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -396,6 +436,19 @@ let rec ax_eq_b ?(trans=`N) a b =
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
(* ------- Parallel routines ---------- *)
|
||||||
|
|
||||||
|
let parallel_mm ?(transa=`N) ?(transb=`N) ?(threshold=epsilon) a b =
|
||||||
|
|
||||||
|
let n = 4 in
|
||||||
|
split_cols n b
|
||||||
|
|> Stream.of_list
|
||||||
|
|> Farm.run ~ordered:true ~f:(fun b ->
|
||||||
|
mm ~transa ~transb ~threshold a b
|
||||||
|
)
|
||||||
|
|> Util.stream_to_list
|
||||||
|
|> join_cols
|
||||||
|
|
||||||
|
|
||||||
(* ------------ Printers ------------ *)
|
(* ------------ Printers ------------ *)
|
||||||
|
|
||||||
@ -591,6 +644,16 @@ let test_case () =
|
|||||||
Alcotest.(check (float 1.e-10)) "sparse sparse 4" 0. (norm_diff (mm a x) b);
|
Alcotest.(check (float 1.e-10)) "sparse sparse 4" 0. (norm_diff (mm a x) b);
|
||||||
in
|
in
|
||||||
|
|
||||||
|
let test_split_join () =
|
||||||
|
let m1_split = split_cols 7 m1 in
|
||||||
|
let m1_s_split = split_cols 7 m1_s in
|
||||||
|
let m2 = join_cols m1_split in
|
||||||
|
let m2_s = join_cols m1_s_split in
|
||||||
|
Alcotest.(check int) "length" 6 (List.length m1_split);
|
||||||
|
Alcotest.(check int) "length" 6 (List.length m1_s_split);
|
||||||
|
Alcotest.(check bool) "join" true (m1 = m2);
|
||||||
|
Alcotest.(check bool) "join" true (m1_s = m2_s);
|
||||||
|
in
|
||||||
[
|
[
|
||||||
"Conversion", `Quick, test_conversion;
|
"Conversion", `Quick, test_conversion;
|
||||||
"Dimensions", `Quick, test_dimensions;
|
"Dimensions", `Quick, test_dimensions;
|
||||||
@ -600,5 +663,6 @@ let test_case () =
|
|||||||
"Matrix Vector", `Quick, test_mv;
|
"Matrix Vector", `Quick, test_mv;
|
||||||
"Matrix Matrix", `Quick, test_mm;
|
"Matrix Matrix", `Quick, test_mm;
|
||||||
"Linear solve", `Quick, test_solve;
|
"Linear solve", `Quick, test_solve;
|
||||||
|
"split_join", `Quick, test_split_join;
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -48,6 +48,12 @@ val sparse_of_vector_array : Vector.t array -> t
|
|||||||
val transpose : t -> t
|
val transpose : t -> t
|
||||||
(** Returns the transposed matrix. *)
|
(** Returns the transposed matrix. *)
|
||||||
|
|
||||||
|
val split_cols : int -> t -> t list
|
||||||
|
(** [split_cols n m_M] Split the matrix [m_M] by packs of [n] columns *)
|
||||||
|
|
||||||
|
val join_cols : t list -> t
|
||||||
|
(** [join_cols l] Joins the list of matrices into a single matrix (along columns). *)
|
||||||
|
|
||||||
|
|
||||||
(** {1 Operations} *)
|
(** {1 Operations} *)
|
||||||
|
|
||||||
@ -69,6 +75,12 @@ val add : t -> t -> t
|
|||||||
val sub : t -> t -> t
|
val sub : t -> t -> t
|
||||||
(** Subtract two matrices *)
|
(** Subtract two matrices *)
|
||||||
|
|
||||||
|
|
||||||
|
(** {1 Parallel routines } *)
|
||||||
|
|
||||||
|
val parallel_mm : ?transa:trans3 -> ?transb:trans3 -> ?threshold:float -> t -> t -> t
|
||||||
|
(** Matrix multiplication, parallelized by splitting b along the columns. *)
|
||||||
|
|
||||||
(** {1 Printers } *)
|
(** {1 Printers } *)
|
||||||
|
|
||||||
val pp_matrix : Format.formatter -> t -> unit
|
val pp_matrix : Format.formatter -> t -> unit
|
||||||
|
Loading…
Reference in New Issue
Block a user