2019-02-27 21:28:56 +01:00
|
|
|
open Lacaml.D
|
|
|
|
|
|
|
|
type sparse_matrix =
|
|
|
|
{
|
|
|
|
m: int;
|
|
|
|
n: int;
|
|
|
|
v: Vector.t array;
|
|
|
|
}
|
|
|
|
|
2019-04-03 18:09:13 +02:00
|
|
|
type computed =
|
|
|
|
{
|
|
|
|
m: int;
|
|
|
|
n: int;
|
|
|
|
f: int -> int -> float;
|
|
|
|
}
|
|
|
|
|
2019-02-27 21:28:56 +01:00
|
|
|
type t =
|
|
|
|
| Dense of Mat.t
|
|
|
|
| Sparse of sparse_matrix
|
2019-04-03 18:09:13 +02:00
|
|
|
| Computed of computed
|
2019-02-27 21:28:56 +01:00
|
|
|
|
|
|
|
let epsilon = Constants.epsilon
|
|
|
|
|
2019-04-03 18:09:13 +02:00
|
|
|
let is_computed = function
|
|
|
|
| Computed _ -> true
|
|
|
|
| _ -> false
|
|
|
|
|
2019-02-27 21:28:56 +01:00
|
|
|
let is_sparse = function
|
|
|
|
| Sparse _ -> true
|
2019-04-03 18:09:13 +02:00
|
|
|
| _ -> false
|
2019-02-27 21:28:56 +01:00
|
|
|
|
|
|
|
let is_dense = function
|
|
|
|
| Dense _ -> true
|
2019-04-03 18:09:13 +02:00
|
|
|
| _ -> false
|
2019-02-27 21:28:56 +01:00
|
|
|
|
|
|
|
|
|
|
|
let dim1 = function
|
|
|
|
| Dense m -> Mat.dim1 m
|
2019-04-03 18:09:13 +02:00
|
|
|
| Sparse {m ; _} -> m
|
|
|
|
| Computed {m ; _} -> m
|
2019-02-27 21:28:56 +01:00
|
|
|
|
|
|
|
|
|
|
|
let dim2 = function
|
|
|
|
| Dense m -> Mat.dim2 m
|
2019-04-03 18:09:13 +02:00
|
|
|
| Sparse {n ; _} -> n
|
|
|
|
| Computed {n ; _} -> n
|
|
|
|
|
2019-02-27 21:28:56 +01:00
|
|
|
|
2019-04-03 18:09:13 +02:00
|
|
|
let check_bounds m n i j =
|
|
|
|
if (i <= 0 || i > m || j <= 0 || j > n) then
|
|
|
|
raise (Invalid_argument "Index out of bounds")
|
2019-02-27 21:28:56 +01:00
|
|
|
|
|
|
|
let get = function
|
|
|
|
| Dense m -> (fun i j -> m.{i,j})
|
2019-04-03 18:09:13 +02:00
|
|
|
| Sparse { m ; n ; v } -> (fun i j -> Vector.get v.(j-1) i)
|
|
|
|
| Computed { m ; n ; f } -> (fun i j -> check_bounds m n i j ; f i j)
|
2019-02-27 21:28:56 +01:00
|
|
|
|
|
|
|
|
|
|
|
let sparse_of_dense ?(threshold=epsilon) = function
|
|
|
|
| Dense m' ->
|
|
|
|
let m = Mat.dim1 m'
|
|
|
|
and n = Mat.dim2 m'
|
|
|
|
and v =
|
|
|
|
Mat.to_col_vecs m'
|
|
|
|
|> Array.map (fun v -> Vector.sparse_of_vec ~threshold v)
|
|
|
|
in Sparse {m ; n ; v}
|
2019-04-03 18:09:13 +02:00
|
|
|
| _ -> invalid_arg "Expected a dense matrix"
|
2019-02-27 21:28:56 +01:00
|
|
|
|
|
|
|
|
|
|
|
let dense_of_sparse = function
|
|
|
|
| Sparse {m ; n ; v} ->
|
|
|
|
let m' =
|
|
|
|
Array.map (fun v -> Vector.to_vec v) v
|
|
|
|
|> Mat.of_col_vecs
|
|
|
|
in Dense m'
|
2019-04-03 18:09:13 +02:00
|
|
|
| _ -> invalid_arg "Expected a sparse matrix"
|
2019-02-27 21:28:56 +01:00
|
|
|
|
|
|
|
|
2019-04-03 18:09:13 +02:00
|
|
|
let sparse_of_computed ?(threshold=epsilon) = function
|
|
|
|
| Computed {m ; n ; f} ->
|
|
|
|
Sparse { m ; n ; v=Array.init n (fun j ->
|
|
|
|
Util.list_range 1 m
|
|
|
|
|> List.map (fun i ->
|
|
|
|
let x = f i (j+1) in
|
|
|
|
if abs_float x > threshold then Some (i, x)
|
|
|
|
else None)
|
|
|
|
|> Util.list_some
|
|
|
|
|> Vector.sparse_of_assoc_list m
|
|
|
|
) }
|
|
|
|
| _ -> invalid_arg "Expected a computed matrix"
|
|
|
|
|
|
|
|
let dense_of_computed x = dense_of_sparse @@ sparse_of_computed x
|
|
|
|
|
2019-02-27 21:28:56 +01:00
|
|
|
let dense_of_mat m = Dense m
|
|
|
|
|
2019-04-03 18:09:13 +02:00
|
|
|
let of_fun m n f = Computed {m ; n ; f}
|
2019-03-21 16:32:41 +01:00
|
|
|
|
2019-02-28 12:30:20 +01:00
|
|
|
let rec to_vector_array ?(threshold=epsilon) = function
|
|
|
|
| Sparse {m ; n ; v} -> v
|
|
|
|
| Dense m -> to_vector_array (sparse_of_dense ~threshold (Dense m))
|
2019-04-03 18:09:13 +02:00
|
|
|
| Computed m -> to_vector_array (sparse_of_computed ~threshold (Computed m))
|
2019-02-28 12:30:20 +01:00
|
|
|
|
2019-02-27 21:28:56 +01:00
|
|
|
|
2019-03-21 16:32:41 +01:00
|
|
|
let identity n =
|
|
|
|
Sparse { n ; m=n ;
|
|
|
|
v = Array.init n (fun i -> Vector.sparse_of_assoc_list n [(i+1,1.0)])
|
|
|
|
}
|
|
|
|
|
2019-02-27 21:28:56 +01:00
|
|
|
let sparse_of_mat ?(threshold=epsilon) m =
|
|
|
|
dense_of_mat m
|
|
|
|
|> sparse_of_dense ~threshold
|
|
|
|
|
|
|
|
|
|
|
|
let sparse_of_vector_array v =
|
|
|
|
let m =
|
|
|
|
Array.fold_left (fun accu v' ->
|
|
|
|
if Vector.dim v' <> accu then
|
|
|
|
invalid_arg "Inconsistent dimension"
|
|
|
|
else accu) (Vector.dim v.(0)) v
|
|
|
|
and n = Array.length v
|
|
|
|
in
|
|
|
|
Sparse {m ; n ; v}
|
|
|
|
|
|
|
|
|
|
|
|
let rec to_mat = function
|
|
|
|
| Dense m -> m
|
|
|
|
| Sparse m ->
|
|
|
|
dense_of_sparse (Sparse m)
|
|
|
|
|> to_mat
|
2019-04-03 18:09:13 +02:00
|
|
|
| Computed m -> sparse_of_computed (Computed m) |> dense_of_sparse |> to_mat
|
2019-02-27 21:28:56 +01:00
|
|
|
|
|
|
|
let transpose = function
|
|
|
|
| Dense m -> Dense (Mat.transpose_copy m)
|
|
|
|
| Sparse {m ; n ; v} ->
|
|
|
|
begin
|
|
|
|
let v' = Array.init m (fun i -> ref []) in
|
|
|
|
Array.iteri (fun j v_j ->
|
|
|
|
Vector.to_assoc_list v_j
|
|
|
|
|> List.iter (fun (i, v_ij) ->
|
|
|
|
v'.(i-1) := (j+1, v_ij) :: !(v'.(i-1))
|
|
|
|
)
|
|
|
|
) v;
|
|
|
|
let v' =
|
|
|
|
Array.map (fun x -> Vector.sparse_of_assoc_list n (List.rev !x) ) v'
|
|
|
|
in
|
|
|
|
Sparse {m=n ; n=m ; v=v'}
|
|
|
|
end
|
2019-04-03 18:09:13 +02:00
|
|
|
| Computed {m ; n ; f} ->
|
|
|
|
let f' i j = f j i in
|
|
|
|
Computed { m=n ; n=m ; f=f' }
|
2019-02-27 21:28:56 +01:00
|
|
|
|
|
|
|
|
|
|
|
let outer_product ?(threshold=epsilon) v1 v2 =
|
2019-02-27 23:54:53 +01:00
|
|
|
match Vector.(is_dense v1, is_dense v2) with
|
|
|
|
| (true, true) ->
|
|
|
|
let v1 = Vector.to_vec v1
|
|
|
|
and v2 = Vector.to_vec v2
|
|
|
|
in
|
2019-02-28 12:08:28 +01:00
|
|
|
let a = Mat.make0 (Vec.dim v1) (Vec.dim v2) in
|
2019-02-27 23:54:53 +01:00
|
|
|
ger v1 v2 a;
|
|
|
|
Dense a
|
|
|
|
| (true, false) ->
|
|
|
|
let v = Vector.to_vec v1
|
|
|
|
and v' = Vector.to_vec v2
|
|
|
|
in
|
|
|
|
let v =
|
|
|
|
Array.init (Vector.dim v2) (fun j ->
|
|
|
|
Vec.map (fun x -> x *. v'.{j+1}) v
|
|
|
|
|> Vector.sparse_of_vec)
|
|
|
|
in
|
|
|
|
Sparse {m=Vector.dim v1 ; n=Vector.dim v2 ; v}
|
|
|
|
| (false, true)
|
|
|
|
| (false, false) ->
|
|
|
|
let v = Vector.to_assoc_list v1
|
|
|
|
and v' = Vector.to_vec v2
|
|
|
|
in
|
|
|
|
let v =
|
|
|
|
Array.init (Vector.dim v2) (fun j ->
|
|
|
|
List.map (fun (i, x) ->
|
|
|
|
let z = x *. v'.{j+1} in
|
|
|
|
if abs_float z < threshold then
|
|
|
|
None
|
|
|
|
else
|
|
|
|
Some (i, z)
|
|
|
|
) v
|
|
|
|
|> Util.list_some
|
|
|
|
|> Vector.sparse_of_assoc_list (Vector.dim v1)
|
|
|
|
)
|
|
|
|
in
|
|
|
|
Sparse {m=Vector.dim v1 ; n=Vector.dim v2 ; v}
|
2019-02-27 21:28:56 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
2019-04-03 18:09:13 +02:00
|
|
|
let rec mm ?(transa=`N) ?(transb=`N) ?(threshold=epsilon) a b =
|
2019-02-27 21:28:56 +01:00
|
|
|
|
2019-02-27 23:54:53 +01:00
|
|
|
let f, f' =
|
|
|
|
match transa, transb with
|
|
|
|
| `N, `N -> dim2, dim1
|
|
|
|
| `T, `N -> dim1, dim1
|
|
|
|
| `T, `T -> dim1, dim2
|
|
|
|
| `N, `T -> dim2, dim2
|
|
|
|
in
|
|
|
|
if f a <> f' b then
|
2019-04-03 18:46:59 +02:00
|
|
|
Printf.sprintf "%d %d : Inconsistent dimensions" (f a) (f' b)
|
|
|
|
|> invalid_arg;
|
2019-02-27 23:54:53 +01:00
|
|
|
|
2019-02-28 00:03:18 +01:00
|
|
|
(* Dense x sparse *)
|
2019-02-27 23:54:53 +01:00
|
|
|
let mmsp transa transb a b =
|
|
|
|
let a =
|
|
|
|
match transa with
|
|
|
|
| `N -> Mat.transpose_copy a
|
|
|
|
| `T -> a
|
|
|
|
in
|
|
|
|
let m' = Mat.dim2 a in
|
|
|
|
let a =
|
|
|
|
Mat.to_col_vecs a
|
|
|
|
|> Array.map (fun v -> Vector.dense_of_vec v)
|
|
|
|
in
|
|
|
|
let {m ; n ; v} =
|
|
|
|
if transb = `T then
|
|
|
|
match transpose (Sparse b) with
|
|
|
|
| Sparse x -> x
|
|
|
|
| _ -> assert false
|
|
|
|
else
|
|
|
|
b
|
|
|
|
in
|
|
|
|
let v' =
|
|
|
|
Array.map (fun b_j ->
|
|
|
|
Array.map (fun a_i ->
|
|
|
|
Vector.dot a_i b_j) a
|
|
|
|
|> Vec.of_array
|
|
|
|
|> Vector.sparse_of_vec
|
|
|
|
) v
|
|
|
|
in
|
|
|
|
Sparse {m=m' ; n ; v=v'}
|
|
|
|
in
|
|
|
|
|
2019-02-28 00:03:18 +01:00
|
|
|
(* Sparse x dense *)
|
2019-02-27 23:54:53 +01:00
|
|
|
let spmm transa transb a b =
|
2019-02-27 21:28:56 +01:00
|
|
|
let b =
|
2019-02-27 23:54:53 +01:00
|
|
|
match transb with
|
|
|
|
| `N -> b
|
|
|
|
| `T -> Mat.transpose_copy b
|
|
|
|
in
|
|
|
|
let n' = Mat.dim2 b in
|
|
|
|
let b =
|
2019-02-27 21:28:56 +01:00
|
|
|
Mat.to_col_vecs b
|
|
|
|
|> Array.map (fun v -> Vector.dense_of_vec v)
|
|
|
|
in
|
2019-02-27 23:54:53 +01:00
|
|
|
let m, n, v =
|
|
|
|
if transa = `N then
|
|
|
|
match transpose (Sparse a) with
|
|
|
|
| Sparse {m ; n ; v} -> n, m, v
|
|
|
|
| _ -> assert false
|
|
|
|
else
|
|
|
|
match Sparse a with
|
|
|
|
| Sparse {m ; n ; v} -> n, m, v
|
|
|
|
| _ -> assert false
|
|
|
|
in
|
|
|
|
let v' =
|
|
|
|
Array.map (fun b_j ->
|
|
|
|
Array.map (fun a_i ->
|
|
|
|
Vector.dot a_i b_j) v
|
|
|
|
|> Vec.of_array
|
|
|
|
|> Vector.sparse_of_vec
|
|
|
|
) b
|
2019-02-27 21:28:56 +01:00
|
|
|
in
|
2019-02-27 23:54:53 +01:00
|
|
|
Sparse {m ; n=n' ; v=v'}
|
2019-02-27 21:28:56 +01:00
|
|
|
in
|
|
|
|
|
2019-02-28 00:03:18 +01:00
|
|
|
(* Sparse x Sparse *)
|
2019-02-27 23:54:53 +01:00
|
|
|
let mmspmm transa transb a b =
|
|
|
|
let {m ; n ; v} =
|
|
|
|
if transb = `T then
|
|
|
|
match transpose (Sparse b) with
|
|
|
|
| Sparse x -> x
|
|
|
|
| _ -> assert false
|
|
|
|
else
|
|
|
|
b
|
|
|
|
in
|
|
|
|
let m', n', v' =
|
|
|
|
if transa = `N then
|
|
|
|
match transpose (Sparse a) with
|
|
|
|
| Sparse {m ; n ; v} -> n, m, v
|
|
|
|
| _ -> assert false
|
|
|
|
else
|
|
|
|
match Sparse a with
|
|
|
|
| Sparse {m ; n ; v} -> n, m, v
|
|
|
|
| _ -> assert false
|
|
|
|
in
|
|
|
|
let v'' =
|
|
|
|
Array.map (fun b_j ->
|
|
|
|
Array.map (fun a_i ->
|
|
|
|
Vector.dot a_i b_j) v'
|
|
|
|
|> Vec.of_array
|
|
|
|
|> Vector.sparse_of_vec
|
|
|
|
) v
|
|
|
|
in
|
|
|
|
Sparse {m=m' ; n=n ; v=v''}
|
|
|
|
in
|
2019-02-27 21:28:56 +01:00
|
|
|
|
2019-04-03 18:09:13 +02:00
|
|
|
let mmcc transa transb a b =
|
|
|
|
let {m ; n ; f} =
|
|
|
|
if transb = `T then
|
|
|
|
match transpose (Computed b) with
|
|
|
|
| Computed x -> x
|
|
|
|
| _ -> assert false
|
|
|
|
else
|
|
|
|
b
|
|
|
|
in
|
|
|
|
let m', n', f' =
|
|
|
|
if transa = `T then
|
|
|
|
match transpose (Computed a) with
|
|
|
|
| Computed {m ; n ; f} -> m, n, f
|
|
|
|
| _ -> assert false
|
|
|
|
else
|
|
|
|
let {m ; n ; f} = a in
|
|
|
|
m, n, f
|
|
|
|
in
|
|
|
|
if n' <> m then
|
|
|
|
invalid_arg "Inconsistent dimensions";
|
|
|
|
let g i j =
|
|
|
|
let result = ref 0. in
|
|
|
|
for k=1 to m do
|
|
|
|
let a = f k j in
|
|
|
|
if a <> 0. then
|
|
|
|
result := !result +. (f' i k) *. a ;
|
|
|
|
done;
|
|
|
|
!result
|
|
|
|
in
|
|
|
|
Computed {m=m' ; n=n ; f=g}
|
|
|
|
in
|
|
|
|
|
2019-04-03 22:17:20 +02:00
|
|
|
let mmccde transa transb a b =
|
|
|
|
let m', n', f' =
|
|
|
|
if transa = `T then
|
|
|
|
match transpose (Computed a) with
|
|
|
|
| Computed {m ; n ; f} -> m, n, f
|
|
|
|
| _ -> assert false
|
|
|
|
else
|
|
|
|
let {m ; n ; f} = a in
|
|
|
|
m, n, f
|
|
|
|
in
|
|
|
|
let m, n =
|
|
|
|
match transb with
|
|
|
|
| `N -> Mat.dim1 b , Mat.dim2 b
|
|
|
|
| `T -> Mat.dim2 b , Mat.dim1 b
|
|
|
|
in
|
|
|
|
if n' <> m then
|
|
|
|
invalid_arg "Inconsistent dimensions";
|
|
|
|
|
|
|
|
let matrix =
|
|
|
|
Array.init n (fun j ->
|
|
|
|
let bj =
|
|
|
|
if transb = `T then
|
|
|
|
(Mat.copy_row b (j+1))
|
|
|
|
else
|
|
|
|
(Mat.to_col_vecs b).(j)
|
|
|
|
in
|
|
|
|
let accu = Vec.make0 m' in
|
2019-04-05 09:46:23 +02:00
|
|
|
let v = Vec.make0 m' in
|
|
|
|
let bj = Vec.to_array bj in
|
|
|
|
Array.iteri (fun k a ->
|
2019-04-03 22:17:20 +02:00
|
|
|
if a <> 0. then
|
2019-04-05 09:46:23 +02:00
|
|
|
begin
|
|
|
|
for i = 1 to m' do
|
|
|
|
Bigarray.Array1.unsafe_set v i (f' i (k+1));
|
|
|
|
done;
|
|
|
|
axpy ~alpha:a v accu
|
|
|
|
end
|
|
|
|
) bj;
|
2019-04-03 22:17:20 +02:00
|
|
|
accu
|
|
|
|
)
|
|
|
|
|> Mat.of_col_vecs
|
|
|
|
in
|
|
|
|
Dense matrix
|
|
|
|
in
|
2019-02-27 21:28:56 +01:00
|
|
|
match a, b with
|
2019-02-27 23:54:53 +01:00
|
|
|
| (Dense a), (Dense b) -> Dense (gemm ~transa ~transb a b)
|
|
|
|
| (Sparse a), (Dense b) -> spmm transa transb a b
|
|
|
|
| (Dense a), (Sparse b) -> mmsp transa transb a b
|
|
|
|
| (Sparse a), (Sparse b) -> mmspmm transa transb a b
|
2019-04-03 18:09:13 +02:00
|
|
|
| (Computed a), (Computed b) -> mmcc transa transb a b
|
2019-04-03 22:17:20 +02:00
|
|
|
| (Computed a), (Dense b) -> mmccde transa transb a b
|
2019-04-03 18:09:13 +02:00
|
|
|
| (Computed a), (Sparse _) ->
|
|
|
|
let b = { m = dim1 b ; n = dim2 b ; f = get b } in
|
|
|
|
mmcc transa transb a b
|
|
|
|
|> sparse_of_computed ~threshold
|
|
|
|
| _, (Computed _) ->
|
|
|
|
begin
|
|
|
|
match transa, transb with
|
|
|
|
| `N, `N -> mm ~transa:`T ~transb:`T ~threshold b a
|
|
|
|
| `N, `T -> mm ~transa:`N ~transb:`T ~threshold b a
|
|
|
|
| `T, `N -> mm ~transa:`T ~transb:`N ~threshold b a
|
|
|
|
| `T, `T -> mm ~transa:`N ~transb:`N ~threshold b a
|
|
|
|
end |> transpose
|
2019-02-27 21:28:56 +01:00
|
|
|
|
|
|
|
|
2019-02-28 12:08:28 +01:00
|
|
|
let mv ?(sparse=false) ?(trans=`N) ?(threshold=epsilon) a b =
|
|
|
|
|
|
|
|
let f =
|
|
|
|
match trans with
|
|
|
|
| `N -> dim2
|
|
|
|
| `T -> dim1
|
|
|
|
in
|
|
|
|
if f a <> Vector.dim b then
|
2019-03-21 16:32:41 +01:00
|
|
|
invalid_arg "Inconsistent dimensions";
|
2019-02-28 12:08:28 +01:00
|
|
|
|
|
|
|
let spmv a b =
|
|
|
|
let {m ; n ; v} =
|
|
|
|
if trans = `N then
|
|
|
|
match transpose (Sparse a) with
|
|
|
|
| Sparse x -> x
|
|
|
|
| _ -> assert false
|
|
|
|
else
|
|
|
|
a
|
|
|
|
in
|
|
|
|
Array.map (fun row_a -> Vector.dot row_a b) v
|
|
|
|
|> Vec.of_array
|
|
|
|
in
|
|
|
|
|
|
|
|
let mv a b =
|
|
|
|
let f_a =
|
|
|
|
match trans with
|
|
|
|
| `N -> (fun i -> Mat.copy_row a i)
|
|
|
|
| `T -> (fun i -> Mat.col a i)
|
|
|
|
in
|
|
|
|
Vec.init (Mat.dim1 a) (fun i ->
|
2019-03-21 16:32:41 +01:00
|
|
|
Vector.dense_of_vec (f_a i)
|
|
|
|
|> Vector.dot b )
|
2019-02-28 12:08:28 +01:00
|
|
|
in
|
|
|
|
|
2019-04-03 18:09:13 +02:00
|
|
|
let cmv a b =
|
|
|
|
match trans with
|
|
|
|
| `N -> Vec.init a.m (fun i ->
|
|
|
|
let accu = ref 0. in
|
|
|
|
for j=1 to a.n do
|
|
|
|
accu := !accu +. a.f i j *. Vector.get b j
|
|
|
|
done;
|
|
|
|
!accu)
|
|
|
|
| `T -> Vec.init a.m (fun i ->
|
|
|
|
let accu = ref 0. in
|
|
|
|
for j=1 to a.n do
|
|
|
|
accu := !accu +. a.f j i *. Vector.get b j
|
|
|
|
done;
|
|
|
|
!accu)
|
|
|
|
in
|
|
|
|
|
2019-02-28 12:08:28 +01:00
|
|
|
let dense_result =
|
|
|
|
match a, Vector.is_dense b with
|
|
|
|
| Dense a, true -> gemv ~trans a (Vector.to_vec b)
|
|
|
|
| Dense a, false -> mv a b
|
2019-04-03 18:09:13 +02:00
|
|
|
| Sparse a, _ -> spmv a b
|
|
|
|
| Computed a, _ -> cmv a b
|
2019-02-28 12:08:28 +01:00
|
|
|
in
|
|
|
|
|
|
|
|
if sparse then
|
|
|
|
Vector.sparse_of_vec dense_result
|
|
|
|
else
|
|
|
|
Vector.dense_of_vec dense_result
|
|
|
|
|
2019-03-21 16:32:41 +01:00
|
|
|
|
|
|
|
let rec op2 dense_op sparse_op a b =
|
|
|
|
if dim1 a <> dim1 b || dim2 a <> dim2 b then
|
|
|
|
failwith "Inconsistent dimensions";
|
|
|
|
|
|
|
|
match a, b with
|
|
|
|
| (Dense a), (Dense b) -> Dense (dense_op a b)
|
|
|
|
| (Dense _), (Sparse _) -> op2 dense_op sparse_op (sparse_of_dense a) b
|
|
|
|
| (Sparse _), (Dense _) -> op2 dense_op sparse_op a (sparse_of_dense b)
|
|
|
|
| (Sparse a), (Sparse b) -> Sparse
|
2019-04-02 13:54:16 +02:00
|
|
|
{ m=a.m ; n=a.n ;
|
|
|
|
v = Array.map2 sparse_op a.v b.v
|
|
|
|
}
|
2019-04-03 18:09:13 +02:00
|
|
|
| _ -> failwith "Not implemented"
|
2019-03-21 16:32:41 +01:00
|
|
|
|
|
|
|
let add = op2 (fun a b -> Mat.add a b) (fun a b -> Vector.add a b)
|
|
|
|
let sub = op2 (fun a b -> Mat.sub a b) (fun a b -> Vector.sub a b)
|
|
|
|
|
|
|
|
let scale f = function
|
2019-04-02 13:54:16 +02:00
|
|
|
| Dense a -> let b = lacpy a in (Mat.scal f b ; Dense b)
|
|
|
|
| Sparse a -> Sparse
|
|
|
|
{ a with
|
|
|
|
v = if f = 1.0 then a.v
|
|
|
|
else Array.map (fun v -> Vector.scale f v) a.v }
|
2019-04-03 18:09:13 +02:00
|
|
|
| _ -> failwith "Not implemented"
|
2019-03-21 16:32:41 +01:00
|
|
|
|
|
|
|
let frobenius_norm = function
|
2019-04-02 13:54:16 +02:00
|
|
|
| Dense a -> lange ~norm:`F a
|
|
|
|
| Sparse a ->
|
|
|
|
Array.fold_left (fun accu v -> accu +. Vector.dot v v) 0. a.v
|
|
|
|
|> sqrt
|
2019-04-03 18:09:13 +02:00
|
|
|
| _ -> failwith "Not implemented"
|
2019-04-02 13:54:16 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let split_cols nrows = function
|
|
|
|
| Dense a ->
|
|
|
|
begin
|
|
|
|
Mat.to_col_vecs a
|
|
|
|
|> Array.to_list
|
|
|
|
|> Util.list_pack nrows
|
|
|
|
|> List.map (fun l ->
|
|
|
|
Dense (Mat.of_col_vecs @@ Array.of_list l) )
|
|
|
|
end
|
|
|
|
| Sparse a ->
|
|
|
|
begin
|
|
|
|
Array.to_list a.v
|
|
|
|
|> Util.list_pack nrows
|
|
|
|
|> List.map Array.of_list
|
|
|
|
|> List.map (fun v -> Sparse { m=a.m ; n= Array.length v ; v })
|
|
|
|
end
|
2019-04-03 18:09:13 +02:00
|
|
|
| Computed a ->
|
|
|
|
begin
|
|
|
|
Util.list_range 0 (a.n-1)
|
|
|
|
|> Util.list_pack nrows
|
|
|
|
|> List.map Array.of_list
|
|
|
|
|> List.map (fun v -> Computed { m=a.m ; n= Array.length v ; f = (fun i j -> a.f i (j+v.(0)) ) })
|
|
|
|
end
|
2019-04-02 13:54:16 +02:00
|
|
|
|
|
|
|
|
|
|
|
let join_cols l =
|
|
|
|
let rec aux_dense accu = function
|
|
|
|
| [] -> Dense (Mat.of_col_vecs_list (List.concat accu))
|
|
|
|
| (Dense a) :: rest -> aux_dense ((Mat.to_col_vecs_list a) :: accu) rest
|
|
|
|
| _ -> assert false
|
|
|
|
|
|
|
|
and aux_sparse m n accu = function
|
|
|
|
| [] -> Sparse { m ; n ; v=Array.of_list (List.concat accu) }
|
|
|
|
| (Sparse a) :: rest -> aux_sparse a.m (n+a.n) ((Array.to_list a.v)::accu) rest
|
|
|
|
| _ -> assert false
|
|
|
|
|
|
|
|
and aux = function
|
|
|
|
| [] -> Sparse { m=0 ; n=0 ; v=[| |] }
|
2019-04-03 18:09:13 +02:00
|
|
|
| (Dense a) :: rest -> aux_dense [] ((Dense a) :: rest)
|
|
|
|
| (Sparse a) :: rest -> aux_sparse 0 0 [] ((Sparse a) :: rest)
|
|
|
|
| (Computed a) :: rest -> aux_sparse 0 0 [] (List.map sparse_of_computed ( (Computed a) :: rest ))
|
2019-04-02 13:54:16 +02:00
|
|
|
|
|
|
|
in aux (List.rev l)
|
|
|
|
|
|
|
|
|
2019-03-21 16:32:41 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let ax_eq_b_conj_grad ?x a b =
|
|
|
|
(* /!\ : A needs to be positive definite and symmetric *)
|
|
|
|
let x =
|
|
|
|
match x with
|
|
|
|
| Some x0 -> x0
|
|
|
|
| None -> b
|
|
|
|
in
|
|
|
|
let r = Vector.sub b (mv a x) in
|
|
|
|
let p = r in
|
|
|
|
let rsold = Vector.dot r r in
|
|
|
|
let rec aux rsold r p x = function
|
|
|
|
| 0 -> x
|
|
|
|
| i ->
|
|
|
|
let ap = mv a p in
|
|
|
|
let alpha = rsold /. (Vector.dot p ap) in
|
|
|
|
let x = Vector.add x (Vector.scale alpha p) in
|
|
|
|
let r = Vector.sub r (Vector.scale alpha ap) in
|
|
|
|
let rsnew = Vector.dot r r in
|
|
|
|
if rsnew < Constants.epsilon then
|
|
|
|
x
|
|
|
|
else
|
|
|
|
let p =
|
|
|
|
Vector.add r (Vector.scale (rsnew /. (rsold +. 1.e-12) ) p)
|
|
|
|
in
|
2019-09-10 18:39:14 +02:00
|
|
|
(aux [@tailcall]) rsnew r p x (i-1)
|
2019-03-21 16:32:41 +01:00
|
|
|
in
|
|
|
|
aux rsold r p x (Vector.dim b *2)
|
|
|
|
|
|
|
|
|
2019-03-21 11:02:58 +01:00
|
|
|
|
|
|
|
let rec ax_eq_b ?(trans=`N) a b =
|
|
|
|
match a, b with
|
|
|
|
| (Dense a), (Dense b) ->
|
2019-03-21 16:32:41 +01:00
|
|
|
let a = lacpy a in
|
|
|
|
let x = lacpy b in
|
|
|
|
(getrs ~trans a x; Dense x)
|
2019-03-21 11:02:58 +01:00
|
|
|
| (Dense _), (Sparse _) ->
|
2019-03-21 16:32:41 +01:00
|
|
|
let b = dense_of_sparse b in
|
|
|
|
ax_eq_b ~trans a b
|
|
|
|
| _ ->
|
|
|
|
let ata, atb =
|
|
|
|
if trans = `N then
|
|
|
|
mm ~transa:`T a a, mm ~transa:`T a b
|
|
|
|
else
|
|
|
|
mm ~transa:`N a a, mm ~transa:`N a b
|
|
|
|
in
|
|
|
|
Sparse { m=dim1 b ; n=dim2 b ;
|
|
|
|
v=Array.map (fun v -> ax_eq_b_conj_grad ata v) (to_vector_array atb)
|
|
|
|
}
|
2019-03-21 11:02:58 +01:00
|
|
|
|
|
|
|
|
2019-04-02 13:54:16 +02:00
|
|
|
(* ------- Parallel routines ---------- *)
|
|
|
|
|
|
|
|
let parallel_mm ?(transa=`N) ?(transb=`N) ?(threshold=epsilon) a b =
|
|
|
|
|
2019-04-03 18:09:13 +02:00
|
|
|
let n =
|
|
|
|
match transa with
|
|
|
|
| `N -> dim2 a
|
|
|
|
| `T -> dim1 a
|
|
|
|
in
|
2019-04-03 18:46:59 +02:00
|
|
|
let n = n / (Parallel.size * 7) in
|
|
|
|
let b =
|
|
|
|
match transb with
|
|
|
|
| `T -> transpose b
|
|
|
|
| `N -> b
|
|
|
|
in
|
2019-04-02 13:54:16 +02:00
|
|
|
split_cols n b
|
|
|
|
|> Stream.of_list
|
|
|
|
|> Farm.run ~ordered:true ~f:(fun b ->
|
2019-04-03 18:09:13 +02:00
|
|
|
match a, b with
|
|
|
|
| Computed _, Computed _ ->
|
2019-04-03 18:46:59 +02:00
|
|
|
mm ~transa ~threshold a b
|
2019-04-03 18:09:13 +02:00
|
|
|
|> sparse_of_computed ~threshold
|
|
|
|
| _ ->
|
2019-04-03 18:46:59 +02:00
|
|
|
mm ~transa ~threshold a b
|
2019-04-02 13:54:16 +02:00
|
|
|
)
|
|
|
|
|> Util.stream_to_list
|
|
|
|
|> join_cols
|
|
|
|
|
2019-03-21 11:02:58 +01:00
|
|
|
|
|
|
|
(* ------------ Printers ------------ *)
|
2019-02-28 12:08:28 +01:00
|
|
|
|
2019-12-02 14:58:48 +01:00
|
|
|
let rec pp ppf = function
|
2019-04-03 18:09:13 +02:00
|
|
|
| Dense m -> Util.pp_matrix ppf m
|
2019-12-02 14:58:48 +01:00
|
|
|
| Sparse m -> pp ppf @@ dense_of_sparse (Sparse m)
|
|
|
|
| Computed m -> pp ppf @@ dense_of_computed (Computed m)
|
2019-02-27 21:28:56 +01:00
|
|
|
|
|
|
|
|
2019-02-28 12:08:28 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
(* ---------- Unit tests ------------ *)
|
|
|
|
|
2019-02-27 21:28:56 +01:00
|
|
|
let test_case () =
|
|
|
|
|
2019-02-28 12:08:28 +01:00
|
|
|
let d1 = 30
|
|
|
|
and d2 = 40
|
|
|
|
and d3 = 50
|
2019-02-27 23:54:53 +01:00
|
|
|
in
|
|
|
|
|
|
|
|
let x1 = Mat.map (fun x -> if abs_float x < 0.6 then 0. else x) (Mat.random d1 d2)
|
|
|
|
and x2 = Mat.map (fun x -> if abs_float x < 0.3 then 0. else x) (Mat.random d2 d3)
|
2019-02-27 21:28:56 +01:00
|
|
|
in
|
|
|
|
|
|
|
|
let m1 = dense_of_mat x1
|
|
|
|
and m2 = dense_of_mat x2
|
|
|
|
in
|
|
|
|
|
|
|
|
let m1_s = sparse_of_mat x1
|
|
|
|
and m2_s = sparse_of_mat x2
|
|
|
|
in
|
|
|
|
|
2019-02-28 12:08:28 +01:00
|
|
|
let norm_diff m1 m2 =
|
|
|
|
(Mat.sub (to_mat m1) (to_mat m2)
|
|
|
|
|> Mat.syrk_trace)
|
|
|
|
in
|
|
|
|
|
2019-02-27 21:28:56 +01:00
|
|
|
let test_dimensions () =
|
2019-02-27 23:54:53 +01:00
|
|
|
Alcotest.(check int) "dim1 1" d1 (dim1 m1 );
|
|
|
|
Alcotest.(check int) "dim1 2" d1 (dim1 m1_s);
|
|
|
|
Alcotest.(check int) "dim2 3" d2 (dim2 m1 );
|
|
|
|
Alcotest.(check int) "dim2 4" d2 (dim2 m1_s);
|
|
|
|
Alcotest.(check int) "dim1 5" d2 (dim1 m2 );
|
|
|
|
Alcotest.(check int) "dim1 6" d2 (dim1 m2_s);
|
|
|
|
Alcotest.(check int) "dim2 7" d3 (dim2 m2 );
|
|
|
|
Alcotest.(check int) "dim2 8" d3 (dim2 m2_s);
|
2019-02-27 21:28:56 +01:00
|
|
|
in
|
|
|
|
|
|
|
|
let test_conversion () =
|
|
|
|
Alcotest.(check bool) "sparse -> dense 1" true (dense_of_sparse m1_s = m1 );
|
|
|
|
Alcotest.(check bool) "sparse -> dense 2" true (dense_of_sparse m2_s = m2 );
|
|
|
|
Alcotest.(check bool) "dense -> sparse 1" true (sparse_of_dense m1 = m1_s);
|
|
|
|
Alcotest.(check bool) "dense -> sparse 3" true (sparse_of_dense m2 = m2_s);
|
|
|
|
in
|
|
|
|
|
|
|
|
let test_transpose () =
|
|
|
|
let m1t = Mat.transpose_copy x1 |> dense_of_mat
|
|
|
|
and m2t = Mat.transpose_copy x2 |> dense_of_mat
|
|
|
|
in
|
|
|
|
Alcotest.(check bool) "dense 1" true (transpose m1 = m1t);
|
|
|
|
Alcotest.(check bool) "dense 2" true (transpose m2 = m2t);
|
|
|
|
Alcotest.(check bool) "sparse 1" true (transpose m1_s = sparse_of_dense m1t);
|
|
|
|
Alcotest.(check bool) "sparse 2" true (transpose m2_s = sparse_of_dense m2t);
|
|
|
|
in
|
2019-02-27 23:54:53 +01:00
|
|
|
|
|
|
|
let test_outer () =
|
2019-02-28 12:08:28 +01:00
|
|
|
let x1 = Vec.init d1 (fun i -> float_of_int i)
|
|
|
|
and x2 = Vec.init d2 (fun i -> float_of_int i -. 0.3)
|
2019-02-27 23:54:53 +01:00
|
|
|
in
|
|
|
|
let v1 = Vector.dense_of_vec x1
|
|
|
|
and v2 = Vector.dense_of_vec x2
|
|
|
|
in
|
|
|
|
let v1_s = Vector.sparse_of_vec x1
|
|
|
|
and v2_s = Vector.sparse_of_vec x2
|
|
|
|
in
|
|
|
|
let m1 =
|
2019-02-28 12:08:28 +01:00
|
|
|
Dense (Mat.init_cols d1 d2 (fun i j -> (float_of_int i) *. (float_of_int j -. 0.3)))
|
2019-02-27 23:54:53 +01:00
|
|
|
in
|
|
|
|
let m1_s =
|
|
|
|
sparse_of_dense m1
|
|
|
|
in
|
2019-02-28 12:08:28 +01:00
|
|
|
Alcotest.(check (float 1.e-10)) "dense dense " 0. (norm_diff m1 (outer_product v1 v2));
|
|
|
|
Alcotest.(check (float 1.e-10)) "sparse dense " 0. (norm_diff m1_s (outer_product v1_s v2));
|
|
|
|
Alcotest.(check (float 1.e-10)) "dense sparse" 0. (norm_diff m1_s (outer_product v1 v2_s));
|
|
|
|
Alcotest.(check (float 1.e-10)) "sparse sparse" 0. (norm_diff m1_s (outer_product v1_s v2_s));
|
|
|
|
in
|
|
|
|
|
2019-03-21 16:32:41 +01:00
|
|
|
let test_add_sub () =
|
|
|
|
let x2 = Mat.map (fun x -> if abs_float x < 0.3 then 0. else x) (Mat.random d1 d2) in
|
|
|
|
let m2 = dense_of_mat x2 in
|
|
|
|
let m3 = Mat.add x1 x2 |> dense_of_mat in
|
|
|
|
let m4 = Mat.sub x1 x2 |> dense_of_mat in
|
|
|
|
let m2_s = sparse_of_mat x2 in
|
|
|
|
let m3_s = Mat.add x1 x2 |> sparse_of_mat in
|
|
|
|
let m4_s = Mat.sub x1 x2 |> sparse_of_mat in
|
|
|
|
Alcotest.(check (float 1.e-10)) "dense dense 1" 0. (norm_diff (add m1 m2) m3);
|
|
|
|
Alcotest.(check (float 1.e-10)) "dense dense 2" 0. (norm_diff (sub m1 m2) m4);
|
|
|
|
Alcotest.(check (float 1.e-10)) "dense sparse 3" 0. (norm_diff (add m1 m2_s) m3_s);
|
|
|
|
Alcotest.(check (float 1.e-10)) "dense sparse 4" 0. (norm_diff (sub m1 m2_s) m4_s);
|
|
|
|
Alcotest.(check (float 1.e-10)) "sparse dense 5" 0. (norm_diff (add m1_s m2) m3);
|
|
|
|
Alcotest.(check (float 1.e-10)) "sparse dense 6" 0. (norm_diff (sub m1_s m2) m4);
|
|
|
|
Alcotest.(check (float 1.e-10)) "dense sparse 7" 0. (norm_diff (add m1_s m2_s) m3_s);
|
|
|
|
Alcotest.(check (float 1.e-10)) "dense sparse 8" 0. (norm_diff (sub m1_s m2_s) m4_s);
|
|
|
|
Alcotest.(check (float 1.e-10)) "dense sparse 9" (frobenius_norm m1_s) (frobenius_norm m1);
|
|
|
|
in
|
|
|
|
|
2019-02-28 12:08:28 +01:00
|
|
|
let test_mv () =
|
|
|
|
let y = Vec.random d2 in
|
|
|
|
let z = Vec.random d1 in
|
|
|
|
let x3 = gemv x1 y in
|
|
|
|
let x4 = gemv ~trans:`T x1 z in
|
|
|
|
let v = Vector.dense_of_vec y in
|
|
|
|
let v2 = Vector.dense_of_vec z in
|
|
|
|
let v3 = Vector.dense_of_vec x3 in
|
|
|
|
let v4 = Vector.dense_of_vec x4 in
|
|
|
|
let v_s = Vector.sparse_of_vec y in
|
|
|
|
let v2_s = Vector.sparse_of_vec z in
|
|
|
|
let norm_diff v1 v2 =
|
|
|
|
Vec.sub (Vector.to_vec v1) (Vector.to_vec v2)
|
|
|
|
|> nrm2
|
|
|
|
in
|
|
|
|
Alcotest.(check (float 1.e-10)) "dense dense 1" 0. (norm_diff (mv m1 v) v3);
|
|
|
|
Alcotest.(check (float 1.e-10)) "dense dense 2" 0. (norm_diff (mv ~trans:`T m1 v2) v4);
|
|
|
|
|
|
|
|
Alcotest.(check (float 1.e-10)) "dense sparse 3" 0. (norm_diff (mv m1 v_s) v3);
|
|
|
|
Alcotest.(check (float 1.e-10)) "dense sparse 4" 0. (norm_diff (mv ~trans:`T m1 v2_s) v4);
|
|
|
|
|
|
|
|
Alcotest.(check (float 1.e-10)) "sparse dense 5" 0. (norm_diff (mv m1_s v) v3);
|
|
|
|
Alcotest.(check (float 1.e-10)) "sparse dense 6" 0. (norm_diff (mv ~trans:`T m1_s v2) v4);
|
|
|
|
|
|
|
|
Alcotest.(check (float 1.e-10)) "sparse sparse 7" 0. (norm_diff (mv m1_s v_s) v3);
|
|
|
|
Alcotest.(check (float 1.e-10)) "sparse sparse 8" 0. (norm_diff (mv ~trans:`T m1_s v2_s) v4);
|
2019-02-27 23:54:53 +01:00
|
|
|
in
|
|
|
|
|
|
|
|
let test_mm () =
|
2019-02-28 12:08:28 +01:00
|
|
|
let x3 = gemm x1 x2 in
|
|
|
|
let m3 = dense_of_mat x3
|
2019-02-27 23:54:53 +01:00
|
|
|
and m3_s = sparse_of_mat x3
|
2019-02-28 12:08:28 +01:00
|
|
|
and m4 = dense_of_mat x1 |> transpose
|
2019-02-28 00:03:18 +01:00
|
|
|
and m4_s = sparse_of_mat x1 |> transpose
|
2019-02-28 12:08:28 +01:00
|
|
|
and m5 = dense_of_mat x2 |> transpose
|
2019-02-28 00:03:18 +01:00
|
|
|
and m5_s = sparse_of_mat x2 |> transpose
|
2019-02-27 23:54:53 +01:00
|
|
|
in
|
2019-04-03 18:09:13 +02:00
|
|
|
let c1 = of_fun (Mat.dim1 x1) (Mat.dim2 x1) (fun i j -> x1.{i,j}) in
|
|
|
|
let c2 = of_fun (Mat.dim1 x2) (Mat.dim2 x2) (fun i j -> x2.{i,j}) in
|
|
|
|
let c3 = of_fun (Mat.dim1 x3) (Mat.dim2 x3) (fun i j -> x3.{i,j}) in
|
|
|
|
let c4 = of_fun (dim1 m4) (dim2 m4) (fun i j -> get m4 i j ) in
|
|
|
|
let c5 = of_fun (dim1 m5) (dim2 m5) (fun i j -> get m5 i j ) in
|
|
|
|
Alcotest.(check (float 1.e-10)) "dense dense 0" 0. (norm_diff m3 c3);
|
2019-02-28 00:03:18 +01:00
|
|
|
Alcotest.(check (float 1.e-10)) "dense dense 1" 0. (norm_diff (mm m1 m2) m3);
|
2019-04-03 18:09:13 +02:00
|
|
|
Alcotest.(check (float 1.e-10)) "dense dense 2" 0. (norm_diff (mm c1 c2) m3);
|
|
|
|
Alcotest.(check (float 1.e-10)) "dense dense 3" 0. (norm_diff (mm c1 m2) m3);
|
|
|
|
Alcotest.(check (float 1.e-10)) "dense dense 4" 0. (norm_diff (mm c1 m2_s) m3);
|
|
|
|
Alcotest.(check (float 1.e-10)) "dense dense 5" 0. (norm_diff (mm m1 c2) m3);
|
|
|
|
Alcotest.(check (float 1.e-10)) "dense dense 6" 0. (norm_diff (mm m1_s c2) m3);
|
|
|
|
Alcotest.(check (float 1.e-10)) "dense dense 7" 0. (norm_diff (mm ~transa:`T m4 m2) m3);
|
|
|
|
Alcotest.(check (float 1.e-10)) "dense dense 8" 0. (norm_diff (mm ~transa:`T c4 m2) m3);
|
|
|
|
Alcotest.(check (float 1.e-10)) "dense dense 9" 0. (norm_diff (mm ~transb:`T m1 m5) m3);
|
|
|
|
Alcotest.(check (float 1.e-10)) "dense dense 10" 0. (norm_diff (mm ~transb:`T m1 c5) m3);
|
|
|
|
Alcotest.(check (float 1.e-10)) "dense dense 11" 0. (norm_diff (mm ~transa:`T ~transb:`T m2 m1) (transpose m3));
|
2019-02-28 00:03:18 +01:00
|
|
|
|
2019-04-03 18:09:13 +02:00
|
|
|
Alcotest.(check (float 1.e-10)) "dense sparse 12" 0. (norm_diff (mm m1 m2_s) m3_s);
|
|
|
|
Alcotest.(check (float 1.e-10)) "dense sparse 13" 0. (norm_diff (mm ~transa:`T m4 m2_s) m3_s);
|
|
|
|
Alcotest.(check (float 1.e-10)) "dense sparse 14" 0. (norm_diff (mm ~transa:`T c4 m2_s) m3_s);
|
|
|
|
Alcotest.(check (float 1.e-10)) "dense sparse 15" 0. (norm_diff (mm ~transb:`T m1 m5_s) m3_s);
|
|
|
|
Alcotest.(check (float 1.e-10)) "dense sparse 16" 0. (norm_diff (transpose (mm m2 m1_s ~transa:`T ~transb:`T)) m3_s);
|
|
|
|
|
|
|
|
Alcotest.(check (float 1.e-10)) "sparse dense 17" 0. (norm_diff (mm m1_s m2) m3_s);
|
|
|
|
Alcotest.(check (float 1.e-10)) "sparse dense 18" 0. (norm_diff (mm ~transa:`T m4_s m2) m3_s);
|
|
|
|
Alcotest.(check (float 1.e-10)) "sparse dense 19" 0. (norm_diff (mm ~transb:`T m1_s m5) m3_s);
|
|
|
|
Alcotest.(check (float 1.e-10)) "sparse dense 20" 0. (norm_diff (mm ~transb:`T m1_s c5) m3_s);
|
|
|
|
Alcotest.(check (float 1.e-10)) "sparse dense 21" 0. (norm_diff (transpose (mm m2_s m1 ~transa:`T ~transb:`T)) m3_s);
|
|
|
|
|
|
|
|
Alcotest.(check (float 1.e-10)) "sparse sparse 22" 0. (norm_diff (mm m1_s m2_s) m3_s);
|
|
|
|
Alcotest.(check (float 1.e-10)) "sparse sparse 23" 0. (norm_diff (mm ~transa:`T m4_s m2_s) m3_s);
|
|
|
|
Alcotest.(check (float 1.e-10)) "sparse sparse 24" 0. (norm_diff (mm ~transb:`T m1_s m5_s) m3_s);
|
|
|
|
Alcotest.(check (float 1.e-10)) "sparse sparse 25" 0. (norm_diff (transpose (mm m2_s m1_s ~transa:`T ~transb:`T)) m3_s);
|
2019-02-27 23:54:53 +01:00
|
|
|
in
|
2019-03-21 16:32:41 +01:00
|
|
|
|
|
|
|
let test_solve () =
|
|
|
|
let x1 = Mat.map (fun x -> if abs_float x < 0.6 then 0. else x) (Mat.random 30 30)
|
|
|
|
and x2 = Mat.map (fun x -> if abs_float x < 0.3 then 0. else x) (Mat.random 30 5)
|
|
|
|
in
|
|
|
|
|
|
|
|
let m1 = dense_of_mat x1
|
|
|
|
and m2 = dense_of_mat x2
|
|
|
|
in
|
|
|
|
|
|
|
|
let m1_s = sparse_of_mat x1
|
|
|
|
and m2_s = sparse_of_mat x2
|
|
|
|
in
|
|
|
|
|
|
|
|
let a = m1 and b = m2 in
|
|
|
|
let x = ax_eq_b a b in
|
2019-03-21 16:44:24 +01:00
|
|
|
Alcotest.(check (float 1.e-10)) "dense dense 1" 0. (norm_diff (mm a x) b);
|
2019-03-21 16:32:41 +01:00
|
|
|
|
|
|
|
let a = m1 and b = m2_s in
|
|
|
|
let x = ax_eq_b a b in
|
2019-03-21 16:44:24 +01:00
|
|
|
Alcotest.(check (float 1.e-10)) "dense sparse 2" 0. (norm_diff (mm a x) b);
|
2019-03-21 16:32:41 +01:00
|
|
|
|
|
|
|
let a = m1_s and b = m2 in
|
|
|
|
let x = ax_eq_b a b in
|
2019-03-21 16:44:24 +01:00
|
|
|
Alcotest.(check (float 1.e-10)) "sparse dense 3" 0. (norm_diff (mm a x) b);
|
2019-03-21 16:32:41 +01:00
|
|
|
|
|
|
|
let a = m1_s and b = m2_s in
|
|
|
|
let x = ax_eq_b a b in
|
2019-03-21 16:44:24 +01:00
|
|
|
Alcotest.(check (float 1.e-10)) "sparse sparse 4" 0. (norm_diff (mm a x) b);
|
2019-03-21 16:32:41 +01:00
|
|
|
in
|
|
|
|
|
2019-04-02 13:54:16 +02:00
|
|
|
let test_split_join () =
|
|
|
|
let m1_split = split_cols 7 m1 in
|
|
|
|
let m1_s_split = split_cols 7 m1_s in
|
|
|
|
let m2 = join_cols m1_split in
|
|
|
|
let m2_s = join_cols m1_s_split in
|
|
|
|
Alcotest.(check int) "length" 6 (List.length m1_split);
|
|
|
|
Alcotest.(check int) "length" 6 (List.length m1_s_split);
|
|
|
|
Alcotest.(check bool) "join" true (m1 = m2);
|
|
|
|
Alcotest.(check bool) "join" true (m1_s = m2_s);
|
|
|
|
in
|
2019-02-27 21:28:56 +01:00
|
|
|
[
|
|
|
|
"Conversion", `Quick, test_conversion;
|
|
|
|
"Dimensions", `Quick, test_dimensions;
|
|
|
|
"Transposition", `Quick, test_transpose;
|
2019-02-28 12:08:28 +01:00
|
|
|
"Outer product", `Quick, test_outer;
|
2019-03-21 16:32:41 +01:00
|
|
|
"Add sub", `Quick, test_add_sub;
|
2019-02-28 12:08:28 +01:00
|
|
|
"Matrix Vector", `Quick, test_mv;
|
2019-02-27 23:54:53 +01:00
|
|
|
"Matrix Matrix", `Quick, test_mm;
|
2019-03-21 16:32:41 +01:00
|
|
|
"Linear solve", `Quick, test_solve;
|
2019-04-02 13:54:16 +02:00
|
|
|
"split_join", `Quick, test_split_join;
|
2019-02-27 21:28:56 +01:00
|
|
|
]
|
|
|
|
|