2019-02-27 21:28:56 +01:00
|
|
|
open Lacaml.D
|
|
|
|
|
|
|
|
let epsilon = Constants.epsilon
|
|
|
|
|
2019-02-28 15:50:00 +01:00
|
|
|
type index_value =
|
|
|
|
{
|
|
|
|
index: int;
|
|
|
|
value: float
|
|
|
|
}
|
|
|
|
|
2019-02-27 21:28:56 +01:00
|
|
|
type sparse_vector =
|
|
|
|
{
|
|
|
|
n: int;
|
2019-10-24 23:08:24 +02:00
|
|
|
v: index_value array
|
2019-02-27 21:28:56 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
type t =
|
|
|
|
| Dense of Vec.t
|
|
|
|
| Sparse of sparse_vector
|
|
|
|
|
|
|
|
|
|
|
|
let is_sparse = function
|
|
|
|
| Sparse _ -> true
|
|
|
|
| Dense _ -> false
|
|
|
|
|
|
|
|
|
|
|
|
let is_dense = function
|
|
|
|
| Sparse _ -> false
|
|
|
|
| Dense _ -> true
|
|
|
|
|
|
|
|
|
2019-02-28 15:50:00 +01:00
|
|
|
exception Found of float
|
|
|
|
|
2019-02-27 21:28:56 +01:00
|
|
|
let get = function
|
|
|
|
| Dense v -> (fun i -> v.{i})
|
|
|
|
| Sparse { n ; v } -> (fun i ->
|
|
|
|
if i < 1 || i > n then invalid_arg "index out of bounds";
|
2019-02-28 15:50:00 +01:00
|
|
|
try
|
2019-10-24 23:08:24 +02:00
|
|
|
Array.iter (fun {index ; value} ->
|
2019-02-28 15:50:00 +01:00
|
|
|
if index=i then
|
|
|
|
raise (Found value)) v;
|
|
|
|
raise Not_found
|
|
|
|
with
|
|
|
|
| Not_found -> 0.
|
|
|
|
| Found x -> x
|
|
|
|
)
|
|
|
|
|
2019-02-27 21:28:56 +01:00
|
|
|
|
|
|
|
|
|
|
|
let dim = function
|
|
|
|
| Dense v -> Vec.dim v
|
|
|
|
| Sparse {n ; v} -> n
|
|
|
|
|
|
|
|
|
|
|
|
let sparse_of_dense ?(threshold=epsilon) = function
|
|
|
|
| Sparse _ -> invalid_arg "Expected a dense vector"
|
|
|
|
| Dense v ->
|
|
|
|
let rec aux accu = function
|
2019-10-24 23:08:24 +02:00
|
|
|
| 0 -> accu |> Array.of_list
|
2019-02-27 21:28:56 +01:00
|
|
|
| i ->
|
|
|
|
let x = v.{i} in
|
|
|
|
if abs_float x < threshold then
|
2019-09-10 18:39:14 +02:00
|
|
|
(aux [@tailcall]) accu (i-1)
|
2019-02-27 21:28:56 +01:00
|
|
|
else
|
2019-09-10 18:39:14 +02:00
|
|
|
(aux [@tailcall]) ({index=i ; value=x}::accu) (i-1)
|
2019-02-27 21:28:56 +01:00
|
|
|
in
|
|
|
|
let n = Vec.dim v in
|
|
|
|
Sparse { n ; v=aux [] n }
|
|
|
|
|
|
|
|
|
|
|
|
let rec to_assoc_list ?(threshold=epsilon) = function
|
2019-10-24 23:08:24 +02:00
|
|
|
| Sparse {n ; v} -> Array.map (fun {index ; value} -> (index, value)) v |> Array.to_list
|
2019-02-27 21:28:56 +01:00
|
|
|
| Dense v -> to_assoc_list @@ sparse_of_dense ~threshold (Dense v)
|
|
|
|
|
|
|
|
|
|
|
|
let dense_of_sparse = function
|
|
|
|
| Dense _ -> invalid_arg "Expected a sparse vector"
|
|
|
|
| Sparse {n ; v} ->
|
|
|
|
let v' = Vec.make0 n in
|
2019-10-24 23:08:24 +02:00
|
|
|
Array.iter (fun {index ; value} -> v'.{index} <- value) v;
|
2019-02-27 21:28:56 +01:00
|
|
|
Dense v'
|
|
|
|
|
|
|
|
|
|
|
|
let dense_of_vec v = Dense v
|
|
|
|
|
|
|
|
|
|
|
|
let sparse_of_vec ?(threshold=epsilon) v =
|
|
|
|
dense_of_vec v
|
|
|
|
|> sparse_of_dense ~threshold
|
|
|
|
|
|
|
|
|
2019-02-28 15:50:00 +01:00
|
|
|
let sparse_of_assoc_list n v =
|
|
|
|
Sparse { n ;
|
|
|
|
v = List.map (fun (index, value) -> {index ; value}) v
|
2019-10-24 23:08:24 +02:00
|
|
|
|> Array.of_list
|
2019-02-28 15:50:00 +01:00
|
|
|
}
|
2019-02-27 21:28:56 +01:00
|
|
|
|
|
|
|
|
|
|
|
let rec to_vec = function
|
|
|
|
| Dense v -> v
|
|
|
|
| Sparse v -> dense_of_sparse (Sparse v) |> to_vec
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let scale ?(threshold=epsilon) x = function
|
|
|
|
| Dense v -> let v' = copy v in (scal x v'; Dense v')
|
2019-02-28 15:50:00 +01:00
|
|
|
| Sparse {n ; v} ->
|
2019-10-24 23:08:24 +02:00
|
|
|
Sparse {n ; v = Array.map (fun {index ; value} -> { index ; value=x *. value} ) v }
|
2019-02-27 21:28:56 +01:00
|
|
|
|
|
|
|
|
|
|
|
let rec neg = function
|
|
|
|
| Dense v -> Dense (Vec.neg v)
|
2019-02-28 15:50:00 +01:00
|
|
|
| Sparse {n ; v} ->
|
2019-10-24 23:08:24 +02:00
|
|
|
Sparse {n ; v = Array.map (fun {index ; value} -> {index ; value = -. value}) v}
|
2019-02-27 21:28:56 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let axpy ?(threshold=epsilon) ?(alpha=1.) x y =
|
|
|
|
if dim x <> dim y then
|
|
|
|
invalid_arg "Inconsistent dimensions";
|
|
|
|
|
|
|
|
match x, y with
|
|
|
|
| Dense x , Dense y -> Dense (let y = copy y in axpy ~alpha x y ; y)
|
|
|
|
| Sparse {n ; v}, Dense y ->
|
|
|
|
begin
|
|
|
|
let v' = copy y in
|
2019-10-24 23:08:24 +02:00
|
|
|
Array.iter (fun {index ; value} -> v'.{index} <- v'.{index} +. alpha *. value) v;
|
2019-02-27 21:28:56 +01:00
|
|
|
sparse_of_vec ~threshold v'
|
|
|
|
end
|
|
|
|
| Dense x , Sparse {n ; v} ->
|
|
|
|
begin
|
|
|
|
let v' = copy x in
|
|
|
|
scal alpha v';
|
2019-10-24 23:08:24 +02:00
|
|
|
Array.iter (fun {index ; value} -> v'.{index} <- v'.{index} +. value) v;
|
2019-02-27 21:28:56 +01:00
|
|
|
sparse_of_vec ~threshold v'
|
|
|
|
end
|
|
|
|
| Sparse {n ; v}, Sparse {n=n' ; v=v'} ->
|
|
|
|
begin
|
2019-10-24 23:08:24 +02:00
|
|
|
let rec aux accu k l =
|
|
|
|
match k < Array.length v, l < Array.length v' with
|
|
|
|
| true, true ->
|
2019-02-28 15:50:00 +01:00
|
|
|
begin
|
2019-10-24 23:08:24 +02:00
|
|
|
let {index=i ; value=x} = v.(k)
|
|
|
|
and {index=j ; value=y} = v'.(l)
|
|
|
|
in
|
2019-02-28 15:50:00 +01:00
|
|
|
match compare i j with
|
|
|
|
| -1 ->
|
2019-02-27 21:28:56 +01:00
|
|
|
let z = alpha *. x in
|
2019-02-28 15:50:00 +01:00
|
|
|
let new_accu =
|
2019-02-27 21:28:56 +01:00
|
|
|
if abs_float z > threshold then
|
2019-02-28 15:50:00 +01:00
|
|
|
{index=i ; value=z} :: accu
|
2019-02-27 21:28:56 +01:00
|
|
|
else
|
2019-02-28 15:50:00 +01:00
|
|
|
accu
|
2019-10-24 23:08:24 +02:00
|
|
|
in (aux [@tailcall]) new_accu (k+1) l
|
2019-02-28 15:50:00 +01:00
|
|
|
| 1 ->
|
|
|
|
let new_accu =
|
2019-02-27 21:28:56 +01:00
|
|
|
if abs_float y > threshold then
|
2019-02-28 15:50:00 +01:00
|
|
|
{index=j ; value=y} :: accu
|
2019-02-27 21:28:56 +01:00
|
|
|
else
|
2019-02-28 15:50:00 +01:00
|
|
|
accu
|
2019-10-24 23:08:24 +02:00
|
|
|
in (aux [@tailcall]) new_accu k (l+1)
|
2019-02-28 15:50:00 +01:00
|
|
|
| 0 ->
|
|
|
|
let z = alpha *. x +. y in
|
|
|
|
let new_accu =
|
|
|
|
if abs_float z > threshold then
|
|
|
|
{index=i ; value=z} :: accu
|
|
|
|
else
|
|
|
|
accu
|
2019-10-24 23:08:24 +02:00
|
|
|
in (aux [@tailcall]) new_accu (k+1) (l+1)
|
2019-02-28 15:50:00 +01:00
|
|
|
| _ -> assert false
|
|
|
|
end
|
2019-10-24 23:08:24 +02:00
|
|
|
|
|
|
|
| true, false ->
|
|
|
|
let {index=i ; value=x} = v.(k) in
|
|
|
|
(aux [@tailcall]) ({index=i ; value=alpha *. x}::accu) (k+1) l
|
|
|
|
|
|
|
|
| false, true ->
|
|
|
|
(aux [@tailcall]) (v'.(l)::accu) k (l+1)
|
|
|
|
|
|
|
|
| false, false -> {n ; v=List.rev accu |> Array.of_list}
|
2019-02-27 21:28:56 +01:00
|
|
|
in
|
2019-10-24 23:08:24 +02:00
|
|
|
Sparse (aux [] 0 0)
|
2019-02-27 21:28:56 +01:00
|
|
|
end
|
|
|
|
|
|
|
|
let add = axpy ~alpha:1.
|
|
|
|
|
|
|
|
let sub ?(threshold=epsilon) x y = add ~threshold x @@ neg y
|
|
|
|
|
|
|
|
let pp_vector ppf = function
|
|
|
|
| Dense m -> Util.pp_float_array ppf @@ Vec.to_array m
|
|
|
|
| Sparse {n ; v} ->
|
|
|
|
begin
|
|
|
|
Format.fprintf ppf "@[[ %d | " n;
|
2019-10-24 23:08:24 +02:00
|
|
|
Array.iter (fun {index ; value} -> Format.fprintf ppf "@[(%d, %f); @]" index value) v;
|
2019-02-27 21:28:56 +01:00
|
|
|
Format.fprintf ppf "]@]"
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let dot v v' =
|
|
|
|
|
|
|
|
let d_d v v' =
|
|
|
|
dot v v'
|
|
|
|
in
|
|
|
|
|
|
|
|
let d_sp v' {n ; v} =
|
|
|
|
if n <> Vec.dim v' then
|
|
|
|
invalid_arg "Inconsistent dimensions";
|
2019-10-24 23:08:24 +02:00
|
|
|
Array.fold_left (fun accu {index ; value} -> accu +. value *. v'.{index}) 0. v
|
2019-02-27 21:28:56 +01:00
|
|
|
in
|
|
|
|
|
|
|
|
let sp_sp {n ; v} {n=n' ; v=v'} =
|
|
|
|
if n <> n' then
|
|
|
|
invalid_arg "Inconsistent dimensions";
|
|
|
|
|
2019-10-24 23:08:24 +02:00
|
|
|
let rec aux accu k l =
|
|
|
|
match Array.length v > k, Array.length v' > l with
|
|
|
|
| true, true ->
|
|
|
|
let {index=i ; value=x} = v.(k)
|
|
|
|
and {index=j ; value=y} = v'.(l)
|
|
|
|
in
|
2019-02-28 15:50:00 +01:00
|
|
|
begin
|
|
|
|
match compare i j with
|
2019-10-24 23:08:24 +02:00
|
|
|
| -1 -> (aux [@tailcall]) accu (k+1) l
|
|
|
|
| 1 -> (aux [@tailcall]) accu k (l+1)
|
|
|
|
| 0 -> (aux [@tailcall]) (accu +. x *. y) (k+1) (l+1)
|
2019-02-28 15:50:00 +01:00
|
|
|
| _ -> assert false
|
|
|
|
end
|
2019-10-24 23:08:24 +02:00
|
|
|
| _ -> accu
|
2019-02-28 15:50:00 +01:00
|
|
|
in
|
2019-10-24 23:08:24 +02:00
|
|
|
aux 0. 0 0
|
2019-02-27 21:28:56 +01:00
|
|
|
in
|
|
|
|
|
|
|
|
match v, v' with
|
|
|
|
| (Dense v), (Dense v') -> d_d v v'
|
|
|
|
| (Sparse v), (Sparse v') -> sp_sp v v'
|
|
|
|
| (Dense v), (Sparse v') -> d_sp v v'
|
|
|
|
| (Sparse v), (Dense v') -> d_sp v' v
|
|
|
|
|
|
|
|
|
2019-03-21 16:32:41 +01:00
|
|
|
let norm v = sqrt @@ dot v v
|
2019-02-27 21:28:56 +01:00
|
|
|
|
|
|
|
let test_case () =
|
|
|
|
|
|
|
|
let x1 = Vec.map (fun x -> if abs_float x < 0.6 then 0. else x) (Vec.random 100)
|
|
|
|
and x2 = Vec.map (fun x -> if abs_float x < 0.3 then 0. else x) (Vec.random 100)
|
|
|
|
in
|
|
|
|
let x3 = Vec.map (fun x -> 2. *. x) x1
|
|
|
|
and x4 = Vec.add x1 x2
|
|
|
|
and x5 = Vec.sub x1 x2
|
|
|
|
and x6 =
|
|
|
|
let v = copy x2 in
|
|
|
|
Lacaml.D.axpy ~alpha:3. x1 v;
|
|
|
|
v
|
|
|
|
in
|
|
|
|
|
|
|
|
let v1 = dense_of_vec x1
|
|
|
|
and v2 = dense_of_vec x2
|
|
|
|
and v3 = dense_of_vec x3
|
|
|
|
and v4 = dense_of_vec x4
|
|
|
|
and v5 = dense_of_vec x5
|
|
|
|
and v6 = dense_of_vec x6
|
|
|
|
in
|
|
|
|
|
|
|
|
let v1_s = sparse_of_vec x1
|
|
|
|
and v2_s = sparse_of_vec x2
|
|
|
|
and v3_s = sparse_of_vec x3
|
|
|
|
and v4_s = sparse_of_vec x4
|
|
|
|
and v5_s = sparse_of_vec x5
|
|
|
|
and v6_s = sparse_of_vec x6
|
|
|
|
in
|
|
|
|
|
|
|
|
let zero = dense_of_vec (Vec.make0 100)
|
|
|
|
and zero_s = sparse_of_vec (Vec.make0 100)
|
|
|
|
in
|
|
|
|
|
|
|
|
let test_conversion () =
|
|
|
|
Alcotest.(check bool) "sparse -> dense 1" true (dense_of_sparse v1_s = v1 );
|
|
|
|
Alcotest.(check bool) "sparse -> dense 2" true (dense_of_sparse v2_s = v2 );
|
|
|
|
Alcotest.(check bool) "dense -> sparse 1" true (sparse_of_dense v1 = v1_s);
|
|
|
|
Alcotest.(check bool) "dense -> sparse 2" true (sparse_of_dense v2 = v2_s);
|
|
|
|
in
|
|
|
|
let test_operations () =
|
|
|
|
Alcotest.(check bool) "dense scale" true (scale 2. v1 = v3);
|
|
|
|
Alcotest.(check bool) "sparse scale" true (scale 2. v1_s = v3_s);
|
|
|
|
|
|
|
|
Alcotest.(check bool) "dense dense add" true (add v1 v2 = v4);
|
|
|
|
Alcotest.(check bool) "dense sparse add" true (add v1 v2_s = v4_s);
|
|
|
|
Alcotest.(check bool) "sparse dense add" true (add v1_s v2 = v4_s);
|
|
|
|
Alcotest.(check bool) "sparse dense add" true (add v1 v2_s = v4_s);
|
|
|
|
Alcotest.(check bool) "sparse sparse add" true (add v1_s v2_s = v4_s);
|
|
|
|
|
|
|
|
Alcotest.(check bool) "dense dense sub" true (sub v1 v2 = v5);
|
|
|
|
Alcotest.(check bool) "dense sparse sub" true (sub v1 v2_s = v5_s);
|
|
|
|
Alcotest.(check bool) "sparse dense sub" true (sub v1_s v2 = v5_s);
|
|
|
|
Alcotest.(check bool) "sparse dense sub" true (sub v1 v2_s = v5_s);
|
|
|
|
Alcotest.(check bool) "sparse sparse sub" true (sub v1_s v2_s = v5_s);
|
|
|
|
|
|
|
|
Alcotest.(check bool) "dense dense sub" true (sub v1 v1 = zero);
|
|
|
|
Alcotest.(check bool) "dense sparse sub" true (sub v1 v1_s = zero_s);
|
|
|
|
Alcotest.(check bool) "sparse dense sub" true (sub v1_s v1 = zero_s);
|
|
|
|
Alcotest.(check bool) "sparse sparse sub" true (sub v1_s v1_s = zero_s);
|
|
|
|
|
|
|
|
Alcotest.(check bool) "dense dense axpy" true (axpy ~alpha:3. v1 v2 = v6);
|
|
|
|
Alcotest.(check bool) "dense sparse axpy" true (sub ~threshold:1.e-12 (axpy ~alpha:3. v1 v2_s) v6_s = zero_s);
|
|
|
|
Alcotest.(check bool) "sparse dense axpy" true (sub ~threshold:1.e-12 (axpy ~alpha:3. v1_s v2) v6_s = zero_s);
|
2019-05-13 15:54:02 +02:00
|
|
|
|
2019-02-27 21:28:56 +01:00
|
|
|
Alcotest.(check bool) "sparse sparse axpy" true (sub ~threshold:1.e-12 (axpy ~alpha:3. v1_s v2_s) v6_s = zero_s);
|
|
|
|
in
|
|
|
|
let test_dot () =
|
|
|
|
let d1d2 = Lacaml.D.dot x1 x2
|
|
|
|
and d1d1 = Lacaml.D.dot x1 x1
|
|
|
|
and d2d2 = Lacaml.D.dot x2 x2
|
|
|
|
in
|
|
|
|
Alcotest.(check (float 1.e-10)) "sparse x dense 1" (dot v1_s v2 ) d1d2;
|
|
|
|
Alcotest.(check (float 1.e-10)) "sparse x dense 2" (dot v1_s v1 ) d1d1;
|
|
|
|
Alcotest.(check (float 1.e-10)) "sparse x dense 3" (dot v2_s v2 ) d2d2;
|
|
|
|
Alcotest.(check (float 1.e-10)) "dense x sparse 1" (dot v1 v2_s) d1d2;
|
|
|
|
Alcotest.(check (float 1.e-10)) "dense x sparse 2" (dot v1 v1_s) d1d1;
|
|
|
|
Alcotest.(check (float 1.e-10)) "dense x sparse 3" (dot v2 v2_s) d2d2;
|
|
|
|
Alcotest.(check (float 1.e-10)) "sparse x sparse 1" (dot v1_s v2_s) d1d2;
|
|
|
|
Alcotest.(check (float 1.e-10)) "sparse x sparse 2" (dot v1_s v1_s) d1d1;
|
|
|
|
Alcotest.(check (float 1.e-10)) "sparse x sparse 3" (dot v2_s v2_s) d2d2;
|
|
|
|
in
|
|
|
|
[
|
|
|
|
"Conversion", `Quick, test_conversion;
|
|
|
|
"Operations", `Quick, test_operations;
|
|
|
|
"Dot product", `Quick, test_dot;
|
|
|
|
]
|
|
|
|
|