let max_index = 1 lsl 14 type index_pair = { first : int ; second : int } type storage_t = | Dense of (float, Bigarray.float64_elt, Bigarray.fortran_layout) Bigarray.Array2.t | Sparse of (int, float) Hashtbl.t type t = { size : int ; four_index : storage_t ; } let key_of_indices ~r1 ~r2 = let { first=i ; second=k } = r1 and { first=j ; second=l } = r2 in let f i k = let p, r = if i <= k then i, k else k, i in p+ (r*r-r)/2 in let p = f i k and q = f j l in f p q let dense_index i j size = (j-1)*size + i let get_four_index ~r1 ~r2 t = match t.four_index with | Dense a -> (let { first=i ; second=k } = r1 and { first=j ; second=l } = r2 in let size = t.size in assert ( (i lor j lor k lor l) > 0 ); assert ( i <= size && j <= size && k <= size && l <= size ); Bigarray.Array2.unsafe_get a (dense_index i j size) (dense_index k l size) ) | Sparse a -> let key = key_of_indices ~r1 ~r2 in try Hashtbl.find a key with Not_found -> 0. let set_four_index ~r1 ~r2 ~value t = match t.four_index with | Dense a -> (let { first=i ; second=k } = r1 and { first=j ; second=l } = r2 in let size = t.size in assert ( (i lor j lor k lor l) > 0 ); assert ( i <= size && j <= size && k <= size && l <= size); let ij = (dense_index i j size) and kl = (dense_index k l size) and il = (dense_index i l size) and kj = (dense_index k j size) and ji = (dense_index j i size) and lk = (dense_index l k size) and li = (dense_index l i size) and jk = (dense_index j k size) in let open Bigarray.Array2 in unsafe_set a ij kl value; unsafe_set a kj il value; unsafe_set a il kj value; unsafe_set a kl ij value; unsafe_set a ji lk value; unsafe_set a li jk value; unsafe_set a jk li value; unsafe_set a lk ji value ) | Sparse a -> let key = key_of_indices ~r1 ~r2 in Hashtbl.replace a key value let increment_four_index ~r1 ~r2 ~value t = match t.four_index with | Dense a -> (let { first=i ; second=k } = r1 and { first=j ; second=l } = r2 in let size = t.size in assert ( (i lor j lor k lor l) > 0 ); assert ( i <= size && j <= size && k <= size && l <= size); let ij = (dense_index i j size) and kl = (dense_index k l size) and il = (dense_index i l size) and kj = (dense_index k j size) and ji = (dense_index j i size) and lk = (dense_index l k size) and li = (dense_index l i size) and jk = (dense_index j k size) in let open Bigarray.Array2 in unsafe_set a ij kl (value +. unsafe_get a ij kl) ; unsafe_set a kj il (value +. unsafe_get a kj il) ; unsafe_set a il kj (value +. unsafe_get a il kj) ; unsafe_set a kl ij (value +. unsafe_get a kl ij) ; unsafe_set a ji lk (value +. unsafe_get a ji lk) ; unsafe_set a li jk (value +. unsafe_get a li jk) ; unsafe_set a jk li (value +. unsafe_get a jk li) ; unsafe_set a lk ji (value +. unsafe_get a lk ji) ) | Sparse a -> let key = key_of_indices ~r1 ~r2 in let old_value = try Hashtbl.find a key with Not_found -> 0. in Hashtbl.replace a key (old_value +. value) let get ~r1 ~r2 a = get_four_index ~r1 ~r2 a let set ~r1 ~r2 ~value = match classify_float value with | FP_normal -> set_four_index ~r1 ~r2 ~value | FP_zero | FP_subnormal -> fun _ -> () | FP_infinite | FP_nan -> let msg = Printf.sprintf "FourIdxStorage.ml : set : r1 = (%d,%d) ; r2 = (%d,%d)" r1.first r1.second r2.first r2.second in raise (Invalid_argument msg) let increment ~r1 ~r2 = increment_four_index ~r1 ~r2 let create ~size sparsity = assert (size < max_index); let four_index = match sparsity with | `Dense -> let result = Bigarray.Array2.create Float64 Bigarray.fortran_layout (size*size) (size*size) in Bigarray.Array2.fill result 0.; Dense result | `Sparse -> let result = Hashtbl.create (size*size+13) in Sparse result in { size ; four_index } let size t = t.size (** TODO : remove epsilons *) let get_chem t i j k l = get ~r1:{ first=i ; second=j } ~r2:{ first=k ; second=l } t let get_phys t i j k l = get ~r1:{ first=i ; second=k } ~r2:{ first=j ; second=l } t let set_chem t i j k l value = set ~r1:{ first=i ; second=j } ~r2:{ first=k ; second=l } ~value t let set_phys t i j k l value = set ~r1:{ first=i ; second=k } ~r2:{ first=j ; second=l } ~value t (** Write all integrals to a file with the convention *) let to_file ?(cutoff=Constants.epsilon) ~filename data = let oc = open_out filename in for l_c=1 to size data do for k_c=1 to l_c do for j_c=1 to l_c do for i_c=1 to k_c do let value = get_phys data i_c j_c k_c l_c in if (abs_float value > cutoff) then Printf.fprintf oc " %5d %5d %5d %5d%20.15f\n" i_c j_c k_c l_c value; done; done; done; done; close_out oc