mirror of
https://gitlab.com/scemama/QCaml.git
synced 2024-11-14 10:03:39 +01:00
Changed data structure of sparse_vector
This commit is contained in:
parent
0462e19117
commit
a1349ea69e
@ -270,6 +270,12 @@ let stream_fold f init stream =
|
||||
in
|
||||
aux init
|
||||
|
||||
(** {2 Array functions} *)
|
||||
|
||||
let array_range first last =
|
||||
if last < first then [| |] else
|
||||
Array.init (last-first+1) (fun i -> i+first)
|
||||
|
||||
|
||||
(** {2 Linear algebra} *)
|
||||
|
||||
|
@ -11,7 +11,7 @@ type index_value =
|
||||
type sparse_vector =
|
||||
{
|
||||
n: int;
|
||||
v: index_value list
|
||||
v: index_value array
|
||||
}
|
||||
|
||||
type t =
|
||||
@ -36,7 +36,7 @@ let get = function
|
||||
| Sparse { n ; v } -> (fun i ->
|
||||
if i < 1 || i > n then invalid_arg "index out of bounds";
|
||||
try
|
||||
List.iter (fun {index ; value} ->
|
||||
Array.iter (fun {index ; value} ->
|
||||
if index=i then
|
||||
raise (Found value)) v;
|
||||
raise Not_found
|
||||
@ -56,7 +56,7 @@ let sparse_of_dense ?(threshold=epsilon) = function
|
||||
| Sparse _ -> invalid_arg "Expected a dense vector"
|
||||
| Dense v ->
|
||||
let rec aux accu = function
|
||||
| 0 -> accu
|
||||
| 0 -> accu |> Array.of_list
|
||||
| i ->
|
||||
let x = v.{i} in
|
||||
if abs_float x < threshold then
|
||||
@ -69,7 +69,7 @@ let sparse_of_dense ?(threshold=epsilon) = function
|
||||
|
||||
|
||||
let rec to_assoc_list ?(threshold=epsilon) = function
|
||||
| Sparse {n ; v} -> List.map (fun {index ; value} -> (index, value)) v
|
||||
| Sparse {n ; v} -> Array.map (fun {index ; value} -> (index, value)) v |> Array.to_list
|
||||
| Dense v -> to_assoc_list @@ sparse_of_dense ~threshold (Dense v)
|
||||
|
||||
|
||||
@ -77,7 +77,7 @@ let dense_of_sparse = function
|
||||
| Dense _ -> invalid_arg "Expected a sparse vector"
|
||||
| Sparse {n ; v} ->
|
||||
let v' = Vec.make0 n in
|
||||
List.iter (fun {index ; value} -> v'.{index} <- value) v;
|
||||
Array.iter (fun {index ; value} -> v'.{index} <- value) v;
|
||||
Dense v'
|
||||
|
||||
|
||||
@ -92,6 +92,7 @@ let sparse_of_vec ?(threshold=epsilon) v =
|
||||
let sparse_of_assoc_list n v =
|
||||
Sparse { n ;
|
||||
v = List.map (fun (index, value) -> {index ; value}) v
|
||||
|> Array.of_list
|
||||
}
|
||||
|
||||
|
||||
@ -104,19 +105,13 @@ let rec to_vec = function
|
||||
let scale ?(threshold=epsilon) x = function
|
||||
| Dense v -> let v' = copy v in (scal x v'; Dense v')
|
||||
| Sparse {n ; v} ->
|
||||
Sparse {n ; v=List.map (fun {index ; value} ->
|
||||
let z = x *. value in
|
||||
if abs_float z > threshold then
|
||||
Some {index ; value=z}
|
||||
else
|
||||
None
|
||||
) v |> Util.list_some }
|
||||
Sparse {n ; v = Array.map (fun {index ; value} -> { index ; value=x *. value} ) v }
|
||||
|
||||
|
||||
let rec neg = function
|
||||
| Dense v -> Dense (Vec.neg v)
|
||||
| Sparse {n ; v} ->
|
||||
Sparse {n ; v=List.map (fun {index ; value} -> {index ; value = -. value}) v}
|
||||
Sparse {n ; v = Array.map (fun {index ; value} -> {index ; value = -. value}) v}
|
||||
|
||||
|
||||
|
||||
@ -129,22 +124,25 @@ let axpy ?(threshold=epsilon) ?(alpha=1.) x y =
|
||||
| Sparse {n ; v}, Dense y ->
|
||||
begin
|
||||
let v' = copy y in
|
||||
List.iter (fun {index ; value} -> v'.{index} <- v'.{index} +. alpha *. value) v;
|
||||
Array.iter (fun {index ; value} -> v'.{index} <- v'.{index} +. alpha *. value) v;
|
||||
sparse_of_vec ~threshold v'
|
||||
end
|
||||
| Dense x , Sparse {n ; v} ->
|
||||
begin
|
||||
let v' = copy x in
|
||||
scal alpha v';
|
||||
List.iter (fun {index ; value} -> v'.{index} <- v'.{index} +. value) v;
|
||||
Array.iter (fun {index ; value} -> v'.{index} <- v'.{index} +. value) v;
|
||||
sparse_of_vec ~threshold v'
|
||||
end
|
||||
| Sparse {n ; v}, Sparse {n=n' ; v=v'} ->
|
||||
begin
|
||||
let rec aux accu v1 v2 =
|
||||
match v1, v2 with
|
||||
| ({index=i ; value=x}::r1), ({index=j ; value=y}::r2) ->
|
||||
let rec aux accu k l =
|
||||
match k < Array.length v, l < Array.length v' with
|
||||
| true, true ->
|
||||
begin
|
||||
let {index=i ; value=x} = v.(k)
|
||||
and {index=j ; value=y} = v'.(l)
|
||||
in
|
||||
match compare i j with
|
||||
| -1 ->
|
||||
let z = alpha *. x in
|
||||
@ -153,14 +151,14 @@ let axpy ?(threshold=epsilon) ?(alpha=1.) x y =
|
||||
{index=i ; value=z} :: accu
|
||||
else
|
||||
accu
|
||||
in (aux [@tailcall]) new_accu r1 v2
|
||||
in (aux [@tailcall]) new_accu (k+1) l
|
||||
| 1 ->
|
||||
let new_accu =
|
||||
if abs_float y > threshold then
|
||||
{index=j ; value=y} :: accu
|
||||
else
|
||||
accu
|
||||
in (aux [@tailcall]) new_accu v1 r2
|
||||
in (aux [@tailcall]) new_accu k (l+1)
|
||||
| 0 ->
|
||||
let z = alpha *. x +. y in
|
||||
let new_accu =
|
||||
@ -168,14 +166,20 @@ let axpy ?(threshold=epsilon) ?(alpha=1.) x y =
|
||||
{index=i ; value=z} :: accu
|
||||
else
|
||||
accu
|
||||
in (aux [@tailcall]) new_accu r1 r2
|
||||
in (aux [@tailcall]) new_accu (k+1) (l+1)
|
||||
| _ -> assert false
|
||||
end
|
||||
| ({index=i ; value=x}::r1), [] -> (aux [@tailcall]) ({index=i ; value=alpha *. x}::accu) r1 []
|
||||
| [] , ({index=j ; value=y}::r2) -> (aux [@tailcall]) ({index=j ; value=y}::accu) [] r2
|
||||
| [] , [] -> {n ; v=List.rev accu}
|
||||
|
||||
| true, false ->
|
||||
let {index=i ; value=x} = v.(k) in
|
||||
(aux [@tailcall]) ({index=i ; value=alpha *. x}::accu) (k+1) l
|
||||
|
||||
| false, true ->
|
||||
(aux [@tailcall]) (v'.(l)::accu) k (l+1)
|
||||
|
||||
| false, false -> {n ; v=List.rev accu |> Array.of_list}
|
||||
in
|
||||
Sparse (aux [] v v')
|
||||
Sparse (aux [] 0 0)
|
||||
end
|
||||
|
||||
let add = axpy ~alpha:1.
|
||||
@ -187,7 +191,7 @@ let pp_vector ppf = function
|
||||
| Sparse {n ; v} ->
|
||||
begin
|
||||
Format.fprintf ppf "@[[ %d | " n;
|
||||
List.iter (fun {index ; value} -> Format.fprintf ppf "@[(%d, %f); @]" index value) v;
|
||||
Array.iter (fun {index ; value} -> Format.fprintf ppf "@[(%d, %f); @]" index value) v;
|
||||
Format.fprintf ppf "]@]"
|
||||
end
|
||||
|
||||
@ -202,26 +206,29 @@ let dot v v' =
|
||||
let d_sp v' {n ; v} =
|
||||
if n <> Vec.dim v' then
|
||||
invalid_arg "Inconsistent dimensions";
|
||||
List.fold_left (fun accu {index ; value} -> accu +. value *. v'.{index}) 0. v
|
||||
Array.fold_left (fun accu {index ; value} -> accu +. value *. v'.{index}) 0. v
|
||||
in
|
||||
|
||||
let sp_sp {n ; v} {n=n' ; v=v'} =
|
||||
if n <> n' then
|
||||
invalid_arg "Inconsistent dimensions";
|
||||
|
||||
let rec aux accu = function
|
||||
| (({index=i ; value=v1} :: r1) as s1), (({index=j ; value=v2}::r2) as s2)->
|
||||
let rec aux accu k l =
|
||||
match Array.length v > k, Array.length v' > l with
|
||||
| true, true ->
|
||||
let {index=i ; value=x} = v.(k)
|
||||
and {index=j ; value=y} = v'.(l)
|
||||
in
|
||||
begin
|
||||
match compare i j with
|
||||
| -1 -> (aux [@tailcall]) accu (r1, s2)
|
||||
| 1 -> (aux [@tailcall]) accu (s1, r2)
|
||||
| 0 -> (aux [@tailcall]) (accu +. v1 *. v2) (r1, r2)
|
||||
| -1 -> (aux [@tailcall]) accu (k+1) l
|
||||
| 1 -> (aux [@tailcall]) accu k (l+1)
|
||||
| 0 -> (aux [@tailcall]) (accu +. x *. y) (k+1) (l+1)
|
||||
| _ -> assert false
|
||||
end
|
||||
| ([], _ )
|
||||
| (_ , []) -> accu
|
||||
| _ -> accu
|
||||
in
|
||||
aux 0. (v, v')
|
||||
aux 0. 0 0
|
||||
in
|
||||
|
||||
match v, v' with
|
||||
|
Loading…
Reference in New Issue
Block a user