QCaml/Utils/FourIdxStorage.ml

217 lines
6.3 KiB
OCaml

open Util
let max_index = 1 lsl 14
type index_pair = { first : int ; second : int }
type storage_t =
| Dense of (float, Bigarray.float64_elt, Bigarray.fortran_layout) Bigarray.Array2.t
| Sparse of (int, float) Hashtbl.t
type t =
{
size : int ;
four_index : storage_t ;
}
let key_of_indices ~r1 ~r2 =
let { first=i ; second=k } = r1 and { first=j ; second=l } = r2 in
let f i k =
let p, r =
if i <= k then i, k else k, i
in p+ (r*r-r)/2
in
let p = f i k and q = f j l in
f p q
let dense_index i j size =
(j-1)*size + i
let get_four_index ~r1 ~r2 t =
match t.four_index with
| Dense a -> (let { first=i ; second=k } = r1 and { first=j ; second=l } = r2 in
let size = t.size in
assert ( (i lor j lor k lor l) > 0 );
assert ( i <= size && j <= size && k <= size && l <= size );
Bigarray.Array2.unsafe_get a (dense_index i j size) (dense_index k l size)
)
| Sparse a -> let key = key_of_indices ~r1 ~r2 in
try Hashtbl.find a key
with Not_found -> 0.
let set_four_index ~r1 ~r2 ~value t =
match t.four_index with
| Dense a -> (let { first=i ; second=k } = r1 and { first=j ; second=l } = r2 in
let size = t.size in
assert ( (i lor j lor k lor l) > 0 );
assert ( i <= size && j <= size && k <= size && l <= size);
let ij = (dense_index i j size)
and kl = (dense_index k l size)
and il = (dense_index i l size)
and kj = (dense_index k j size)
and ji = (dense_index j i size)
and lk = (dense_index l k size)
and li = (dense_index l i size)
and jk = (dense_index j k size)
in
let open Bigarray.Array2 in
unsafe_set a ij kl value;
unsafe_set a kj il value;
unsafe_set a il kj value;
unsafe_set a kl ij value;
unsafe_set a ji lk value;
unsafe_set a li jk value;
unsafe_set a jk li value;
unsafe_set a lk ji value
)
| Sparse a -> let key = key_of_indices ~r1 ~r2 in
Hashtbl.replace a key value
let increment_four_index ~r1 ~r2 ~value t =
match t.four_index with
| Dense a -> (let { first=i ; second=k } = r1 and { first=j ; second=l } = r2 in
let size = t.size in
assert ( (i lor j lor k lor l) > 0 );
assert ( i <= size && j <= size && k <= size && l <= size);
let ij = (dense_index i j size)
and kl = (dense_index k l size)
and il = (dense_index i l size)
and kj = (dense_index k j size)
and ji = (dense_index j i size)
and lk = (dense_index l k size)
and li = (dense_index l i size)
and jk = (dense_index j k size)
in
let open Bigarray.Array2 in
unsafe_set a ij kl (value +. unsafe_get a ij kl) ;
unsafe_set a kj il (value +. unsafe_get a kj il) ;
unsafe_set a il kj (value +. unsafe_get a il kj) ;
unsafe_set a kl ij (value +. unsafe_get a kl ij) ;
unsafe_set a ji lk (value +. unsafe_get a ji lk) ;
unsafe_set a li jk (value +. unsafe_get a li jk) ;
unsafe_set a jk li (value +. unsafe_get a jk li) ;
unsafe_set a lk ji (value +. unsafe_get a lk ji)
)
| Sparse a -> let key = key_of_indices ~r1 ~r2 in
let old_value =
try Hashtbl.find a key
with Not_found -> 0.
in
Hashtbl.replace a key (old_value +. value)
let get ~r1 ~r2 a =
get_four_index ~r1 ~r2 a
let set ~r1 ~r2 ~value =
match classify_float value with
| FP_normal -> set_four_index ~r1 ~r2 ~value
| FP_zero
| FP_subnormal -> fun _ -> ()
| FP_infinite
| FP_nan ->
let msg =
Printf.sprintf "FourIdxStorage.ml : set : r1 = (%d,%d) ; r2 = (%d,%d)"
r1.first r1.second r2.first r2.second
in
raise (Invalid_argument msg)
let increment ~r1 ~r2 =
increment_four_index ~r1 ~r2
let create ~size sparsity =
assert (size < max_index);
let four_index =
match sparsity with
| `Dense ->
let result =
Bigarray.Array2.create Float64 Bigarray.fortran_layout (size*size) (size*size)
in
Bigarray.Array2.fill result 0.;
Dense result
| `Sparse ->
let result = Hashtbl.create (size*size+13) in
Sparse result
in
{ size ; four_index }
let size t = t.size
(** TODO : remove epsilons *)
let get_chem t i j k l = get ~r1:{ first=i ; second=j } ~r2:{ first=k ; second=l } t
let get_phys t i j k l = get ~r1:{ first=i ; second=k } ~r2:{ first=j ; second=l } t
let set_chem t i j k l value = set ~r1:{ first=i ; second=j } ~r2:{ first=k ; second=l } ~value t
let set_phys t i j k l value = set ~r1:{ first=i ; second=k } ~r2:{ first=j ; second=l } ~value t
type element = (** Element for the stream *)
{
i_r1: int ;
j_r2: int ;
k_r1: int ;
l_r2: int ;
value: float
}
let to_stream d =
let i = ref 0
and j = ref 1
and k = ref 1
and l = ref 1
in
let f i k =
let p, r =
if i <= k then i, k else k, i
in p+ (r*r-r)/2
in
let rec f_dense _ =
i := !i+1;
if !i > !k then begin
i := 1;
j := !j + 1;
if !j > !l then begin
j := 1;
k := !k + 1;
if !k > d.size then begin
k := 1;
l := !l + 1;
end;
end;
end;
if !l <= d.size then
Some { i_r1 = !i ; j_r2 = !j ;
k_r1 = !k ; l_r2 = !l ;
value = get_phys d !i !j !k !l
}
else
None
in
Stream.from f_dense
(** Write all integrals to a file with the <ij|kl> convention *)
let to_file ?(cutoff=Constants.epsilon) ~filename data =
let oc = open_out filename in
to_stream data
|> Stream.iter (fun {i_r1 ; j_r2 ; k_r1 ; l_r2 ; value} ->
if (abs_float value > cutoff) then
Printf.fprintf oc " %5d %5d %5d %5d%20.15f\n" i_r1 j_r2 k_r1 l_r2 value);
close_out oc