mirror of
https://gitlab.com/scemama/QCaml.git
synced 2025-01-03 01:55:40 +01:00
Sparse matrix product OK
This commit is contained in:
parent
9a46fe36a4
commit
2f28baf125
219
Utils/Matrix.ml
219
Utils/Matrix.ml
@ -101,46 +101,156 @@ let transpose = function
|
||||
end
|
||||
|
||||
|
||||
(*
|
||||
let outer_product ?(threshold=epsilon) v1 v2 =
|
||||
match Vector.(is_dense v1, is_dense v2) with
|
||||
| (true, true) ->
|
||||
let v1 = Vector.to_vec v1
|
||||
and v2 = Vector.to_vec v2
|
||||
in
|
||||
let a = Mat.create (Vec.dim v1) (Vec.dim v2) in
|
||||
ger v1 v2 a;
|
||||
Dense a
|
||||
| (true, false) ->
|
||||
let v = Vector.to_vec v1
|
||||
and v' = Vector.to_vec v2
|
||||
in
|
||||
let v =
|
||||
Array.init (Vector.dim v1) (fun i ->
|
||||
Vector.scale (Vector.get v1 i)
|
||||
*)
|
||||
Array.init (Vector.dim v2) (fun j ->
|
||||
Vec.map (fun x -> x *. v'.{j+1}) v
|
||||
|> Vector.sparse_of_vec)
|
||||
in
|
||||
Sparse {m=Vector.dim v1 ; n=Vector.dim v2 ; v}
|
||||
| (false, true)
|
||||
| (false, false) ->
|
||||
let v = Vector.to_assoc_list v1
|
||||
and v' = Vector.to_vec v2
|
||||
in
|
||||
let v =
|
||||
Array.init (Vector.dim v2) (fun j ->
|
||||
List.map (fun (i, x) ->
|
||||
let z = x *. v'.{j+1} in
|
||||
if abs_float z < threshold then
|
||||
None
|
||||
else
|
||||
Some (i, z)
|
||||
) v
|
||||
|> Util.list_some
|
||||
|> Vector.sparse_of_assoc_list (Vector.dim v1)
|
||||
)
|
||||
in
|
||||
Sparse {m=Vector.dim v1 ; n=Vector.dim v2 ; v}
|
||||
|
||||
|
||||
|
||||
(*
|
||||
let mm ?(threshold=epsilon) a b =
|
||||
let mm ?(transa=`N) ?(transb=`N) ?(threshold=epsilon) a b =
|
||||
|
||||
if dim2 a <> dim1 b then
|
||||
let f, f' =
|
||||
match transa, transb with
|
||||
| `N, `N -> dim2, dim1
|
||||
| `T, `N -> dim1, dim1
|
||||
| `T, `T -> dim1, dim2
|
||||
| `N, `T -> dim2, dim2
|
||||
in
|
||||
if f a <> f' b then
|
||||
invalid_arg "Inconsistent dimensions";
|
||||
|
||||
let spmm {m ; n ; v} b =
|
||||
let n = Mat.dim2 b in
|
||||
let mmsp transa transb a b =
|
||||
let a =
|
||||
match transa with
|
||||
| `N -> Mat.transpose_copy a
|
||||
| `T -> a
|
||||
in
|
||||
let m' = Mat.dim2 a in
|
||||
let a =
|
||||
Mat.to_col_vecs a
|
||||
|> Array.map (fun v -> Vector.dense_of_vec v)
|
||||
in
|
||||
let {m ; n ; v} =
|
||||
if transb = `T then
|
||||
match transpose (Sparse b) with
|
||||
| Sparse x -> x
|
||||
| _ -> assert false
|
||||
else
|
||||
b
|
||||
in
|
||||
let v' =
|
||||
Array.map (fun b_j ->
|
||||
Array.map (fun a_i ->
|
||||
Vector.dot a_i b_j) a
|
||||
|> Vec.of_array
|
||||
|> Vector.sparse_of_vec
|
||||
) v
|
||||
in
|
||||
Sparse {m=m' ; n ; v=v'}
|
||||
in
|
||||
|
||||
let spmm transa transb a b =
|
||||
let b =
|
||||
match transb with
|
||||
| `N -> b
|
||||
| `T -> Mat.transpose_copy b
|
||||
in
|
||||
let n' = Mat.dim2 b in
|
||||
let b =
|
||||
Mat.to_col_vecs b
|
||||
|> Array.map (fun v -> Vector.dense_of_vec v)
|
||||
in
|
||||
let m, n, v =
|
||||
if transa = `N then
|
||||
match transpose (Sparse a) with
|
||||
| Sparse {m ; n ; v} -> n, m, v
|
||||
| _ -> assert false
|
||||
else
|
||||
match Sparse a with
|
||||
| Sparse {m ; n ; v} -> n, m, v
|
||||
| _ -> assert false
|
||||
in
|
||||
let v' =
|
||||
Array.map (fun b_j ->
|
||||
Array.map (fun a_i ->
|
||||
Vec.init n (fun j ->
|
||||
Vector.dot a_i b.(j-1))
|
||||
|> Vector.sparse_of_vec ~threshold
|
||||
Vector.dot a_i b_j) v
|
||||
|> Vec.of_array
|
||||
|> Vector.sparse_of_vec
|
||||
) b
|
||||
in
|
||||
Sparse {m ; n=n' ; v=v'}
|
||||
in
|
||||
|
||||
let mmspmm transa transb a b =
|
||||
let {m ; n ; v} =
|
||||
if transb = `T then
|
||||
match transpose (Sparse b) with
|
||||
| Sparse x -> x
|
||||
| _ -> assert false
|
||||
else
|
||||
b
|
||||
in
|
||||
let m', n', v' =
|
||||
if transa = `N then
|
||||
match transpose (Sparse a) with
|
||||
| Sparse {m ; n ; v} -> n, m, v
|
||||
| _ -> assert false
|
||||
else
|
||||
match Sparse a with
|
||||
| Sparse {m ; n ; v} -> n, m, v
|
||||
| _ -> assert false
|
||||
in
|
||||
let v'' =
|
||||
Array.map (fun b_j ->
|
||||
Array.map (fun a_i ->
|
||||
Vector.dot a_i b_j) v'
|
||||
|> Vec.of_array
|
||||
|> Vector.sparse_of_vec
|
||||
) v
|
||||
in
|
||||
Sparse {m ; n ; v=v'}
|
||||
Sparse {m=m' ; n=n ; v=v''}
|
||||
in
|
||||
|
||||
let mmsp a {m ; n ; v} =
|
||||
|
||||
|
||||
match a, b with
|
||||
| (Dense a), (Dense b) -> Dense (gemm a b)
|
||||
| (Sparse a), (Dense b) -> spmm a b
|
||||
| (Dense a), (Sparse b) -> mmsp a b
|
||||
| (Sparse a), (Sparse b) -> mmspmm a b
|
||||
*)
|
||||
| (Dense a), (Dense b) -> Dense (gemm ~transa ~transb a b)
|
||||
| (Sparse a), (Dense b) -> spmm transa transb a b
|
||||
| (Dense a), (Sparse b) -> mmsp transa transb a b
|
||||
| (Sparse a), (Sparse b) -> mmspmm transa transb a b
|
||||
|
||||
|
||||
let rec pp_matrix ppf = function
|
||||
@ -150,8 +260,13 @@ let rec pp_matrix ppf = function
|
||||
|
||||
let test_case () =
|
||||
|
||||
let x1 = Mat.map (fun x -> if abs_float x < 0.6 then 0. else x) (Mat.random 3 4)
|
||||
and x2 = Mat.map (fun x -> if abs_float x < 0.3 then 0. else x) (Mat.random 4 5)
|
||||
let d1 = 3
|
||||
and d2 = 4
|
||||
and d3 = 5
|
||||
in
|
||||
|
||||
let x1 = Mat.map (fun x -> if abs_float x < 0.6 then 0. else x) (Mat.random d1 d2)
|
||||
and x2 = Mat.map (fun x -> if abs_float x < 0.3 then 0. else x) (Mat.random d2 d3)
|
||||
in
|
||||
|
||||
let m1 = dense_of_mat x1
|
||||
@ -163,14 +278,14 @@ let test_case () =
|
||||
in
|
||||
|
||||
let test_dimensions () =
|
||||
Alcotest.(check int) "dim1 1" 3 (dim1 m1 );
|
||||
Alcotest.(check int) "dim1 2" 3 (dim1 m1_s);
|
||||
Alcotest.(check int) "dim2 3" 4 (dim2 m1 );
|
||||
Alcotest.(check int) "dim2 4" 4 (dim2 m1_s);
|
||||
Alcotest.(check int) "dim1 5" 4 (dim1 m2 );
|
||||
Alcotest.(check int) "dim1 6" 4 (dim1 m2_s);
|
||||
Alcotest.(check int) "dim2 7" 5 (dim2 m2 );
|
||||
Alcotest.(check int) "dim2 8" 5 (dim2 m2_s);
|
||||
Alcotest.(check int) "dim1 1" d1 (dim1 m1 );
|
||||
Alcotest.(check int) "dim1 2" d1 (dim1 m1_s);
|
||||
Alcotest.(check int) "dim2 3" d2 (dim2 m1 );
|
||||
Alcotest.(check int) "dim2 4" d2 (dim2 m1_s);
|
||||
Alcotest.(check int) "dim1 5" d2 (dim1 m2 );
|
||||
Alcotest.(check int) "dim1 6" d2 (dim1 m2_s);
|
||||
Alcotest.(check int) "dim2 7" d3 (dim2 m2 );
|
||||
Alcotest.(check int) "dim2 8" d3 (dim2 m2_s);
|
||||
in
|
||||
|
||||
let test_conversion () =
|
||||
@ -189,9 +304,51 @@ let test_case () =
|
||||
Alcotest.(check bool) "sparse 1" true (transpose m1_s = sparse_of_dense m1t);
|
||||
Alcotest.(check bool) "sparse 2" true (transpose m2_s = sparse_of_dense m2t);
|
||||
in
|
||||
|
||||
let test_outer () =
|
||||
let x1 = Vec.init 3 (fun i -> float_of_int i)
|
||||
and x2 = Vec.init 4 (fun i -> float_of_int i -. 3.)
|
||||
in
|
||||
let v1 = Vector.dense_of_vec x1
|
||||
and v2 = Vector.dense_of_vec x2
|
||||
in
|
||||
let v1_s = Vector.sparse_of_vec x1
|
||||
and v2_s = Vector.sparse_of_vec x2
|
||||
in
|
||||
let m1 =
|
||||
Dense (Mat.init_cols 3 4 (fun i j -> (float_of_int i) *. (float_of_int j -. 3.)))
|
||||
in
|
||||
let m1_s =
|
||||
sparse_of_dense m1
|
||||
in
|
||||
Alcotest.(check bool) "dense dense " true (m1 = outer_product v1 v2);
|
||||
Alcotest.(check bool) "sparse dense " true (m1_s = outer_product v1_s v2);
|
||||
Alcotest.(check bool) "dense sparse" true (m1_s = outer_product v1 v2_s);
|
||||
Alcotest.(check bool) "sparse sparse" true (m1_s = outer_product v1_s v2_s);
|
||||
in
|
||||
|
||||
let test_mm () =
|
||||
let x3 = gemm x1 x2 in
|
||||
let m3 = dense_of_mat x3
|
||||
and m3_s = sparse_of_mat x3
|
||||
in
|
||||
let norm_diff m1 m2 =
|
||||
(Mat.sub (to_mat m1) (to_mat m2)
|
||||
|> Mat.syrk_trace)
|
||||
in
|
||||
Alcotest.(check (float 1.e-10)) "dense dense 1" 0. (norm_diff (mm m1 m2) m3);
|
||||
Alcotest.(check (float 1.e-10)) "dense sparse 2" 0. (norm_diff (mm m1 m2_s) m3_s);
|
||||
Alcotest.(check (float 1.e-10)) "dense sparse 3" 0. (norm_diff (transpose (mm m2 m1_s ~transa:`T ~transb:`T)) m3_s);
|
||||
Alcotest.(check (float 1.e-10)) "sparse dense 4" 0. (norm_diff (mm m1_s m2) m3_s);
|
||||
Alcotest.(check (float 1.e-10)) "sparse dense 5" 0. (norm_diff (transpose (mm m2_s m1 ~transa:`T ~transb:`T)) m3_s);
|
||||
Alcotest.(check (float 1.e-10)) "sparse sparse 6" 0. (norm_diff (mm m1_s m2_s) m3_s);
|
||||
Alcotest.(check (float 1.e-10)) "sparse sparse 7" 0. (norm_diff (transpose (mm m2_s m1_s ~transa:`T ~transb:`T)) m3_s);
|
||||
in
|
||||
[
|
||||
"Conversion", `Quick, test_conversion;
|
||||
"Dimensions", `Quick, test_dimensions;
|
||||
"Transposition", `Quick, test_transpose;
|
||||
"Oueter product", `Quick, test_outer;
|
||||
"Matrix Matrix", `Quick, test_mm;
|
||||
]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user