10
1
mirror of https://gitlab.com/scemama/QCaml.git synced 2025-01-05 02:48:37 +01:00
QCaml/Utils/Util.ml

276 lines
6.6 KiB
OCaml
Raw Normal View History

2018-02-24 23:57:38 +01:00
(** All utilities which should be included in all source files are defined here *)
(** {1 Functions from libm} *)
2018-02-01 22:39:23 +01:00
2018-02-02 01:25:10 +01:00
open Constants
2018-02-23 15:49:27 +01:00
open Lacaml.D
2018-01-18 23:42:48 +01:00
2018-01-17 15:56:57 +01:00
let factmax = 150
2018-02-24 23:57:38 +01:00
external erf_float : float -> float = "erf_float_bytecode" "erf_float"
[@@unboxed] [@@noalloc]
external erfc_float : float -> float = "erfc_float_bytecode" "erfc_float"
[@@unboxed] [@@noalloc]
external gamma_float : float -> float = "gamma_float_bytecode" "gamma_float"
[@@unboxed] [@@noalloc]
2018-01-17 15:56:57 +01:00
2018-02-01 22:53:00 +01:00
(* Incomplete gamma function : Int_0^x exp(-t) t^(a-1) dt
p: 1 / Gamma(a) * Int_0^x exp(-t) t^(a-1) dt
q: 1 / Gamma(a) * Int_x^inf exp(-t) t^(a-1) dt
reference - Haruhiko Okumura: C-gengo niyoru saishin algorithm jiten
2018-02-01 22:53:00 +01:00
(New Algorithm handbook in C language) (Gijyutsu hyouron
sha, Tokyo, 1991) p.227 [in Japanese] *)
let incomplete_gamma ~alpha x =
2018-02-09 01:32:07 +01:00
let a = alpha in
let a_inv = 1./. a in
let gf = gamma_float alpha in
let loggamma_a = log gf in
let rec p_gamma x =
if x >= 1. +. a then 1. -. q_gamma x
else if x = 0. then 0.
2018-02-01 22:53:00 +01:00
else
let rec pg_loop prev res term k =
if k > 1000. then failwith "p_gamma did not converge."
2018-02-01 22:53:00 +01:00
else if prev = res then res
else
let term = term *. x /. (a +. k) in
pg_loop res (res +. term) term (k +. 1.)
2018-02-01 22:53:00 +01:00
in
2018-02-09 01:32:07 +01:00
let r0 = exp (a *. log x -. x -. loggamma_a) *. a_inv in
pg_loop min_float r0 r0 1.
2018-02-01 22:53:00 +01:00
2018-02-09 01:32:07 +01:00
and q_gamma x =
if x < 1. +. a then 1. -. p_gamma x
2018-02-03 19:01:30 +01:00
else
let rec qg_loop prev res la lb w k =
if k > 1000. then failwith "q_gamma did not converge."
else if prev = res then res
else
let k_inv = 1. /. k in
2018-02-09 01:32:07 +01:00
let kma = (k -. 1. -. a) *. k_inv in
2018-02-03 19:01:30 +01:00
let la, lb =
2018-02-09 01:32:07 +01:00
lb, kma *. (lb -. la) +. (k +. x) *. lb *. k_inv
2018-02-03 19:01:30 +01:00
in
2018-02-09 01:32:07 +01:00
let w = w *. kma in
2018-02-03 19:01:30 +01:00
let prev, res = res, res +. w /. (la *. lb) in
qg_loop prev res la lb w (k +. 1.)
in
let w = exp (a *. log x -. x -. loggamma_a) in
let lb = (1. +. x -. a) in
qg_loop min_float (w /. lb) 1. lb w 2.0
in
2018-02-09 01:32:07 +01:00
gf *. p_gamma x
2018-01-17 15:56:57 +01:00
let fact_memo =
let rec aux accu_l accu = function
2018-02-03 19:01:30 +01:00
| 0 -> aux [1.] 1. 1
| i when (i = factmax) ->
let x = (float_of_int factmax) *. accu in
List.rev (x::accu_l)
| i -> let x = (float_of_int i) *. accu in
aux (x::accu_l) x (i+1)
2018-01-17 15:56:57 +01:00
in
aux [] 0. 0
|> Array.of_list
let fact = function
| i when (i < 0) ->
raise (Invalid_argument "Argument of factorial should be non-negative")
| i when (i > 150) ->
raise (Invalid_argument "Result of factorial is infinite")
| i -> fact_memo.(i)
let rec pow a = function
| 0 -> 1.
| 1 -> a
| 2 -> a *. a
| 3 -> a *. a *. a
| -1 -> 1. /. a
2018-02-24 23:57:38 +01:00
| n when n > 0 ->
2018-01-17 15:56:57 +01:00
let b = pow a (n / 2) in
b *. b *. (if n mod 2 = 0 then 1. else a)
2018-02-24 23:57:38 +01:00
| n when n < 0 -> pow (1./.a) (-n)
| _ -> assert false
2018-01-17 15:56:57 +01:00
2018-01-22 23:19:24 +01:00
let chop f g =
2018-02-24 23:57:38 +01:00
if (abs_float f) < Constants.epsilon then 0.
2018-01-22 23:19:24 +01:00
else f *. (g ())
2018-02-01 22:19:23 +01:00
(** Generalized Boys function.
maxm : Maximum total angular momentum
*)
let boys_function ~maxm t =
match maxm with
| 0 ->
begin
if t = 0. then [| 1. |] else
2018-02-03 19:01:30 +01:00
let sq_t = sqrt t in
[| (sq_pi_over_two /. sq_t) *. erf_float sq_t |]
end
| _ ->
begin
let result =
Array.init (maxm+1) (fun m -> 1. /. float_of_int (2*m+1))
in
2018-06-27 13:13:59 +02:00
(*
assert (abs_float t > 1.e-10);
*)
if t <> 0. then
begin
let fmax =
let t_inv = sqrt (1. /. t) in
let n = float_of_int maxm in
let dm = 0.5 +. n in
let f = (pow t_inv (maxm+maxm+1) ) in
2018-06-27 13:13:59 +02:00
match classify_float f with
| FP_zero
| FP_subnormal
| FP_normal ->
(incomplete_gamma dm t) *. 0.5 *. f
| _ -> invalid_arg "zero_m overflow"
in
let emt = exp (-. t) in
result.(maxm) <- fmax;
for n=maxm-1 downto 0 do
result.(n) <- ( (t+.t) *. result.(n+1) +. emt) *. result.(n)
done
end;
result
end
2018-02-20 23:54:48 +01:00
2018-03-22 00:29:14 +01:00
(** {2 List functions} *)
let list_some l =
List.filter (function None -> false | _ -> true) l
|> List.map (function Some x -> x | _ -> assert false)
2018-06-28 14:43:24 +02:00
(** {2 Stream functions} *)
let range ?(start=0) n =
Stream.from (fun i ->
let result = i+start in
if result < n then
Some result
else None
)
2018-03-22 00:29:14 +01:00
(** {2 Linear algebra} *)
2018-02-20 23:54:48 +01:00
let array_sum a =
Array.fold_left ( +. ) 0. a
let array_product a =
Array.fold_left ( *. ) 0. a
2018-02-21 17:06:24 +01:00
2018-05-30 09:19:49 +02:00
let diagonalize_symm m_H =
let m_V = lacpy m_H in
2018-02-21 17:06:24 +01:00
let result =
2018-05-31 16:46:45 +02:00
syevd ~vectors:true m_V
2018-02-21 17:06:24 +01:00
in
2018-05-30 09:19:49 +02:00
m_V, result
2018-02-22 18:20:45 +01:00
2018-02-24 23:57:38 +01:00
let xt_o_x ~o ~x =
2018-02-22 18:20:45 +01:00
gemm o x
|> gemm ~transa:`T x
let canonical_ortho ?thresh:(thresh=1.e-6) ~overlap c =
let d, u, _ = gesvd (lacpy overlap) in
2018-05-30 09:19:49 +02:00
let d_sqrt = Vec.sqrt d in
let n = (* Number of non-negligible singular vectors *)
Vec.fold (fun accu x -> if x > thresh then accu + 1 else accu) 0 d
in
2018-05-30 09:19:49 +02:00
let d_inv_sq = (* D^{-1/2} *)
Vec.map (fun x ->
if x >= thresh then 1. /. x
else 0. ) ~y:d d_sqrt
in
if n < Vec.dim d_sqrt then
Printf.printf "Removed linear dependencies below %f\n" (1. /. d.{n})
;
2018-05-30 09:19:49 +02:00
Mat.scal_cols u d_inv_sq ;
gemm c u
2018-03-15 15:25:49 +01:00
(** {2 Printers} *)
let pp_float_array_size ppf a =
2018-03-20 14:11:31 +01:00
Format.fprintf ppf "@[<2>@[ %d:@[<2>" (Array.length a);
2018-03-15 15:25:49 +01:00
Array.iter (fun f -> Format.fprintf ppf "@[%10f@]@ " f) a;
2018-03-20 14:11:31 +01:00
Format.fprintf ppf "]@]@]"
2018-03-15 15:25:49 +01:00
let pp_float_array ppf a =
Format.fprintf ppf "@[<2>[@ ";
Array.iter (fun f -> Format.fprintf ppf "@[%10f@]@ " f) a;
Format.fprintf ppf "]@]"
2018-03-20 14:11:31 +01:00
let pp_float_2darray ppf a =
Format.fprintf ppf "@[<2>[@ ";
Array.iter (fun f -> Format.fprintf ppf "@[%a@]@ " pp_float_array f) a;
Format.fprintf ppf "]@]"
let pp_float_2darray_size ppf a =
Format.fprintf ppf "@[<2>@[ %d:@[" (Array.length a);
Array.iter (fun f -> Format.fprintf ppf "@[%a@]@ " pp_float_array_size f) a;
Format.fprintf ppf "]@]@]"
2018-06-27 13:13:59 +02:00
let pp_matrix ppf m =
let open Lacaml.Io in
let rows = Mat.dim1 m
and cols = Mat.dim2 m
in
let rec aux first last =
if (first <= last) then begin
Format.fprintf ppf "@[\n\n %a@]@ " (Lacaml.Io.pp_lfmat
~row_labels:
(Array.init rows (fun i -> Printf.sprintf "%d " (i + 1)))
~col_labels:
(Array.init (min 5 (cols-first+1)) (fun i -> Printf.sprintf "-- %d --" (i + first) ))
~print_right:false
~print_foot:false
() ) (lacpy ~ac:first ~n:(min 5 (cols-first+1)) m);
aux (first+5) last
end
in
aux 1 cols
let string_of_matrix m =
Format.asprintf "%a" pp_matrix m
let debug_matrix name a =
Format.printf "@[%s =\n@[%a@]@]@ " name pp_matrix a