open Util open Lacaml.D open Constants let max_index = 1 lsl 14 type index_pair = { first : int ; second : int } type array2 = (float, Bigarray.float64_elt, Bigarray.fortran_layout) Bigarray.Array2.t type storage_t = | Dense of array2 | Sparse of (int, float) Hashtbl.t type t = { size : int ; two_index : array2; two_index_anti : array2; three_index : array2; three_index_anti : array2; four_index : storage_t ; } let key_of_indices ~r1 ~r2 = let { first=i ; second=k } = r1 and { first=j ; second=l } = r2 in let f i k = let p, r = if i <= k then i, k else k, i in p + (r*(r-1))/2 in let p = f i k and q = f j l in f p q let check_bounds r1 r2 t = let { first=i ; second=k } = r1 and { first=j ; second=l } = r2 in let size = t.size in assert ( (i lor j lor k lor l) > 0 ); assert ( i <= size && j <= size && k <= size && l <= size ) let dense_index i j size = (j-1)*size + i let sym_index i j = if i < j then (j*(j-1))/2 + i else (i*(i-1))/2 + j let unsafe_get_four_index ~r1 ~r2 t = let open Bigarray.Array2 in let { first=i ; second=k } = r1 and { first=j ; second=l } = r2 in if i=k then if j=l then unsafe_get t.two_index i j else unsafe_get t.three_index (dense_index j l t.size) i else if j=l then unsafe_get t.three_index (dense_index i k t.size) j else if i=l then if k=j then unsafe_get t.two_index_anti i j else unsafe_get t.three_index_anti (dense_index j k t.size) i else if j=k then unsafe_get t.three_index_anti (dense_index i l t.size) j else if i=j then if k=l then unsafe_get t.two_index_anti i k else unsafe_get t.three_index_anti (dense_index k l t.size) i else if k=l then (* *) unsafe_get t.three_index_anti (dense_index i j t.size) k else match t.four_index with | Dense a -> unsafe_get a (dense_index i k t.size) (sym_index j l) | Sparse a -> let key = key_of_indices ~r1 ~r2 in try Hashtbl.find a key with Not_found -> 0. let get_four_index ~r1 ~r2 t = check_bounds r1 r2 t; unsafe_get_four_index ~r1 ~r2 t let unsafe_set_four_index ~r1 ~r2 ~value t = let open Bigarray.Array2 in let { first=i ; second=k } = r1 and { first=j ; second=l } = r2 in if i=k then if j=l then begin unsafe_set t.two_index i j value; unsafe_set t.two_index j i value; end else begin unsafe_set t.three_index (dense_index j l t.size) i value; unsafe_set t.three_index (dense_index l j t.size) i value; end else if j=l then begin unsafe_set t.three_index (dense_index i k t.size) j value; unsafe_set t.three_index (dense_index k i t.size) j value; end else if i=l then if j=k then begin unsafe_set t.two_index_anti i j value; unsafe_set t.two_index_anti j i value; end else begin unsafe_set t.three_index_anti (dense_index j k t.size) i value; unsafe_set t.three_index_anti (dense_index k j t.size) i value; end else if j=k then begin unsafe_set t.three_index_anti (dense_index i l t.size) j value; unsafe_set t.three_index_anti (dense_index l i t.size) j value; end else if i=j then if k=l then begin unsafe_set t.two_index_anti i k value; unsafe_set t.two_index_anti k i value; end else (* *) begin unsafe_set t.three_index_anti (dense_index k l t.size) i value; unsafe_set t.three_index_anti (dense_index l k t.size) i value; end else if k=l then (* *) begin unsafe_set t.three_index_anti (dense_index i j t.size) k value; unsafe_set t.three_index_anti (dense_index j i t.size) k value; end else match t.four_index with | Dense a -> let ik = (dense_index i k t.size) and jl = (dense_index j l t.size) and ki = (dense_index k i t.size) and lj = (dense_index l j t.size) and ik_s = (sym_index i k) and jl_s = (sym_index j l) in begin unsafe_set a ik jl_s value; unsafe_set a ki jl_s value; unsafe_set a jl ik_s value; unsafe_set a lj ik_s value; end | Sparse a -> let key = key_of_indices ~r1 ~r2 in Hashtbl.replace a key value let set_four_index ~r1 ~r2 ~value t = check_bounds r1 r2 t; unsafe_set_four_index ~r1 ~r2 ~value t let unsafe_increment_four_index ~r1 ~r2 ~value t = let updated_value = value +. unsafe_get_four_index ~r1 ~r2 t in unsafe_set_four_index ~r1 ~r2 ~value:updated_value t let increment_four_index ~r1 ~r2 ~value t = check_bounds r1 r2 t; unsafe_increment_four_index ~r1 ~r2 ~value t let get ~r1 ~r2 a = get_four_index ~r1 ~r2 a let set ~r1 ~r2 ~value = match classify_float value with | FP_normal -> set_four_index ~r1 ~r2 ~value | FP_zero | FP_subnormal -> fun _ -> () | FP_infinite | FP_nan -> let msg = Printf.sprintf "FourIdxStorage.ml : set : r1 = (%d,%d) ; r2 = (%d,%d)" r1.first r1.second r2.first r2.second in raise (Invalid_argument msg) let increment ~r1 ~r2 = increment_four_index ~r1 ~r2 let create ~size ?(temp_dir="/dev/shm") sparsity = assert (size < max_index); let two_index = SharedMemory.create ~temp_dir Float64 [| size ; size |] |> Bigarray.array2_of_genarray in let two_index_anti = SharedMemory.create ~temp_dir Float64 [| size ; size |] |> Bigarray.array2_of_genarray in let three_index = SharedMemory.create ~temp_dir Float64 [| size * size ; size |] |> Bigarray.array2_of_genarray in let three_index_anti = SharedMemory.create ~temp_dir Float64 [| size * size ; size |] |> Bigarray.array2_of_genarray in let four_index = match sparsity with | `Dense -> let result = SharedMemory.create ~temp_dir Float64 [| size*size ; (size*(size+1))/2 |] |> Bigarray.array2_of_genarray in Dense result | `Sparse -> let result = Hashtbl.create (size*size+13) in Sparse result in { size ; two_index ; two_index_anti ; three_index ; three_index_anti ; four_index } let size t = t.size (** TODO : remove epsilons *) let get_chem t i j k l = get ~r1:{ first=i ; second=j } ~r2:{ first=k ; second=l } t let get_phys t i j k l = get ~r1:{ first=i ; second=k } ~r2:{ first=j ; second=l } t let set_chem t i j k l value = set ~r1:{ first=i ; second=j } ~r2:{ first=k ; second=l } ~value t let set_phys t i j k l value = set ~r1:{ first=i ; second=k } ~r2:{ first=j ; second=l } ~value t type element = (** Element for the stream *) { i_r1: int ; j_r2: int ; k_r1: int ; l_r2: int ; value: float } let get_phys_all_i d ~j ~k ~l = Array.init d.size (fun i -> get_phys d (i+1) j k l) let get_chem_all_i d ~j ~k ~l = Array.init d.size (fun i -> get_chem d (i+1) j k l) let get_phys_all_ji d ~k ~l = Array.init d.size (fun j -> get_phys_all_i d ~j:(j+1) ~k ~l) let get_chem_all_ji d ~k ~l = Array.init d.size (fun j -> get_chem_all_i d ~j:(j+1) ~k ~l) let to_stream d = let i = ref 0 and j = ref 1 and k = ref 1 and l = ref 1 in let rec f_dense _ = i := !i+1; if !i > !k then begin i := 1; j := !j + 1; if !j > !l then begin j := 1; k := !k + 1; if !k > !l then begin k := 1; l := !l + 1; end; end; end; if !l <= d.size then Some { i_r1 = !i ; j_r2 = !j ; k_r1 = !k ; l_r2 = !l ; value = get_phys d !i !j !k !l } else None in Stream.from f_dense (** Write all integrals to a file with the convention *) let to_file ?(cutoff=Constants.epsilon) ~filename data = let oc = open_out filename in to_stream data |> Stream.iter (fun {i_r1 ; j_r2 ; k_r1 ; l_r2 ; value} -> if (abs_float value > cutoff) then Printf.fprintf oc " %5d %5d %5d %5d%20.15f\n" i_r1 j_r2 k_r1 l_r2 value); close_out oc let of_file ~size ~sparsity filename = let result = create ~size sparsity in let ic = Scanf.Scanning.open_in filename in let rec read_line () = let result = try Some (Scanf.bscanf ic " %d %d %d %d %f" (fun i j k l v -> set_phys result i j k l v)) with End_of_file -> None in match result with | Some () -> read_line () | None -> () in read_line (); Scanf.Scanning.close_in ic; result let to_list data = let s = to_stream data in let rec append accu = let d = try Some (Stream.next s) with | Stream.Failure -> None in match d with | None -> List.rev accu | Some d -> append (d :: accu) in append [] let broadcast t = t (* let size = Parallel.broadcast (lazy t.size) in let bufsize = size * size * size in let stream = to_stream t in let rec iterate () = let buffer = Parallel.broadcast (lazy ( if Stream.peek stream = None then None else Some (Array.init bufsize (fun _ -> try Some (Stream.next stream) with _ -> None )) ) ) in match buffer with | None -> () | Some buffer -> begin if not Parallel.master then Array.iter (fun x -> match x with | Some {i_r1 ; j_r2 ; k_r1 ; l_r2 ; value} -> set_phys t i_r1 j_r2 k_r1 l_r2 value | None -> () ) buffer; iterate () end in iterate (); t *) let four_index_transform coef source = let ao_num = Mat.dim1 coef in let mo_num = Mat.dim2 coef in let destination = match source.four_index with | Dense _ -> create ~size:mo_num `Dense | Sparse _ -> create ~size:mo_num `Sparse in let mo_num_2 = mo_num * mo_num in let ao_num_2 = ao_num * ao_num in let ao_mo_num = ao_num * mo_num in let range_mo = list_range 1 mo_num in let range_ao = list_range 1 ao_num in let u = Mat.create mo_num_2 mo_num and o = Mat.create ao_num ao_num_2 and p = Mat.create ao_num_2 mo_num and q = Mat.create ao_mo_num mo_num in if Parallel.master then Printf.eprintf "4-idx transformation \n%!"; let task delta = Mat.fill u 0.; List.iter (fun l -> if abs_float coef.{l,delta} > epsilon then begin let jk = ref 0 in List.iter (fun k -> List.iter (fun j -> incr jk; get_chem_all_i source ~j ~k ~l |> Array.iteri (fun i x -> o.{i+1,!jk} <- x) (* lacpy ~bc:!jk ~b:o (Mat.of_col_vecs [| Vec.of_array (get_chem_all_i source ~j ~k ~l) |] ) |> ignore *) ) range_ao ) range_ao; (* o_i_jk *) let p = gemm ~transa:`T ~c:p o coef (* p_jk_alpha = \sum_i o_i_jk c_i_alpha *) in let p' = Bigarray.reshape_2 (Bigarray.genarray_of_array2 p) ao_num ao_mo_num (* p_j_kalpha *) in let q = gemm ~transa:`T ~c:q p' coef (* q_kalpha_beta = \sum_j p_j_kalpha c_j_beta *) in let q' = Bigarray.reshape_2 (Bigarray.genarray_of_array2 q) ao_num mo_num_2 (* q_k_alphabeta = \sum_j p_j_kalpha c_j_beta *) in ignore @@ gemm ~transa:`T ~beta:1. ~alpha:coef.{l,delta} ~c:u q' coef ; (* u_alphabeta_gamma = \sum_k q_k_alphabeta c_k_gamma *) end ) range_ao; let u = Bigarray.reshape (Bigarray.genarray_of_array2 u) [| mo_num ; mo_num ; mo_num |] |> Bigarray.array3_of_genarray in let result = ref [] in List.iter (fun gamma -> List.iter (fun beta -> List.iter (fun alpha -> let x = u.{alpha,beta,gamma} in if x <> 0. then result := (alpha, beta, gamma, delta, x) :: !result; ) (list_range 1 beta) ) range_mo ) (list_range 1 delta); Array.of_list !result in let n = ref 0 in Stream.of_list range_mo |> Farm.run ~f:task ~ordered:true ~comm:Parallel.Node.comm |> Stream.iter (fun l -> if Parallel.master then (incr n ; Printf.eprintf "\r%d / %d%!" !n mo_num); Array.iter (fun (alpha, beta, gamma, delta, x) -> set_chem destination alpha beta gamma delta x) l); if Parallel.master then Printf.eprintf "\n%!"; broadcast destination