10
1
mirror of https://gitlab.com/scemama/QCaml.git synced 2024-11-07 14:43:41 +01:00
QCaml/Utils/Cholesky.ml

241 lines
4.8 KiB
OCaml
Raw Normal View History

2019-04-12 17:53:00 +02:00
open Lacaml.D
let full_ldl m_A =
let n = Mat.dim1 m_A in
2019-06-19 14:04:57 +02:00
2019-04-12 17:53:00 +02:00
assert (Mat.dim2 m_A = n);
2019-04-12 18:48:23 +02:00
let v_D = Vec.make0 n in
2019-06-19 14:04:57 +02:00
2019-04-12 18:48:23 +02:00
let m_Lt = Mat.identity n in
let v_D_inv = Vec.make0 n in
2019-04-12 17:53:00 +02:00
2019-06-19 14:04:57 +02:00
2019-04-12 17:53:00 +02:00
let compute_d j =
let l_jk =
2019-04-12 18:48:23 +02:00
Mat.col m_Lt j
2019-04-12 17:53:00 +02:00
in
let l_jk__d_k =
Vec.mul ~n:(j-1) l_jk v_D
in
m_A.{j,j} -. dot ~n:(j-1) l_jk l_jk__d_k
in
2019-06-19 14:04:57 +02:00
2019-04-12 18:48:23 +02:00
let compute_l i =
let l_ik__d_k =
Mat.col m_Lt i
|> Vec.mul ~n:(i-1) v_D
in
fun j ->
assert (i > j);
let l_jk =
Mat.col m_Lt j
in
v_D_inv.{j} *. ( m_A.{j,i} -. dot ~n:(j-1) l_ik__d_k l_jk )
in
2019-06-19 14:04:57 +02:00
2019-04-12 18:48:23 +02:00
for i=1 to n do
for j=1 to (i-1) do
m_Lt.{j,i} <- compute_l i j;
done;
let d_i = compute_d i in
v_D.{i} <- d_i;
v_D_inv.{i} <- 1. /. d_i;
done;
m_Lt, v_D
let pivoted_ldl threshold m_A =
2019-05-07 16:59:40 +02:00
(** {% $P A P^\dagger = L D L^\dagger$. %}
Input : Matrix $A$
Output : Matrices $L, D, P$.
*)
2019-04-12 18:48:23 +02:00
let n = Mat.dim1 m_A in
assert (Mat.dim2 m_A = n);
let pi = Array.init (n+1) (fun i -> i) in
let v_D = Vec.init n (fun i -> abs_float m_A.{i,i}) in
let m_Lt = Mat.identity n in
let v_D_inv = Vec.make0 n in
2019-05-13 15:54:02 +02:00
let compute_d i =
let l_ik =
Mat.col m_Lt i
2019-04-12 18:48:23 +02:00
in
2019-05-13 15:54:02 +02:00
let l_ik__d_k =
Vec.mul ~n:(i-1) l_ik v_D
2019-04-12 17:53:00 +02:00
in
2019-05-13 15:54:02 +02:00
m_A.{pi.(i),pi.(i)} -. dot ~n:(i-1) l_ik l_ik__d_k
2019-04-12 17:53:00 +02:00
in
2019-04-12 18:48:23 +02:00
let compute_l i =
let l_ik__d_k =
Mat.col m_Lt i
|> Vec.mul ~n:(i-1) v_D
in
fun j ->
assert (i > j);
let l_jk =
Mat.col m_Lt j
in
v_D_inv.{j} *. ( m_A.{pi.(j),pi.(i)} -. dot ~n:(j-1) l_ik__d_k l_jk )
in
2019-05-13 15:54:02 +02:00
let maxloc v i =
let rec aux pos value i =
2019-04-12 18:48:23 +02:00
if i > n then
pos
2019-06-19 14:04:57 +02:00
else if v_D.{i} > value then
2019-09-10 18:39:14 +02:00
(aux [@tailcall]) i v_D.{i} (i+1)
2019-04-12 18:48:23 +02:00
else
2019-09-10 18:39:14 +02:00
(aux [@tailcall]) pos value (i+1)
2019-04-12 18:48:23 +02:00
in
2019-05-13 15:54:02 +02:00
aux i v.{i} (i+1)
2019-04-12 18:48:23 +02:00
in
2019-05-13 15:54:02 +02:00
let pivot i =
let j = maxloc v_D i in
2019-04-12 18:48:23 +02:00
let p_i, p_j = pi.(i), pi.(j) in
2019-05-07 16:59:40 +02:00
pi.(i) <- p_j;
pi.(j) <- p_i;
in
2019-04-12 18:48:23 +02:00
2019-05-07 16:59:40 +02:00
let () =
try
for i=1 to n do
pivot i;
for j=1 to (i-1) do
m_Lt.{j,i} <- compute_l i j;
done;
let d_i = compute_d i in
if abs_float d_i < threshold then
raise Exit;
v_D.{i} <- d_i;
v_D_inv.{i} <- 1. /. d_i;
done
with Exit -> ()
in
m_Lt, v_D, pi
2019-04-12 18:48:23 +02:00
2019-04-12 17:53:00 +02:00
2019-04-12 18:48:23 +02:00
let make_ldl ?(threshold=Constants.epsilon) m_A =
pivoted_ldl m_A
2019-04-12 17:53:00 +02:00
2019-04-12 18:48:23 +02:00
2019-04-12 17:53:00 +02:00
let test_case () =
let matrix_diff m_A m_B =
Mat.syrk_trace (Mat.sub m_A m_B)
in
let vector_diff v_A v_B =
let v = Vec.sub v_A v_B in
dot v v
in
2019-05-07 16:59:40 +02:00
let m_A = Mat.random 1000 1000 in
let m_A = Mat.add m_A (Mat.transpose_copy m_A) in
let test_full_small () =
2019-04-12 17:53:00 +02:00
let m_A =
Mat.of_array [| [| 4. ; 12. ; -16. |] ;
[| 12. ; 37. ; -43. |] ;
[| -16. ; -43. ; 98. |] |]
in
let m_L_ref =
Mat.of_array [| [| 1. ; 0. ; 0. |] ;
[| 3. ; 1. ; 0. |] ;
[| -4. ; 5. ; 1. |] |]
in
let m_Lt_ref =
Mat.transpose_copy m_L_ref
in
let v_D_ref =
Vec.of_array [| 4. ; 1. ; 9. |]
in
2019-04-12 18:48:23 +02:00
let m_Lt, v_D = full_ldl m_A in
Alcotest.(check (float 1.e-15)) "full L" 0. (matrix_diff m_Lt m_Lt_ref);
2019-05-07 16:59:40 +02:00
Alcotest.(check (float 1.e-15)) "full D" 0. (vector_diff v_D v_D_ref);
let m_D = Mat.of_diag v_D in
let m_B = gemm ~transa:`T m_Lt @@ gemm m_D m_Lt in
2019-05-13 15:54:02 +02:00
Alcotest.(check (float 1.e-15)) "full L" 0. (matrix_diff m_A m_B);
2019-05-07 16:59:40 +02:00
()
in
let test_full () =
let m_Lt, v_D = full_ldl m_A in
let m_D = Mat.of_diag v_D in
let m_B = gemm ~transa:`T m_Lt @@ gemm m_D m_Lt in
Alcotest.(check (float 1.e-15)) "full D" 0. (matrix_diff m_A m_B);
()
in
let test_pivoted () =
let m_Lt, v_D, pi = pivoted_ldl 0. m_A in
let n = Mat.dim1 m_A in
let m_P = Mat.make0 n n in
for i=1 to n do
m_P.{i,pi.(i)} <- 1.
done;
let m_D = Mat.of_diag v_D in
let m_B =
gemm ~transa:`T m_P @@
gemm ~transa:`T m_Lt @@
gemm m_D @@
gemm m_Lt m_P
in
2019-05-13 15:54:02 +02:00
Alcotest.(check (float 1.e-14)) "pivoted D" 0. (matrix_diff m_A m_B);
2019-05-07 16:59:40 +02:00
()
in
let test_truncated () =
2019-05-13 15:54:02 +02:00
let m_Lt, v_D, pi = pivoted_ldl 0.001 m_A in
2019-05-07 16:59:40 +02:00
let n = Mat.dim1 m_Lt in
let m_P = Mat.make0 n n in
for i=1 to n do
m_P.{i,pi.(i)} <- 1.
done;
let m_D = Mat.of_diag v_D in
let m_B =
gemm ~transa:`T m_P @@
gemm ~transa:`T m_Lt @@
gemm m_D @@
gemm m_Lt m_P
in
2019-05-13 15:54:02 +02:00
Alcotest.(check (float 1.e-3)) "full D" 0. (matrix_diff m_A m_B);
2019-05-07 16:59:40 +02:00
()
2019-04-12 17:53:00 +02:00
in
[
2019-05-07 16:59:40 +02:00
"Small", `Quick, test_full_small;
2019-04-12 17:53:00 +02:00
"Full", `Quick, test_full;
2019-05-07 16:59:40 +02:00
"Pivoted", `Quick, test_pivoted ;
"Truncated", `Quick, test_truncated ;
2019-04-12 17:53:00 +02:00
]