(** Utility functions *) external erf_float : float -> float = "erf_float_bytecode" "erf_float" [@@unboxed] [@@noalloc] external erfc_float : float -> float = "erfc_float_bytecode" "erfc_float" [@@unboxed] [@@noalloc] external gamma_float : float -> float = "gamma_float_bytecode" "gamma_float" [@@unboxed] [@@noalloc] external popcnt : int64 -> int32 = "popcnt_bytecode" "popcnt" [@@unboxed] [@@noalloc] let[@inline] popcnt i = Int32.to_int (popcnt i) external trailz : int64 -> int32 = "trailz_bytecode" "trailz" "int" [@@unboxed] [@@noalloc] let[@inline] trailz i = Int32.to_int (trailz i) external leadz : int64 -> int32 = "leadz_bytecode" "leadz" "int" [@@unboxed] [@@noalloc] let[@inline] leadz i = Int32.to_int (leadz i) let memo_float_of_int = Array.init 64 float_of_int let float_of_int_fast i = if Int.logand i 63 = i then memo_float_of_int.(i) else float_of_int i let factmax = 150 let fact_memo = let rec aux accu_l accu = function | 0 -> (aux [@tailcall]) [1.] 1. 1 | i when (i = factmax) -> let x = (float_of_int factmax) *. accu in List.rev (x::accu_l) | i -> let x = (float_of_int i) *. accu in (aux [@tailcall]) (x::accu_l) x (i+1) in aux [] 0. 0 |> Array.of_list let fact = function | i when (i < 0) -> raise (Invalid_argument "Argument of factorial should be non-negative") | i when (i > 150) -> raise (Invalid_argument "Result of factorial is infinite") | i -> fact_memo.(i) let binom = let memo = let m = Array.make_matrix 64 64 0 in for n=0 to Array.length m - 1 do m.(n).(0) <- 1; m.(n).(n) <- 1; for k=1 to (n - 1) do m.(n).(k) <- m.(n-1).(k-1) + m.(n-1).(k) done done; m in let rec f n k = assert (k >= 0); assert (n >= k); if k = 0 || k = n then 1 else if n < 64 then memo.(n).(k) else f (n-1) (k-1) + f (n-1) k in f let binom_float n k = binom n k |> float_of_int_fast let rec pow a = function | 0 -> 1. | 1 -> a | 2 -> a *. a | 3 -> a *. a *. a | -1 -> 1. /. a | n when n > 0 -> let b = pow a (n / 2) in b *. b *. (if n mod 2 = 0 then 1. else a) | n when n < 0 -> (pow [@tailcall]) (1./.a) (-n) | _ -> assert false let chop f g = if (abs_float f) < Constants.epsilon then 0. else f *. (g ()) exception Not_implemented of string let not_implemented string = raise (Not_implemented string) let of_some = function | Some a -> a | None -> assert false let incomplete_gamma ~alpha x = assert (alpha >= 0.); assert (x >= 0.); let a = alpha in let a_inv = 1./. a in let gf = gamma_float alpha in let loggamma_a = log gf in let rec p_gamma x = if x >= 1. +. a then 1. -. q_gamma x else if x = 0. then 0. else let rec pg_loop prev res term k = if k > 1000. then failwith "p_gamma did not converge." else if prev = res then res else let term = term *. x /. (a +. k) in (pg_loop [@tailcall]) res (res +. term) term (k +. 1.) in let r0 = exp (a *. log x -. x -. loggamma_a) *. a_inv in pg_loop min_float r0 r0 1. and q_gamma x = if x < 1. +. a then 1. -. p_gamma x else let rec qg_loop prev res la lb w k = if k > 1000. then failwith "q_gamma did not converge." else if prev = res then res else let k_inv = 1. /. k in let kma = (k -. 1. -. a) *. k_inv in let la, lb = lb, kma *. (lb -. la) +. (k +. x) *. lb *. k_inv in let w = w *. kma in let prev, res = res, res +. w /. (la *. lb) in (qg_loop [@tailcall]) prev res la lb w (k +. 1.) in let w = exp (a *. log x -. x -. loggamma_a) in let lb = (1. +. x -. a) in qg_loop min_float (w /. lb) 1. lb w 2.0 in gf *. p_gamma x let boys_function ~maxm t = assert (t >= 0.); match maxm with | 0 -> begin if t = 0. then [| 1. |] else let sq_t = sqrt t in [| (Constants.sq_pi_over_two /. sq_t) *. erf_float sq_t |] end | _ -> begin assert (maxm > 0); let result = Array.init (maxm+1) (fun m -> 1. /. float_of_int (2*m+1)) in let power_t_inv = (maxm+maxm+1) in try let fmax = let t_inv = sqrt (1. /. t) in let n = float_of_int maxm in let dm = 0.5 +. n in let f = (pow t_inv power_t_inv ) in match classify_float f with | FP_normal -> (incomplete_gamma ~alpha:dm t) *. 0.5 *. f | FP_zero | FP_subnormal -> 0. | _ -> raise Exit in let emt = exp (-. t) in result.(maxm) <- fmax; for n=maxm-1 downto 0 do result.(n) <- ( (t+.t) *. result.(n+1) +. emt) *. result.(n) done; result with Exit -> result end let list_some l = List.filter (function None -> false | _ -> true) l |> List.rev_map (function Some x -> x | _ -> assert false) |> List.rev let list_range first last = if last < first then [] else let rec aux accu = function | 0 -> first :: accu | i -> (aux [@tailcall]) ( (first+i)::accu ) (i-1) in aux [] (last-first) let list_pack n l = assert (n>=0); let rec aux i accu1 accu2 = function | [] -> if accu1 = [] then List.rev accu2 else List.rev ((List.rev accu1) :: accu2) | a :: rest -> match i with | 0 -> (aux [@tailcall]) (n-1) [] ((List.rev (a::accu1)) :: accu2) rest | _ -> (aux [@tailcall]) (i-1) (a::accu1) accu2 rest in aux (n-1) [] [] l let array_range first last = if last < first then [| |] else Array.init (last-first+1) (fun i -> i+first) let array_sum a = Array.fold_left ( +. ) 0. a let array_product a = Array.fold_left ( *. ) 1. a let seq_range first last = Seq.init (last-first) (fun i -> i+first) let seq_to_list seq = let rec aux accu xs = match Seq.uncons xs with | Some (x, xs) -> aux (x::accu) xs | None -> List.rev accu in aux [] seq let seq_fold f init seq = Seq.fold_left f init seq let pp_float_array ppf a = Format.fprintf ppf "@[<2>[@ "; Array.iter (fun f -> Format.fprintf ppf "@[%10f@]@ " f) a; Format.fprintf ppf "]@]" let pp_float_array_size ppf a = Format.fprintf ppf "@[<2>@[ %d:@[<2>" (Array.length a); Array.iter (fun f -> Format.fprintf ppf "@[%10f@]@ " f) a; Format.fprintf ppf "]@]@]" let pp_float_2darray ppf a = Format.fprintf ppf "@[<2>[@ "; Array.iter (fun f -> Format.fprintf ppf "@[%a@]@ " pp_float_array f) a; Format.fprintf ppf "]@]" let pp_float_2darray_size ppf a = Format.fprintf ppf "@[<2>@[ %d:@[" (Array.length a); Array.iter (fun f -> Format.fprintf ppf "@[%a@]@ " pp_float_array_size f) a; Format.fprintf ppf "]@]@]" let pp_bitstring n ppf bs = String.init n (fun i -> if (Z.testbit bs i) then '+' else '-') |> Format.fprintf ppf "@[%s@]"