From a1349ea69e3a30ef8048b1cc942a85c81b906738 Mon Sep 17 00:00:00 2001 From: Anthony Scemama Date: Thu, 24 Oct 2019 23:08:24 +0200 Subject: [PATCH] Changed data structure of sparse_vector --- Utils/Util.ml | 6 ++++ Utils/Vector.ml | 77 +++++++++++++++++++++++++++---------------------- 2 files changed, 48 insertions(+), 35 deletions(-) diff --git a/Utils/Util.ml b/Utils/Util.ml index 58e00e9..03203d7 100644 --- a/Utils/Util.ml +++ b/Utils/Util.ml @@ -270,6 +270,12 @@ let stream_fold f init stream = in aux init +(** {2 Array functions} *) + +let array_range first last = + if last < first then [| |] else + Array.init (last-first+1) (fun i -> i+first) + (** {2 Linear algebra} *) diff --git a/Utils/Vector.ml b/Utils/Vector.ml index 6833397..b881ed8 100644 --- a/Utils/Vector.ml +++ b/Utils/Vector.ml @@ -11,7 +11,7 @@ type index_value = type sparse_vector = { n: int; - v: index_value list + v: index_value array } type t = @@ -36,7 +36,7 @@ let get = function | Sparse { n ; v } -> (fun i -> if i < 1 || i > n then invalid_arg "index out of bounds"; try - List.iter (fun {index ; value} -> + Array.iter (fun {index ; value} -> if index=i then raise (Found value)) v; raise Not_found @@ -56,7 +56,7 @@ let sparse_of_dense ?(threshold=epsilon) = function | Sparse _ -> invalid_arg "Expected a dense vector" | Dense v -> let rec aux accu = function - | 0 -> accu + | 0 -> accu |> Array.of_list | i -> let x = v.{i} in if abs_float x < threshold then @@ -69,7 +69,7 @@ let sparse_of_dense ?(threshold=epsilon) = function let rec to_assoc_list ?(threshold=epsilon) = function - | Sparse {n ; v} -> List.map (fun {index ; value} -> (index, value)) v + | Sparse {n ; v} -> Array.map (fun {index ; value} -> (index, value)) v |> Array.to_list | Dense v -> to_assoc_list @@ sparse_of_dense ~threshold (Dense v) @@ -77,7 +77,7 @@ let dense_of_sparse = function | Dense _ -> invalid_arg "Expected a sparse vector" | Sparse {n ; v} -> let v' = Vec.make0 n in - List.iter (fun {index ; value} -> v'.{index} <- value) v; + Array.iter (fun {index ; value} -> v'.{index} <- value) v; Dense v' @@ -92,6 +92,7 @@ let sparse_of_vec ?(threshold=epsilon) v = let sparse_of_assoc_list n v = Sparse { n ; v = List.map (fun (index, value) -> {index ; value}) v + |> Array.of_list } @@ -104,19 +105,13 @@ let rec to_vec = function let scale ?(threshold=epsilon) x = function | Dense v -> let v' = copy v in (scal x v'; Dense v') | Sparse {n ; v} -> - Sparse {n ; v=List.map (fun {index ; value} -> - let z = x *. value in - if abs_float z > threshold then - Some {index ; value=z} - else - None - ) v |> Util.list_some } + Sparse {n ; v = Array.map (fun {index ; value} -> { index ; value=x *. value} ) v } let rec neg = function | Dense v -> Dense (Vec.neg v) | Sparse {n ; v} -> - Sparse {n ; v=List.map (fun {index ; value} -> {index ; value = -. value}) v} + Sparse {n ; v = Array.map (fun {index ; value} -> {index ; value = -. value}) v} @@ -129,22 +124,25 @@ let axpy ?(threshold=epsilon) ?(alpha=1.) x y = | Sparse {n ; v}, Dense y -> begin let v' = copy y in - List.iter (fun {index ; value} -> v'.{index} <- v'.{index} +. alpha *. value) v; + Array.iter (fun {index ; value} -> v'.{index} <- v'.{index} +. alpha *. value) v; sparse_of_vec ~threshold v' end | Dense x , Sparse {n ; v} -> begin let v' = copy x in scal alpha v'; - List.iter (fun {index ; value} -> v'.{index} <- v'.{index} +. value) v; + Array.iter (fun {index ; value} -> v'.{index} <- v'.{index} +. value) v; sparse_of_vec ~threshold v' end | Sparse {n ; v}, Sparse {n=n' ; v=v'} -> begin - let rec aux accu v1 v2 = - match v1, v2 with - | ({index=i ; value=x}::r1), ({index=j ; value=y}::r2) -> + let rec aux accu k l = + match k < Array.length v, l < Array.length v' with + | true, true -> begin + let {index=i ; value=x} = v.(k) + and {index=j ; value=y} = v'.(l) + in match compare i j with | -1 -> let z = alpha *. x in @@ -153,14 +151,14 @@ let axpy ?(threshold=epsilon) ?(alpha=1.) x y = {index=i ; value=z} :: accu else accu - in (aux [@tailcall]) new_accu r1 v2 + in (aux [@tailcall]) new_accu (k+1) l | 1 -> let new_accu = if abs_float y > threshold then {index=j ; value=y} :: accu else accu - in (aux [@tailcall]) new_accu v1 r2 + in (aux [@tailcall]) new_accu k (l+1) | 0 -> let z = alpha *. x +. y in let new_accu = @@ -168,14 +166,20 @@ let axpy ?(threshold=epsilon) ?(alpha=1.) x y = {index=i ; value=z} :: accu else accu - in (aux [@tailcall]) new_accu r1 r2 + in (aux [@tailcall]) new_accu (k+1) (l+1) | _ -> assert false end - | ({index=i ; value=x}::r1), [] -> (aux [@tailcall]) ({index=i ; value=alpha *. x}::accu) r1 [] - | [] , ({index=j ; value=y}::r2) -> (aux [@tailcall]) ({index=j ; value=y}::accu) [] r2 - | [] , [] -> {n ; v=List.rev accu} + + | true, false -> + let {index=i ; value=x} = v.(k) in + (aux [@tailcall]) ({index=i ; value=alpha *. x}::accu) (k+1) l + + | false, true -> + (aux [@tailcall]) (v'.(l)::accu) k (l+1) + + | false, false -> {n ; v=List.rev accu |> Array.of_list} in - Sparse (aux [] v v') + Sparse (aux [] 0 0) end let add = axpy ~alpha:1. @@ -187,7 +191,7 @@ let pp_vector ppf = function | Sparse {n ; v} -> begin Format.fprintf ppf "@[[ %d | " n; - List.iter (fun {index ; value} -> Format.fprintf ppf "@[(%d, %f); @]" index value) v; + Array.iter (fun {index ; value} -> Format.fprintf ppf "@[(%d, %f); @]" index value) v; Format.fprintf ppf "]@]" end @@ -202,26 +206,29 @@ let dot v v' = let d_sp v' {n ; v} = if n <> Vec.dim v' then invalid_arg "Inconsistent dimensions"; - List.fold_left (fun accu {index ; value} -> accu +. value *. v'.{index}) 0. v + Array.fold_left (fun accu {index ; value} -> accu +. value *. v'.{index}) 0. v in let sp_sp {n ; v} {n=n' ; v=v'} = if n <> n' then invalid_arg "Inconsistent dimensions"; - let rec aux accu = function - | (({index=i ; value=v1} :: r1) as s1), (({index=j ; value=v2}::r2) as s2)-> + let rec aux accu k l = + match Array.length v > k, Array.length v' > l with + | true, true -> + let {index=i ; value=x} = v.(k) + and {index=j ; value=y} = v'.(l) + in begin match compare i j with - | -1 -> (aux [@tailcall]) accu (r1, s2) - | 1 -> (aux [@tailcall]) accu (s1, r2) - | 0 -> (aux [@tailcall]) (accu +. v1 *. v2) (r1, r2) + | -1 -> (aux [@tailcall]) accu (k+1) l + | 1 -> (aux [@tailcall]) accu k (l+1) + | 0 -> (aux [@tailcall]) (accu +. x *. y) (k+1) (l+1) | _ -> assert false end - | ([], _ ) - | (_ , []) -> accu + | _ -> accu in - aux 0. (v, v') + aux 0. 0 0 in match v, v' with