mirror of
https://gitlab.com/scemama/QCaml.git
synced 2024-12-22 20:33:36 +01:00
Cholesky OK
This commit is contained in:
parent
5289d36b76
commit
ec19195a6f
@ -62,14 +62,14 @@ let pivoted_ldl threshold m_A =
|
|||||||
|
|
||||||
let v_D_inv = Vec.make0 n in
|
let v_D_inv = Vec.make0 n in
|
||||||
|
|
||||||
let compute_d j =
|
let compute_d i =
|
||||||
let l_jk =
|
let l_ik =
|
||||||
Mat.col m_Lt j
|
Mat.col m_Lt i
|
||||||
in
|
in
|
||||||
let l_jk__d_k =
|
let l_ik__d_k =
|
||||||
Vec.mul ~n:(j-1) l_jk v_D
|
Vec.mul ~n:(i-1) l_ik v_D
|
||||||
in
|
in
|
||||||
m_A.{pi.(j),pi.(j)} -. dot ~n:(j-1) l_jk l_jk__d_k
|
m_A.{pi.(i),pi.(i)} -. dot ~n:(i-1) l_ik l_ik__d_k
|
||||||
in
|
in
|
||||||
|
|
||||||
let compute_l i =
|
let compute_l i =
|
||||||
@ -85,7 +85,7 @@ let pivoted_ldl threshold m_A =
|
|||||||
v_D_inv.{j} *. ( m_A.{pi.(j),pi.(i)} -. dot ~n:(j-1) l_ik__d_k l_jk )
|
v_D_inv.{j} *. ( m_A.{pi.(j),pi.(i)} -. dot ~n:(j-1) l_ik__d_k l_jk )
|
||||||
in
|
in
|
||||||
|
|
||||||
let pivot i =
|
let maxloc v i =
|
||||||
let rec aux pos value i =
|
let rec aux pos value i =
|
||||||
if i > n then
|
if i > n then
|
||||||
pos
|
pos
|
||||||
@ -94,11 +94,15 @@ let pivoted_ldl threshold m_A =
|
|||||||
else
|
else
|
||||||
aux pos value (i+1)
|
aux pos value (i+1)
|
||||||
in
|
in
|
||||||
let j = aux i v_D.{i} (i+1) in
|
aux i v.{i} (i+1)
|
||||||
|
in
|
||||||
|
|
||||||
|
|
||||||
|
let pivot i =
|
||||||
|
let j = maxloc v_D i in
|
||||||
let p_i, p_j = pi.(i), pi.(j) in
|
let p_i, p_j = pi.(i), pi.(j) in
|
||||||
pi.(i) <- p_j;
|
pi.(i) <- p_j;
|
||||||
pi.(j) <- p_i;
|
pi.(j) <- p_i;
|
||||||
|
|
||||||
in
|
in
|
||||||
|
|
||||||
|
|
||||||
@ -120,17 +124,6 @@ let pivoted_ldl threshold m_A =
|
|||||||
m_Lt, v_D, pi
|
m_Lt, v_D, pi
|
||||||
|
|
||||||
|
|
||||||
(*
|
|
||||||
let rec run accu_l accu_d accu_p i =
|
|
||||||
let finish
|
|
||||||
if i > n then
|
|
||||||
finish accu_l accu_d accu_p
|
|
||||||
else
|
|
||||||
in
|
|
||||||
run [] [] [] 1
|
|
||||||
*)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -179,7 +172,7 @@ let test_case () =
|
|||||||
Alcotest.(check (float 1.e-15)) "full D" 0. (vector_diff v_D v_D_ref);
|
Alcotest.(check (float 1.e-15)) "full D" 0. (vector_diff v_D v_D_ref);
|
||||||
let m_D = Mat.of_diag v_D in
|
let m_D = Mat.of_diag v_D in
|
||||||
let m_B = gemm ~transa:`T m_Lt @@ gemm m_D m_Lt in
|
let m_B = gemm ~transa:`T m_Lt @@ gemm m_D m_Lt in
|
||||||
Alcotest.(check (float 1.e-15)) "full D" 0. (matrix_diff m_A m_B);
|
Alcotest.(check (float 1.e-15)) "full L" 0. (matrix_diff m_A m_B);
|
||||||
()
|
()
|
||||||
in
|
in
|
||||||
|
|
||||||
@ -205,12 +198,12 @@ let test_case () =
|
|||||||
gemm m_D @@
|
gemm m_D @@
|
||||||
gemm m_Lt m_P
|
gemm m_Lt m_P
|
||||||
in
|
in
|
||||||
Alcotest.(check (float 1.e-15)) "full D" 0. (matrix_diff m_A m_B);
|
Alcotest.(check (float 1.e-14)) "pivoted D" 0. (matrix_diff m_A m_B);
|
||||||
()
|
()
|
||||||
in
|
in
|
||||||
|
|
||||||
let test_truncated () =
|
let test_truncated () =
|
||||||
let m_Lt, v_D, pi = pivoted_ldl 1. m_A in
|
let m_Lt, v_D, pi = pivoted_ldl 0.001 m_A in
|
||||||
let n = Mat.dim1 m_Lt in
|
let n = Mat.dim1 m_Lt in
|
||||||
let m_P = Mat.make0 n n in
|
let m_P = Mat.make0 n n in
|
||||||
for i=1 to n do
|
for i=1 to n do
|
||||||
@ -223,7 +216,7 @@ let test_case () =
|
|||||||
gemm m_D @@
|
gemm m_D @@
|
||||||
gemm m_Lt m_P
|
gemm m_Lt m_P
|
||||||
in
|
in
|
||||||
Alcotest.(check (float 1.e-15)) "full D" 0. (matrix_diff m_A m_B);
|
Alcotest.(check (float 1.e-3)) "full D" 0. (matrix_diff m_A m_B);
|
||||||
()
|
()
|
||||||
in
|
in
|
||||||
[
|
[
|
||||||
|
@ -171,7 +171,7 @@ let axpy ?(threshold=epsilon) ?(alpha=1.) x y =
|
|||||||
in aux new_accu r1 r2
|
in aux new_accu r1 r2
|
||||||
| _ -> assert false
|
| _ -> assert false
|
||||||
end
|
end
|
||||||
| ({index=i ; value=x}::r1), [] -> aux ({index=i ; value=x}::accu) r1 []
|
| ({index=i ; value=x}::r1), [] -> aux ({index=i ; value=alpha *. x}::accu) r1 []
|
||||||
| [] , ({index=j ; value=y}::r2) -> aux ({index=j ; value=y}::accu) [] r2
|
| [] , ({index=j ; value=y}::r2) -> aux ({index=j ; value=y}::accu) [] r2
|
||||||
| [] , [] -> {n ; v=List.rev accu}
|
| [] , [] -> {n ; v=List.rev accu}
|
||||||
in
|
in
|
||||||
@ -297,6 +297,7 @@ let test_case () =
|
|||||||
Alcotest.(check bool) "dense dense axpy" true (axpy ~alpha:3. v1 v2 = v6);
|
Alcotest.(check bool) "dense dense axpy" true (axpy ~alpha:3. v1 v2 = v6);
|
||||||
Alcotest.(check bool) "dense sparse axpy" true (sub ~threshold:1.e-12 (axpy ~alpha:3. v1 v2_s) v6_s = zero_s);
|
Alcotest.(check bool) "dense sparse axpy" true (sub ~threshold:1.e-12 (axpy ~alpha:3. v1 v2_s) v6_s = zero_s);
|
||||||
Alcotest.(check bool) "sparse dense axpy" true (sub ~threshold:1.e-12 (axpy ~alpha:3. v1_s v2) v6_s = zero_s);
|
Alcotest.(check bool) "sparse dense axpy" true (sub ~threshold:1.e-12 (axpy ~alpha:3. v1_s v2) v6_s = zero_s);
|
||||||
|
|
||||||
Alcotest.(check bool) "sparse sparse axpy" true (sub ~threshold:1.e-12 (axpy ~alpha:3. v1_s v2_s) v6_s = zero_s);
|
Alcotest.(check bool) "sparse sparse axpy" true (sub ~threshold:1.e-12 (axpy ~alpha:3. v1_s v2_s) v6_s = zero_s);
|
||||||
in
|
in
|
||||||
let test_dot () =
|
let test_dot () =
|
||||||
|
Loading…
Reference in New Issue
Block a user