From 4146264c2ea1490a778e58da6caf0d441e179546 Mon Sep 17 00:00:00 2001 From: Anthony Scemama Date: Thu, 17 Dec 2020 19:52:09 +0100 Subject: [PATCH] Added exponential of matrix --- common/lib/constants.ml | 4 +-- linear_algebra/lib/matrix.ml | 65 +++++++++++++++++++++++++++-------- linear_algebra/lib/matrix.mli | 9 +++++ 3 files changed, 61 insertions(+), 17 deletions(-) diff --git a/common/lib/constants.ml b/common/lib/constants.ml index a25723d..f96a4c9 100644 --- a/common/lib/constants.ml +++ b/common/lib/constants.ml @@ -1,5 +1,5 @@ -let integrals_cutoff = 1.e-15 -let epsilon = 1.e-20 +let epsilon = 2.e-15 +let integrals_cutoff = epsilon (** Constants *) let pi = acos (-1.) diff --git a/linear_algebra/lib/matrix.ml b/linear_algebra/lib/matrix.ml index ee77ed0..568e968 100644 --- a/linear_algebra/lib/matrix.ml +++ b/linear_algebra/lib/matrix.ml @@ -1,4 +1,5 @@ open Lacaml.D +open Common type ('a, 'b) t = Mat.t @@ -352,6 +353,16 @@ let scale_cols a v = a' +let scale_rows_inplace v a = + let v' = Vector.to_bigarray_inplace v in + Mat.scal_rows v' a + +let scale_rows v a = + let a' = copy a in + let v' = Vector.to_bigarray_inplace v in + Mat.scal_rows v' a' ; + a' + let svd a = let d, u, vt = gesvd (lacpy a) @@ -382,24 +393,48 @@ let qr a = q, r -let exponential a = +let exponential_iterative a = assert (dim1 a = dim2 a); - let a = to_bigarray_inplace a in - let n = Mat.dim1 a in - let a2 = Lacaml.D.gemm a a in - let (lv, wr, _wi, rv) = Lacaml.D.geev a2 in - let tau = Vec.map (fun x -> -. sqrt x) wr in - let cos_tau = - Mat.init_cols n n (fun i j -> - if i<>j then 0. else cos tau.{i}) + let rec loop result accu n = + let b = scale (1./.n) a in + let new_accu = gemm accu b in + let residual = + sub new_accu accu + |> amax + |> abs_float + in + let result = add result new_accu in + if residual > Constants.epsilon then + loop result new_accu (n+.1.) + else + result in - let sin_tau = - Mat.init_cols n n (fun i j -> - if i<>j then 0. else (sin tau.{i}) /. tau.{i}) + let id = identity (dim1 a) in + loop id id 1. + + +let exponential a = + + let n = dim1 a in + assert (n = dim2 a); + let a2 = gemm a a in + let (u, w, vt) = svd a2 in + let tau = Vector.map (fun x -> -. sqrt x) w in + let cos_tau = Vector.cos tau in + let sin_tau_tau = Vector.mul (Vector.sin tau) (Vector.reci tau) in + let result = + add (gemm (scale_cols u cos_tau) vt) (gemm (scale_cols u sin_tau_tau) @@ gemm vt a) in - let g = Lacaml.D.gemm in - Mat.add (g lv @@ g cos_tau rv) (g lv @@ g sin_tau @@ g rv a) - |> of_bigarray_inplace + + (* Post-condition: Check if exp(-A) * exp(A) = I *) + let id = identity n in + let test = + gemm_tn ~beta:(-.1.0) ~c:id result result + |> amax + |> abs_float + in + assert (test < Constants.epsilon); + result let to_file ~filename ?(sym=false) ?(cutoff=0.) t = diff --git a/linear_algebra/lib/matrix.mli b/linear_algebra/lib/matrix.mli index 55466b8..f9e4b44 100644 --- a/linear_algebra/lib/matrix.mli +++ b/linear_algebra/lib/matrix.mli @@ -171,6 +171,12 @@ val scale_cols: ('a,'b) t -> 'b Vector.t -> ('a,'b) t val scale_cols_inplace: ('a,'b) t -> 'b Vector.t -> unit (** Multiplies the matrix by a constant *) +val scale_rows: 'a Vector.t -> ('a,'b) t -> ('a,'b) t +(** Multiplies the matrix by a constant *) + +val scale_rows_inplace: 'a Vector.t -> ('a,'b) t -> unit +(** Multiplies the matrix by a constant *) + val sycon: ('a,'b) t -> float (** Returns the condition number of a matrix *) @@ -300,6 +306,9 @@ val diagonalize_symm : ('a,'a) t -> ('a,'a) t * 'a Vector.t val exponential : ('a,'a) t -> ('a,'a) t (** Computes the exponential of a square matrix. *) +val exponential_iterative : ('a,'a) t -> ('a,'a) t +(** Computes the exponential of a square matrix with an iteratve algorithm. *) + val xt_o_x : o:('a,'a) t -> x:('a,'b) t -> ('b,'b) t (** Computes {% $\mathbf{X^\dag\, O\, X}$ %} *)