mirror of
https://gitlab.com/scemama/QCaml.git
synced 2025-01-08 20:33:03 +01:00
Added exponential of matrix
This commit is contained in:
parent
3c27ec0c10
commit
4146264c2e
@ -1,5 +1,5 @@
|
|||||||
let integrals_cutoff = 1.e-15
|
let epsilon = 2.e-15
|
||||||
let epsilon = 1.e-20
|
let integrals_cutoff = epsilon
|
||||||
|
|
||||||
(** Constants *)
|
(** Constants *)
|
||||||
let pi = acos (-1.)
|
let pi = acos (-1.)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
open Lacaml.D
|
open Lacaml.D
|
||||||
|
open Common
|
||||||
|
|
||||||
type ('a, 'b) t = Mat.t
|
type ('a, 'b) t = Mat.t
|
||||||
|
|
||||||
@ -352,6 +353,16 @@ let scale_cols a v =
|
|||||||
a'
|
a'
|
||||||
|
|
||||||
|
|
||||||
|
let scale_rows_inplace v a =
|
||||||
|
let v' = Vector.to_bigarray_inplace v in
|
||||||
|
Mat.scal_rows v' a
|
||||||
|
|
||||||
|
let scale_rows v a =
|
||||||
|
let a' = copy a in
|
||||||
|
let v' = Vector.to_bigarray_inplace v in
|
||||||
|
Mat.scal_rows v' a' ;
|
||||||
|
a'
|
||||||
|
|
||||||
let svd a =
|
let svd a =
|
||||||
let d, u, vt =
|
let d, u, vt =
|
||||||
gesvd (lacpy a)
|
gesvd (lacpy a)
|
||||||
@ -382,24 +393,48 @@ let qr a =
|
|||||||
q, r
|
q, r
|
||||||
|
|
||||||
|
|
||||||
let exponential a =
|
let exponential_iterative a =
|
||||||
assert (dim1 a = dim2 a);
|
assert (dim1 a = dim2 a);
|
||||||
let a = to_bigarray_inplace a in
|
let rec loop result accu n =
|
||||||
let n = Mat.dim1 a in
|
let b = scale (1./.n) a in
|
||||||
let a2 = Lacaml.D.gemm a a in
|
let new_accu = gemm accu b in
|
||||||
let (lv, wr, _wi, rv) = Lacaml.D.geev a2 in
|
let residual =
|
||||||
let tau = Vec.map (fun x -> -. sqrt x) wr in
|
sub new_accu accu
|
||||||
let cos_tau =
|
|> amax
|
||||||
Mat.init_cols n n (fun i j ->
|
|> abs_float
|
||||||
if i<>j then 0. else cos tau.{i})
|
in
|
||||||
|
let result = add result new_accu in
|
||||||
|
if residual > Constants.epsilon then
|
||||||
|
loop result new_accu (n+.1.)
|
||||||
|
else
|
||||||
|
result
|
||||||
in
|
in
|
||||||
let sin_tau =
|
let id = identity (dim1 a) in
|
||||||
Mat.init_cols n n (fun i j ->
|
loop id id 1.
|
||||||
if i<>j then 0. else (sin tau.{i}) /. tau.{i})
|
|
||||||
|
|
||||||
|
let exponential a =
|
||||||
|
|
||||||
|
let n = dim1 a in
|
||||||
|
assert (n = dim2 a);
|
||||||
|
let a2 = gemm a a in
|
||||||
|
let (u, w, vt) = svd a2 in
|
||||||
|
let tau = Vector.map (fun x -> -. sqrt x) w in
|
||||||
|
let cos_tau = Vector.cos tau in
|
||||||
|
let sin_tau_tau = Vector.mul (Vector.sin tau) (Vector.reci tau) in
|
||||||
|
let result =
|
||||||
|
add (gemm (scale_cols u cos_tau) vt) (gemm (scale_cols u sin_tau_tau) @@ gemm vt a)
|
||||||
in
|
in
|
||||||
let g = Lacaml.D.gemm in
|
|
||||||
Mat.add (g lv @@ g cos_tau rv) (g lv @@ g sin_tau @@ g rv a)
|
(* Post-condition: Check if exp(-A) * exp(A) = I *)
|
||||||
|> of_bigarray_inplace
|
let id = identity n in
|
||||||
|
let test =
|
||||||
|
gemm_tn ~beta:(-.1.0) ~c:id result result
|
||||||
|
|> amax
|
||||||
|
|> abs_float
|
||||||
|
in
|
||||||
|
assert (test < Constants.epsilon);
|
||||||
|
result
|
||||||
|
|
||||||
|
|
||||||
let to_file ~filename ?(sym=false) ?(cutoff=0.) t =
|
let to_file ~filename ?(sym=false) ?(cutoff=0.) t =
|
||||||
|
@ -171,6 +171,12 @@ val scale_cols: ('a,'b) t -> 'b Vector.t -> ('a,'b) t
|
|||||||
val scale_cols_inplace: ('a,'b) t -> 'b Vector.t -> unit
|
val scale_cols_inplace: ('a,'b) t -> 'b Vector.t -> unit
|
||||||
(** Multiplies the matrix by a constant *)
|
(** Multiplies the matrix by a constant *)
|
||||||
|
|
||||||
|
val scale_rows: 'a Vector.t -> ('a,'b) t -> ('a,'b) t
|
||||||
|
(** Multiplies the matrix by a constant *)
|
||||||
|
|
||||||
|
val scale_rows_inplace: 'a Vector.t -> ('a,'b) t -> unit
|
||||||
|
(** Multiplies the matrix by a constant *)
|
||||||
|
|
||||||
val sycon: ('a,'b) t -> float
|
val sycon: ('a,'b) t -> float
|
||||||
(** Returns the condition number of a matrix *)
|
(** Returns the condition number of a matrix *)
|
||||||
|
|
||||||
@ -300,6 +306,9 @@ val diagonalize_symm : ('a,'a) t -> ('a,'a) t * 'a Vector.t
|
|||||||
val exponential : ('a,'a) t -> ('a,'a) t
|
val exponential : ('a,'a) t -> ('a,'a) t
|
||||||
(** Computes the exponential of a square matrix. *)
|
(** Computes the exponential of a square matrix. *)
|
||||||
|
|
||||||
|
val exponential_iterative : ('a,'a) t -> ('a,'a) t
|
||||||
|
(** Computes the exponential of a square matrix with an iteratve algorithm. *)
|
||||||
|
|
||||||
val xt_o_x : o:('a,'a) t -> x:('a,'b) t -> ('b,'b) t
|
val xt_o_x : o:('a,'a) t -> x:('a,'b) t -> ('b,'b) t
|
||||||
(** Computes {% $\mathbf{X^\dag\, O\, X}$ %} *)
|
(** Computes {% $\mathbf{X^\dag\, O\, X}$ %} *)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user