From 6122bf79d8d7c63ce3f22020a27a392ea99f1dbd Mon Sep 17 00:00:00 2001 From: Anthony Scemama Date: Fri, 2 Oct 2020 15:49:09 +0200 Subject: [PATCH] Added tests for vector --- gaussian_integrals/lib/orthonormalization.ml | 3 +- .../lib/two_electron_rr_vectorized.ml | 3 +- linear_algebra/lib/matrix.ml | 15 +++-- linear_algebra/lib/orthonormalization.ml | 3 +- linear_algebra/lib/vector.ml | 7 ++- linear_algebra/lib/vector.mli | 5 +- linear_algebra/test/b.ml | 14 ----- linear_algebra/test/vector.ml | 59 +++++++++++++++++++ test/dune | 1 + test/run_tests.ml | 1 + 10 files changed, 87 insertions(+), 24 deletions(-) delete mode 100644 linear_algebra/test/b.ml create mode 100644 linear_algebra/test/vector.ml diff --git a/gaussian_integrals/lib/orthonormalization.ml b/gaussian_integrals/lib/orthonormalization.ml index 0396ca3..4fd4c89 100644 --- a/gaussian_integrals/lib/orthonormalization.ml +++ b/gaussian_integrals/lib/orthonormalization.ml @@ -58,7 +58,8 @@ let make_lowdin ~thresh ~overlap = let u_val = Vector.reci (Vector.sqrt u_val) in let u_vec' = - Matrix.init_cols (Matrix.dim1 u_vec) (Matrix.dim2 u_vec) (fun i j -> u_vec_x.{i,j} *. u_val.{j}) + Matrix.init_cols (Matrix.dim1 u_vec) (Matrix.dim2 u_vec) + (fun i j -> u_vec_x.{i,j} *. (Vector.at u_val j)) in Matrix.gemm u_vec' ~transb:`T u_vec diff --git a/gaussian_integrals/lib/two_electron_rr_vectorized.ml b/gaussian_integrals/lib/two_electron_rr_vectorized.ml index dc01a62..e1f8174 100644 --- a/gaussian_integrals/lib/two_electron_rr_vectorized.ml +++ b/gaussian_integrals/lib/two_electron_rr_vectorized.ml @@ -638,7 +638,8 @@ let contracted_class_shell_pairs ?operator ~zero_m ?schwartz_p ?schwartz_q shell raise NullQuartet; let expo_p_inv, expo_q_inv = - expo_p_inv.{i}, expo_q_inv.{j} + (Vector.at expo_p_inv i), + (Vector.at expo_q_inv j) in let center_pq = diff --git a/linear_algebra/lib/matrix.ml b/linear_algebra/lib/matrix.ml index d42c11b..df13691 100644 --- a/linear_algebra/lib/matrix.ml +++ b/linear_algebra/lib/matrix.ml @@ -77,10 +77,12 @@ let detri t = let as_vec_inplace t = Mat.as_vec t + |> Vector.of_bigarray_inplace let as_vec t = - Mat.as_vec t - |> Vector.copy + lacpy t + |> Mat.as_vec + |> Vector.of_bigarray_inplace let random ?rnd_state ?(from= -. 1.0) ?(range=2.0) m n = Mat.random ?rnd_state ~from ~range m n @@ -134,6 +136,7 @@ let diagonalize_symm m_H = let m_V = lacpy m_H in let result = syevd ~vectors:true m_V + |> Vector.of_bigarray in m_V, result @@ -231,10 +234,14 @@ let copy_inplace ?m ?n ?br ?bc ~b ?ar ?ac a = ignore @@ lacpy ?m ?n ?br ?bc ~b ?ar ?ac a let scale_cols_inplace a v = - Mat.scal_cols a v + Vector.to_bigarray v + |> Mat.scal_cols a let scale_cols a v = - out_of_place (fun a -> Mat.scal_cols a v) a + let a' = copy a in + Vector.to_bigarray v + |> Mat.scal_cols a' ; + a' let svd a = diff --git a/linear_algebra/lib/orthonormalization.ml b/linear_algebra/lib/orthonormalization.ml index 905806b..6fbe7a1 100644 --- a/linear_algebra/lib/orthonormalization.ml +++ b/linear_algebra/lib/orthonormalization.ml @@ -9,8 +9,9 @@ let canonical_ortho ?thresh:(thresh=1.e-6) ~overlap c = if x >= thresh then 1. /. x else 0. ) d_sqrt in + let dx = Vector.to_bigarray d in if n < Vector.dim d_sqrt then - Printf.printf "Removed linear dependencies below %f\n" (1. /. d.{n}) + Printf.printf "Removed linear dependencies below %f\n" (1. /. dx.{n}) ; Matrix.scale_cols_inplace u d_inv_sq ; Matrix.gemm c u diff --git a/linear_algebra/lib/vector.ml b/linear_algebra/lib/vector.ml index f1e9872..9bd1387 100644 --- a/linear_algebra/lib/vector.ml +++ b/linear_algebra/lib/vector.ml @@ -27,9 +27,9 @@ let iteri f t = Vec.iteri f t let fold f a t = Vec.fold f a t let add t1 t2 = Vec.add t1 t2 -let sub t1 t2 = Vec.add t1 t2 +let sub t1 t2 = Vec.sub t1 t2 let mul t1 t2 = Vec.mul t1 t2 -let div t1 t2 = Vec.mul t1 t2 +let div t1 t2 = Vec.div t1 t2 let dot t1 t2 = dot t1 t2 let create n = Vec.create n @@ -60,3 +60,6 @@ let normalize v = scal (1. /. (nrm2 v)) result; result + +let at t i = t.{i} + diff --git a/linear_algebra/lib/vector.mli b/linear_algebra/lib/vector.mli index ab081b7..e9117aa 100644 --- a/linear_algebra/lib/vector.mli +++ b/linear_algebra/lib/vector.mli @@ -8,7 +8,7 @@ The indexing of vectors is 1-based. open Lacaml.D -type 'a t = Vec.t +type 'a t (* Parameter ['a] defines the basis on which the vector is expanded. *) val dim : 'a t -> int @@ -74,6 +74,9 @@ val init : int -> (int -> float) -> 'a t val sum : 'a t -> float (** Returns the sum of the elements of the vector *) +val at : 'a t -> int -> float +(** Returns t.{i} *) + val copy : ?n:int -> ?ofsy:int -> ?incy:int -> ?y:vec -> ?ofsx:int -> ?incx:int -> 'a t -> 'a t (** Returns a copy of the vector X into Y. [ofs] controls the offset and [inc] the increment. *) diff --git a/linear_algebra/test/b.ml b/linear_algebra/test/b.ml deleted file mode 100644 index 7dc33ba..0000000 --- a/linear_algebra/test/b.ml +++ /dev/null @@ -1,14 +0,0 @@ -(* - Tests for Sub1.B -*) - -let test_string () = - Alcotest.(check (neg string)) "foo is not bar" "foo" "bar" - -let test_string_hasty () = - assert ("foo" <> "bar") - -let tests = [ - "string", `Quick, test_string; - "string, hasty", `Quick, test_string_hasty; -] diff --git a/linear_algebra/test/vector.ml b/linear_algebra/test/vector.ml new file mode 100644 index 0000000..fca07ca --- /dev/null +++ b/linear_algebra/test/vector.ml @@ -0,0 +1,59 @@ +open Qcaml_linear_algebra +open Alcotest +open Lacaml.D + +let test_all () = + let n = 100 in + let a1 = Array.init n (fun _ -> Random.float 10. -. 5.) in + let a2 = Array.init n (fun _ -> Random.float 10. -. 5.) in + let u1 = Vec.of_array a1 in + let u2 = Vec.of_array a2 in + let v1 = Vector.of_array a1 in + let v2 = Vector.of_array a2 in + + let check_dot1 label f1 f2 = + check (float 1.e-14) label (dot u1 (f1 u2)) (Vector.dot v1 (f2 v2)) + in + + let check_dot2 label f1 f2 = + check (float 1.e-14) label (dot u1 (f1 u1 u2)) (Vector.dot v1 (f2 v1 v2)) + in + + check int "dim" (Array.length a1) (Vector.dim v1); + check (float 1.e-14) "dot" (dot u1 u2) (Vector.dot v1 v2); + check_dot2 "add" (fun x y -> Vec.add x y) Vector.add ; + check_dot2 "sub" (fun x y -> Vec.sub x y) Vector.sub ; + check_dot2 "mul" (fun x y -> Vec.mul x y) Vector.mul ; + check_dot2 "div" (fun x y -> Vec.div x y) Vector.div ; + check_dot1 "sqr" (fun x -> Vec.sqr x) Vector.sqr ; + check_dot1 "sin" (fun x -> Vec.sin x) Vector.sin ; + check_dot1 "cos" (fun x -> Vec.cos x) Vector.cos ; + check_dot1 "tan" (fun x -> Vec.tan x) Vector.tan ; + check_dot1 "abs" (fun x -> Vec.abs x) Vector.abs ; + check_dot1 "neg" (fun x -> Vec.neg x) Vector.neg ; + check_dot1 "asin" (fun x -> Vec.asin x) Vector.asin ; + check_dot1 "acos" (fun x -> Vec.acos x) Vector.acos ; + check_dot1 "atan" (fun x -> Vec.atan x) Vector.atan ; + check_dot1 "sqrt" (fun x -> Vec.sqrt x) Vector.sqrt ; + check_dot1 "reci" (fun x -> Vec.reci x) Vector.reci ; + check_dot1 "map" (fun x -> Vec.map (fun y -> y+. 3.) x) (Vector.map (fun y -> y+. 3.)) ; + check (float 1.e-14) "norm" (sqrt (dot u1 u1)) (Vector.norm v1); + check (float 1.e-14) "norm" (sqrt (dot u2 u2)) (Vector.norm v2); + check (float 1.e-14) "sum" (Vec.sum u1) (Vector.sum v1); + check (float 1.e-14) "sum" (Vec.sum u2) (Vector.sum v2); + check (float 1.e-14) "at" (u1.{n/2}) (Vector.at v1 (n/2)); + check (float 1.e-14) "at" (u2.{n/2}) (Vector.at v2 (n/2)); + check (bool) "of_list" true (v1 = Vector.of_list @@ Array.to_list a1); + check (bool) "of_list" true (v2 = Vector.of_list @@ Array.to_list a2); + check (bool) "to_list" true (Vector.to_list v1 = Array.to_list a1); + check (bool) "to_list" true (Vector.to_list v2 = Array.to_list a2); + check (bool) "to_array" true (Vector.to_array v1 = a1); + check (bool) "to_array" true (Vector.to_array v2 = a2); + check (bool) "make0" true (Vector.make0 n = Vector.make n 0.); + check (float 1.e-14) "fold" (Vector.sum v1) (Vector.fold (fun a x -> a +. x) 0. v1); + check (float 1.e-14) "fold" (Vector.sum v2) (Vector.fold (fun a x -> a +. x) 0. v2); + () + +let tests = [ + "string", `Quick, test_all; +] diff --git a/test/dune b/test/dune index f13591a..ac1a612 100644 --- a/test/dune +++ b/test/dune @@ -3,6 +3,7 @@ (libraries alcotest test_common + test_linear_algebra test_particles test_gaussian_basis )) diff --git a/test/run_tests.ml b/test/run_tests.ml index bd3ad61..33c6572 100644 --- a/test/run_tests.ml +++ b/test/run_tests.ml @@ -5,6 +5,7 @@ let test_suites: unit Alcotest.test list = [ "Common.Bitstring", Test_common.Bitstring.tests; "Common.Util", Test_common.Math_functions.tests; + "Linear_algebra.Vector", Test_linear_algebra.Vector.tests; "Particles.Nuclei", Test_particles.Nuclei.tests; "Particles.Electrons", Test_particles.Electrons.tests; "Gaussian_basis.General_basis", Test_gaussian_basis.General_basis.tests;