10
1
mirror of https://gitlab.com/scemama/QCaml.git synced 2024-06-26 15:12:05 +02:00

Cholesky OK

This commit is contained in:
Anthony Scemama 2019-05-13 15:54:02 +02:00
parent 5289d36b76
commit ec19195a6f
2 changed files with 20 additions and 26 deletions

View File

@ -62,14 +62,14 @@ let pivoted_ldl threshold m_A =
let v_D_inv = Vec.make0 n in
let compute_d j =
let l_jk =
Mat.col m_Lt j
let compute_d i =
let l_ik =
Mat.col m_Lt i
in
let l_jk__d_k =
Vec.mul ~n:(j-1) l_jk v_D
let l_ik__d_k =
Vec.mul ~n:(i-1) l_ik v_D
in
m_A.{pi.(j),pi.(j)} -. dot ~n:(j-1) l_jk l_jk__d_k
m_A.{pi.(i),pi.(i)} -. dot ~n:(i-1) l_ik l_ik__d_k
in
let compute_l i =
@ -85,8 +85,8 @@ let pivoted_ldl threshold m_A =
v_D_inv.{j} *. ( m_A.{pi.(j),pi.(i)} -. dot ~n:(j-1) l_ik__d_k l_jk )
in
let pivot i =
let rec aux pos value i =
let maxloc v i =
let rec aux pos value i =
if i > n then
pos
else if v_D.{i} > value then
@ -94,11 +94,15 @@ let pivoted_ldl threshold m_A =
else
aux pos value (i+1)
in
let j = aux i v_D.{i} (i+1) in
aux i v.{i} (i+1)
in
let pivot i =
let j = maxloc v_D i in
let p_i, p_j = pi.(i), pi.(j) in
pi.(i) <- p_j;
pi.(j) <- p_i;
in
@ -120,16 +124,6 @@ let pivoted_ldl threshold m_A =
m_Lt, v_D, pi
(*
let rec run accu_l accu_d accu_p i =
let finish
if i > n then
finish accu_l accu_d accu_p
else
in
run [] [] [] 1
*)
@ -138,7 +132,6 @@ let pivoted_ldl threshold m_A =
let test_case () =
let matrix_diff m_A m_B =
@ -179,7 +172,7 @@ let test_case () =
Alcotest.(check (float 1.e-15)) "full D" 0. (vector_diff v_D v_D_ref);
let m_D = Mat.of_diag v_D in
let m_B = gemm ~transa:`T m_Lt @@ gemm m_D m_Lt in
Alcotest.(check (float 1.e-15)) "full D" 0. (matrix_diff m_A m_B);
Alcotest.(check (float 1.e-15)) "full L" 0. (matrix_diff m_A m_B);
()
in
@ -205,12 +198,12 @@ let test_case () =
gemm m_D @@
gemm m_Lt m_P
in
Alcotest.(check (float 1.e-15)) "full D" 0. (matrix_diff m_A m_B);
Alcotest.(check (float 1.e-14)) "pivoted D" 0. (matrix_diff m_A m_B);
()
in
let test_truncated () =
let m_Lt, v_D, pi = pivoted_ldl 1. m_A in
let m_Lt, v_D, pi = pivoted_ldl 0.001 m_A in
let n = Mat.dim1 m_Lt in
let m_P = Mat.make0 n n in
for i=1 to n do
@ -223,7 +216,7 @@ let test_case () =
gemm m_D @@
gemm m_Lt m_P
in
Alcotest.(check (float 1.e-15)) "full D" 0. (matrix_diff m_A m_B);
Alcotest.(check (float 1.e-3)) "full D" 0. (matrix_diff m_A m_B);
()
in
[

View File

@ -171,7 +171,7 @@ let axpy ?(threshold=epsilon) ?(alpha=1.) x y =
in aux new_accu r1 r2
| _ -> assert false
end
| ({index=i ; value=x}::r1), [] -> aux ({index=i ; value=x}::accu) r1 []
| ({index=i ; value=x}::r1), [] -> aux ({index=i ; value=alpha *. x}::accu) r1 []
| [] , ({index=j ; value=y}::r2) -> aux ({index=j ; value=y}::accu) [] r2
| [] , [] -> {n ; v=List.rev accu}
in
@ -297,6 +297,7 @@ let test_case () =
Alcotest.(check bool) "dense dense axpy" true (axpy ~alpha:3. v1 v2 = v6);
Alcotest.(check bool) "dense sparse axpy" true (sub ~threshold:1.e-12 (axpy ~alpha:3. v1 v2_s) v6_s = zero_s);
Alcotest.(check bool) "sparse dense axpy" true (sub ~threshold:1.e-12 (axpy ~alpha:3. v1_s v2) v6_s = zero_s);
Alcotest.(check bool) "sparse sparse axpy" true (sub ~threshold:1.e-12 (axpy ~alpha:3. v1_s v2_s) v6_s = zero_s);
in
let test_dot () =