open Lacaml.D let epsilon = Constants.epsilon type sparse_vector = { n: int; v: (int*float) list } type t = | Dense of Vec.t | Sparse of sparse_vector let is_sparse = function | Sparse _ -> true | Dense _ -> false let is_dense = function | Sparse _ -> false | Dense _ -> true let get = function | Dense v -> (fun i -> v.{i}) | Sparse { n ; v } -> (fun i -> if i < 1 || i > n then invalid_arg "index out of bounds"; match List.assoc_opt i v with | Some x -> x | None -> 0. ) let dim = function | Dense v -> Vec.dim v | Sparse {n ; v} -> n let sparse_of_dense ?(threshold=epsilon) = function | Sparse _ -> invalid_arg "Expected a dense vector" | Dense v -> let rec aux accu = function | 0 -> accu | i -> let x = v.{i} in if abs_float x < threshold then aux accu (i-1) else aux ((i, x)::accu) (i-1) in let n = Vec.dim v in Sparse { n ; v=aux [] n } let rec to_assoc_list ?(threshold=epsilon) = function | Sparse {n ; v} -> v | Dense v -> to_assoc_list @@ sparse_of_dense ~threshold (Dense v) let dense_of_sparse = function | Dense _ -> invalid_arg "Expected a sparse vector" | Sparse {n ; v} -> let v' = Vec.make0 n in List.iter (fun (i, x) -> v'.{i} <- x) v; Dense v' let dense_of_vec v = Dense v let sparse_of_vec ?(threshold=epsilon) v = dense_of_vec v |> sparse_of_dense ~threshold let sparse_of_assoc_list n v = Sparse { n ; v } let rec to_vec = function | Dense v -> v | Sparse v -> dense_of_sparse (Sparse v) |> to_vec let scale ?(threshold=epsilon) x = function | Dense v -> let v' = copy v in (scal x v'; Dense v') | Sparse {n ; v} -> Sparse {n ; v=List.map (fun (i,y) -> let z = x *. y in if abs_float z > threshold then Some (i, z) else None ) v |> Util.list_some } let rec neg = function | Dense v -> Dense (Vec.neg v) | Sparse {n ; v} -> Sparse {n ; v=List.map (fun (i,y) -> (i, -. y)) v} (* let rec add ?(threshold=epsilon) x y = if dim x <> dim y then invalid_arg "Inconsistent dimensions"; match x, y with | Dense x , Dense y -> Dense (Vec.add x y) | Sparse {n ; v}, Dense y -> let v' = copy y in List.iter (fun (i, x) -> v'.{i} <- v'.{i} +. x) v; sparse_of_vec ~threshold v' | Sparse {n ; v}, Sparse {n=n' ; v=v'} -> begin let rec aux accu v1 v2 = match v1, v2 with | [], [] -> {n ; v=List.rev accu} | ((i, x)::v1), [] -> aux ((i, x)::accu) v1 [] | [], ((j, y)::v2) -> aux ((j, y)::accu) [] v2 | ((i, x)::v1), ((j, y)::v2) -> if i = j then begin let z = x +. y in if abs_float z > threshold then aux ((i, (x +. y))::accu) v1 v2 else aux accu v1 v2 end else if i < j then begin if abs_float x > threshold then aux ((i, x)::accu) v1 ((j, y)::v2) else aux accu v1 ((j, y)::v2) end else begin if abs_float y > threshold then aux ((j, y)::accu) ((i, x)::v1) v2 else aux accu ((i, x)::v1) v2 end in Sparse (aux [] v v') end | x, y -> add ~threshold y x let sub ?(threshold=epsilon) x y = add ~threshold x (neg y) *) let axpy ?(threshold=epsilon) ?(alpha=1.) x y = if dim x <> dim y then invalid_arg "Inconsistent dimensions"; match x, y with | Dense x , Dense y -> Dense (let y = copy y in axpy ~alpha x y ; y) | Sparse {n ; v}, Dense y -> begin let v' = copy y in List.iter (fun (i, x) -> v'.{i} <- v'.{i} +. alpha *. x) v; sparse_of_vec ~threshold v' end | Dense x , Sparse {n ; v} -> begin let v' = copy x in scal alpha v'; List.iter (fun (i, y) -> v'.{i} <- v'.{i} +. y) v; sparse_of_vec ~threshold v' end | Sparse {n ; v}, Sparse {n=n' ; v=v'} -> begin let rec aux accu v1 v2 = match v1, v2 with | [] , [] -> {n ; v=List.rev accu} | ((i, x)::v1), [] -> aux ((i, x)::accu) v1 [] | [] , ((j, y)::v2) -> aux ((j, y)::accu) [] v2 | ((i, x)::v1), ((j, y)::v2) -> if i = j then begin let z = alpha *. x +. y in if abs_float z > threshold then aux ((i, z)::accu) v1 v2 else aux accu v1 v2 end else if i < j then let z = alpha *. x in begin if abs_float z > threshold then aux ((i, z)::accu) v1 ((j, y)::v2) else aux accu v1 ((j, y)::v2) end else begin if abs_float y > threshold then aux ((j, y)::accu) ((i, x)::v1) v2 else aux accu ((i, x)::v1) v2 end in Sparse (aux [] v v') end let add = axpy ~alpha:1. let sub ?(threshold=epsilon) x y = add ~threshold x @@ neg y let pp_vector ppf = function | Dense m -> Util.pp_float_array ppf @@ Vec.to_array m | Sparse {n ; v} -> begin Format.fprintf ppf "@[[ %d | " n; List.iter (fun (i,x) -> Format.fprintf ppf "@[(%d, %f); @]" i x) v; Format.fprintf ppf "]@]" end let dot v v' = let d_d v v' = dot v v' in let d_sp v' {n ; v} = if n <> Vec.dim v' then invalid_arg "Inconsistent dimensions"; List.fold_left (fun accu (i, v_i) -> accu +. v_i *. v'.{i}) 0. v in let sp_sp {n ; v} {n=n' ; v=v'} = if n <> n' then invalid_arg "Inconsistent dimensions"; List.fold_left (fun accu (i, v_i) -> match List.assoc_opt i v' with | Some w_i -> accu +. v_i *. w_i | None -> accu ) 0. v in match v, v' with | (Dense v), (Dense v') -> d_d v v' | (Sparse v), (Sparse v') -> sp_sp v v' | (Dense v), (Sparse v') -> d_sp v v' | (Sparse v), (Dense v') -> d_sp v' v let test_case () = let x1 = Vec.map (fun x -> if abs_float x < 0.6 then 0. else x) (Vec.random 100) and x2 = Vec.map (fun x -> if abs_float x < 0.3 then 0. else x) (Vec.random 100) in let x3 = Vec.map (fun x -> 2. *. x) x1 and x4 = Vec.add x1 x2 and x5 = Vec.sub x1 x2 and x6 = let v = copy x2 in Lacaml.D.axpy ~alpha:3. x1 v; v in let v1 = dense_of_vec x1 and v2 = dense_of_vec x2 and v3 = dense_of_vec x3 and v4 = dense_of_vec x4 and v5 = dense_of_vec x5 and v6 = dense_of_vec x6 in let v1_s = sparse_of_vec x1 and v2_s = sparse_of_vec x2 and v3_s = sparse_of_vec x3 and v4_s = sparse_of_vec x4 and v5_s = sparse_of_vec x5 and v6_s = sparse_of_vec x6 in let zero = dense_of_vec (Vec.make0 100) and zero_s = sparse_of_vec (Vec.make0 100) in let test_conversion () = Alcotest.(check bool) "sparse -> dense 1" true (dense_of_sparse v1_s = v1 ); Alcotest.(check bool) "sparse -> dense 2" true (dense_of_sparse v2_s = v2 ); Alcotest.(check bool) "dense -> sparse 1" true (sparse_of_dense v1 = v1_s); Alcotest.(check bool) "dense -> sparse 2" true (sparse_of_dense v2 = v2_s); in let test_operations () = Alcotest.(check bool) "dense scale" true (scale 2. v1 = v3); Alcotest.(check bool) "sparse scale" true (scale 2. v1_s = v3_s); Alcotest.(check bool) "dense dense add" true (add v1 v2 = v4); Alcotest.(check bool) "dense sparse add" true (add v1 v2_s = v4_s); Alcotest.(check bool) "sparse dense add" true (add v1_s v2 = v4_s); Alcotest.(check bool) "sparse dense add" true (add v1 v2_s = v4_s); Alcotest.(check bool) "sparse sparse add" true (add v1_s v2_s = v4_s); Alcotest.(check bool) "dense dense sub" true (sub v1 v2 = v5); Alcotest.(check bool) "dense sparse sub" true (sub v1 v2_s = v5_s); Alcotest.(check bool) "sparse dense sub" true (sub v1_s v2 = v5_s); Alcotest.(check bool) "sparse dense sub" true (sub v1 v2_s = v5_s); Alcotest.(check bool) "sparse sparse sub" true (sub v1_s v2_s = v5_s); Alcotest.(check bool) "dense dense sub" true (sub v1 v1 = zero); Alcotest.(check bool) "dense sparse sub" true (sub v1 v1_s = zero_s); Alcotest.(check bool) "sparse dense sub" true (sub v1_s v1 = zero_s); Alcotest.(check bool) "sparse sparse sub" true (sub v1_s v1_s = zero_s); Alcotest.(check bool) "dense dense axpy" true (axpy ~alpha:3. v1 v2 = v6); Alcotest.(check bool) "dense sparse axpy" true (sub ~threshold:1.e-12 (axpy ~alpha:3. v1 v2_s) v6_s = zero_s); Alcotest.(check bool) "sparse dense axpy" true (sub ~threshold:1.e-12 (axpy ~alpha:3. v1_s v2) v6_s = zero_s); Alcotest.(check bool) "sparse sparse axpy" true (sub ~threshold:1.e-12 (axpy ~alpha:3. v1_s v2_s) v6_s = zero_s); in let test_dot () = let d1d2 = Lacaml.D.dot x1 x2 and d1d1 = Lacaml.D.dot x1 x1 and d2d2 = Lacaml.D.dot x2 x2 in Alcotest.(check (float 1.e-10)) "sparse x dense 1" (dot v1_s v2 ) d1d2; Alcotest.(check (float 1.e-10)) "sparse x dense 2" (dot v1_s v1 ) d1d1; Alcotest.(check (float 1.e-10)) "sparse x dense 3" (dot v2_s v2 ) d2d2; Alcotest.(check (float 1.e-10)) "dense x sparse 1" (dot v1 v2_s) d1d2; Alcotest.(check (float 1.e-10)) "dense x sparse 2" (dot v1 v1_s) d1d1; Alcotest.(check (float 1.e-10)) "dense x sparse 3" (dot v2 v2_s) d2d2; Alcotest.(check (float 1.e-10)) "sparse x sparse 1" (dot v1_s v2_s) d1d2; Alcotest.(check (float 1.e-10)) "sparse x sparse 2" (dot v1_s v1_s) d1d1; Alcotest.(check (float 1.e-10)) "sparse x sparse 3" (dot v2_s v2_s) d2d2; in [ "Conversion", `Quick, test_conversion; "Operations", `Quick, test_operations; "Dot product", `Quick, test_dot; ]