mirror of
https://gitlab.com/scemama/QCaml.git
synced 2025-01-03 10:05:40 +01:00
479 lines
12 KiB
OCaml
479 lines
12 KiB
OCaml
(** All utilities which should be included in all source files are defined here *)
|
|
|
|
(** {1 Functions from libm} *)
|
|
|
|
open Constants
|
|
open Lacaml.D
|
|
|
|
|
|
|
|
external erf_float : float -> float = "erf_float_bytecode" "erf_float"
|
|
[@@unboxed] [@@noalloc]
|
|
|
|
external erfc_float : float -> float = "erfc_float_bytecode" "erfc_float"
|
|
[@@unboxed] [@@noalloc]
|
|
|
|
external gamma_float : float -> float = "gamma_float_bytecode" "gamma_float"
|
|
[@@unboxed] [@@noalloc]
|
|
|
|
external popcnt : int64 -> int32 = "popcnt_bytecode" "popcnt"
|
|
[@@unboxed] [@@noalloc]
|
|
(** popcnt instruction *)
|
|
|
|
let popcnt i = popcnt i |> Int32.to_int
|
|
|
|
external trailz : int64 -> int32 = "trailz_bytecode" "trailz"
|
|
[@@unboxed] [@@noalloc]
|
|
(** ctz instruction *)
|
|
|
|
let trailz i = trailz i |> Int32.to_int
|
|
|
|
external leadz : int64 -> int32 = "leadz_bytecode" "leadz"
|
|
[@@unboxed] [@@noalloc]
|
|
(** bsf instruction *)
|
|
|
|
external vfork : unit -> int = "unix_vfork" "unix_vfork"
|
|
|
|
let leadz i = leadz i |> Int32.to_int
|
|
|
|
|
|
exception SIGTERM
|
|
|
|
let () =
|
|
let f _ = raise SIGTERM in
|
|
Sys.set_signal Sys.sigint (Sys.Signal_handle f)
|
|
;;
|
|
|
|
|
|
|
|
let factmax = 150
|
|
|
|
(* Incomplete gamma function : Int_0^x exp(-t) t^(a-1) dt
|
|
p: 1 / Gamma(a) * Int_0^x exp(-t) t^(a-1) dt
|
|
q: 1 / Gamma(a) * Int_x^inf exp(-t) t^(a-1) dt
|
|
|
|
reference - Haruhiko Okumura: C-gengo niyoru saishin algorithm jiten
|
|
(New Algorithm handbook in C language) (Gijyutsu hyouron
|
|
sha, Tokyo, 1991) p.227 [in Japanese] *)
|
|
|
|
let incomplete_gamma ~alpha x =
|
|
let a = alpha in
|
|
let a_inv = 1./. a in
|
|
let gf = gamma_float alpha in
|
|
let loggamma_a = log gf in
|
|
let rec p_gamma x =
|
|
if x >= 1. +. a then 1. -. q_gamma x
|
|
else if x = 0. then 0.
|
|
else
|
|
let rec pg_loop prev res term k =
|
|
if k > 1000. then failwith "p_gamma did not converge."
|
|
else if prev = res then res
|
|
else
|
|
let term = term *. x /. (a +. k) in
|
|
(pg_loop [@tailcall]) res (res +. term) term (k +. 1.)
|
|
in
|
|
let r0 = exp (a *. log x -. x -. loggamma_a) *. a_inv in
|
|
pg_loop min_float r0 r0 1.
|
|
|
|
and q_gamma x =
|
|
if x < 1. +. a then 1. -. p_gamma x
|
|
else
|
|
let rec qg_loop prev res la lb w k =
|
|
if k > 1000. then failwith "q_gamma did not converge."
|
|
else if prev = res then res
|
|
else
|
|
let k_inv = 1. /. k in
|
|
let kma = (k -. 1. -. a) *. k_inv in
|
|
let la, lb =
|
|
lb, kma *. (lb -. la) +. (k +. x) *. lb *. k_inv
|
|
in
|
|
let w = w *. kma in
|
|
let prev, res = res, res +. w /. (la *. lb) in
|
|
(qg_loop [@tailcall]) prev res la lb w (k +. 1.)
|
|
in
|
|
let w = exp (a *. log x -. x -. loggamma_a) in
|
|
let lb = (1. +. x -. a) in
|
|
qg_loop min_float (w /. lb) 1. lb w 2.0
|
|
in
|
|
gf *. p_gamma x
|
|
|
|
|
|
|
|
|
|
|
|
let fact_memo =
|
|
let rec aux accu_l accu = function
|
|
| 0 -> (aux [@tailcall]) [1.] 1. 1
|
|
| i when (i = factmax) ->
|
|
let x = (float_of_int factmax) *. accu in
|
|
List.rev (x::accu_l)
|
|
| i -> let x = (float_of_int i) *. accu in
|
|
(aux [@tailcall]) (x::accu_l) x (i+1)
|
|
in
|
|
aux [] 0. 0
|
|
|> Array.of_list
|
|
|
|
|
|
|
|
let fact = function
|
|
| i when (i < 0) ->
|
|
raise (Invalid_argument "Argument of factorial should be non-negative")
|
|
| i when (i > 150) ->
|
|
raise (Invalid_argument "Result of factorial is infinite")
|
|
| i -> fact_memo.(i)
|
|
|
|
|
|
let binom n k =
|
|
(*TODO : slow function *)
|
|
assert (n >= k);
|
|
let rec aux n k =
|
|
if k = 0 || k = n then
|
|
1
|
|
else
|
|
aux (n-1) (k-1) + aux (n-1) k
|
|
in aux n k
|
|
|
|
|
|
let rec pow a = function
|
|
| 0 -> 1.
|
|
| 1 -> a
|
|
| 2 -> a *. a
|
|
| 3 -> a *. a *. a
|
|
| -1 -> 1. /. a
|
|
| n when n > 0 ->
|
|
let b = pow a (n / 2) in
|
|
b *. b *. (if n mod 2 = 0 then 1. else a)
|
|
| n when n < 0 -> (pow [@tailcall]) (1./.a) (-n)
|
|
| _ -> assert false
|
|
|
|
|
|
|
|
|
|
let chop f g =
|
|
if (abs_float f) < Constants.epsilon then 0.
|
|
else f *. (g ())
|
|
|
|
|
|
|
|
(** Generalized Boys function.
|
|
maxm : Maximum total angular momentum
|
|
*)
|
|
let boys_function ~maxm t =
|
|
match maxm with
|
|
| 0 ->
|
|
begin
|
|
if t = 0. then [| 1. |] else
|
|
let sq_t = sqrt t in
|
|
[| (sq_pi_over_two /. sq_t) *. erf_float sq_t |]
|
|
end
|
|
| _ ->
|
|
begin
|
|
let result =
|
|
Array.init (maxm+1) (fun m -> 1. /. float_of_int (2*m+1))
|
|
in
|
|
(*
|
|
assert (abs_float t > 1.e-10);
|
|
*)
|
|
if t <> 0. then
|
|
begin
|
|
let fmax =
|
|
let t_inv = sqrt (1. /. t) in
|
|
let n = float_of_int maxm in
|
|
let dm = 0.5 +. n in
|
|
let f = (pow t_inv (maxm+maxm+1) ) in
|
|
match classify_float f with
|
|
| FP_zero
|
|
| FP_subnormal
|
|
| FP_normal ->
|
|
(incomplete_gamma dm t) *. 0.5 *. f
|
|
| _ -> invalid_arg "zero_m overflow"
|
|
in
|
|
let emt = exp (-. t) in
|
|
result.(maxm) <- fmax;
|
|
for n=maxm-1 downto 0 do
|
|
result.(n) <- ( (t+.t) *. result.(n+1) +. emt) *. result.(n)
|
|
done
|
|
end;
|
|
result
|
|
end
|
|
|
|
|
|
let of_some = function
|
|
| Some a -> a
|
|
| None -> assert false
|
|
|
|
|
|
(** {2 List functions} *)
|
|
|
|
let list_some l =
|
|
List.filter (function None -> false | _ -> true) l
|
|
|> List.map (function Some x -> x | _ -> assert false)
|
|
|
|
|
|
let list_range first last =
|
|
if last < first then [] else
|
|
let rec aux accu = function
|
|
| 0 -> first :: accu
|
|
| i -> (aux [@tailcall]) ( (first+i)::accu ) (i-1)
|
|
in
|
|
aux [] (last-first)
|
|
|
|
|
|
let list_pack n l =
|
|
let rec aux i accu1 accu2 = function
|
|
| [] -> if accu1 = [] then
|
|
List.rev accu2
|
|
else
|
|
List.rev ((List.rev accu1) :: accu2)
|
|
| a :: rest ->
|
|
match i with
|
|
| 0 -> (aux [@tailcall]) (n-1) [] ((List.rev (a::accu1)) :: accu2) rest
|
|
| _ -> (aux [@tailcall]) (i-1) (a::accu1) accu2 rest
|
|
in
|
|
aux (n-1) [] [] l
|
|
|
|
|
|
(** {2 Stream functions} *)
|
|
|
|
let stream_range first last =
|
|
Stream.from (fun i ->
|
|
let result = i+first in
|
|
if result <= last then
|
|
Some result
|
|
else None
|
|
)
|
|
|
|
let stream_to_list stream =
|
|
let rec aux accu =
|
|
let new_accu =
|
|
try
|
|
Some (Stream.next stream :: accu)
|
|
with Stream.Failure -> None
|
|
in
|
|
match new_accu with
|
|
| Some new_accu -> (aux [@tailcall]) new_accu
|
|
| None -> accu
|
|
in List.rev @@ aux []
|
|
|
|
|
|
let stream_fold f init stream =
|
|
let rec aux accu =
|
|
let new_accu =
|
|
try
|
|
let element = Stream.next stream in
|
|
Some (f accu element)
|
|
with Stream.Failure -> None
|
|
in
|
|
match new_accu with
|
|
| Some new_accu -> (aux [@tailcall]) new_accu
|
|
| None -> accu
|
|
in
|
|
aux init
|
|
|
|
(** {2 Array functions} *)
|
|
|
|
let array_range first last =
|
|
if last < first then [| |] else
|
|
Array.init (last-first+1) (fun i -> i+first)
|
|
|
|
|
|
(** {2 Linear algebra} *)
|
|
|
|
|
|
let array_sum a =
|
|
Array.fold_left ( +. ) 0. a
|
|
|
|
let array_product a =
|
|
Array.fold_left ( *. ) 0. a
|
|
|
|
|
|
let diagonalize_symm m_H =
|
|
let m_V = lacpy m_H in
|
|
let result =
|
|
syevd ~vectors:true m_V
|
|
in
|
|
m_V, result
|
|
|
|
let xt_o_x ~o ~x =
|
|
gemm o x
|
|
|> gemm ~transa:`T x
|
|
|
|
let x_o_xt ~o ~x =
|
|
gemm o x ~transb:`T
|
|
|> gemm x
|
|
|
|
|
|
let canonical_ortho ?thresh:(thresh=1.e-6) ~overlap c =
|
|
let d, u, _ = gesvd (lacpy overlap) in
|
|
let d_sqrt = Vec.sqrt d in
|
|
let n = (* Number of non-negligible singular vectors *)
|
|
Vec.fold (fun accu x -> if x > thresh then accu + 1 else accu) 0 d
|
|
in
|
|
let d_inv_sq = (* D^{-1/2} *)
|
|
Vec.map (fun x ->
|
|
if x >= thresh then 1. /. x
|
|
else 0. ) ~y:d d_sqrt
|
|
in
|
|
if n < Vec.dim d_sqrt then
|
|
Printf.printf "Removed linear dependencies below %f\n" (1. /. d.{n})
|
|
;
|
|
Mat.scal_cols u d_inv_sq ;
|
|
gemm c u
|
|
|
|
|
|
let qr_ortho m =
|
|
(** Performed twice for precision *)
|
|
let result = lacpy m in
|
|
let tau = geqrf result in
|
|
orgqr ~tau result;
|
|
let tau = geqrf result in
|
|
orgqr ~tau result;
|
|
result
|
|
|
|
|
|
let normalize v =
|
|
let result = copy v in
|
|
scal (1. /. (nrm2 v)) result;
|
|
result
|
|
|
|
|
|
let normalize_mat m =
|
|
Mat.to_col_vecs m
|
|
|> Array.map (fun v -> normalize v)
|
|
|> Mat.of_col_vecs
|
|
|
|
|
|
(** {2 Bitstring functions} *)
|
|
let bit_permtutations m n =
|
|
|
|
let rec aux k u rest =
|
|
if k=1 then
|
|
List.rev (u :: rest)
|
|
else
|
|
let t = Z.(logor u (u-one)) in
|
|
let t' = Z.(t+one) in
|
|
let t'' = Z.(shift_right ((logand (lognot t) t') - one)) (Z.trailing_zeros u + 1) in
|
|
(aux [@tailcall]) (k-1) (Z.logor t' t'') (u :: rest)
|
|
in
|
|
aux (binom n m) Z.(shift_left one m - one) []
|
|
|
|
|
|
(** {2 Printers} *)
|
|
|
|
let pp_float_array_size ppf a =
|
|
Format.fprintf ppf "@[<2>@[ %d:@[<2>" (Array.length a);
|
|
Array.iter (fun f -> Format.fprintf ppf "@[%10f@]@ " f) a;
|
|
Format.fprintf ppf "]@]@]"
|
|
|
|
let pp_float_array ppf a =
|
|
Format.fprintf ppf "@[<2>[@ ";
|
|
Array.iter (fun f -> Format.fprintf ppf "@[%10f@]@ " f) a;
|
|
Format.fprintf ppf "]@]"
|
|
|
|
let pp_float_2darray ppf a =
|
|
Format.fprintf ppf "@[<2>[@ ";
|
|
Array.iter (fun f -> Format.fprintf ppf "@[%a@]@ " pp_float_array f) a;
|
|
Format.fprintf ppf "]@]"
|
|
|
|
let pp_float_2darray_size ppf a =
|
|
Format.fprintf ppf "@[<2>@[ %d:@[" (Array.length a);
|
|
Array.iter (fun f -> Format.fprintf ppf "@[%a@]@ " pp_float_array_size f) a;
|
|
Format.fprintf ppf "]@]@]"
|
|
|
|
let pp_matrix ppf m =
|
|
let open Lacaml.Io in
|
|
let rows = Mat.dim1 m
|
|
and cols = Mat.dim2 m
|
|
in
|
|
let rec aux first last =
|
|
if (first <= last) then begin
|
|
Format.fprintf ppf "@[\n\n %a@]@ " (Lacaml.Io.pp_lfmat
|
|
~row_labels:
|
|
(Array.init rows (fun i -> Printf.sprintf "%d " (i + 1)))
|
|
~col_labels:
|
|
(Array.init (min 5 (cols-first+1)) (fun i -> Printf.sprintf "-- %d --" (i + first) ))
|
|
~print_right:false
|
|
~print_foot:false
|
|
() ) (lacpy ~ac:first ~n:(min 5 (cols-first+1)) m);
|
|
(aux [@tailcall]) (first+5) last
|
|
end
|
|
in
|
|
aux 1 cols
|
|
|
|
|
|
let pp_bitstring n ppf bs =
|
|
String.init n (fun i -> if (Z.testbit bs i) then '+' else '-')
|
|
|> Format.fprintf ppf "@[<h>%s@]"
|
|
|
|
|
|
|
|
|
|
let string_of_matrix m =
|
|
Format.asprintf "%a" pp_matrix m
|
|
|
|
let debug_matrix name a =
|
|
Format.printf "@[%s =\n@[%a@]@]@." name pp_matrix a
|
|
|
|
|
|
let matrix_of_file filename =
|
|
let ic = Scanf.Scanning.open_in filename in
|
|
let rec read_line accu =
|
|
let result =
|
|
try
|
|
Some (Scanf.bscanf ic " %d %d %f" (fun i j v ->
|
|
(i,j,v) :: accu))
|
|
with End_of_file -> None
|
|
in
|
|
match result with
|
|
| Some accu -> (read_line [@tailcall]) accu
|
|
| None -> List.rev accu
|
|
in
|
|
let data = read_line [] in
|
|
Scanf.Scanning.close_in ic;
|
|
let isize, jsize =
|
|
List.fold_left (fun (accu_i,accu_j) (i,j,v) ->
|
|
(max i accu_i, max j accu_j)) (0,0) data
|
|
in
|
|
let result =
|
|
Lacaml.D.Mat.of_array
|
|
(Array.make_matrix isize jsize 0.)
|
|
in
|
|
List.iter (fun (i,j,v) -> result.{i,j} <- v) data;
|
|
result
|
|
|
|
|
|
let sym_matrix_of_file filename =
|
|
let result =
|
|
matrix_of_file filename
|
|
in
|
|
for j=1 to Mat.dim1 result do
|
|
for i=1 to j do
|
|
result.{j,i} <- result.{i,j}
|
|
done;
|
|
done;
|
|
result
|
|
|
|
|
|
|
|
let test_case () =
|
|
|
|
let test_external () =
|
|
Alcotest.(check (float 1.e-15)) "erf" 0.842700792949715 (erf_float 1.0);
|
|
Alcotest.(check (float 1.e-15)) "erf" 0.112462916018285 (erf_float 0.1);
|
|
Alcotest.(check (float 1.e-15)) "erf" (-0.112462916018285) (erf_float (-0.1));
|
|
Alcotest.(check (float 1.e-15)) "erfc" 0.157299207050285 (erfc_float 1.0);
|
|
Alcotest.(check (float 1.e-15)) "erfc" 0.887537083981715 (erfc_float 0.1);
|
|
Alcotest.(check (float 1.e-15)) "erfc" (1.112462916018285) (erfc_float (-0.1));
|
|
Alcotest.(check (float 1.e-14)) "gamma" (1.77245385090552) (gamma_float 0.5);
|
|
Alcotest.(check (float 1.e-14)) "gamma" (9.51350769866873) (gamma_float (0.1));
|
|
Alcotest.(check (float 1.e-14)) "gamma" (-3.54490770181103) (gamma_float (-0.5));
|
|
Alcotest.(check int) "popcnt" 6 (popcnt @@ Int64.of_int 63);
|
|
Alcotest.(check int) "popcnt" 8 (popcnt @@ Int64.of_int 299605);
|
|
Alcotest.(check int) "popcnt" 1 (popcnt @@ Int64.of_int 65536);
|
|
Alcotest.(check int) "popcnt" 0 (popcnt @@ Int64.of_int 0);
|
|
in
|
|
[
|
|
"External", `Quick, test_external;
|
|
]
|
|
|