mirror of
https://gitlab.com/scemama/QCaml.git
synced 2025-01-03 01:55:40 +01:00
Added conjugate-gradient
This commit is contained in:
parent
8e7107cfba
commit
7633f01746
@ -21,6 +21,9 @@ default: $(ALL_EXE) doc
|
||||
|
||||
tests: run_tests.native
|
||||
|
||||
bytelib:
|
||||
ocamlmklib -o bytelib ./_build/Utils/Util.cmo
|
||||
|
||||
QCaml.odocl: $(MLIFILES)
|
||||
ls $(MLIFILES) | sed "s/\.mli//" > QCaml.odocl
|
||||
|
||||
|
137
Utils/Matrix.ml
137
Utils/Matrix.ml
@ -60,11 +60,17 @@ let dense_of_sparse = function
|
||||
|
||||
let dense_of_mat m = Dense m
|
||||
|
||||
|
||||
let rec to_vector_array ?(threshold=epsilon) = function
|
||||
| Sparse {m ; n ; v} -> v
|
||||
| Dense m -> to_vector_array (sparse_of_dense ~threshold (Dense m))
|
||||
|
||||
|
||||
let identity n =
|
||||
Sparse { n ; m=n ;
|
||||
v = Array.init n (fun i -> Vector.sparse_of_assoc_list n [(i+1,1.0)])
|
||||
}
|
||||
|
||||
let sparse_of_mat ?(threshold=epsilon) m =
|
||||
dense_of_mat m
|
||||
|> sparse_of_dense ~threshold
|
||||
@ -307,18 +313,87 @@ let mv ?(sparse=false) ?(trans=`N) ?(threshold=epsilon) a b =
|
||||
else
|
||||
Vector.dense_of_vec dense_result
|
||||
|
||||
let iterative_ax_eq_b ~trans a b =
|
||||
failwith "Not implemented"
|
||||
|
||||
let rec op2 dense_op sparse_op a b =
|
||||
if dim1 a <> dim1 b || dim2 a <> dim2 b then
|
||||
failwith "Inconsistent dimensions";
|
||||
|
||||
match a, b with
|
||||
| (Dense a), (Dense b) -> Dense (dense_op a b)
|
||||
| (Dense _), (Sparse _) -> op2 dense_op sparse_op (sparse_of_dense a) b
|
||||
| (Sparse _), (Dense _) -> op2 dense_op sparse_op a (sparse_of_dense b)
|
||||
| (Sparse a), (Sparse b) -> Sparse
|
||||
{ m=a.m ; n=a.n ;
|
||||
v = Array.map2 sparse_op a.v b.v
|
||||
}
|
||||
|
||||
let add = op2 (fun a b -> Mat.add a b) (fun a b -> Vector.add a b)
|
||||
let sub = op2 (fun a b -> Mat.sub a b) (fun a b -> Vector.sub a b)
|
||||
|
||||
let scale f = function
|
||||
| Dense a -> let b = lacpy a in (Mat.scal f b ; Dense b)
|
||||
| Sparse a -> Sparse
|
||||
{ a with
|
||||
v = if f = 1.0 then a.v
|
||||
else Array.map (fun v -> Vector.scale f v) a.v }
|
||||
|
||||
let frobenius_norm = function
|
||||
| Dense a -> lange ~norm:`F a
|
||||
| Sparse a ->
|
||||
Array.fold_left (fun accu v -> accu +. Vector.dot v v) 0. a.v
|
||||
|> sqrt
|
||||
|
||||
|
||||
|
||||
let ax_eq_b_conj_grad ?x a b =
|
||||
(* /!\ : A needs to be positive definite and symmetric *)
|
||||
let x =
|
||||
match x with
|
||||
| Some x0 -> x0
|
||||
| None -> b
|
||||
in
|
||||
let r = Vector.sub b (mv a x) in
|
||||
let p = r in
|
||||
let rsold = Vector.dot r r in
|
||||
let rec aux rsold r p x = function
|
||||
| 0 -> x
|
||||
| i ->
|
||||
let ap = mv a p in
|
||||
let alpha = rsold /. (Vector.dot p ap) in
|
||||
let x = Vector.add x (Vector.scale alpha p) in
|
||||
let r = Vector.sub r (Vector.scale alpha ap) in
|
||||
let rsnew = Vector.dot r r in
|
||||
if rsnew < Constants.epsilon then
|
||||
x
|
||||
else
|
||||
let p =
|
||||
Vector.add r (Vector.scale (rsnew /. (rsold +. 1.e-12) ) p)
|
||||
in
|
||||
aux rsnew r p x (i-1)
|
||||
in
|
||||
aux rsold r p x (Vector.dim b *2)
|
||||
|
||||
|
||||
|
||||
let rec ax_eq_b ?(trans=`N) a b =
|
||||
match a, b with
|
||||
| (Dense a), (Dense b) ->
|
||||
let x = lacpy a in
|
||||
(getrs ~trans x b; Dense x)
|
||||
let a = lacpy a in
|
||||
let x = lacpy b in
|
||||
(getrs ~trans a x; Dense x)
|
||||
| (Dense _), (Sparse _) ->
|
||||
let b = dense_of_sparse b in
|
||||
ax_eq_b ~trans a b
|
||||
| _ -> iterative_ax_eq_b ~trans a b
|
||||
| _ ->
|
||||
let ata, atb =
|
||||
if trans = `N then
|
||||
mm ~transa:`T a a, mm ~transa:`T a b
|
||||
else
|
||||
mm ~transa:`N a a, mm ~transa:`N a b
|
||||
in
|
||||
Sparse { m=dim1 b ; n=dim2 b ;
|
||||
v=Array.map (fun v -> ax_eq_b_conj_grad ata v) (to_vector_array atb)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -409,6 +484,25 @@ let test_case () =
|
||||
Alcotest.(check (float 1.e-10)) "sparse sparse" 0. (norm_diff m1_s (outer_product v1_s v2_s));
|
||||
in
|
||||
|
||||
let test_add_sub () =
|
||||
let x2 = Mat.map (fun x -> if abs_float x < 0.3 then 0. else x) (Mat.random d1 d2) in
|
||||
let m2 = dense_of_mat x2 in
|
||||
let m3 = Mat.add x1 x2 |> dense_of_mat in
|
||||
let m4 = Mat.sub x1 x2 |> dense_of_mat in
|
||||
let m2_s = sparse_of_mat x2 in
|
||||
let m3_s = Mat.add x1 x2 |> sparse_of_mat in
|
||||
let m4_s = Mat.sub x1 x2 |> sparse_of_mat in
|
||||
Alcotest.(check (float 1.e-10)) "dense dense 1" 0. (norm_diff (add m1 m2) m3);
|
||||
Alcotest.(check (float 1.e-10)) "dense dense 2" 0. (norm_diff (sub m1 m2) m4);
|
||||
Alcotest.(check (float 1.e-10)) "dense sparse 3" 0. (norm_diff (add m1 m2_s) m3_s);
|
||||
Alcotest.(check (float 1.e-10)) "dense sparse 4" 0. (norm_diff (sub m1 m2_s) m4_s);
|
||||
Alcotest.(check (float 1.e-10)) "sparse dense 5" 0. (norm_diff (add m1_s m2) m3);
|
||||
Alcotest.(check (float 1.e-10)) "sparse dense 6" 0. (norm_diff (sub m1_s m2) m4);
|
||||
Alcotest.(check (float 1.e-10)) "dense sparse 7" 0. (norm_diff (add m1_s m2_s) m3_s);
|
||||
Alcotest.(check (float 1.e-10)) "dense sparse 8" 0. (norm_diff (sub m1_s m2_s) m4_s);
|
||||
Alcotest.(check (float 1.e-10)) "dense sparse 9" (frobenius_norm m1_s) (frobenius_norm m1);
|
||||
in
|
||||
|
||||
let test_mv () =
|
||||
let y = Vec.random d2 in
|
||||
let z = Vec.random d1 in
|
||||
@ -466,12 +560,45 @@ let test_case () =
|
||||
Alcotest.(check (float 1.e-10)) "sparse sparse 15" 0. (norm_diff (mm ~transb:`T m1_s m5_s) m3_s);
|
||||
Alcotest.(check (float 1.e-10)) "sparse sparse 16" 0. (norm_diff (transpose (mm m2_s m1_s ~transa:`T ~transb:`T)) m3_s);
|
||||
in
|
||||
|
||||
let test_solve () =
|
||||
let x1 = Mat.map (fun x -> if abs_float x < 0.6 then 0. else x) (Mat.random 30 30)
|
||||
and x2 = Mat.map (fun x -> if abs_float x < 0.3 then 0. else x) (Mat.random 30 5)
|
||||
in
|
||||
|
||||
let m1 = dense_of_mat x1
|
||||
and m2 = dense_of_mat x2
|
||||
in
|
||||
|
||||
let m1_s = sparse_of_mat x1
|
||||
and m2_s = sparse_of_mat x2
|
||||
in
|
||||
|
||||
let a = m1 and b = m2 in
|
||||
let x = ax_eq_b a b in
|
||||
Alcotest.(check (float 1.e-10)) "dense dense 1" 0. (norm_diff (mm a x) b);
|
||||
|
||||
let a = m1 and b = m2_s in
|
||||
let x = ax_eq_b a b in
|
||||
Alcotest.(check (float 1.e-10)) "dense dense 2" 0. (norm_diff (mm a x) b);
|
||||
|
||||
let a = m1_s and b = m2 in
|
||||
let x = ax_eq_b a b in
|
||||
Alcotest.(check (float 1.e-10)) "dense dense 2" 0. (norm_diff (mm a x) b);
|
||||
|
||||
let a = m1_s and b = m2_s in
|
||||
let x = ax_eq_b a b in
|
||||
Alcotest.(check (float 1.e-10)) "dense dense 2" 0. (norm_diff (mm a x) b);
|
||||
in
|
||||
|
||||
[
|
||||
"Conversion", `Quick, test_conversion;
|
||||
"Dimensions", `Quick, test_dimensions;
|
||||
"Transposition", `Quick, test_transpose;
|
||||
"Outer product", `Quick, test_outer;
|
||||
"Add sub", `Quick, test_add_sub;
|
||||
"Matrix Vector", `Quick, test_mv;
|
||||
"Matrix Matrix", `Quick, test_mm;
|
||||
"Linear solve", `Quick, test_solve;
|
||||
]
|
||||
|
||||
|
@ -15,7 +15,7 @@ let expand_range r =
|
||||
| i -> i::(do_work (i+1))
|
||||
in do_work start
|
||||
end
|
||||
| r :: [] -> int_of_string r
|
||||
| r :: [] -> [int_of_string r]
|
||||
| [] -> []
|
||||
| _ -> invalid_arg "Only one range expected"
|
||||
|
||||
@ -27,7 +27,7 @@ let of_string s =
|
||||
assert (s.[0] = '[') ;
|
||||
assert (s.[(String.length s)-1] = ']') ;
|
||||
let s = String.sub s 1 ((String.length s) - 2) in
|
||||
let l = String_ext.split ~on:',' s in
|
||||
let l = String.split_on_char ',' s in
|
||||
let l = List.map expand_range l in
|
||||
List.concat l
|
||||
|> List.sort_uniq compare
|
||||
@ -41,6 +41,6 @@ let to_string l =
|
||||
|
||||
|
||||
let pp_range ppf t =
|
||||
Format.fprintf "@[%s@]" ppf (to_string t)
|
||||
Format.fprintf ppf "@[%s@]" (to_string t)
|
||||
|
||||
|
||||
|
@ -385,7 +385,7 @@ let string_of_matrix m =
|
||||
Format.asprintf "%a" pp_matrix m
|
||||
|
||||
let debug_matrix name a =
|
||||
Format.printf "@[%s =\n@[%a@]@]@ " name pp_matrix a
|
||||
Format.printf "@[%s =\n@[%a@]@]@." name pp_matrix a
|
||||
|
||||
|
||||
let matrix_of_file filename =
|
||||
|
@ -231,6 +231,7 @@ let dot v v' =
|
||||
| (Sparse v), (Dense v') -> d_sp v' v
|
||||
|
||||
|
||||
let norm v = sqrt @@ dot v v
|
||||
|
||||
let test_case () =
|
||||
|
||||
|
@ -65,6 +65,8 @@ val axpy : ?threshold:float -> ?alpha:float -> t -> t -> t
|
||||
val dot : t -> t -> float
|
||||
(** Dot product. *)
|
||||
|
||||
val norm : t -> float
|
||||
(** l2-norm of the vector : {% $\sqrt{\sum_i x_i^2}$ %} *)
|
||||
|
||||
(** {1 Printers } *)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user