10
1
mirror of https://gitlab.com/scemama/QCaml.git synced 2024-12-22 20:33:36 +01:00

Cholesky OK

This commit is contained in:
Anthony Scemama 2019-05-13 15:54:02 +02:00
parent 5289d36b76
commit ec19195a6f
2 changed files with 20 additions and 26 deletions

View File

@ -62,14 +62,14 @@ let pivoted_ldl threshold m_A =
let v_D_inv = Vec.make0 n in let v_D_inv = Vec.make0 n in
let compute_d j = let compute_d i =
let l_jk = let l_ik =
Mat.col m_Lt j Mat.col m_Lt i
in in
let l_jk__d_k = let l_ik__d_k =
Vec.mul ~n:(j-1) l_jk v_D Vec.mul ~n:(i-1) l_ik v_D
in in
m_A.{pi.(j),pi.(j)} -. dot ~n:(j-1) l_jk l_jk__d_k m_A.{pi.(i),pi.(i)} -. dot ~n:(i-1) l_ik l_ik__d_k
in in
let compute_l i = let compute_l i =
@ -85,7 +85,7 @@ let pivoted_ldl threshold m_A =
v_D_inv.{j} *. ( m_A.{pi.(j),pi.(i)} -. dot ~n:(j-1) l_ik__d_k l_jk ) v_D_inv.{j} *. ( m_A.{pi.(j),pi.(i)} -. dot ~n:(j-1) l_ik__d_k l_jk )
in in
let pivot i = let maxloc v i =
let rec aux pos value i = let rec aux pos value i =
if i > n then if i > n then
pos pos
@ -94,11 +94,15 @@ let pivoted_ldl threshold m_A =
else else
aux pos value (i+1) aux pos value (i+1)
in in
let j = aux i v_D.{i} (i+1) in aux i v.{i} (i+1)
in
let pivot i =
let j = maxloc v_D i in
let p_i, p_j = pi.(i), pi.(j) in let p_i, p_j = pi.(i), pi.(j) in
pi.(i) <- p_j; pi.(i) <- p_j;
pi.(j) <- p_i; pi.(j) <- p_i;
in in
@ -120,17 +124,6 @@ let pivoted_ldl threshold m_A =
m_Lt, v_D, pi m_Lt, v_D, pi
(*
let rec run accu_l accu_d accu_p i =
let finish
if i > n then
finish accu_l accu_d accu_p
else
in
run [] [] [] 1
*)
@ -179,7 +172,7 @@ let test_case () =
Alcotest.(check (float 1.e-15)) "full D" 0. (vector_diff v_D v_D_ref); Alcotest.(check (float 1.e-15)) "full D" 0. (vector_diff v_D v_D_ref);
let m_D = Mat.of_diag v_D in let m_D = Mat.of_diag v_D in
let m_B = gemm ~transa:`T m_Lt @@ gemm m_D m_Lt in let m_B = gemm ~transa:`T m_Lt @@ gemm m_D m_Lt in
Alcotest.(check (float 1.e-15)) "full D" 0. (matrix_diff m_A m_B); Alcotest.(check (float 1.e-15)) "full L" 0. (matrix_diff m_A m_B);
() ()
in in
@ -205,12 +198,12 @@ let test_case () =
gemm m_D @@ gemm m_D @@
gemm m_Lt m_P gemm m_Lt m_P
in in
Alcotest.(check (float 1.e-15)) "full D" 0. (matrix_diff m_A m_B); Alcotest.(check (float 1.e-14)) "pivoted D" 0. (matrix_diff m_A m_B);
() ()
in in
let test_truncated () = let test_truncated () =
let m_Lt, v_D, pi = pivoted_ldl 1. m_A in let m_Lt, v_D, pi = pivoted_ldl 0.001 m_A in
let n = Mat.dim1 m_Lt in let n = Mat.dim1 m_Lt in
let m_P = Mat.make0 n n in let m_P = Mat.make0 n n in
for i=1 to n do for i=1 to n do
@ -223,7 +216,7 @@ let test_case () =
gemm m_D @@ gemm m_D @@
gemm m_Lt m_P gemm m_Lt m_P
in in
Alcotest.(check (float 1.e-15)) "full D" 0. (matrix_diff m_A m_B); Alcotest.(check (float 1.e-3)) "full D" 0. (matrix_diff m_A m_B);
() ()
in in
[ [

View File

@ -171,7 +171,7 @@ let axpy ?(threshold=epsilon) ?(alpha=1.) x y =
in aux new_accu r1 r2 in aux new_accu r1 r2
| _ -> assert false | _ -> assert false
end end
| ({index=i ; value=x}::r1), [] -> aux ({index=i ; value=x}::accu) r1 [] | ({index=i ; value=x}::r1), [] -> aux ({index=i ; value=alpha *. x}::accu) r1 []
| [] , ({index=j ; value=y}::r2) -> aux ({index=j ; value=y}::accu) [] r2 | [] , ({index=j ; value=y}::r2) -> aux ({index=j ; value=y}::accu) [] r2
| [] , [] -> {n ; v=List.rev accu} | [] , [] -> {n ; v=List.rev accu}
in in
@ -297,6 +297,7 @@ let test_case () =
Alcotest.(check bool) "dense dense axpy" true (axpy ~alpha:3. v1 v2 = v6); Alcotest.(check bool) "dense dense axpy" true (axpy ~alpha:3. v1 v2 = v6);
Alcotest.(check bool) "dense sparse axpy" true (sub ~threshold:1.e-12 (axpy ~alpha:3. v1 v2_s) v6_s = zero_s); Alcotest.(check bool) "dense sparse axpy" true (sub ~threshold:1.e-12 (axpy ~alpha:3. v1 v2_s) v6_s = zero_s);
Alcotest.(check bool) "sparse dense axpy" true (sub ~threshold:1.e-12 (axpy ~alpha:3. v1_s v2) v6_s = zero_s); Alcotest.(check bool) "sparse dense axpy" true (sub ~threshold:1.e-12 (axpy ~alpha:3. v1_s v2) v6_s = zero_s);
Alcotest.(check bool) "sparse sparse axpy" true (sub ~threshold:1.e-12 (axpy ~alpha:3. v1_s v2_s) v6_s = zero_s); Alcotest.(check bool) "sparse sparse axpy" true (sub ~threshold:1.e-12 (axpy ~alpha:3. v1_s v2_s) v6_s = zero_s);
in in
let test_dot () = let test_dot () =