diff --git a/Utils/Cholesky.ml b/Utils/Cholesky.ml index bfb68f2..b290444 100644 --- a/Utils/Cholesky.ml +++ b/Utils/Cholesky.ml @@ -62,14 +62,14 @@ let pivoted_ldl threshold m_A = let v_D_inv = Vec.make0 n in - let compute_d j = - let l_jk = - Mat.col m_Lt j + let compute_d i = + let l_ik = + Mat.col m_Lt i in - let l_jk__d_k = - Vec.mul ~n:(j-1) l_jk v_D + let l_ik__d_k = + Vec.mul ~n:(i-1) l_ik v_D in - m_A.{pi.(j),pi.(j)} -. dot ~n:(j-1) l_jk l_jk__d_k + m_A.{pi.(i),pi.(i)} -. dot ~n:(i-1) l_ik l_ik__d_k in let compute_l i = @@ -85,8 +85,8 @@ let pivoted_ldl threshold m_A = v_D_inv.{j} *. ( m_A.{pi.(j),pi.(i)} -. dot ~n:(j-1) l_ik__d_k l_jk ) in - let pivot i = - let rec aux pos value i = + let maxloc v i = + let rec aux pos value i = if i > n then pos else if v_D.{i} > value then @@ -94,11 +94,15 @@ let pivoted_ldl threshold m_A = else aux pos value (i+1) in - let j = aux i v_D.{i} (i+1) in + aux i v.{i} (i+1) + in + + + let pivot i = + let j = maxloc v_D i in let p_i, p_j = pi.(i), pi.(j) in pi.(i) <- p_j; pi.(j) <- p_i; - in @@ -120,16 +124,6 @@ let pivoted_ldl threshold m_A = m_Lt, v_D, pi - (* - let rec run accu_l accu_d accu_p i = - let finish - if i > n then - finish accu_l accu_d accu_p - else - in - run [] [] [] 1 - *) - @@ -138,7 +132,6 @@ let pivoted_ldl threshold m_A = - let test_case () = let matrix_diff m_A m_B = @@ -179,7 +172,7 @@ let test_case () = Alcotest.(check (float 1.e-15)) "full D" 0. (vector_diff v_D v_D_ref); let m_D = Mat.of_diag v_D in let m_B = gemm ~transa:`T m_Lt @@ gemm m_D m_Lt in - Alcotest.(check (float 1.e-15)) "full D" 0. (matrix_diff m_A m_B); + Alcotest.(check (float 1.e-15)) "full L" 0. (matrix_diff m_A m_B); () in @@ -205,12 +198,12 @@ let test_case () = gemm m_D @@ gemm m_Lt m_P in - Alcotest.(check (float 1.e-15)) "full D" 0. (matrix_diff m_A m_B); + Alcotest.(check (float 1.e-14)) "pivoted D" 0. (matrix_diff m_A m_B); () in let test_truncated () = - let m_Lt, v_D, pi = pivoted_ldl 1. m_A in + let m_Lt, v_D, pi = pivoted_ldl 0.001 m_A in let n = Mat.dim1 m_Lt in let m_P = Mat.make0 n n in for i=1 to n do @@ -223,7 +216,7 @@ let test_case () = gemm m_D @@ gemm m_Lt m_P in - Alcotest.(check (float 1.e-15)) "full D" 0. (matrix_diff m_A m_B); + Alcotest.(check (float 1.e-3)) "full D" 0. (matrix_diff m_A m_B); () in [ diff --git a/Utils/Vector.ml b/Utils/Vector.ml index fe1fd55..8905680 100644 --- a/Utils/Vector.ml +++ b/Utils/Vector.ml @@ -171,7 +171,7 @@ let axpy ?(threshold=epsilon) ?(alpha=1.) x y = in aux new_accu r1 r2 | _ -> assert false end - | ({index=i ; value=x}::r1), [] -> aux ({index=i ; value=x}::accu) r1 [] + | ({index=i ; value=x}::r1), [] -> aux ({index=i ; value=alpha *. x}::accu) r1 [] | [] , ({index=j ; value=y}::r2) -> aux ({index=j ; value=y}::accu) [] r2 | [] , [] -> {n ; v=List.rev accu} in @@ -297,6 +297,7 @@ let test_case () = Alcotest.(check bool) "dense dense axpy" true (axpy ~alpha:3. v1 v2 = v6); Alcotest.(check bool) "dense sparse axpy" true (sub ~threshold:1.e-12 (axpy ~alpha:3. v1 v2_s) v6_s = zero_s); Alcotest.(check bool) "sparse dense axpy" true (sub ~threshold:1.e-12 (axpy ~alpha:3. v1_s v2) v6_s = zero_s); + Alcotest.(check bool) "sparse sparse axpy" true (sub ~threshold:1.e-12 (axpy ~alpha:3. v1_s v2_s) v6_s = zero_s); in let test_dot () =