mirror of
https://gitlab.com/scemama/QCaml.git
synced 2024-10-04 23:36:08 +02:00
Changed data structure of sparse_vector
This commit is contained in:
parent
0462e19117
commit
a1349ea69e
@ -270,6 +270,12 @@ let stream_fold f init stream =
|
|||||||
in
|
in
|
||||||
aux init
|
aux init
|
||||||
|
|
||||||
|
(** {2 Array functions} *)
|
||||||
|
|
||||||
|
let array_range first last =
|
||||||
|
if last < first then [| |] else
|
||||||
|
Array.init (last-first+1) (fun i -> i+first)
|
||||||
|
|
||||||
|
|
||||||
(** {2 Linear algebra} *)
|
(** {2 Linear algebra} *)
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ type index_value =
|
|||||||
type sparse_vector =
|
type sparse_vector =
|
||||||
{
|
{
|
||||||
n: int;
|
n: int;
|
||||||
v: index_value list
|
v: index_value array
|
||||||
}
|
}
|
||||||
|
|
||||||
type t =
|
type t =
|
||||||
@ -36,7 +36,7 @@ let get = function
|
|||||||
| Sparse { n ; v } -> (fun i ->
|
| Sparse { n ; v } -> (fun i ->
|
||||||
if i < 1 || i > n then invalid_arg "index out of bounds";
|
if i < 1 || i > n then invalid_arg "index out of bounds";
|
||||||
try
|
try
|
||||||
List.iter (fun {index ; value} ->
|
Array.iter (fun {index ; value} ->
|
||||||
if index=i then
|
if index=i then
|
||||||
raise (Found value)) v;
|
raise (Found value)) v;
|
||||||
raise Not_found
|
raise Not_found
|
||||||
@ -56,7 +56,7 @@ let sparse_of_dense ?(threshold=epsilon) = function
|
|||||||
| Sparse _ -> invalid_arg "Expected a dense vector"
|
| Sparse _ -> invalid_arg "Expected a dense vector"
|
||||||
| Dense v ->
|
| Dense v ->
|
||||||
let rec aux accu = function
|
let rec aux accu = function
|
||||||
| 0 -> accu
|
| 0 -> accu |> Array.of_list
|
||||||
| i ->
|
| i ->
|
||||||
let x = v.{i} in
|
let x = v.{i} in
|
||||||
if abs_float x < threshold then
|
if abs_float x < threshold then
|
||||||
@ -69,7 +69,7 @@ let sparse_of_dense ?(threshold=epsilon) = function
|
|||||||
|
|
||||||
|
|
||||||
let rec to_assoc_list ?(threshold=epsilon) = function
|
let rec to_assoc_list ?(threshold=epsilon) = function
|
||||||
| Sparse {n ; v} -> List.map (fun {index ; value} -> (index, value)) v
|
| Sparse {n ; v} -> Array.map (fun {index ; value} -> (index, value)) v |> Array.to_list
|
||||||
| Dense v -> to_assoc_list @@ sparse_of_dense ~threshold (Dense v)
|
| Dense v -> to_assoc_list @@ sparse_of_dense ~threshold (Dense v)
|
||||||
|
|
||||||
|
|
||||||
@ -77,7 +77,7 @@ let dense_of_sparse = function
|
|||||||
| Dense _ -> invalid_arg "Expected a sparse vector"
|
| Dense _ -> invalid_arg "Expected a sparse vector"
|
||||||
| Sparse {n ; v} ->
|
| Sparse {n ; v} ->
|
||||||
let v' = Vec.make0 n in
|
let v' = Vec.make0 n in
|
||||||
List.iter (fun {index ; value} -> v'.{index} <- value) v;
|
Array.iter (fun {index ; value} -> v'.{index} <- value) v;
|
||||||
Dense v'
|
Dense v'
|
||||||
|
|
||||||
|
|
||||||
@ -92,6 +92,7 @@ let sparse_of_vec ?(threshold=epsilon) v =
|
|||||||
let sparse_of_assoc_list n v =
|
let sparse_of_assoc_list n v =
|
||||||
Sparse { n ;
|
Sparse { n ;
|
||||||
v = List.map (fun (index, value) -> {index ; value}) v
|
v = List.map (fun (index, value) -> {index ; value}) v
|
||||||
|
|> Array.of_list
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -104,19 +105,13 @@ let rec to_vec = function
|
|||||||
let scale ?(threshold=epsilon) x = function
|
let scale ?(threshold=epsilon) x = function
|
||||||
| Dense v -> let v' = copy v in (scal x v'; Dense v')
|
| Dense v -> let v' = copy v in (scal x v'; Dense v')
|
||||||
| Sparse {n ; v} ->
|
| Sparse {n ; v} ->
|
||||||
Sparse {n ; v=List.map (fun {index ; value} ->
|
Sparse {n ; v = Array.map (fun {index ; value} -> { index ; value=x *. value} ) v }
|
||||||
let z = x *. value in
|
|
||||||
if abs_float z > threshold then
|
|
||||||
Some {index ; value=z}
|
|
||||||
else
|
|
||||||
None
|
|
||||||
) v |> Util.list_some }
|
|
||||||
|
|
||||||
|
|
||||||
let rec neg = function
|
let rec neg = function
|
||||||
| Dense v -> Dense (Vec.neg v)
|
| Dense v -> Dense (Vec.neg v)
|
||||||
| Sparse {n ; v} ->
|
| Sparse {n ; v} ->
|
||||||
Sparse {n ; v=List.map (fun {index ; value} -> {index ; value = -. value}) v}
|
Sparse {n ; v = Array.map (fun {index ; value} -> {index ; value = -. value}) v}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -129,22 +124,25 @@ let axpy ?(threshold=epsilon) ?(alpha=1.) x y =
|
|||||||
| Sparse {n ; v}, Dense y ->
|
| Sparse {n ; v}, Dense y ->
|
||||||
begin
|
begin
|
||||||
let v' = copy y in
|
let v' = copy y in
|
||||||
List.iter (fun {index ; value} -> v'.{index} <- v'.{index} +. alpha *. value) v;
|
Array.iter (fun {index ; value} -> v'.{index} <- v'.{index} +. alpha *. value) v;
|
||||||
sparse_of_vec ~threshold v'
|
sparse_of_vec ~threshold v'
|
||||||
end
|
end
|
||||||
| Dense x , Sparse {n ; v} ->
|
| Dense x , Sparse {n ; v} ->
|
||||||
begin
|
begin
|
||||||
let v' = copy x in
|
let v' = copy x in
|
||||||
scal alpha v';
|
scal alpha v';
|
||||||
List.iter (fun {index ; value} -> v'.{index} <- v'.{index} +. value) v;
|
Array.iter (fun {index ; value} -> v'.{index} <- v'.{index} +. value) v;
|
||||||
sparse_of_vec ~threshold v'
|
sparse_of_vec ~threshold v'
|
||||||
end
|
end
|
||||||
| Sparse {n ; v}, Sparse {n=n' ; v=v'} ->
|
| Sparse {n ; v}, Sparse {n=n' ; v=v'} ->
|
||||||
begin
|
begin
|
||||||
let rec aux accu v1 v2 =
|
let rec aux accu k l =
|
||||||
match v1, v2 with
|
match k < Array.length v, l < Array.length v' with
|
||||||
| ({index=i ; value=x}::r1), ({index=j ; value=y}::r2) ->
|
| true, true ->
|
||||||
begin
|
begin
|
||||||
|
let {index=i ; value=x} = v.(k)
|
||||||
|
and {index=j ; value=y} = v'.(l)
|
||||||
|
in
|
||||||
match compare i j with
|
match compare i j with
|
||||||
| -1 ->
|
| -1 ->
|
||||||
let z = alpha *. x in
|
let z = alpha *. x in
|
||||||
@ -153,14 +151,14 @@ let axpy ?(threshold=epsilon) ?(alpha=1.) x y =
|
|||||||
{index=i ; value=z} :: accu
|
{index=i ; value=z} :: accu
|
||||||
else
|
else
|
||||||
accu
|
accu
|
||||||
in (aux [@tailcall]) new_accu r1 v2
|
in (aux [@tailcall]) new_accu (k+1) l
|
||||||
| 1 ->
|
| 1 ->
|
||||||
let new_accu =
|
let new_accu =
|
||||||
if abs_float y > threshold then
|
if abs_float y > threshold then
|
||||||
{index=j ; value=y} :: accu
|
{index=j ; value=y} :: accu
|
||||||
else
|
else
|
||||||
accu
|
accu
|
||||||
in (aux [@tailcall]) new_accu v1 r2
|
in (aux [@tailcall]) new_accu k (l+1)
|
||||||
| 0 ->
|
| 0 ->
|
||||||
let z = alpha *. x +. y in
|
let z = alpha *. x +. y in
|
||||||
let new_accu =
|
let new_accu =
|
||||||
@ -168,14 +166,20 @@ let axpy ?(threshold=epsilon) ?(alpha=1.) x y =
|
|||||||
{index=i ; value=z} :: accu
|
{index=i ; value=z} :: accu
|
||||||
else
|
else
|
||||||
accu
|
accu
|
||||||
in (aux [@tailcall]) new_accu r1 r2
|
in (aux [@tailcall]) new_accu (k+1) (l+1)
|
||||||
| _ -> assert false
|
| _ -> assert false
|
||||||
end
|
end
|
||||||
| ({index=i ; value=x}::r1), [] -> (aux [@tailcall]) ({index=i ; value=alpha *. x}::accu) r1 []
|
|
||||||
| [] , ({index=j ; value=y}::r2) -> (aux [@tailcall]) ({index=j ; value=y}::accu) [] r2
|
| true, false ->
|
||||||
| [] , [] -> {n ; v=List.rev accu}
|
let {index=i ; value=x} = v.(k) in
|
||||||
|
(aux [@tailcall]) ({index=i ; value=alpha *. x}::accu) (k+1) l
|
||||||
|
|
||||||
|
| false, true ->
|
||||||
|
(aux [@tailcall]) (v'.(l)::accu) k (l+1)
|
||||||
|
|
||||||
|
| false, false -> {n ; v=List.rev accu |> Array.of_list}
|
||||||
in
|
in
|
||||||
Sparse (aux [] v v')
|
Sparse (aux [] 0 0)
|
||||||
end
|
end
|
||||||
|
|
||||||
let add = axpy ~alpha:1.
|
let add = axpy ~alpha:1.
|
||||||
@ -187,7 +191,7 @@ let pp_vector ppf = function
|
|||||||
| Sparse {n ; v} ->
|
| Sparse {n ; v} ->
|
||||||
begin
|
begin
|
||||||
Format.fprintf ppf "@[[ %d | " n;
|
Format.fprintf ppf "@[[ %d | " n;
|
||||||
List.iter (fun {index ; value} -> Format.fprintf ppf "@[(%d, %f); @]" index value) v;
|
Array.iter (fun {index ; value} -> Format.fprintf ppf "@[(%d, %f); @]" index value) v;
|
||||||
Format.fprintf ppf "]@]"
|
Format.fprintf ppf "]@]"
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -202,26 +206,29 @@ let dot v v' =
|
|||||||
let d_sp v' {n ; v} =
|
let d_sp v' {n ; v} =
|
||||||
if n <> Vec.dim v' then
|
if n <> Vec.dim v' then
|
||||||
invalid_arg "Inconsistent dimensions";
|
invalid_arg "Inconsistent dimensions";
|
||||||
List.fold_left (fun accu {index ; value} -> accu +. value *. v'.{index}) 0. v
|
Array.fold_left (fun accu {index ; value} -> accu +. value *. v'.{index}) 0. v
|
||||||
in
|
in
|
||||||
|
|
||||||
let sp_sp {n ; v} {n=n' ; v=v'} =
|
let sp_sp {n ; v} {n=n' ; v=v'} =
|
||||||
if n <> n' then
|
if n <> n' then
|
||||||
invalid_arg "Inconsistent dimensions";
|
invalid_arg "Inconsistent dimensions";
|
||||||
|
|
||||||
let rec aux accu = function
|
let rec aux accu k l =
|
||||||
| (({index=i ; value=v1} :: r1) as s1), (({index=j ; value=v2}::r2) as s2)->
|
match Array.length v > k, Array.length v' > l with
|
||||||
|
| true, true ->
|
||||||
|
let {index=i ; value=x} = v.(k)
|
||||||
|
and {index=j ; value=y} = v'.(l)
|
||||||
|
in
|
||||||
begin
|
begin
|
||||||
match compare i j with
|
match compare i j with
|
||||||
| -1 -> (aux [@tailcall]) accu (r1, s2)
|
| -1 -> (aux [@tailcall]) accu (k+1) l
|
||||||
| 1 -> (aux [@tailcall]) accu (s1, r2)
|
| 1 -> (aux [@tailcall]) accu k (l+1)
|
||||||
| 0 -> (aux [@tailcall]) (accu +. v1 *. v2) (r1, r2)
|
| 0 -> (aux [@tailcall]) (accu +. x *. y) (k+1) (l+1)
|
||||||
| _ -> assert false
|
| _ -> assert false
|
||||||
end
|
end
|
||||||
| ([], _ )
|
| _ -> accu
|
||||||
| (_ , []) -> accu
|
|
||||||
in
|
in
|
||||||
aux 0. (v, v')
|
aux 0. 0 0
|
||||||
in
|
in
|
||||||
|
|
||||||
match v, v' with
|
match v, v' with
|
||||||
|
Loading…
Reference in New Issue
Block a user