(** All utilities which should be included in all source files are defined here *) (** {1 Functions from libm} *) open Constants open Lacaml.D external erf_float : float -> float = "erf_float_bytecode" "erf_float" [@@unboxed] [@@noalloc] external erfc_float : float -> float = "erfc_float_bytecode" "erfc_float" [@@unboxed] [@@noalloc] external gamma_float : float -> float = "gamma_float_bytecode" "gamma_float" [@@unboxed] [@@noalloc] external popcnt : int64 -> int32 = "popcnt_bytecode" "popcnt" [@@unboxed] [@@noalloc] (** popcnt instruction *) let popcnt i = popcnt i |> Int32.to_int external trailz : int64 -> int32 = "trailz_bytecode" "trailz" [@@unboxed] [@@noalloc] (** ctz instruction *) let trailz i = trailz i |> Int32.to_int external leadz : int64 -> int32 = "leadz_bytecode" "leadz" [@@unboxed] [@@noalloc] (** bsf instruction *) let leadz i = leadz i |> Int32.to_int exception SIGTERM let () = let f _ = raise SIGTERM in Sys.set_signal Sys.sigint (Sys.Signal_handle f) ;; let factmax = 150 (* Incomplete gamma function : Int_0^x exp(-t) t^(a-1) dt p: 1 / Gamma(a) * Int_0^x exp(-t) t^(a-1) dt q: 1 / Gamma(a) * Int_x^inf exp(-t) t^(a-1) dt reference - Haruhiko Okumura: C-gengo niyoru saishin algorithm jiten (New Algorithm handbook in C language) (Gijyutsu hyouron sha, Tokyo, 1991) p.227 [in Japanese] *) let incomplete_gamma ~alpha x = let a = alpha in let a_inv = 1./. a in let gf = gamma_float alpha in let loggamma_a = log gf in let rec p_gamma x = if x >= 1. +. a then 1. -. q_gamma x else if x = 0. then 0. else let rec pg_loop prev res term k = if k > 1000. then failwith "p_gamma did not converge." else if prev = res then res else let term = term *. x /. (a +. k) in (pg_loop [@tailcall]) res (res +. term) term (k +. 1.) in let r0 = exp (a *. log x -. x -. loggamma_a) *. a_inv in pg_loop min_float r0 r0 1. and q_gamma x = if x < 1. +. a then 1. -. p_gamma x else let rec qg_loop prev res la lb w k = if k > 1000. then failwith "q_gamma did not converge." else if prev = res then res else let k_inv = 1. /. k in let kma = (k -. 1. -. a) *. k_inv in let la, lb = lb, kma *. (lb -. la) +. (k +. x) *. lb *. k_inv in let w = w *. kma in let prev, res = res, res +. w /. (la *. lb) in (qg_loop [@tailcall]) prev res la lb w (k +. 1.) in let w = exp (a *. log x -. x -. loggamma_a) in let lb = (1. +. x -. a) in qg_loop min_float (w /. lb) 1. lb w 2.0 in gf *. p_gamma x let fact_memo = let rec aux accu_l accu = function | 0 -> (aux [@tailcall]) [1.] 1. 1 | i when (i = factmax) -> let x = (float_of_int factmax) *. accu in List.rev (x::accu_l) | i -> let x = (float_of_int i) *. accu in (aux [@tailcall]) (x::accu_l) x (i+1) in aux [] 0. 0 |> Array.of_list let fact = function | i when (i < 0) -> raise (Invalid_argument "Argument of factorial should be non-negative") | i when (i > 150) -> raise (Invalid_argument "Result of factorial is infinite") | i -> fact_memo.(i) let binom n k = (*TODO : slow function *) assert (n >= k); let rec aux n k = if k = 0 || k = n then 1 else aux (n-1) (k-1) + aux (n-1) k in aux n k let rec pow a = function | 0 -> 1. | 1 -> a | 2 -> a *. a | 3 -> a *. a *. a | -1 -> 1. /. a | n when n > 0 -> let b = pow a (n / 2) in b *. b *. (if n mod 2 = 0 then 1. else a) | n when n < 0 -> (pow [@tailcall]) (1./.a) (-n) | _ -> assert false let chop f g = if (abs_float f) < Constants.epsilon then 0. else f *. (g ()) (** Generalized Boys function. maxm : Maximum total angular momentum *) let boys_function ~maxm t = match maxm with | 0 -> begin if t = 0. then [| 1. |] else let sq_t = sqrt t in [| (sq_pi_over_two /. sq_t) *. erf_float sq_t |] end | _ -> begin let result = Array.init (maxm+1) (fun m -> 1. /. float_of_int (2*m+1)) in (* assert (abs_float t > 1.e-10); *) if t <> 0. then begin let fmax = let t_inv = sqrt (1. /. t) in let n = float_of_int maxm in let dm = 0.5 +. n in let f = (pow t_inv (maxm+maxm+1) ) in match classify_float f with | FP_zero | FP_subnormal | FP_normal -> (incomplete_gamma dm t) *. 0.5 *. f | _ -> invalid_arg "zero_m overflow" in let emt = exp (-. t) in result.(maxm) <- fmax; for n=maxm-1 downto 0 do result.(n) <- ( (t+.t) *. result.(n+1) +. emt) *. result.(n) done end; result end let of_some = function | Some a -> a | None -> assert false (** {2 List functions} *) let list_some l = List.filter (function None -> false | _ -> true) l |> List.map (function Some x -> x | _ -> assert false) let list_range first last = if last < first then [] else let rec aux accu = function | 0 -> first :: accu | i -> (aux [@tailcall]) ( (first+i)::accu ) (i-1) in aux [] (last-first) let list_pack n l = let rec aux i accu1 accu2 = function | [] -> if accu1 = [] then List.rev accu2 else List.rev ((List.rev accu1) :: accu2) | a :: rest -> match i with | 0 -> (aux [@tailcall]) (n-1) [] ((List.rev (a::accu1)) :: accu2) rest | _ -> (aux [@tailcall]) (i-1) (a::accu1) accu2 rest in aux (n-1) [] [] l (** {2 Stream functions} *) let stream_range first last = Stream.from (fun i -> let result = i+first in if result <= last then Some result else None ) let stream_to_list stream = let rec aux accu = let new_accu = try Some (Stream.next stream :: accu) with Stream.Failure -> None in match new_accu with | Some new_accu -> (aux [@tailcall]) new_accu | None -> accu in List.rev @@ aux [] let stream_fold f init stream = let rec aux accu = let new_accu = try let element = Stream.next stream in Some (f accu element) with Stream.Failure -> None in match new_accu with | Some new_accu -> (aux [@tailcall]) new_accu | None -> accu in aux init (** {2 Linear algebra} *) let array_sum a = Array.fold_left ( +. ) 0. a let array_product a = Array.fold_left ( *. ) 0. a let diagonalize_symm m_H = let m_V = lacpy m_H in let result = syevd ~vectors:true m_V in m_V, result let xt_o_x ~o ~x = gemm o x |> gemm ~transa:`T x let x_o_xt ~o ~x = gemm o x ~transb:`T |> gemm x let canonical_ortho ?thresh:(thresh=1.e-6) ~overlap c = let d, u, _ = gesvd (lacpy overlap) in let d_sqrt = Vec.sqrt d in let n = (* Number of non-negligible singular vectors *) Vec.fold (fun accu x -> if x > thresh then accu + 1 else accu) 0 d in let d_inv_sq = (* D^{-1/2} *) Vec.map (fun x -> if x >= thresh then 1. /. x else 0. ) ~y:d d_sqrt in if n < Vec.dim d_sqrt then Printf.printf "Removed linear dependencies below %f\n" (1. /. d.{n}) ; Mat.scal_cols u d_inv_sq ; gemm c u let qr_ortho m = (** Performed twice for precision *) let result = lacpy m in let tau = geqrf result in orgqr ~tau result; let tau = geqrf result in orgqr ~tau result; result let normalize v = let result = copy v in scal (1. /. (nrm2 v)) result; result let normalize_mat m = Mat.to_col_vecs m |> Array.map (fun v -> normalize v) |> Mat.of_col_vecs (** {2 Bitstring functions} *) let bit_permtutations m n = let rec aux k u rest = if k=1 then List.rev (u :: rest) else let t = Z.(logor u (u-one)) in let t' = Z.(t+one) in let t'' = Z.(shift_right ((logand (lognot t) t') - one)) (Z.trailing_zeros u + 1) in (aux [@tailcall]) (k-1) (Z.logor t' t'') (u :: rest) in aux (binom n m) Z.(shift_left one m - one) [] (** {2 Printers} *) let pp_float_array_size ppf a = Format.fprintf ppf "@[<2>@[ %d:@[<2>" (Array.length a); Array.iter (fun f -> Format.fprintf ppf "@[%10f@]@ " f) a; Format.fprintf ppf "]@]@]" let pp_float_array ppf a = Format.fprintf ppf "@[<2>[@ "; Array.iter (fun f -> Format.fprintf ppf "@[%10f@]@ " f) a; Format.fprintf ppf "]@]" let pp_float_2darray ppf a = Format.fprintf ppf "@[<2>[@ "; Array.iter (fun f -> Format.fprintf ppf "@[%a@]@ " pp_float_array f) a; Format.fprintf ppf "]@]" let pp_float_2darray_size ppf a = Format.fprintf ppf "@[<2>@[ %d:@[" (Array.length a); Array.iter (fun f -> Format.fprintf ppf "@[%a@]@ " pp_float_array_size f) a; Format.fprintf ppf "]@]@]" let pp_matrix ppf m = let open Lacaml.Io in let rows = Mat.dim1 m and cols = Mat.dim2 m in let rec aux first last = if (first <= last) then begin Format.fprintf ppf "@[\n\n %a@]@ " (Lacaml.Io.pp_lfmat ~row_labels: (Array.init rows (fun i -> Printf.sprintf "%d " (i + 1))) ~col_labels: (Array.init (min 5 (cols-first+1)) (fun i -> Printf.sprintf "-- %d --" (i + first) )) ~print_right:false ~print_foot:false () ) (lacpy ~ac:first ~n:(min 5 (cols-first+1)) m); (aux [@tailcall]) (first+5) last end in aux 1 cols let pp_bitstring n ppf bs = String.init n (fun i -> if (Z.testbit bs i) then '+' else '-') |> Format.fprintf ppf "@[%s@]" let string_of_matrix m = Format.asprintf "%a" pp_matrix m let debug_matrix name a = Format.printf "@[%s =\n@[%a@]@]@." name pp_matrix a let matrix_of_file filename = let ic = Scanf.Scanning.open_in filename in let rec read_line accu = let result = try Some (Scanf.bscanf ic " %d %d %f" (fun i j v -> (i,j,v) :: accu)) with End_of_file -> None in match result with | Some accu -> (read_line [@tailcall]) accu | None -> List.rev accu in let data = read_line [] in Scanf.Scanning.close_in ic; let isize, jsize = List.fold_left (fun (accu_i,accu_j) (i,j,v) -> (max i accu_i, max j accu_j)) (0,0) data in let result = Lacaml.D.Mat.of_array (Array.make_matrix isize jsize 0.) in List.iter (fun (i,j,v) -> result.{i,j} <- v) data; result let sym_matrix_of_file filename = let result = matrix_of_file filename in for j=1 to Mat.dim1 result do for i=1 to j do result.{j,i} <- result.{i,j} done; done; result let test_case () = let test_external () = Alcotest.(check (float 1.e-15)) "erf" 0.842700792949715 (erf_float 1.0); Alcotest.(check (float 1.e-15)) "erf" 0.112462916018285 (erf_float 0.1); Alcotest.(check (float 1.e-15)) "erf" (-0.112462916018285) (erf_float (-0.1)); Alcotest.(check (float 1.e-15)) "erfc" 0.157299207050285 (erfc_float 1.0); Alcotest.(check (float 1.e-15)) "erfc" 0.887537083981715 (erfc_float 0.1); Alcotest.(check (float 1.e-15)) "erfc" (1.112462916018285) (erfc_float (-0.1)); Alcotest.(check (float 1.e-14)) "gamma" (1.77245385090552) (gamma_float 0.5); Alcotest.(check (float 1.e-14)) "gamma" (9.51350769866873) (gamma_float (0.1)); Alcotest.(check (float 1.e-14)) "gamma" (-3.54490770181103) (gamma_float (-0.5)); Alcotest.(check int) "popcnt" 6 (popcnt @@ Int64.of_int 63); Alcotest.(check int) "popcnt" 8 (popcnt @@ Int64.of_int 299605); Alcotest.(check int) "popcnt" 1 (popcnt @@ Int64.of_int 65536); Alcotest.(check int) "popcnt" 0 (popcnt @@ Int64.of_int 0); in [ "External", `Quick, test_external; ]