10
1
mirror of https://gitlab.com/scemama/QCaml.git synced 2025-01-03 10:05:40 +01:00
QCaml/Utils/Matrix.ml

463 lines
13 KiB
OCaml
Raw Normal View History

2019-02-27 21:28:56 +01:00
open Lacaml.D
type sparse_matrix =
{
m: int;
n: int;
v: Vector.t array;
}
type t =
| Dense of Mat.t
| Sparse of sparse_matrix
let epsilon = Constants.epsilon
let is_sparse = function
| Sparse _ -> true
| Dense _ -> false
let is_dense = function
| Sparse _ -> false
| Dense _ -> true
let dim1 = function
| Dense m -> Mat.dim1 m
| Sparse {m ; n ; v} -> m
let dim2 = function
| Dense m -> Mat.dim2 m
| Sparse {m ; n ; v} -> n
let get = function
| Dense m -> (fun i j -> m.{i,j})
| Sparse {m ; n ; v } -> (fun i j -> Vector.get v.(i-1) j)
let sparse_of_dense ?(threshold=epsilon) = function
| Sparse _ -> invalid_arg "Expected a dense matrix"
| Dense m' ->
let m = Mat.dim1 m'
and n = Mat.dim2 m'
and v =
Mat.to_col_vecs m'
|> Array.map (fun v -> Vector.sparse_of_vec ~threshold v)
in Sparse {m ; n ; v}
let dense_of_sparse = function
| Dense _ -> invalid_arg "Expected a sparse matrix"
| Sparse {m ; n ; v} ->
let m' =
Array.map (fun v -> Vector.to_vec v) v
|> Mat.of_col_vecs
in Dense m'
let dense_of_mat m = Dense m
2019-02-28 12:30:20 +01:00
let rec to_vector_array ?(threshold=epsilon) = function
| Sparse {m ; n ; v} -> v
| Dense m -> to_vector_array (sparse_of_dense ~threshold (Dense m))
2019-02-27 21:28:56 +01:00
let sparse_of_mat ?(threshold=epsilon) m =
dense_of_mat m
|> sparse_of_dense ~threshold
let sparse_of_vector_array v =
let m =
Array.fold_left (fun accu v' ->
if Vector.dim v' <> accu then
invalid_arg "Inconsistent dimension"
else accu) (Vector.dim v.(0)) v
and n = Array.length v
in
Sparse {m ; n ; v}
let rec to_mat = function
| Dense m -> m
| Sparse m ->
dense_of_sparse (Sparse m)
|> to_mat
let transpose = function
| Dense m -> Dense (Mat.transpose_copy m)
| Sparse {m ; n ; v} ->
begin
let v' = Array.init m (fun i -> ref []) in
Array.iteri (fun j v_j ->
Vector.to_assoc_list v_j
|> List.iter (fun (i, v_ij) ->
v'.(i-1) := (j+1, v_ij) :: !(v'.(i-1))
)
) v;
let v' =
Array.map (fun x -> Vector.sparse_of_assoc_list n (List.rev !x) ) v'
in
Sparse {m=n ; n=m ; v=v'}
end
let outer_product ?(threshold=epsilon) v1 v2 =
2019-02-27 23:54:53 +01:00
match Vector.(is_dense v1, is_dense v2) with
| (true, true) ->
let v1 = Vector.to_vec v1
and v2 = Vector.to_vec v2
in
2019-02-28 12:08:28 +01:00
let a = Mat.make0 (Vec.dim v1) (Vec.dim v2) in
2019-02-27 23:54:53 +01:00
ger v1 v2 a;
Dense a
| (true, false) ->
let v = Vector.to_vec v1
and v' = Vector.to_vec v2
in
let v =
Array.init (Vector.dim v2) (fun j ->
Vec.map (fun x -> x *. v'.{j+1}) v
|> Vector.sparse_of_vec)
in
Sparse {m=Vector.dim v1 ; n=Vector.dim v2 ; v}
| (false, true)
| (false, false) ->
let v = Vector.to_assoc_list v1
and v' = Vector.to_vec v2
in
let v =
Array.init (Vector.dim v2) (fun j ->
List.map (fun (i, x) ->
let z = x *. v'.{j+1} in
if abs_float z < threshold then
None
else
Some (i, z)
) v
|> Util.list_some
|> Vector.sparse_of_assoc_list (Vector.dim v1)
)
in
Sparse {m=Vector.dim v1 ; n=Vector.dim v2 ; v}
2019-02-27 21:28:56 +01:00
2019-02-27 23:54:53 +01:00
let mm ?(transa=`N) ?(transb=`N) ?(threshold=epsilon) a b =
2019-02-27 21:28:56 +01:00
2019-02-27 23:54:53 +01:00
let f, f' =
match transa, transb with
| `N, `N -> dim2, dim1
| `T, `N -> dim1, dim1
| `T, `T -> dim1, dim2
| `N, `T -> dim2, dim2
in
if f a <> f' b then
invalid_arg "Inconsistent dimensions";
2019-02-28 00:03:18 +01:00
(* Dense x sparse *)
2019-02-27 23:54:53 +01:00
let mmsp transa transb a b =
let a =
match transa with
| `N -> Mat.transpose_copy a
| `T -> a
in
let m' = Mat.dim2 a in
let a =
Mat.to_col_vecs a
|> Array.map (fun v -> Vector.dense_of_vec v)
in
let {m ; n ; v} =
if transb = `T then
match transpose (Sparse b) with
| Sparse x -> x
| _ -> assert false
else
b
in
let v' =
Array.map (fun b_j ->
Array.map (fun a_i ->
Vector.dot a_i b_j) a
|> Vec.of_array
|> Vector.sparse_of_vec
) v
in
Sparse {m=m' ; n ; v=v'}
in
2019-02-28 00:03:18 +01:00
(* Sparse x dense *)
2019-02-27 23:54:53 +01:00
let spmm transa transb a b =
2019-02-27 21:28:56 +01:00
let b =
2019-02-27 23:54:53 +01:00
match transb with
| `N -> b
| `T -> Mat.transpose_copy b
in
let n' = Mat.dim2 b in
let b =
2019-02-27 21:28:56 +01:00
Mat.to_col_vecs b
|> Array.map (fun v -> Vector.dense_of_vec v)
in
2019-02-27 23:54:53 +01:00
let m, n, v =
if transa = `N then
match transpose (Sparse a) with
| Sparse {m ; n ; v} -> n, m, v
| _ -> assert false
else
match Sparse a with
| Sparse {m ; n ; v} -> n, m, v
| _ -> assert false
in
let v' =
Array.map (fun b_j ->
Array.map (fun a_i ->
Vector.dot a_i b_j) v
|> Vec.of_array
|> Vector.sparse_of_vec
) b
2019-02-27 21:28:56 +01:00
in
2019-02-27 23:54:53 +01:00
Sparse {m ; n=n' ; v=v'}
2019-02-27 21:28:56 +01:00
in
2019-02-28 00:03:18 +01:00
(* Sparse x Sparse *)
2019-02-27 23:54:53 +01:00
let mmspmm transa transb a b =
let {m ; n ; v} =
if transb = `T then
match transpose (Sparse b) with
| Sparse x -> x
| _ -> assert false
else
b
in
let m', n', v' =
if transa = `N then
match transpose (Sparse a) with
| Sparse {m ; n ; v} -> n, m, v
| _ -> assert false
else
match Sparse a with
| Sparse {m ; n ; v} -> n, m, v
| _ -> assert false
in
let v'' =
Array.map (fun b_j ->
Array.map (fun a_i ->
Vector.dot a_i b_j) v'
|> Vec.of_array
|> Vector.sparse_of_vec
) v
in
Sparse {m=m' ; n=n ; v=v''}
in
2019-02-27 21:28:56 +01:00
match a, b with
2019-02-27 23:54:53 +01:00
| (Dense a), (Dense b) -> Dense (gemm ~transa ~transb a b)
| (Sparse a), (Dense b) -> spmm transa transb a b
| (Dense a), (Sparse b) -> mmsp transa transb a b
| (Sparse a), (Sparse b) -> mmspmm transa transb a b
2019-02-27 21:28:56 +01:00
2019-02-28 12:08:28 +01:00
let mv ?(sparse=false) ?(trans=`N) ?(threshold=epsilon) a b =
let f =
match trans with
| `N -> dim2
| `T -> dim1
in
if f a <> Vector.dim b then
invalid_arg "Inconsistent dimensions";
let spmv a b =
let {m ; n ; v} =
if trans = `N then
match transpose (Sparse a) with
| Sparse x -> x
| _ -> assert false
else
a
in
Array.map (fun row_a -> Vector.dot row_a b) v
|> Vec.of_array
in
let mv a b =
let f_a =
match trans with
| `N -> (fun i -> Mat.copy_row a i)
| `T -> (fun i -> Mat.col a i)
in
Vec.init (Mat.dim1 a) (fun i ->
Vector.dense_of_vec (f_a i)
|> Vector.dot b )
in
let dense_result =
match a, Vector.is_dense b with
| Dense a, true -> gemv ~trans a (Vector.to_vec b)
| Dense a, false -> mv a b
| Sparse a, true -> spmv a b
| Sparse a, false -> spmv a b
in
if sparse then
Vector.sparse_of_vec dense_result
else
Vector.dense_of_vec dense_result
2019-02-27 21:28:56 +01:00
let rec pp_matrix ppf = function
| Dense m -> Util.pp_matrix ppf m
| Sparse m -> pp_matrix ppf @@ dense_of_sparse (Sparse m)
2019-02-28 12:08:28 +01:00
(* ---------- Unit tests ------------ *)
2019-02-27 21:28:56 +01:00
let test_case () =
2019-02-28 12:08:28 +01:00
let d1 = 30
and d2 = 40
and d3 = 50
2019-02-27 23:54:53 +01:00
in
let x1 = Mat.map (fun x -> if abs_float x < 0.6 then 0. else x) (Mat.random d1 d2)
and x2 = Mat.map (fun x -> if abs_float x < 0.3 then 0. else x) (Mat.random d2 d3)
2019-02-27 21:28:56 +01:00
in
let m1 = dense_of_mat x1
and m2 = dense_of_mat x2
in
let m1_s = sparse_of_mat x1
and m2_s = sparse_of_mat x2
in
2019-02-28 12:08:28 +01:00
let norm_diff m1 m2 =
(Mat.sub (to_mat m1) (to_mat m2)
|> Mat.syrk_trace)
in
2019-02-27 21:28:56 +01:00
let test_dimensions () =
2019-02-27 23:54:53 +01:00
Alcotest.(check int) "dim1 1" d1 (dim1 m1 );
Alcotest.(check int) "dim1 2" d1 (dim1 m1_s);
Alcotest.(check int) "dim2 3" d2 (dim2 m1 );
Alcotest.(check int) "dim2 4" d2 (dim2 m1_s);
Alcotest.(check int) "dim1 5" d2 (dim1 m2 );
Alcotest.(check int) "dim1 6" d2 (dim1 m2_s);
Alcotest.(check int) "dim2 7" d3 (dim2 m2 );
Alcotest.(check int) "dim2 8" d3 (dim2 m2_s);
2019-02-27 21:28:56 +01:00
in
let test_conversion () =
Alcotest.(check bool) "sparse -> dense 1" true (dense_of_sparse m1_s = m1 );
Alcotest.(check bool) "sparse -> dense 2" true (dense_of_sparse m2_s = m2 );
Alcotest.(check bool) "dense -> sparse 1" true (sparse_of_dense m1 = m1_s);
Alcotest.(check bool) "dense -> sparse 3" true (sparse_of_dense m2 = m2_s);
in
let test_transpose () =
let m1t = Mat.transpose_copy x1 |> dense_of_mat
and m2t = Mat.transpose_copy x2 |> dense_of_mat
in
Alcotest.(check bool) "dense 1" true (transpose m1 = m1t);
Alcotest.(check bool) "dense 2" true (transpose m2 = m2t);
Alcotest.(check bool) "sparse 1" true (transpose m1_s = sparse_of_dense m1t);
Alcotest.(check bool) "sparse 2" true (transpose m2_s = sparse_of_dense m2t);
in
2019-02-27 23:54:53 +01:00
let test_outer () =
2019-02-28 12:08:28 +01:00
let x1 = Vec.init d1 (fun i -> float_of_int i)
and x2 = Vec.init d2 (fun i -> float_of_int i -. 0.3)
2019-02-27 23:54:53 +01:00
in
let v1 = Vector.dense_of_vec x1
and v2 = Vector.dense_of_vec x2
in
let v1_s = Vector.sparse_of_vec x1
and v2_s = Vector.sparse_of_vec x2
in
let m1 =
2019-02-28 12:08:28 +01:00
Dense (Mat.init_cols d1 d2 (fun i j -> (float_of_int i) *. (float_of_int j -. 0.3)))
2019-02-27 23:54:53 +01:00
in
let m1_s =
sparse_of_dense m1
in
2019-02-28 12:08:28 +01:00
Alcotest.(check (float 1.e-10)) "dense dense " 0. (norm_diff m1 (outer_product v1 v2));
Alcotest.(check (float 1.e-10)) "sparse dense " 0. (norm_diff m1_s (outer_product v1_s v2));
Alcotest.(check (float 1.e-10)) "dense sparse" 0. (norm_diff m1_s (outer_product v1 v2_s));
Alcotest.(check (float 1.e-10)) "sparse sparse" 0. (norm_diff m1_s (outer_product v1_s v2_s));
in
let test_mv () =
let y = Vec.random d2 in
let z = Vec.random d1 in
let x3 = gemv x1 y in
let x4 = gemv ~trans:`T x1 z in
let v = Vector.dense_of_vec y in
let v2 = Vector.dense_of_vec z in
let v3 = Vector.dense_of_vec x3 in
let v4 = Vector.dense_of_vec x4 in
let v_s = Vector.sparse_of_vec y in
let v2_s = Vector.sparse_of_vec z in
let norm_diff v1 v2 =
Vec.sub (Vector.to_vec v1) (Vector.to_vec v2)
|> nrm2
in
Alcotest.(check (float 1.e-10)) "dense dense 1" 0. (norm_diff (mv m1 v) v3);
Alcotest.(check (float 1.e-10)) "dense dense 2" 0. (norm_diff (mv ~trans:`T m1 v2) v4);
Alcotest.(check (float 1.e-10)) "dense sparse 3" 0. (norm_diff (mv m1 v_s) v3);
Alcotest.(check (float 1.e-10)) "dense sparse 4" 0. (norm_diff (mv ~trans:`T m1 v2_s) v4);
Alcotest.(check (float 1.e-10)) "sparse dense 5" 0. (norm_diff (mv m1_s v) v3);
Alcotest.(check (float 1.e-10)) "sparse dense 6" 0. (norm_diff (mv ~trans:`T m1_s v2) v4);
Alcotest.(check (float 1.e-10)) "sparse sparse 7" 0. (norm_diff (mv m1_s v_s) v3);
Alcotest.(check (float 1.e-10)) "sparse sparse 8" 0. (norm_diff (mv ~trans:`T m1_s v2_s) v4);
2019-02-27 23:54:53 +01:00
in
let test_mm () =
2019-02-28 12:08:28 +01:00
let x3 = gemm x1 x2 in
let m3 = dense_of_mat x3
2019-02-27 23:54:53 +01:00
and m3_s = sparse_of_mat x3
2019-02-28 12:08:28 +01:00
and m4 = dense_of_mat x1 |> transpose
2019-02-28 00:03:18 +01:00
and m4_s = sparse_of_mat x1 |> transpose
2019-02-28 12:08:28 +01:00
and m5 = dense_of_mat x2 |> transpose
2019-02-28 00:03:18 +01:00
and m5_s = sparse_of_mat x2 |> transpose
2019-02-27 23:54:53 +01:00
in
2019-02-28 00:03:18 +01:00
Alcotest.(check (float 1.e-10)) "dense dense 1" 0. (norm_diff (mm m1 m2) m3);
Alcotest.(check (float 1.e-10)) "dense dense 2" 0. (norm_diff (mm ~transa:`T m4 m2) m3);
Alcotest.(check (float 1.e-10)) "dense dense 3" 0. (norm_diff (mm ~transb:`T m1 m5) m3);
Alcotest.(check (float 1.e-10)) "dense dense 4" 0. (norm_diff (mm ~transa:`T ~transb:`T m2 m1) (transpose m3));
Alcotest.(check (float 1.e-10)) "dense sparse 5" 0. (norm_diff (mm m1 m2_s) m3_s);
Alcotest.(check (float 1.e-10)) "dense sparse 6" 0. (norm_diff (mm ~transa:`T m4 m2_s) m3_s);
Alcotest.(check (float 1.e-10)) "dense sparse 7" 0. (norm_diff (mm ~transb:`T m1 m5_s) m3_s);
Alcotest.(check (float 1.e-10)) "dense sparse 8" 0. (norm_diff (transpose (mm m2 m1_s ~transa:`T ~transb:`T)) m3_s);
Alcotest.(check (float 1.e-10)) "sparse dense 9" 0. (norm_diff (mm m1_s m2) m3_s);
Alcotest.(check (float 1.e-10)) "sparse dense 10" 0. (norm_diff (mm ~transa:`T m4_s m2) m3_s);
Alcotest.(check (float 1.e-10)) "sparse dense 11" 0. (norm_diff (mm ~transb:`T m1_s m5) m3_s);
Alcotest.(check (float 1.e-10)) "sparse dense 12" 0. (norm_diff (transpose (mm m2_s m1 ~transa:`T ~transb:`T)) m3_s);
Alcotest.(check (float 1.e-10)) "sparse sparse 13" 0. (norm_diff (mm m1_s m2_s) m3_s);
Alcotest.(check (float 1.e-10)) "sparse sparse 14" 0. (norm_diff (mm ~transa:`T m4_s m2_s) m3_s);
Alcotest.(check (float 1.e-10)) "sparse sparse 15" 0. (norm_diff (mm ~transb:`T m1_s m5_s) m3_s);
Alcotest.(check (float 1.e-10)) "sparse sparse 16" 0. (norm_diff (transpose (mm m2_s m1_s ~transa:`T ~transb:`T)) m3_s);
2019-02-27 23:54:53 +01:00
in
2019-02-27 21:28:56 +01:00
[
"Conversion", `Quick, test_conversion;
"Dimensions", `Quick, test_dimensions;
"Transposition", `Quick, test_transpose;
2019-02-28 12:08:28 +01:00
"Outer product", `Quick, test_outer;
"Matrix Vector", `Quick, test_mv;
2019-02-27 23:54:53 +01:00
"Matrix Matrix", `Quick, test_mm;
2019-02-27 21:28:56 +01:00
]