diff --git a/Utils/Cholesky.ml b/Utils/Cholesky.ml index fa293b0..bfb68f2 100644 --- a/Utils/Cholesky.ml +++ b/Utils/Cholesky.ml @@ -47,6 +47,10 @@ let full_ldl m_A = let pivoted_ldl threshold m_A = +(** {% $P A P^\dagger = L D L^\dagger$. %} + Input : Matrix $A$ + Output : Matrices $L, D, P$. +*) let n = Mat.dim1 m_A in assert (Mat.dim2 m_A = n); @@ -82,37 +86,56 @@ let pivoted_ldl threshold m_A = in let pivot i = - let rec aux (pos,value) i = + let rec aux pos value i = if i > n then pos - else if m_D.{i} > value then - aux i m_D.{i} (i+1) + else if v_D.{i} > value then + aux i v_D.{i} (i+1) else aux pos value (i+1) in - let j = aux i m_D.{i} (i+1) in + let j = aux i v_D.{i} (i+1) in let p_i, p_j = pi.(i), pi.(j) in - pi.(i) <- pj; - pi.(j) <- pi; + pi.(i) <- p_j; + pi.(j) <- p_i; + in - for i=1 to n do - pivot i; - for j=1 to (i-1) do - m_Lt.{j,i} <- compute_l i j; - done; - let d_i = compute_d i in - v_D.{i} <- d_i; - v_D_inv.{i} <- 1. /. d_i; - done; - m_Lt, v_D + + let () = + try + for i=1 to n do + pivot i; + for j=1 to (i-1) do + m_Lt.{j,i} <- compute_l i j; + done; + let d_i = compute_d i in + if abs_float d_i < threshold then + raise Exit; + v_D.{i} <- d_i; + v_D_inv.{i} <- 1. /. d_i; + done + with Exit -> () + in + m_Lt, v_D, pi + + + (* + let rec run accu_l accu_d accu_p i = + let finish + if i > n then + finish accu_l accu_d accu_p + else + in + run [] [] [] 1 + *) + + -let make_ldl ?(threshold=Constants.epsilon) m_A = - pivoted_ldl m_A @@ -127,7 +150,10 @@ let test_case () = dot v v in - let test_full () = + let m_A = Mat.random 1000 1000 in + let m_A = Mat.add m_A (Mat.transpose_copy m_A) in + + let test_full_small () = let m_A = Mat.of_array [| [| 4. ; 12. ; -16. |] ; [| 12. ; 37. ; -43. |] ; @@ -149,11 +175,61 @@ let test_case () = in let m_Lt, v_D = full_ldl m_A in - Alcotest.(check (float 1.e-15)) "full L" 0. (matrix_diff m_Lt m_Lt_ref); - Alcotest.(check (float 1.e-15)) "full D" 0. (vector_diff v_D v_D_ref) + Alcotest.(check (float 1.e-15)) "full D" 0. (vector_diff v_D v_D_ref); + let m_D = Mat.of_diag v_D in + let m_B = gemm ~transa:`T m_Lt @@ gemm m_D m_Lt in + Alcotest.(check (float 1.e-15)) "full D" 0. (matrix_diff m_A m_B); + () + in + + let test_full () = + let m_Lt, v_D = full_ldl m_A in + let m_D = Mat.of_diag v_D in + let m_B = gemm ~transa:`T m_Lt @@ gemm m_D m_Lt in + Alcotest.(check (float 1.e-15)) "full D" 0. (matrix_diff m_A m_B); + () + in + + let test_pivoted () = + let m_Lt, v_D, pi = pivoted_ldl 0. m_A in + let n = Mat.dim1 m_A in + let m_P = Mat.make0 n n in + for i=1 to n do + m_P.{i,pi.(i)} <- 1. + done; + let m_D = Mat.of_diag v_D in + let m_B = + gemm ~transa:`T m_P @@ + gemm ~transa:`T m_Lt @@ + gemm m_D @@ + gemm m_Lt m_P + in + Alcotest.(check (float 1.e-15)) "full D" 0. (matrix_diff m_A m_B); + () + in + + let test_truncated () = + let m_Lt, v_D, pi = pivoted_ldl 1. m_A in + let n = Mat.dim1 m_Lt in + let m_P = Mat.make0 n n in + for i=1 to n do + m_P.{i,pi.(i)} <- 1. + done; + let m_D = Mat.of_diag v_D in + let m_B = + gemm ~transa:`T m_P @@ + gemm ~transa:`T m_Lt @@ + gemm m_D @@ + gemm m_Lt m_P + in + Alcotest.(check (float 1.e-15)) "full D" 0. (matrix_diff m_A m_B); + () in [ + "Small", `Quick, test_full_small; "Full", `Quick, test_full; + "Pivoted", `Quick, test_pivoted ; + "Truncated", `Quick, test_truncated ; ]