mirror of
https://gitlab.com/scemama/QCaml.git
synced 2025-01-03 01:55:40 +01:00
Working on pivoted Cholesky
This commit is contained in:
parent
41fde2d11a
commit
c517a6d6a5
@ -5,12 +5,14 @@ let full_ldl m_A =
|
||||
let n = Mat.dim1 m_A in
|
||||
assert (Mat.dim2 m_A = n);
|
||||
|
||||
let v_D = Vec.make0 n in
|
||||
let m_L = Mat.identity n in
|
||||
let v_D = Vec.make0 n in
|
||||
let m_Lt = Mat.identity n in
|
||||
|
||||
let v_D_inv = Vec.make0 n in
|
||||
|
||||
let compute_d j =
|
||||
let l_jk =
|
||||
Mat.copy_row m_L j
|
||||
Mat.col m_Lt j
|
||||
in
|
||||
let l_jk__d_k =
|
||||
Vec.mul ~n:(j-1) l_jk v_D
|
||||
@ -18,29 +20,99 @@ let full_ldl m_A =
|
||||
m_A.{j,j} -. dot ~n:(j-1) l_jk l_jk__d_k
|
||||
in
|
||||
|
||||
let compute_l i j =
|
||||
assert (i > j);
|
||||
let l_jk__d_k =
|
||||
Mat.copy_row m_L j
|
||||
|> Vec.mul ~n:(j-1) v_D
|
||||
and l_ik =
|
||||
Mat.copy_row m_L i
|
||||
let compute_l i =
|
||||
let l_ik__d_k =
|
||||
Mat.col m_Lt i
|
||||
|> Vec.mul ~n:(i-1) v_D
|
||||
in
|
||||
1. /. v_D.{j} *. ( m_A.{i,j} -. dot l_jk__d_k l_ik )
|
||||
fun j ->
|
||||
assert (i > j);
|
||||
let l_jk =
|
||||
Mat.col m_Lt j
|
||||
in
|
||||
v_D_inv.{j} *. ( m_A.{j,i} -. dot ~n:(j-1) l_ik__d_k l_jk )
|
||||
in
|
||||
|
||||
v_D.{1} <- m_A.{1,1};
|
||||
for i=2 to n do
|
||||
for i=1 to n do
|
||||
for j=1 to (i-1) do
|
||||
m_L.{i,j} <- compute_l i j
|
||||
m_Lt.{j,i} <- compute_l i j;
|
||||
done;
|
||||
v_D.{i} <- compute_d i
|
||||
let d_i = compute_d i in
|
||||
v_D.{i} <- d_i;
|
||||
v_D_inv.{i} <- 1. /. d_i;
|
||||
done;
|
||||
m_L, v_D
|
||||
m_Lt, v_D
|
||||
|
||||
|
||||
let make_ldl ?target_rank m_A =
|
||||
full_ldl m_A
|
||||
|
||||
|
||||
let pivoted_ldl threshold m_A =
|
||||
|
||||
let n = Mat.dim1 m_A in
|
||||
assert (Mat.dim2 m_A = n);
|
||||
|
||||
let pi = Array.init (n+1) (fun i -> i) in
|
||||
|
||||
let v_D = Vec.init n (fun i -> abs_float m_A.{i,i}) in
|
||||
let m_Lt = Mat.identity n in
|
||||
|
||||
let v_D_inv = Vec.make0 n in
|
||||
|
||||
let compute_d j =
|
||||
let l_jk =
|
||||
Mat.col m_Lt j
|
||||
in
|
||||
let l_jk__d_k =
|
||||
Vec.mul ~n:(j-1) l_jk v_D
|
||||
in
|
||||
m_A.{pi.(j),pi.(j)} -. dot ~n:(j-1) l_jk l_jk__d_k
|
||||
in
|
||||
|
||||
let compute_l i =
|
||||
let l_ik__d_k =
|
||||
Mat.col m_Lt i
|
||||
|> Vec.mul ~n:(i-1) v_D
|
||||
in
|
||||
fun j ->
|
||||
assert (i > j);
|
||||
let l_jk =
|
||||
Mat.col m_Lt j
|
||||
in
|
||||
v_D_inv.{j} *. ( m_A.{pi.(j),pi.(i)} -. dot ~n:(j-1) l_ik__d_k l_jk )
|
||||
in
|
||||
|
||||
let pivot i =
|
||||
let rec aux (pos,value) i =
|
||||
if i > n then
|
||||
pos
|
||||
else if m_D.{i} > value then
|
||||
aux i m_D.{i} (i+1)
|
||||
else
|
||||
aux pos value (i+1)
|
||||
in
|
||||
let j = aux i m_D.{i} (i+1) in
|
||||
let p_i, p_j = pi.(i), pi.(j) in
|
||||
pi.(i) <- pj;
|
||||
pi.(j) <- pi;
|
||||
in
|
||||
|
||||
for i=1 to n do
|
||||
pivot i;
|
||||
for j=1 to (i-1) do
|
||||
m_Lt.{j,i} <- compute_l i j;
|
||||
done;
|
||||
let d_i = compute_d i in
|
||||
v_D.{i} <- d_i;
|
||||
v_D_inv.{i} <- 1. /. d_i;
|
||||
done;
|
||||
m_Lt, v_D
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
let make_ldl ?(threshold=Constants.epsilon) m_A =
|
||||
pivoted_ldl m_A
|
||||
|
||||
|
||||
|
||||
@ -68,19 +140,17 @@ let test_case () =
|
||||
[| -4. ; 5. ; 1. |] |]
|
||||
in
|
||||
|
||||
(*
|
||||
let m_Lt_ref =
|
||||
Mat.transpose_copy m_L_ref
|
||||
in
|
||||
*)
|
||||
|
||||
let v_D_ref =
|
||||
Vec.of_array [| 4. ; 1. ; 9. |]
|
||||
in
|
||||
|
||||
let m_L, v_D = full_ldl m_A in
|
||||
let m_Lt, v_D = full_ldl m_A in
|
||||
|
||||
Alcotest.(check (float 1.e-15)) "full L" 0. (matrix_diff m_L m_L_ref);
|
||||
Alcotest.(check (float 1.e-15)) "full L" 0. (matrix_diff m_Lt m_Lt_ref);
|
||||
Alcotest.(check (float 1.e-15)) "full D" 0. (vector_diff v_D v_D_ref)
|
||||
in
|
||||
[
|
||||
|
26
run_tests.ml
26
run_tests.ml
@ -1,5 +1,17 @@
|
||||
|
||||
let test_water_dz () =
|
||||
let () =
|
||||
|
||||
Alcotest.run "Unit tests" [
|
||||
"Util", Util.test_case ();
|
||||
"Bitstring", Bitstring.test_case ();
|
||||
"Spindeterminant", Spindeterminant.test_case ();
|
||||
"Determinant", Determinant.test_case ();
|
||||
"Excitation", Excitation.test_case ();
|
||||
"Sparse vectors", Vector.test_case ();
|
||||
"Sparse matrices", Matrix.test_case ();
|
||||
"Cholesky", Cholesky.test_case ();
|
||||
];
|
||||
|
||||
|
||||
let basis_file = "test_files/cc-pvdz"
|
||||
and nuclei_file = "test_files/h2o.xyz"
|
||||
@ -12,21 +24,9 @@ let test_water_dz () =
|
||||
let ao_basis =
|
||||
Simulation.ao_basis simulation_closed_shell
|
||||
in
|
||||
Alcotest.run "Unit tests" [
|
||||
"Util", Util.test_case ();
|
||||
"Bitstring", Bitstring.test_case ();
|
||||
"Spindeterminant", Spindeterminant.test_case ();
|
||||
"Determinant", Determinant.test_case ();
|
||||
"Excitation", Excitation.test_case ();
|
||||
"Sparse vectors", Vector.test_case ();
|
||||
"Sparse matrices", Matrix.test_case ();
|
||||
];
|
||||
|
||||
Alcotest.run "Water, cc-pVDZ" [
|
||||
"AO_Basis", AOBasis.test_case ao_basis;
|
||||
"Guess", Guess.test_case ao_basis;
|
||||
]
|
||||
|
||||
|
||||
let () =
|
||||
test_water_dz ()
|
||||
|
Loading…
Reference in New Issue
Block a user