10
1
mirror of https://gitlab.com/scemama/QCaml.git synced 2024-07-31 17:44:12 +02:00
QCaml/Utils/Vector.ml
2019-02-27 21:28:56 +01:00

337 lines
10 KiB
OCaml

open Lacaml.D
let epsilon = Constants.epsilon
type sparse_vector =
{
n: int;
v: (int*float) list
}
type t =
| Dense of Vec.t
| Sparse of sparse_vector
let is_sparse = function
| Sparse _ -> true
| Dense _ -> false
let is_dense = function
| Sparse _ -> false
| Dense _ -> true
let get = function
| Dense v -> (fun i -> v.{i})
| Sparse { n ; v } -> (fun i ->
if i < 1 || i > n then invalid_arg "index out of bounds";
match List.assoc_opt i v with
| Some x -> x
| None -> 0. )
let dim = function
| Dense v -> Vec.dim v
| Sparse {n ; v} -> n
let sparse_of_dense ?(threshold=epsilon) = function
| Sparse _ -> invalid_arg "Expected a dense vector"
| Dense v ->
let rec aux accu = function
| 0 -> accu
| i ->
let x = v.{i} in
if abs_float x < threshold then
aux accu (i-1)
else
aux ((i, x)::accu) (i-1)
in
let n = Vec.dim v in
Sparse { n ; v=aux [] n }
let rec to_assoc_list ?(threshold=epsilon) = function
| Sparse {n ; v} -> v
| Dense v -> to_assoc_list @@ sparse_of_dense ~threshold (Dense v)
let dense_of_sparse = function
| Dense _ -> invalid_arg "Expected a sparse vector"
| Sparse {n ; v} ->
let v' = Vec.make0 n in
List.iter (fun (i, x) -> v'.{i} <- x) v;
Dense v'
let dense_of_vec v = Dense v
let sparse_of_vec ?(threshold=epsilon) v =
dense_of_vec v
|> sparse_of_dense ~threshold
let sparse_of_assoc_list n v = Sparse { n ; v }
let rec to_vec = function
| Dense v -> v
| Sparse v -> dense_of_sparse (Sparse v) |> to_vec
let scale ?(threshold=epsilon) x = function
| Dense v -> let v' = copy v in (scal x v'; Dense v')
| Sparse {n ; v} -> Sparse {n ; v=List.map (fun (i,y) -> let z = x *. y in
if abs_float z > threshold then Some (i, z) else None ) v |> Util.list_some }
let rec neg = function
| Dense v -> Dense (Vec.neg v)
| Sparse {n ; v} -> Sparse {n ; v=List.map (fun (i,y) -> (i, -. y)) v}
(*
let rec add ?(threshold=epsilon) x y =
if dim x <> dim y then
invalid_arg "Inconsistent dimensions";
match x, y with
| Dense x , Dense y -> Dense (Vec.add x y)
| Sparse {n ; v}, Dense y ->
let v' = copy y in
List.iter (fun (i, x) -> v'.{i} <- v'.{i} +. x) v;
sparse_of_vec ~threshold v'
| Sparse {n ; v}, Sparse {n=n' ; v=v'} ->
begin
let rec aux accu v1 v2 =
match v1, v2 with
| [], [] -> {n ; v=List.rev accu}
| ((i, x)::v1), [] ->
aux ((i, x)::accu) v1 []
| [], ((j, y)::v2) ->
aux ((j, y)::accu) [] v2
| ((i, x)::v1), ((j, y)::v2) ->
if i = j then
begin
let z = x +. y in
if abs_float z > threshold then
aux ((i, (x +. y))::accu) v1 v2
else
aux accu v1 v2
end
else if i < j then
begin
if abs_float x > threshold then
aux ((i, x)::accu) v1 ((j, y)::v2)
else
aux accu v1 ((j, y)::v2)
end
else
begin
if abs_float y > threshold then
aux ((j, y)::accu) ((i, x)::v1) v2
else
aux accu ((i, x)::v1) v2
end
in
Sparse (aux [] v v')
end
| x, y -> add ~threshold y x
let sub ?(threshold=epsilon) x y = add ~threshold x (neg y)
*)
let axpy ?(threshold=epsilon) ?(alpha=1.) x y =
if dim x <> dim y then
invalid_arg "Inconsistent dimensions";
match x, y with
| Dense x , Dense y -> Dense (let y = copy y in axpy ~alpha x y ; y)
| Sparse {n ; v}, Dense y ->
begin
let v' = copy y in
List.iter (fun (i, x) -> v'.{i} <- v'.{i} +. alpha *. x) v;
sparse_of_vec ~threshold v'
end
| Dense x , Sparse {n ; v} ->
begin
let v' = copy x in
scal alpha v';
List.iter (fun (i, y) -> v'.{i} <- v'.{i} +. y) v;
sparse_of_vec ~threshold v'
end
| Sparse {n ; v}, Sparse {n=n' ; v=v'} ->
begin
let rec aux accu v1 v2 =
match v1, v2 with
| [] , [] -> {n ; v=List.rev accu}
| ((i, x)::v1), [] -> aux ((i, x)::accu) v1 []
| [] , ((j, y)::v2) -> aux ((j, y)::accu) [] v2
| ((i, x)::v1), ((j, y)::v2) ->
if i = j then
begin
let z = alpha *. x +. y in
if abs_float z > threshold then
aux ((i, z)::accu) v1 v2
else
aux accu v1 v2
end
else if i < j then
let z = alpha *. x in
begin
if abs_float z > threshold then
aux ((i, z)::accu) v1 ((j, y)::v2)
else
aux accu v1 ((j, y)::v2)
end
else
begin
if abs_float y > threshold then
aux ((j, y)::accu) ((i, x)::v1) v2
else
aux accu ((i, x)::v1) v2
end
in
Sparse (aux [] v v')
end
let add = axpy ~alpha:1.
let sub ?(threshold=epsilon) x y = add ~threshold x @@ neg y
let pp_vector ppf = function
| Dense m -> Util.pp_float_array ppf @@ Vec.to_array m
| Sparse {n ; v} ->
begin
Format.fprintf ppf "@[[ %d | " n;
List.iter (fun (i,x) -> Format.fprintf ppf "@[(%d, %f); @]" i x) v;
Format.fprintf ppf "]@]"
end
let dot v v' =
let d_d v v' =
dot v v'
in
let d_sp v' {n ; v} =
if n <> Vec.dim v' then
invalid_arg "Inconsistent dimensions";
List.fold_left (fun accu (i, v_i) -> accu +. v_i *. v'.{i}) 0. v
in
let sp_sp {n ; v} {n=n' ; v=v'} =
if n <> n' then
invalid_arg "Inconsistent dimensions";
List.fold_left (fun accu (i, v_i) ->
match List.assoc_opt i v' with
| Some w_i -> accu +. v_i *. w_i
| None -> accu
) 0. v
in
match v, v' with
| (Dense v), (Dense v') -> d_d v v'
| (Sparse v), (Sparse v') -> sp_sp v v'
| (Dense v), (Sparse v') -> d_sp v v'
| (Sparse v), (Dense v') -> d_sp v' v
let test_case () =
let x1 = Vec.map (fun x -> if abs_float x < 0.6 then 0. else x) (Vec.random 100)
and x2 = Vec.map (fun x -> if abs_float x < 0.3 then 0. else x) (Vec.random 100)
in
let x3 = Vec.map (fun x -> 2. *. x) x1
and x4 = Vec.add x1 x2
and x5 = Vec.sub x1 x2
and x6 =
let v = copy x2 in
Lacaml.D.axpy ~alpha:3. x1 v;
v
in
let v1 = dense_of_vec x1
and v2 = dense_of_vec x2
and v3 = dense_of_vec x3
and v4 = dense_of_vec x4
and v5 = dense_of_vec x5
and v6 = dense_of_vec x6
in
let v1_s = sparse_of_vec x1
and v2_s = sparse_of_vec x2
and v3_s = sparse_of_vec x3
and v4_s = sparse_of_vec x4
and v5_s = sparse_of_vec x5
and v6_s = sparse_of_vec x6
in
let zero = dense_of_vec (Vec.make0 100)
and zero_s = sparse_of_vec (Vec.make0 100)
in
let test_conversion () =
Alcotest.(check bool) "sparse -> dense 1" true (dense_of_sparse v1_s = v1 );
Alcotest.(check bool) "sparse -> dense 2" true (dense_of_sparse v2_s = v2 );
Alcotest.(check bool) "dense -> sparse 1" true (sparse_of_dense v1 = v1_s);
Alcotest.(check bool) "dense -> sparse 2" true (sparse_of_dense v2 = v2_s);
in
let test_operations () =
Alcotest.(check bool) "dense scale" true (scale 2. v1 = v3);
Alcotest.(check bool) "sparse scale" true (scale 2. v1_s = v3_s);
Alcotest.(check bool) "dense dense add" true (add v1 v2 = v4);
Alcotest.(check bool) "dense sparse add" true (add v1 v2_s = v4_s);
Alcotest.(check bool) "sparse dense add" true (add v1_s v2 = v4_s);
Alcotest.(check bool) "sparse dense add" true (add v1 v2_s = v4_s);
Alcotest.(check bool) "sparse sparse add" true (add v1_s v2_s = v4_s);
Alcotest.(check bool) "dense dense sub" true (sub v1 v2 = v5);
Alcotest.(check bool) "dense sparse sub" true (sub v1 v2_s = v5_s);
Alcotest.(check bool) "sparse dense sub" true (sub v1_s v2 = v5_s);
Alcotest.(check bool) "sparse dense sub" true (sub v1 v2_s = v5_s);
Alcotest.(check bool) "sparse sparse sub" true (sub v1_s v2_s = v5_s);
Alcotest.(check bool) "dense dense sub" true (sub v1 v1 = zero);
Alcotest.(check bool) "dense sparse sub" true (sub v1 v1_s = zero_s);
Alcotest.(check bool) "sparse dense sub" true (sub v1_s v1 = zero_s);
Alcotest.(check bool) "sparse sparse sub" true (sub v1_s v1_s = zero_s);
Alcotest.(check bool) "dense dense axpy" true (axpy ~alpha:3. v1 v2 = v6);
Alcotest.(check bool) "dense sparse axpy" true (sub ~threshold:1.e-12 (axpy ~alpha:3. v1 v2_s) v6_s = zero_s);
Alcotest.(check bool) "sparse dense axpy" true (sub ~threshold:1.e-12 (axpy ~alpha:3. v1_s v2) v6_s = zero_s);
Alcotest.(check bool) "sparse sparse axpy" true (sub ~threshold:1.e-12 (axpy ~alpha:3. v1_s v2_s) v6_s = zero_s);
in
let test_dot () =
let d1d2 = Lacaml.D.dot x1 x2
and d1d1 = Lacaml.D.dot x1 x1
and d2d2 = Lacaml.D.dot x2 x2
in
Alcotest.(check (float 1.e-10)) "sparse x dense 1" (dot v1_s v2 ) d1d2;
Alcotest.(check (float 1.e-10)) "sparse x dense 2" (dot v1_s v1 ) d1d1;
Alcotest.(check (float 1.e-10)) "sparse x dense 3" (dot v2_s v2 ) d2d2;
Alcotest.(check (float 1.e-10)) "dense x sparse 1" (dot v1 v2_s) d1d2;
Alcotest.(check (float 1.e-10)) "dense x sparse 2" (dot v1 v1_s) d1d1;
Alcotest.(check (float 1.e-10)) "dense x sparse 3" (dot v2 v2_s) d2d2;
Alcotest.(check (float 1.e-10)) "sparse x sparse 1" (dot v1_s v2_s) d1d2;
Alcotest.(check (float 1.e-10)) "sparse x sparse 2" (dot v1_s v1_s) d1d1;
Alcotest.(check (float 1.e-10)) "sparse x sparse 3" (dot v2_s v2_s) d2d2;
in
[
"Conversion", `Quick, test_conversion;
"Operations", `Quick, test_operations;
"Dot product", `Quick, test_dot;
]