From c517a6d6a59a612bbc085f1dd6148e448bcb8159 Mon Sep 17 00:00:00 2001 From: Anthony Scemama Date: Fri, 12 Apr 2019 18:48:23 +0200 Subject: [PATCH] Working on pivoted Cholesky --- Utils/Cholesky.ml | 114 +++++++++++++++++++++++++++++++++++++--------- run_tests.ml | 26 +++++------ 2 files changed, 105 insertions(+), 35 deletions(-) diff --git a/Utils/Cholesky.ml b/Utils/Cholesky.ml index 9d9c3bc..fa293b0 100644 --- a/Utils/Cholesky.ml +++ b/Utils/Cholesky.ml @@ -5,12 +5,14 @@ let full_ldl m_A = let n = Mat.dim1 m_A in assert (Mat.dim2 m_A = n); - let v_D = Vec.make0 n in - let m_L = Mat.identity n in + let v_D = Vec.make0 n in + let m_Lt = Mat.identity n in + + let v_D_inv = Vec.make0 n in let compute_d j = let l_jk = - Mat.copy_row m_L j + Mat.col m_Lt j in let l_jk__d_k = Vec.mul ~n:(j-1) l_jk v_D @@ -18,29 +20,99 @@ let full_ldl m_A = m_A.{j,j} -. dot ~n:(j-1) l_jk l_jk__d_k in - let compute_l i j = - assert (i > j); - let l_jk__d_k = - Mat.copy_row m_L j - |> Vec.mul ~n:(j-1) v_D - and l_ik = - Mat.copy_row m_L i + let compute_l i = + let l_ik__d_k = + Mat.col m_Lt i + |> Vec.mul ~n:(i-1) v_D in - 1. /. v_D.{j} *. ( m_A.{i,j} -. dot l_jk__d_k l_ik ) + fun j -> + assert (i > j); + let l_jk = + Mat.col m_Lt j + in + v_D_inv.{j} *. ( m_A.{j,i} -. dot ~n:(j-1) l_ik__d_k l_jk ) in - v_D.{1} <- m_A.{1,1}; - for i=2 to n do + for i=1 to n do for j=1 to (i-1) do - m_L.{i,j} <- compute_l i j + m_Lt.{j,i} <- compute_l i j; done; - v_D.{i} <- compute_d i + let d_i = compute_d i in + v_D.{i} <- d_i; + v_D_inv.{i} <- 1. /. d_i; done; - m_L, v_D + m_Lt, v_D -let make_ldl ?target_rank m_A = - full_ldl m_A + + +let pivoted_ldl threshold m_A = + + let n = Mat.dim1 m_A in + assert (Mat.dim2 m_A = n); + + let pi = Array.init (n+1) (fun i -> i) in + + let v_D = Vec.init n (fun i -> abs_float m_A.{i,i}) in + let m_Lt = Mat.identity n in + + let v_D_inv = Vec.make0 n in + + let compute_d j = + let l_jk = + Mat.col m_Lt j + in + let l_jk__d_k = + Vec.mul ~n:(j-1) l_jk v_D + in + m_A.{pi.(j),pi.(j)} -. dot ~n:(j-1) l_jk l_jk__d_k + in + + let compute_l i = + let l_ik__d_k = + Mat.col m_Lt i + |> Vec.mul ~n:(i-1) v_D + in + fun j -> + assert (i > j); + let l_jk = + Mat.col m_Lt j + in + v_D_inv.{j} *. ( m_A.{pi.(j),pi.(i)} -. dot ~n:(j-1) l_ik__d_k l_jk ) + in + + let pivot i = + let rec aux (pos,value) i = + if i > n then + pos + else if m_D.{i} > value then + aux i m_D.{i} (i+1) + else + aux pos value (i+1) + in + let j = aux i m_D.{i} (i+1) in + let p_i, p_j = pi.(i), pi.(j) in + pi.(i) <- pj; + pi.(j) <- pi; + in + + for i=1 to n do + pivot i; + for j=1 to (i-1) do + m_Lt.{j,i} <- compute_l i j; + done; + let d_i = compute_d i in + v_D.{i} <- d_i; + v_D_inv.{i} <- 1. /. d_i; + done; + m_Lt, v_D + + + + + +let make_ldl ?(threshold=Constants.epsilon) m_A = + pivoted_ldl m_A @@ -68,19 +140,17 @@ let test_case () = [| -4. ; 5. ; 1. |] |] in -(* let m_Lt_ref = Mat.transpose_copy m_L_ref in - *) let v_D_ref = Vec.of_array [| 4. ; 1. ; 9. |] in - let m_L, v_D = full_ldl m_A in + let m_Lt, v_D = full_ldl m_A in - Alcotest.(check (float 1.e-15)) "full L" 0. (matrix_diff m_L m_L_ref); + Alcotest.(check (float 1.e-15)) "full L" 0. (matrix_diff m_Lt m_Lt_ref); Alcotest.(check (float 1.e-15)) "full D" 0. (vector_diff v_D v_D_ref) in [ diff --git a/run_tests.ml b/run_tests.ml index 5bdeebb..ffa525a 100644 --- a/run_tests.ml +++ b/run_tests.ml @@ -1,5 +1,17 @@ -let test_water_dz () = +let () = + + Alcotest.run "Unit tests" [ + "Util", Util.test_case (); + "Bitstring", Bitstring.test_case (); + "Spindeterminant", Spindeterminant.test_case (); + "Determinant", Determinant.test_case (); + "Excitation", Excitation.test_case (); + "Sparse vectors", Vector.test_case (); + "Sparse matrices", Matrix.test_case (); + "Cholesky", Cholesky.test_case (); + ]; + let basis_file = "test_files/cc-pvdz" and nuclei_file = "test_files/h2o.xyz" @@ -12,21 +24,9 @@ let test_water_dz () = let ao_basis = Simulation.ao_basis simulation_closed_shell in - Alcotest.run "Unit tests" [ - "Util", Util.test_case (); - "Bitstring", Bitstring.test_case (); - "Spindeterminant", Spindeterminant.test_case (); - "Determinant", Determinant.test_case (); - "Excitation", Excitation.test_case (); - "Sparse vectors", Vector.test_case (); - "Sparse matrices", Matrix.test_case (); - ]; - Alcotest.run "Water, cc-pVDZ" [ "AO_Basis", AOBasis.test_case ao_basis; "Guess", Guess.test_case ao_basis; ] -let () = - test_water_dz ()