mirror of
https://gitlab.com/scemama/QCaml.git
synced 2024-12-22 12:23:31 +01:00
Sparse matrix products OK
This commit is contained in:
parent
2f28baf125
commit
3de8e31a18
@ -154,6 +154,7 @@ let mm ?(transa=`N) ?(transb=`N) ?(threshold=epsilon) a b =
|
|||||||
if f a <> f' b then
|
if f a <> f' b then
|
||||||
invalid_arg "Inconsistent dimensions";
|
invalid_arg "Inconsistent dimensions";
|
||||||
|
|
||||||
|
(* Dense x sparse *)
|
||||||
let mmsp transa transb a b =
|
let mmsp transa transb a b =
|
||||||
let a =
|
let a =
|
||||||
match transa with
|
match transa with
|
||||||
@ -184,6 +185,7 @@ let mm ?(transa=`N) ?(transb=`N) ?(threshold=epsilon) a b =
|
|||||||
Sparse {m=m' ; n ; v=v'}
|
Sparse {m=m' ; n ; v=v'}
|
||||||
in
|
in
|
||||||
|
|
||||||
|
(* Sparse x dense *)
|
||||||
let spmm transa transb a b =
|
let spmm transa transb a b =
|
||||||
let b =
|
let b =
|
||||||
match transb with
|
match transb with
|
||||||
@ -216,6 +218,7 @@ let mm ?(transa=`N) ?(transb=`N) ?(threshold=epsilon) a b =
|
|||||||
Sparse {m ; n=n' ; v=v'}
|
Sparse {m ; n=n' ; v=v'}
|
||||||
in
|
in
|
||||||
|
|
||||||
|
(* Sparse x Sparse *)
|
||||||
let mmspmm transa transb a b =
|
let mmspmm transa transb a b =
|
||||||
let {m ; n ; v} =
|
let {m ; n ; v} =
|
||||||
if transb = `T then
|
if transb = `T then
|
||||||
@ -331,18 +334,34 @@ let test_case () =
|
|||||||
let x3 = gemm x1 x2 in
|
let x3 = gemm x1 x2 in
|
||||||
let m3 = dense_of_mat x3
|
let m3 = dense_of_mat x3
|
||||||
and m3_s = sparse_of_mat x3
|
and m3_s = sparse_of_mat x3
|
||||||
|
and m4 = dense_of_mat x1 |> transpose
|
||||||
|
and m4_s = sparse_of_mat x1 |> transpose
|
||||||
|
and m5 = dense_of_mat x2 |> transpose
|
||||||
|
and m5_s = sparse_of_mat x2 |> transpose
|
||||||
in
|
in
|
||||||
let norm_diff m1 m2 =
|
let norm_diff m1 m2 =
|
||||||
(Mat.sub (to_mat m1) (to_mat m2)
|
(Mat.sub (to_mat m1) (to_mat m2)
|
||||||
|> Mat.syrk_trace)
|
|> Mat.syrk_trace)
|
||||||
in
|
in
|
||||||
Alcotest.(check (float 1.e-10)) "dense dense 1" 0. (norm_diff (mm m1 m2) m3);
|
Alcotest.(check (float 1.e-10)) "dense dense 1" 0. (norm_diff (mm m1 m2) m3);
|
||||||
Alcotest.(check (float 1.e-10)) "dense sparse 2" 0. (norm_diff (mm m1 m2_s) m3_s);
|
Alcotest.(check (float 1.e-10)) "dense dense 2" 0. (norm_diff (mm ~transa:`T m4 m2) m3);
|
||||||
Alcotest.(check (float 1.e-10)) "dense sparse 3" 0. (norm_diff (transpose (mm m2 m1_s ~transa:`T ~transb:`T)) m3_s);
|
Alcotest.(check (float 1.e-10)) "dense dense 3" 0. (norm_diff (mm ~transb:`T m1 m5) m3);
|
||||||
Alcotest.(check (float 1.e-10)) "sparse dense 4" 0. (norm_diff (mm m1_s m2) m3_s);
|
Alcotest.(check (float 1.e-10)) "dense dense 4" 0. (norm_diff (mm ~transa:`T ~transb:`T m2 m1) (transpose m3));
|
||||||
Alcotest.(check (float 1.e-10)) "sparse dense 5" 0. (norm_diff (transpose (mm m2_s m1 ~transa:`T ~transb:`T)) m3_s);
|
|
||||||
Alcotest.(check (float 1.e-10)) "sparse sparse 6" 0. (norm_diff (mm m1_s m2_s) m3_s);
|
Alcotest.(check (float 1.e-10)) "dense sparse 5" 0. (norm_diff (mm m1 m2_s) m3_s);
|
||||||
Alcotest.(check (float 1.e-10)) "sparse sparse 7" 0. (norm_diff (transpose (mm m2_s m1_s ~transa:`T ~transb:`T)) m3_s);
|
Alcotest.(check (float 1.e-10)) "dense sparse 6" 0. (norm_diff (mm ~transa:`T m4 m2_s) m3_s);
|
||||||
|
Alcotest.(check (float 1.e-10)) "dense sparse 7" 0. (norm_diff (mm ~transb:`T m1 m5_s) m3_s);
|
||||||
|
Alcotest.(check (float 1.e-10)) "dense sparse 8" 0. (norm_diff (transpose (mm m2 m1_s ~transa:`T ~transb:`T)) m3_s);
|
||||||
|
|
||||||
|
Alcotest.(check (float 1.e-10)) "sparse dense 9" 0. (norm_diff (mm m1_s m2) m3_s);
|
||||||
|
Alcotest.(check (float 1.e-10)) "sparse dense 10" 0. (norm_diff (mm ~transa:`T m4_s m2) m3_s);
|
||||||
|
Alcotest.(check (float 1.e-10)) "sparse dense 11" 0. (norm_diff (mm ~transb:`T m1_s m5) m3_s);
|
||||||
|
Alcotest.(check (float 1.e-10)) "sparse dense 12" 0. (norm_diff (transpose (mm m2_s m1 ~transa:`T ~transb:`T)) m3_s);
|
||||||
|
|
||||||
|
Alcotest.(check (float 1.e-10)) "sparse sparse 13" 0. (norm_diff (mm m1_s m2_s) m3_s);
|
||||||
|
Alcotest.(check (float 1.e-10)) "sparse sparse 14" 0. (norm_diff (mm ~transa:`T m4_s m2_s) m3_s);
|
||||||
|
Alcotest.(check (float 1.e-10)) "sparse sparse 15" 0. (norm_diff (mm ~transb:`T m1_s m5_s) m3_s);
|
||||||
|
Alcotest.(check (float 1.e-10)) "sparse sparse 16" 0. (norm_diff (transpose (mm m2_s m1_s ~transa:`T ~transb:`T)) m3_s);
|
||||||
in
|
in
|
||||||
[
|
[
|
||||||
"Conversion", `Quick, test_conversion;
|
"Conversion", `Quick, test_conversion;
|
||||||
|
Loading…
Reference in New Issue
Block a user