open Lacaml.D type sparse_matrix = { m: int; n: int; v: Vector.t array; } type t = | Dense of Mat.t | Sparse of sparse_matrix let epsilon = Constants.epsilon let is_sparse = function | Sparse _ -> true | Dense _ -> false let is_dense = function | Sparse _ -> false | Dense _ -> true let dim1 = function | Dense m -> Mat.dim1 m | Sparse {m ; n ; v} -> m let dim2 = function | Dense m -> Mat.dim2 m | Sparse {m ; n ; v} -> n let get = function | Dense m -> (fun i j -> m.{i,j}) | Sparse {m ; n ; v } -> (fun i j -> Vector.get v.(i-1) j) let sparse_of_dense ?(threshold=epsilon) = function | Sparse _ -> invalid_arg "Expected a dense matrix" | Dense m' -> let m = Mat.dim1 m' and n = Mat.dim2 m' and v = Mat.to_col_vecs m' |> Array.map (fun v -> Vector.sparse_of_vec ~threshold v) in Sparse {m ; n ; v} let dense_of_sparse = function | Dense _ -> invalid_arg "Expected a sparse matrix" | Sparse {m ; n ; v} -> let m' = Array.map (fun v -> Vector.to_vec v) v |> Mat.of_col_vecs in Dense m' let dense_of_mat m = Dense m let sparse_of_mat ?(threshold=epsilon) m = dense_of_mat m |> sparse_of_dense ~threshold let sparse_of_vector_array v = let m = Array.fold_left (fun accu v' -> if Vector.dim v' <> accu then invalid_arg "Inconsistent dimension" else accu) (Vector.dim v.(0)) v and n = Array.length v in Sparse {m ; n ; v} let rec to_mat = function | Dense m -> m | Sparse m -> dense_of_sparse (Sparse m) |> to_mat let transpose = function | Dense m -> Dense (Mat.transpose_copy m) | Sparse {m ; n ; v} -> begin let v' = Array.init m (fun i -> ref []) in Array.iteri (fun j v_j -> Vector.to_assoc_list v_j |> List.iter (fun (i, v_ij) -> v'.(i-1) := (j+1, v_ij) :: !(v'.(i-1)) ) ) v; let v' = Array.map (fun x -> Vector.sparse_of_assoc_list n (List.rev !x) ) v' in Sparse {m=n ; n=m ; v=v'} end (* let outer_product ?(threshold=epsilon) v1 v2 = let v = Array.init (Vector.dim v1) (fun i -> Vector.scale (Vector.get v1 i) *) (* let mm ?(threshold=epsilon) a b = if dim2 a <> dim1 b then invalid_arg "Inconsistent dimensions"; let spmm {m ; n ; v} b = let n = Mat.dim2 b in let b = Mat.to_col_vecs b |> Array.map (fun v -> Vector.dense_of_vec v) in let v' = Array.map (fun a_i -> Vec.init n (fun j -> Vector.dot a_i b.(j-1)) |> Vector.sparse_of_vec ~threshold ) v in Sparse {m ; n ; v=v'} in let mmsp a {m ; n ; v} = match a, b with | (Dense a), (Dense b) -> Dense (gemm a b) | (Sparse a), (Dense b) -> spmm a b | (Dense a), (Sparse b) -> mmsp a b | (Sparse a), (Sparse b) -> mmspmm a b *) let rec pp_matrix ppf = function | Dense m -> Util.pp_matrix ppf m | Sparse m -> pp_matrix ppf @@ dense_of_sparse (Sparse m) let test_case () = let x1 = Mat.map (fun x -> if abs_float x < 0.6 then 0. else x) (Mat.random 3 4) and x2 = Mat.map (fun x -> if abs_float x < 0.3 then 0. else x) (Mat.random 4 5) in let m1 = dense_of_mat x1 and m2 = dense_of_mat x2 in let m1_s = sparse_of_mat x1 and m2_s = sparse_of_mat x2 in let test_dimensions () = Alcotest.(check int) "dim1 1" 3 (dim1 m1 ); Alcotest.(check int) "dim1 2" 3 (dim1 m1_s); Alcotest.(check int) "dim2 3" 4 (dim2 m1 ); Alcotest.(check int) "dim2 4" 4 (dim2 m1_s); Alcotest.(check int) "dim1 5" 4 (dim1 m2 ); Alcotest.(check int) "dim1 6" 4 (dim1 m2_s); Alcotest.(check int) "dim2 7" 5 (dim2 m2 ); Alcotest.(check int) "dim2 8" 5 (dim2 m2_s); in let test_conversion () = Alcotest.(check bool) "sparse -> dense 1" true (dense_of_sparse m1_s = m1 ); Alcotest.(check bool) "sparse -> dense 2" true (dense_of_sparse m2_s = m2 ); Alcotest.(check bool) "dense -> sparse 1" true (sparse_of_dense m1 = m1_s); Alcotest.(check bool) "dense -> sparse 3" true (sparse_of_dense m2 = m2_s); in let test_transpose () = let m1t = Mat.transpose_copy x1 |> dense_of_mat and m2t = Mat.transpose_copy x2 |> dense_of_mat in Alcotest.(check bool) "dense 1" true (transpose m1 = m1t); Alcotest.(check bool) "dense 2" true (transpose m2 = m2t); Alcotest.(check bool) "sparse 1" true (transpose m1_s = sparse_of_dense m1t); Alcotest.(check bool) "sparse 2" true (transpose m2_s = sparse_of_dense m2t); in [ "Conversion", `Quick, test_conversion; "Dimensions", `Quick, test_dimensions; "Transposition", `Quick, test_transpose; ]