10
1
mirror of https://gitlab.com/scemama/QCaml.git synced 2024-12-22 20:33:36 +01:00

Added tests for vector

This commit is contained in:
Anthony Scemama 2020-10-02 15:49:09 +02:00
parent 13b791e321
commit 6122bf79d8
10 changed files with 87 additions and 24 deletions

View File

@ -58,7 +58,8 @@ let make_lowdin ~thresh ~overlap =
let u_val = Vector.reci (Vector.sqrt u_val) in
let u_vec' =
Matrix.init_cols (Matrix.dim1 u_vec) (Matrix.dim2 u_vec) (fun i j -> u_vec_x.{i,j} *. u_val.{j})
Matrix.init_cols (Matrix.dim1 u_vec) (Matrix.dim2 u_vec)
(fun i j -> u_vec_x.{i,j} *. (Vector.at u_val j))
in
Matrix.gemm u_vec' ~transb:`T u_vec

View File

@ -638,7 +638,8 @@ let contracted_class_shell_pairs ?operator ~zero_m ?schwartz_p ?schwartz_q shell
raise NullQuartet;
let expo_p_inv, expo_q_inv =
expo_p_inv.{i}, expo_q_inv.{j}
(Vector.at expo_p_inv i),
(Vector.at expo_q_inv j)
in
let center_pq =

View File

@ -77,10 +77,12 @@ let detri t =
let as_vec_inplace t =
Mat.as_vec t
|> Vector.of_bigarray_inplace
let as_vec t =
Mat.as_vec t
|> Vector.copy
lacpy t
|> Mat.as_vec
|> Vector.of_bigarray_inplace
let random ?rnd_state ?(from= -. 1.0) ?(range=2.0) m n =
Mat.random ?rnd_state ~from ~range m n
@ -134,6 +136,7 @@ let diagonalize_symm m_H =
let m_V = lacpy m_H in
let result =
syevd ~vectors:true m_V
|> Vector.of_bigarray
in
m_V, result
@ -231,10 +234,14 @@ let copy_inplace ?m ?n ?br ?bc ~b ?ar ?ac a =
ignore @@ lacpy ?m ?n ?br ?bc ~b ?ar ?ac a
let scale_cols_inplace a v =
Mat.scal_cols a v
Vector.to_bigarray v
|> Mat.scal_cols a
let scale_cols a v =
out_of_place (fun a -> Mat.scal_cols a v) a
let a' = copy a in
Vector.to_bigarray v
|> Mat.scal_cols a' ;
a'
let svd a =

View File

@ -9,8 +9,9 @@ let canonical_ortho ?thresh:(thresh=1.e-6) ~overlap c =
if x >= thresh then 1. /. x
else 0. ) d_sqrt
in
let dx = Vector.to_bigarray d in
if n < Vector.dim d_sqrt then
Printf.printf "Removed linear dependencies below %f\n" (1. /. d.{n})
Printf.printf "Removed linear dependencies below %f\n" (1. /. dx.{n})
;
Matrix.scale_cols_inplace u d_inv_sq ;
Matrix.gemm c u

View File

@ -27,9 +27,9 @@ let iteri f t = Vec.iteri f t
let fold f a t = Vec.fold f a t
let add t1 t2 = Vec.add t1 t2
let sub t1 t2 = Vec.add t1 t2
let sub t1 t2 = Vec.sub t1 t2
let mul t1 t2 = Vec.mul t1 t2
let div t1 t2 = Vec.mul t1 t2
let div t1 t2 = Vec.div t1 t2
let dot t1 t2 = dot t1 t2
let create n = Vec.create n
@ -60,3 +60,6 @@ let normalize v =
scal (1. /. (nrm2 v)) result;
result
let at t i = t.{i}

View File

@ -8,7 +8,7 @@ The indexing of vectors is 1-based.
open Lacaml.D
type 'a t = Vec.t
type 'a t
(* Parameter ['a] defines the basis on which the vector is expanded. *)
val dim : 'a t -> int
@ -74,6 +74,9 @@ val init : int -> (int -> float) -> 'a t
val sum : 'a t -> float
(** Returns the sum of the elements of the vector *)
val at : 'a t -> int -> float
(** Returns t.{i} *)
val copy : ?n:int -> ?ofsy:int -> ?incy:int -> ?y:vec -> ?ofsx:int -> ?incx:int -> 'a t -> 'a t
(** Returns a copy of the vector X into Y. [ofs] controls the offset and [inc]
the increment. *)

View File

@ -1,14 +0,0 @@
(*
Tests for Sub1.B
*)
let test_string () =
Alcotest.(check (neg string)) "foo is not bar" "foo" "bar"
let test_string_hasty () =
assert ("foo" <> "bar")
let tests = [
"string", `Quick, test_string;
"string, hasty", `Quick, test_string_hasty;
]

View File

@ -0,0 +1,59 @@
open Qcaml_linear_algebra
open Alcotest
open Lacaml.D
let test_all () =
let n = 100 in
let a1 = Array.init n (fun _ -> Random.float 10. -. 5.) in
let a2 = Array.init n (fun _ -> Random.float 10. -. 5.) in
let u1 = Vec.of_array a1 in
let u2 = Vec.of_array a2 in
let v1 = Vector.of_array a1 in
let v2 = Vector.of_array a2 in
let check_dot1 label f1 f2 =
check (float 1.e-14) label (dot u1 (f1 u2)) (Vector.dot v1 (f2 v2))
in
let check_dot2 label f1 f2 =
check (float 1.e-14) label (dot u1 (f1 u1 u2)) (Vector.dot v1 (f2 v1 v2))
in
check int "dim" (Array.length a1) (Vector.dim v1);
check (float 1.e-14) "dot" (dot u1 u2) (Vector.dot v1 v2);
check_dot2 "add" (fun x y -> Vec.add x y) Vector.add ;
check_dot2 "sub" (fun x y -> Vec.sub x y) Vector.sub ;
check_dot2 "mul" (fun x y -> Vec.mul x y) Vector.mul ;
check_dot2 "div" (fun x y -> Vec.div x y) Vector.div ;
check_dot1 "sqr" (fun x -> Vec.sqr x) Vector.sqr ;
check_dot1 "sin" (fun x -> Vec.sin x) Vector.sin ;
check_dot1 "cos" (fun x -> Vec.cos x) Vector.cos ;
check_dot1 "tan" (fun x -> Vec.tan x) Vector.tan ;
check_dot1 "abs" (fun x -> Vec.abs x) Vector.abs ;
check_dot1 "neg" (fun x -> Vec.neg x) Vector.neg ;
check_dot1 "asin" (fun x -> Vec.asin x) Vector.asin ;
check_dot1 "acos" (fun x -> Vec.acos x) Vector.acos ;
check_dot1 "atan" (fun x -> Vec.atan x) Vector.atan ;
check_dot1 "sqrt" (fun x -> Vec.sqrt x) Vector.sqrt ;
check_dot1 "reci" (fun x -> Vec.reci x) Vector.reci ;
check_dot1 "map" (fun x -> Vec.map (fun y -> y+. 3.) x) (Vector.map (fun y -> y+. 3.)) ;
check (float 1.e-14) "norm" (sqrt (dot u1 u1)) (Vector.norm v1);
check (float 1.e-14) "norm" (sqrt (dot u2 u2)) (Vector.norm v2);
check (float 1.e-14) "sum" (Vec.sum u1) (Vector.sum v1);
check (float 1.e-14) "sum" (Vec.sum u2) (Vector.sum v2);
check (float 1.e-14) "at" (u1.{n/2}) (Vector.at v1 (n/2));
check (float 1.e-14) "at" (u2.{n/2}) (Vector.at v2 (n/2));
check (bool) "of_list" true (v1 = Vector.of_list @@ Array.to_list a1);
check (bool) "of_list" true (v2 = Vector.of_list @@ Array.to_list a2);
check (bool) "to_list" true (Vector.to_list v1 = Array.to_list a1);
check (bool) "to_list" true (Vector.to_list v2 = Array.to_list a2);
check (bool) "to_array" true (Vector.to_array v1 = a1);
check (bool) "to_array" true (Vector.to_array v2 = a2);
check (bool) "make0" true (Vector.make0 n = Vector.make n 0.);
check (float 1.e-14) "fold" (Vector.sum v1) (Vector.fold (fun a x -> a +. x) 0. v1);
check (float 1.e-14) "fold" (Vector.sum v2) (Vector.fold (fun a x -> a +. x) 0. v2);
()
let tests = [
"string", `Quick, test_all;
]

View File

@ -3,6 +3,7 @@
(libraries
alcotest
test_common
test_linear_algebra
test_particles
test_gaussian_basis
))

View File

@ -5,6 +5,7 @@
let test_suites: unit Alcotest.test list = [
"Common.Bitstring", Test_common.Bitstring.tests;
"Common.Util", Test_common.Math_functions.tests;
"Linear_algebra.Vector", Test_linear_algebra.Vector.tests;
"Particles.Nuclei", Test_particles.Nuclei.tests;
"Particles.Electrons", Test_particles.Electrons.tests;
"Gaussian_basis.General_basis", Test_gaussian_basis.General_basis.tests;