From 2f28baf1257de8195593c3d053d155aa63795838 Mon Sep 17 00:00:00 2001 From: Anthony Scemama Date: Wed, 27 Feb 2019 23:54:53 +0100 Subject: [PATCH] Sparse matrix product OK --- Utils/Matrix.ml | 229 ++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 193 insertions(+), 36 deletions(-) diff --git a/Utils/Matrix.ml b/Utils/Matrix.ml index 656bad2..7f3aeaa 100644 --- a/Utils/Matrix.ml +++ b/Utils/Matrix.ml @@ -101,46 +101,156 @@ let transpose = function end -(* let outer_product ?(threshold=epsilon) v1 v2 = - let v = - Array.init (Vector.dim v1) (fun i -> - Vector.scale (Vector.get v1 i) - *) - + match Vector.(is_dense v1, is_dense v2) with + | (true, true) -> + let v1 = Vector.to_vec v1 + and v2 = Vector.to_vec v2 + in + let a = Mat.create (Vec.dim v1) (Vec.dim v2) in + ger v1 v2 a; + Dense a + | (true, false) -> + let v = Vector.to_vec v1 + and v' = Vector.to_vec v2 + in + let v = + Array.init (Vector.dim v2) (fun j -> + Vec.map (fun x -> x *. v'.{j+1}) v + |> Vector.sparse_of_vec) + in + Sparse {m=Vector.dim v1 ; n=Vector.dim v2 ; v} + | (false, true) + | (false, false) -> + let v = Vector.to_assoc_list v1 + and v' = Vector.to_vec v2 + in + let v = + Array.init (Vector.dim v2) (fun j -> + List.map (fun (i, x) -> + let z = x *. v'.{j+1} in + if abs_float z < threshold then + None + else + Some (i, z) + ) v + |> Util.list_some + |> Vector.sparse_of_assoc_list (Vector.dim v1) + ) + in + Sparse {m=Vector.dim v1 ; n=Vector.dim v2 ; v} -(* -let mm ?(threshold=epsilon) a b = - if dim2 a <> dim1 b then - invalid_arg "Inconsistent dimensions"; +let mm ?(transa=`N) ?(transb=`N) ?(threshold=epsilon) a b = - let spmm {m ; n ; v} b = - let n = Mat.dim2 b in + let f, f' = + match transa, transb with + | `N, `N -> dim2, dim1 + | `T, `N -> dim1, dim1 + | `T, `T -> dim1, dim2 + | `N, `T -> dim2, dim2 + in + if f a <> f' b then + invalid_arg "Inconsistent dimensions"; + + let mmsp transa transb a b = + let a = + match transa with + | `N -> Mat.transpose_copy a + | `T -> a + in + let m' = Mat.dim2 a in + let a = + Mat.to_col_vecs a + |> Array.map (fun v -> Vector.dense_of_vec v) + in + let {m ; n ; v} = + if transb = `T then + match transpose (Sparse b) with + | Sparse x -> x + | _ -> assert false + else + b + in + let v' = + Array.map (fun b_j -> + Array.map (fun a_i -> + Vector.dot a_i b_j) a + |> Vec.of_array + |> Vector.sparse_of_vec + ) v + in + Sparse {m=m' ; n ; v=v'} + in + + let spmm transa transb a b = let b = + match transb with + | `N -> b + | `T -> Mat.transpose_copy b + in + let n' = Mat.dim2 b in + let b = Mat.to_col_vecs b |> Array.map (fun v -> Vector.dense_of_vec v) in - let v' = - Array.map (fun a_i -> - Vec.init n (fun j -> - Vector.dot a_i b.(j-1)) - |> Vector.sparse_of_vec ~threshold - ) v + let m, n, v = + if transa = `N then + match transpose (Sparse a) with + | Sparse {m ; n ; v} -> n, m, v + | _ -> assert false + else + match Sparse a with + | Sparse {m ; n ; v} -> n, m, v + | _ -> assert false in - Sparse {m ; n ; v=v'} + let v' = + Array.map (fun b_j -> + Array.map (fun a_i -> + Vector.dot a_i b_j) v + |> Vec.of_array + |> Vector.sparse_of_vec + ) b + in + Sparse {m ; n=n' ; v=v'} in - let mmsp a {m ; n ; v} = - + let mmspmm transa transb a b = + let {m ; n ; v} = + if transb = `T then + match transpose (Sparse b) with + | Sparse x -> x + | _ -> assert false + else + b + in + let m', n', v' = + if transa = `N then + match transpose (Sparse a) with + | Sparse {m ; n ; v} -> n, m, v + | _ -> assert false + else + match Sparse a with + | Sparse {m ; n ; v} -> n, m, v + | _ -> assert false + in + let v'' = + Array.map (fun b_j -> + Array.map (fun a_i -> + Vector.dot a_i b_j) v' + |> Vec.of_array + |> Vector.sparse_of_vec + ) v + in + Sparse {m=m' ; n=n ; v=v''} + in match a, b with - | (Dense a), (Dense b) -> Dense (gemm a b) - | (Sparse a), (Dense b) -> spmm a b - | (Dense a), (Sparse b) -> mmsp a b - | (Sparse a), (Sparse b) -> mmspmm a b - *) + | (Dense a), (Dense b) -> Dense (gemm ~transa ~transb a b) + | (Sparse a), (Dense b) -> spmm transa transb a b + | (Dense a), (Sparse b) -> mmsp transa transb a b + | (Sparse a), (Sparse b) -> mmspmm transa transb a b let rec pp_matrix ppf = function @@ -150,8 +260,13 @@ let rec pp_matrix ppf = function let test_case () = - let x1 = Mat.map (fun x -> if abs_float x < 0.6 then 0. else x) (Mat.random 3 4) - and x2 = Mat.map (fun x -> if abs_float x < 0.3 then 0. else x) (Mat.random 4 5) + let d1 = 3 + and d2 = 4 + and d3 = 5 + in + + let x1 = Mat.map (fun x -> if abs_float x < 0.6 then 0. else x) (Mat.random d1 d2) + and x2 = Mat.map (fun x -> if abs_float x < 0.3 then 0. else x) (Mat.random d2 d3) in let m1 = dense_of_mat x1 @@ -163,14 +278,14 @@ let test_case () = in let test_dimensions () = - Alcotest.(check int) "dim1 1" 3 (dim1 m1 ); - Alcotest.(check int) "dim1 2" 3 (dim1 m1_s); - Alcotest.(check int) "dim2 3" 4 (dim2 m1 ); - Alcotest.(check int) "dim2 4" 4 (dim2 m1_s); - Alcotest.(check int) "dim1 5" 4 (dim1 m2 ); - Alcotest.(check int) "dim1 6" 4 (dim1 m2_s); - Alcotest.(check int) "dim2 7" 5 (dim2 m2 ); - Alcotest.(check int) "dim2 8" 5 (dim2 m2_s); + Alcotest.(check int) "dim1 1" d1 (dim1 m1 ); + Alcotest.(check int) "dim1 2" d1 (dim1 m1_s); + Alcotest.(check int) "dim2 3" d2 (dim2 m1 ); + Alcotest.(check int) "dim2 4" d2 (dim2 m1_s); + Alcotest.(check int) "dim1 5" d2 (dim1 m2 ); + Alcotest.(check int) "dim1 6" d2 (dim1 m2_s); + Alcotest.(check int) "dim2 7" d3 (dim2 m2 ); + Alcotest.(check int) "dim2 8" d3 (dim2 m2_s); in let test_conversion () = @@ -189,9 +304,51 @@ let test_case () = Alcotest.(check bool) "sparse 1" true (transpose m1_s = sparse_of_dense m1t); Alcotest.(check bool) "sparse 2" true (transpose m2_s = sparse_of_dense m2t); in + + let test_outer () = + let x1 = Vec.init 3 (fun i -> float_of_int i) + and x2 = Vec.init 4 (fun i -> float_of_int i -. 3.) + in + let v1 = Vector.dense_of_vec x1 + and v2 = Vector.dense_of_vec x2 + in + let v1_s = Vector.sparse_of_vec x1 + and v2_s = Vector.sparse_of_vec x2 + in + let m1 = + Dense (Mat.init_cols 3 4 (fun i j -> (float_of_int i) *. (float_of_int j -. 3.))) + in + let m1_s = + sparse_of_dense m1 + in + Alcotest.(check bool) "dense dense " true (m1 = outer_product v1 v2); + Alcotest.(check bool) "sparse dense " true (m1_s = outer_product v1_s v2); + Alcotest.(check bool) "dense sparse" true (m1_s = outer_product v1 v2_s); + Alcotest.(check bool) "sparse sparse" true (m1_s = outer_product v1_s v2_s); + in + + let test_mm () = + let x3 = gemm x1 x2 in + let m3 = dense_of_mat x3 + and m3_s = sparse_of_mat x3 + in + let norm_diff m1 m2 = + (Mat.sub (to_mat m1) (to_mat m2) + |> Mat.syrk_trace) + in + Alcotest.(check (float 1.e-10)) "dense dense 1" 0. (norm_diff (mm m1 m2) m3); + Alcotest.(check (float 1.e-10)) "dense sparse 2" 0. (norm_diff (mm m1 m2_s) m3_s); + Alcotest.(check (float 1.e-10)) "dense sparse 3" 0. (norm_diff (transpose (mm m2 m1_s ~transa:`T ~transb:`T)) m3_s); + Alcotest.(check (float 1.e-10)) "sparse dense 4" 0. (norm_diff (mm m1_s m2) m3_s); + Alcotest.(check (float 1.e-10)) "sparse dense 5" 0. (norm_diff (transpose (mm m2_s m1 ~transa:`T ~transb:`T)) m3_s); + Alcotest.(check (float 1.e-10)) "sparse sparse 6" 0. (norm_diff (mm m1_s m2_s) m3_s); + Alcotest.(check (float 1.e-10)) "sparse sparse 7" 0. (norm_diff (transpose (mm m2_s m1_s ~transa:`T ~transb:`T)) m3_s); + in [ "Conversion", `Quick, test_conversion; "Dimensions", `Quick, test_dimensions; "Transposition", `Quick, test_transpose; + "Oueter product", `Quick, test_outer; + "Matrix Matrix", `Quick, test_mm; ]