10
1
mirror of https://gitlab.com/scemama/QCaml.git synced 2024-11-07 06:33:39 +01:00

Changed data structure of sparse_vector

This commit is contained in:
Anthony Scemama 2019-10-24 23:08:24 +02:00
parent 0462e19117
commit a1349ea69e
2 changed files with 48 additions and 35 deletions

View File

@ -270,6 +270,12 @@ let stream_fold f init stream =
in
aux init
(** {2 Array functions} *)
let array_range first last =
if last < first then [| |] else
Array.init (last-first+1) (fun i -> i+first)
(** {2 Linear algebra} *)

View File

@ -11,7 +11,7 @@ type index_value =
type sparse_vector =
{
n: int;
v: index_value list
v: index_value array
}
type t =
@ -36,7 +36,7 @@ let get = function
| Sparse { n ; v } -> (fun i ->
if i < 1 || i > n then invalid_arg "index out of bounds";
try
List.iter (fun {index ; value} ->
Array.iter (fun {index ; value} ->
if index=i then
raise (Found value)) v;
raise Not_found
@ -56,7 +56,7 @@ let sparse_of_dense ?(threshold=epsilon) = function
| Sparse _ -> invalid_arg "Expected a dense vector"
| Dense v ->
let rec aux accu = function
| 0 -> accu
| 0 -> accu |> Array.of_list
| i ->
let x = v.{i} in
if abs_float x < threshold then
@ -69,7 +69,7 @@ let sparse_of_dense ?(threshold=epsilon) = function
let rec to_assoc_list ?(threshold=epsilon) = function
| Sparse {n ; v} -> List.map (fun {index ; value} -> (index, value)) v
| Sparse {n ; v} -> Array.map (fun {index ; value} -> (index, value)) v |> Array.to_list
| Dense v -> to_assoc_list @@ sparse_of_dense ~threshold (Dense v)
@ -77,7 +77,7 @@ let dense_of_sparse = function
| Dense _ -> invalid_arg "Expected a sparse vector"
| Sparse {n ; v} ->
let v' = Vec.make0 n in
List.iter (fun {index ; value} -> v'.{index} <- value) v;
Array.iter (fun {index ; value} -> v'.{index} <- value) v;
Dense v'
@ -92,6 +92,7 @@ let sparse_of_vec ?(threshold=epsilon) v =
let sparse_of_assoc_list n v =
Sparse { n ;
v = List.map (fun (index, value) -> {index ; value}) v
|> Array.of_list
}
@ -104,19 +105,13 @@ let rec to_vec = function
let scale ?(threshold=epsilon) x = function
| Dense v -> let v' = copy v in (scal x v'; Dense v')
| Sparse {n ; v} ->
Sparse {n ; v=List.map (fun {index ; value} ->
let z = x *. value in
if abs_float z > threshold then
Some {index ; value=z}
else
None
) v |> Util.list_some }
Sparse {n ; v = Array.map (fun {index ; value} -> { index ; value=x *. value} ) v }
let rec neg = function
| Dense v -> Dense (Vec.neg v)
| Sparse {n ; v} ->
Sparse {n ; v=List.map (fun {index ; value} -> {index ; value = -. value}) v}
Sparse {n ; v = Array.map (fun {index ; value} -> {index ; value = -. value}) v}
@ -129,22 +124,25 @@ let axpy ?(threshold=epsilon) ?(alpha=1.) x y =
| Sparse {n ; v}, Dense y ->
begin
let v' = copy y in
List.iter (fun {index ; value} -> v'.{index} <- v'.{index} +. alpha *. value) v;
Array.iter (fun {index ; value} -> v'.{index} <- v'.{index} +. alpha *. value) v;
sparse_of_vec ~threshold v'
end
| Dense x , Sparse {n ; v} ->
begin
let v' = copy x in
scal alpha v';
List.iter (fun {index ; value} -> v'.{index} <- v'.{index} +. value) v;
Array.iter (fun {index ; value} -> v'.{index} <- v'.{index} +. value) v;
sparse_of_vec ~threshold v'
end
| Sparse {n ; v}, Sparse {n=n' ; v=v'} ->
begin
let rec aux accu v1 v2 =
match v1, v2 with
| ({index=i ; value=x}::r1), ({index=j ; value=y}::r2) ->
let rec aux accu k l =
match k < Array.length v, l < Array.length v' with
| true, true ->
begin
let {index=i ; value=x} = v.(k)
and {index=j ; value=y} = v'.(l)
in
match compare i j with
| -1 ->
let z = alpha *. x in
@ -153,14 +151,14 @@ let axpy ?(threshold=epsilon) ?(alpha=1.) x y =
{index=i ; value=z} :: accu
else
accu
in (aux [@tailcall]) new_accu r1 v2
in (aux [@tailcall]) new_accu (k+1) l
| 1 ->
let new_accu =
if abs_float y > threshold then
{index=j ; value=y} :: accu
else
accu
in (aux [@tailcall]) new_accu v1 r2
in (aux [@tailcall]) new_accu k (l+1)
| 0 ->
let z = alpha *. x +. y in
let new_accu =
@ -168,14 +166,20 @@ let axpy ?(threshold=epsilon) ?(alpha=1.) x y =
{index=i ; value=z} :: accu
else
accu
in (aux [@tailcall]) new_accu r1 r2
in (aux [@tailcall]) new_accu (k+1) (l+1)
| _ -> assert false
end
| ({index=i ; value=x}::r1), [] -> (aux [@tailcall]) ({index=i ; value=alpha *. x}::accu) r1 []
| [] , ({index=j ; value=y}::r2) -> (aux [@tailcall]) ({index=j ; value=y}::accu) [] r2
| [] , [] -> {n ; v=List.rev accu}
| true, false ->
let {index=i ; value=x} = v.(k) in
(aux [@tailcall]) ({index=i ; value=alpha *. x}::accu) (k+1) l
| false, true ->
(aux [@tailcall]) (v'.(l)::accu) k (l+1)
| false, false -> {n ; v=List.rev accu |> Array.of_list}
in
Sparse (aux [] v v')
Sparse (aux [] 0 0)
end
let add = axpy ~alpha:1.
@ -187,7 +191,7 @@ let pp_vector ppf = function
| Sparse {n ; v} ->
begin
Format.fprintf ppf "@[[ %d | " n;
List.iter (fun {index ; value} -> Format.fprintf ppf "@[(%d, %f); @]" index value) v;
Array.iter (fun {index ; value} -> Format.fprintf ppf "@[(%d, %f); @]" index value) v;
Format.fprintf ppf "]@]"
end
@ -202,26 +206,29 @@ let dot v v' =
let d_sp v' {n ; v} =
if n <> Vec.dim v' then
invalid_arg "Inconsistent dimensions";
List.fold_left (fun accu {index ; value} -> accu +. value *. v'.{index}) 0. v
Array.fold_left (fun accu {index ; value} -> accu +. value *. v'.{index}) 0. v
in
let sp_sp {n ; v} {n=n' ; v=v'} =
if n <> n' then
invalid_arg "Inconsistent dimensions";
let rec aux accu = function
| (({index=i ; value=v1} :: r1) as s1), (({index=j ; value=v2}::r2) as s2)->
let rec aux accu k l =
match Array.length v > k, Array.length v' > l with
| true, true ->
let {index=i ; value=x} = v.(k)
and {index=j ; value=y} = v'.(l)
in
begin
match compare i j with
| -1 -> (aux [@tailcall]) accu (r1, s2)
| 1 -> (aux [@tailcall]) accu (s1, r2)
| 0 -> (aux [@tailcall]) (accu +. v1 *. v2) (r1, r2)
| -1 -> (aux [@tailcall]) accu (k+1) l
| 1 -> (aux [@tailcall]) accu k (l+1)
| 0 -> (aux [@tailcall]) (accu +. x *. y) (k+1) (l+1)
| _ -> assert false
end
| ([], _ )
| (_ , []) -> accu
| _ -> accu
in
aux 0. (v, v')
aux 0. 0 0
in
match v, v' with