10
1
mirror of https://gitlab.com/scemama/QCaml.git synced 2024-11-07 14:43:41 +01:00
QCaml/Utils/FourIdxStorage.ml

515 lines
13 KiB
OCaml
Raw Normal View History

2018-06-28 14:43:24 +02:00
open Util
2019-03-21 00:44:10 +01:00
open Lacaml.D
open Constants
2018-06-28 14:43:24 +02:00
2018-03-27 16:31:44 +02:00
let max_index = 1 lsl 14
type index_pair = { first : int ; second : int }
2019-04-05 14:33:31 +02:00
type array2 = (float, Bigarray.float64_elt, Bigarray.fortran_layout) Bigarray.Array2.t
2018-03-27 19:32:37 +02:00
2018-03-27 16:31:44 +02:00
type storage_t =
2019-04-05 14:33:31 +02:00
| Dense of array2
2018-03-27 16:31:44 +02:00
| Sparse of (int, float) Hashtbl.t
type t =
{
2019-04-05 14:33:31 +02:00
size : int ;
two_index : array2;
two_index_anti : array2;
three_index : array2;
three_index_anti : array2;
four_index : storage_t ;
2018-03-27 16:31:44 +02:00
}
let key_of_indices ~r1 ~r2 =
let { first=i ; second=k } = r1 and { first=j ; second=l } = r2 in
let f i k =
let p, r =
if i <= k then i, k else k, i
2019-01-15 15:18:01 +01:00
in p + (r*(r-1))/2
2018-03-27 16:31:44 +02:00
in
let p = f i k and q = f j l in
f p q
2019-04-05 14:33:31 +02:00
let check_bounds r1 r2 t =
let { first=i ; second=k } = r1 and { first=j ; second=l } = r2 in
let size = t.size in
assert ( (i lor j lor k lor l) > 0 );
assert ( i <= size && j <= size && k <= size && l <= size )
2018-06-01 10:07:17 +02:00
let dense_index i j size =
(j-1)*size + i
2019-04-05 15:36:47 +02:00
let sym_index i j =
if i < j then
(j*(j-1))/2 + i
else
(i*(i-1))/2 + j
2019-04-05 14:33:31 +02:00
let unsafe_get_four_index ~r1 ~r2 t =
let open Bigarray.Array2 in
let { first=i ; second=k } = r1 and { first=j ; second=l } = r2 in
if i=k then
if j=l then
unsafe_get t.two_index i j
else
unsafe_get t.three_index (dense_index j l t.size) i
else if j=l then
unsafe_get t.three_index (dense_index i k t.size) j
else if i=l then
if k=j then
unsafe_get t.two_index_anti i j
else
unsafe_get t.three_index_anti (dense_index j k t.size) i
else if j=k then
unsafe_get t.three_index_anti (dense_index i l t.size) j
else if i=j then
if k=l then
unsafe_get t.two_index_anti i k
else
unsafe_get t.three_index_anti (dense_index k l t.size) i
else if k=l then
(* <ij|kk> *)
unsafe_get t.three_index_anti (dense_index i j t.size) k
else
match t.four_index with
2019-04-05 15:36:47 +02:00
| Dense a -> unsafe_get a (dense_index i k t.size) (sym_index j l)
2019-04-05 14:33:31 +02:00
| Sparse a -> let key = key_of_indices ~r1 ~r2 in
try Hashtbl.find a key
with Not_found -> 0.
2018-03-27 16:31:44 +02:00
let get_four_index ~r1 ~r2 t =
2019-04-05 14:33:31 +02:00
check_bounds r1 r2 t;
unsafe_get_four_index ~r1 ~r2 t
2018-03-27 16:31:44 +02:00
2019-04-05 14:33:31 +02:00
let unsafe_set_four_index ~r1 ~r2 ~value t =
let open Bigarray.Array2 in
let { first=i ; second=k } = r1 and { first=j ; second=l } = r2 in
2019-04-05 16:54:38 +02:00
let () =
if i=k then
2019-04-05 14:33:31 +02:00
begin
2019-04-05 16:54:38 +02:00
if j=l then
begin
unsafe_set t.two_index i j value;
unsafe_set t.two_index j i value;
2019-04-05 17:05:04 +02:00
unsafe_set t.three_index (dense_index i i t.size) j value;
2019-04-05 16:54:38 +02:00
end;
2019-04-05 14:33:31 +02:00
unsafe_set t.three_index (dense_index j l t.size) i value;
unsafe_set t.three_index (dense_index l j t.size) i value;
end
2019-04-05 16:54:38 +02:00
else if j=l then
begin
unsafe_set t.three_index (dense_index i k t.size) j value;
unsafe_set t.three_index (dense_index k i t.size) j value;
end
2019-04-05 14:33:31 +02:00
2019-04-05 16:54:38 +02:00
else if i=l then
2019-04-05 14:33:31 +02:00
begin
2019-04-05 16:54:38 +02:00
if j=k then
begin
unsafe_set t.two_index_anti i j value;
unsafe_set t.two_index_anti j i value;
2019-04-05 17:05:04 +02:00
unsafe_set t.three_index_anti (dense_index i i t.size) j value;
2019-04-05 16:54:38 +02:00
end;
2019-04-05 14:33:31 +02:00
unsafe_set t.three_index_anti (dense_index j k t.size) i value;
unsafe_set t.three_index_anti (dense_index k j t.size) i value;
end
2019-04-05 16:54:38 +02:00
else if j=k then
begin
unsafe_set t.three_index_anti (dense_index i l t.size) j value;
unsafe_set t.three_index_anti (dense_index l i t.size) j value;
end
2019-04-05 14:33:31 +02:00
2019-04-05 16:54:38 +02:00
else if i=j then
2019-04-05 14:33:31 +02:00
begin
2019-04-05 16:54:38 +02:00
if k=l then
begin
unsafe_set t.two_index_anti i k value;
unsafe_set t.two_index_anti k i value;
2019-04-05 17:05:04 +02:00
unsafe_set t.three_index_anti (dense_index i i t.size) k value;
2019-04-05 16:54:38 +02:00
end;
2019-04-05 14:33:31 +02:00
unsafe_set t.three_index_anti (dense_index k l t.size) i value;
unsafe_set t.three_index_anti (dense_index l k t.size) i value;
end
2019-04-05 16:54:38 +02:00
else if k=l then
(* <ij|kk> *)
begin
unsafe_set t.three_index_anti (dense_index i j t.size) k value;
unsafe_set t.three_index_anti (dense_index j i t.size) k value;
end
in
2019-04-05 14:33:31 +02:00
2019-04-05 16:54:38 +02:00
match t.four_index with
| Dense a -> let ik = (dense_index i k t.size)
and jl = (dense_index j l t.size)
and ki = (dense_index k i t.size)
and lj = (dense_index l j t.size)
and ik_s = (sym_index i k)
and jl_s = (sym_index j l)
in
begin
unsafe_set a ik jl_s value;
unsafe_set a ki jl_s value;
unsafe_set a jl ik_s value;
unsafe_set a lj ik_s value;
end
| Sparse a -> let key = key_of_indices ~r1 ~r2 in
Hashtbl.replace a key value
2018-03-27 16:31:44 +02:00
2019-04-05 14:33:31 +02:00
let set_four_index ~r1 ~r2 ~value t =
check_bounds r1 r2 t;
unsafe_set_four_index ~r1 ~r2 ~value t
let unsafe_increment_four_index ~r1 ~r2 ~value t =
let updated_value =
value +. unsafe_get_four_index ~r1 ~r2 t
in
unsafe_set_four_index ~r1 ~r2 ~value:updated_value t
2018-06-01 10:07:17 +02:00
let increment_four_index ~r1 ~r2 ~value t =
2019-04-05 14:33:31 +02:00
check_bounds r1 r2 t;
unsafe_increment_four_index ~r1 ~r2 ~value t
2018-03-27 16:31:44 +02:00
2018-06-27 13:13:59 +02:00
let get ~r1 ~r2 a =
get_four_index ~r1 ~r2 a
let set ~r1 ~r2 ~value =
match classify_float value with
| FP_normal -> set_four_index ~r1 ~r2 ~value
| FP_zero
| FP_subnormal -> fun _ -> ()
| FP_infinite
| FP_nan ->
let msg =
Printf.sprintf "FourIdxStorage.ml : set : r1 = (%d,%d) ; r2 = (%d,%d)"
r1.first r1.second r2.first r2.second
in
raise (Invalid_argument msg)
2018-03-27 16:31:44 +02:00
2019-04-05 14:33:31 +02:00
2018-03-27 16:31:44 +02:00
let increment ~r1 ~r2 =
increment_four_index ~r1 ~r2
2019-04-01 15:20:17 +02:00
let create ~size ?(temp_dir="/dev/shm") sparsity =
2018-03-27 16:31:44 +02:00
assert (size < max_index);
2019-04-05 14:33:31 +02:00
let two_index =
SharedMemory.create ~temp_dir Float64 [| size ; size |]
|> Bigarray.array2_of_genarray
in
let two_index_anti =
SharedMemory.create ~temp_dir Float64 [| size ; size |]
|> Bigarray.array2_of_genarray
in
let three_index =
SharedMemory.create ~temp_dir Float64 [| size * size ; size |]
|> Bigarray.array2_of_genarray
in
let three_index_anti =
SharedMemory.create ~temp_dir Float64 [| size * size ; size |]
|> Bigarray.array2_of_genarray
in
2018-03-27 16:31:44 +02:00
let four_index =
match sparsity with
2019-04-01 15:20:17 +02:00
| `Dense -> let result =
2019-04-05 15:36:47 +02:00
SharedMemory.create ~temp_dir Float64 [| size*size ; (size*(size+1))/2 |]
2019-04-01 15:20:17 +02:00
|> Bigarray.array2_of_genarray
2019-04-05 14:33:31 +02:00
in
Dense result
2019-04-01 15:20:17 +02:00
| `Sparse -> let result = Hashtbl.create (size*size+13) in
2018-03-27 16:31:44 +02:00
Sparse result
in
2019-04-05 14:33:31 +02:00
{ size ; two_index ; two_index_anti ; three_index ; three_index_anti ; four_index }
2018-03-27 16:31:44 +02:00
let size t = t.size
(** TODO : remove epsilons *)
let get_chem t i j k l = get ~r1:{ first=i ; second=j } ~r2:{ first=k ; second=l } t
let get_phys t i j k l = get ~r1:{ first=i ; second=k } ~r2:{ first=j ; second=l } t
let set_chem t i j k l value = set ~r1:{ first=i ; second=j } ~r2:{ first=k ; second=l } ~value t
let set_phys t i j k l value = set ~r1:{ first=i ; second=k } ~r2:{ first=j ; second=l } ~value t
2018-06-28 14:43:24 +02:00
type element = (** Element for the stream *)
{
i_r1: int ;
j_r2: int ;
k_r1: int ;
l_r2: int ;
value: float
}
2018-07-20 16:09:06 +02:00
let get_phys_all_i d ~j ~k ~l =
2019-04-05 16:54:38 +02:00
Vec.init d.size (fun i -> get_phys d i j k l)
2018-07-20 16:09:06 +02:00
let get_chem_all_i d ~j ~k ~l =
2019-04-05 16:54:38 +02:00
Vec.init d.size (fun i -> get_chem d i j k l)
2018-07-20 16:09:06 +02:00
2019-04-05 16:54:38 +02:00
let get_phys_all_ij d ~k ~l =
Mat.init_cols d.size d.size (fun i j -> get_phys d i j k l)
2018-07-20 16:09:06 +02:00
2019-04-05 16:54:38 +02:00
let get_chem_all_ij d ~k ~l =
2019-04-05 17:05:04 +02:00
2019-04-05 16:54:38 +02:00
if k = l then
2019-04-05 17:05:04 +02:00
2019-04-05 16:54:38 +02:00
let result =
Mat.col d.three_index k
|> Bigarray.genarray_of_array1
in
Bigarray.reshape_2 result d.size d.size
2019-04-05 17:05:04 +02:00
2019-04-05 16:54:38 +02:00
else
2019-04-05 17:05:04 +02:00
match d.four_index with
| Dense a ->
let kl = sym_index k l in
let result =
Mat.col a kl
|> Bigarray.genarray_of_array1
in
Bigarray.reshape_2 result d.size d.size
| Sparse a ->
Mat.init_cols d.size d.size (fun i j -> get_chem d i j k l)
2018-07-20 16:09:06 +02:00
2018-06-28 14:43:24 +02:00
let to_stream d =
let i = ref 0
and j = ref 1
and k = ref 1
and l = ref 1
in
2018-07-04 18:08:38 +02:00
let rec f_dense _ =
2019-04-05 16:54:38 +02:00
incr i;
2018-06-28 14:43:24 +02:00
if !i > !k then begin
i := 1;
2019-04-05 16:54:38 +02:00
incr j;
2018-06-28 14:43:24 +02:00
if !j > !l then begin
j := 1;
2019-04-05 16:54:38 +02:00
incr k;
2018-07-05 00:39:17 +02:00
if !k > !l then begin
2018-06-28 14:43:24 +02:00
k := 1;
2019-04-05 16:54:38 +02:00
incr l;
2018-06-28 14:43:24 +02:00
end;
end;
end;
if !l <= d.size then
Some { i_r1 = !i ; j_r2 = !j ;
k_r1 = !k ; l_r2 = !l ;
value = get_phys d !i !j !k !l
}
else
None
in
2018-06-29 16:04:40 +02:00
Stream.from f_dense
2018-06-28 14:43:24 +02:00
2018-03-27 16:31:44 +02:00
(** Write all integrals to a file with the <ij|kl> convention *)
let to_file ?(cutoff=Constants.epsilon) ~filename data =
let oc = open_out filename in
2018-06-28 14:43:24 +02:00
to_stream data
|> Stream.iter (fun {i_r1 ; j_r2 ; k_r1 ; l_r2 ; value} ->
2018-03-27 16:31:44 +02:00
if (abs_float value > cutoff) then
2018-06-28 14:43:24 +02:00
Printf.fprintf oc " %5d %5d %5d %5d%20.15f\n" i_r1 j_r2 k_r1 l_r2 value);
2018-03-27 16:31:44 +02:00
close_out oc
2018-06-28 14:43:24 +02:00
2018-07-05 00:39:17 +02:00
let of_file ~size ~sparsity filename =
let result = create ~size sparsity in
let ic = Scanf.Scanning.open_in filename in
let rec read_line () =
let result =
try
Some (Scanf.bscanf ic " %d %d %d %d %f" (fun i j k l v ->
set_phys result i j k l v))
with End_of_file -> None
in
match result with
| Some () -> read_line ()
| None -> ()
in
read_line ();
Scanf.Scanning.close_in ic;
result
let to_list data =
let s =
to_stream data
in
let rec append accu =
let d =
try Some (Stream.next s) with
| Stream.Failure -> None
in
match d with
| None -> List.rev accu
| Some d -> append (d :: accu)
in
append []
2019-03-21 00:44:10 +01:00
2019-03-29 17:38:19 +01:00
let broadcast t =
t
(*
let size =
Parallel.broadcast (lazy t.size)
in
let bufsize = size * size * size in
let stream = to_stream t in
let rec iterate () =
let buffer =
Parallel.broadcast (lazy (
if Stream.peek stream = None then None else
Some (Array.init bufsize (fun _ ->
try Some (Stream.next stream)
with _ -> None
))
) )
in
match buffer with
| None -> ()
| Some buffer ->
begin
if not Parallel.master then
Array.iter (fun x ->
match x with
| Some {i_r1 ; j_r2 ; k_r1 ; l_r2 ; value} ->
set_phys t i_r1 j_r2 k_r1 l_r2 value
| None -> () ) buffer;
iterate ()
end
in iterate ();
t
*)
2019-03-21 00:44:10 +01:00
let four_index_transform coef source =
let ao_num = Mat.dim1 coef in
let mo_num = Mat.dim2 coef in
let destination =
match source.four_index with
| Dense _ -> create ~size:mo_num `Dense
| Sparse _ -> create ~size:mo_num `Sparse
in
let mo_num_2 = mo_num * mo_num in
let ao_num_2 = ao_num * ao_num in
let ao_mo_num = ao_num * mo_num in
let range_mo = list_range 1 mo_num in
let range_ao = list_range 1 ao_num in
let u = Mat.create mo_num_2 mo_num
and o = Mat.create ao_num ao_num_2
and p = Mat.create ao_num_2 mo_num
and q = Mat.create ao_mo_num mo_num
in
if Parallel.master then Printf.eprintf "4-idx transformation \n%!";
let task delta =
Mat.fill u 0.;
List.iter (fun l ->
if abs_float coef.{l,delta} > epsilon then
begin
2019-04-05 16:54:38 +02:00
let jk = ref 1 in
2019-03-21 00:44:10 +01:00
List.iter (fun k ->
2019-04-05 16:54:38 +02:00
get_chem_all_ij source ~k ~l
|> lacpy ~b:o ~bc:!jk
|> ignore;
jk := !jk + ao_num;
2019-03-21 00:44:10 +01:00
) range_ao;
(* o_i_jk *)
let p =
gemm ~transa:`T ~c:p o coef
(* p_jk_alpha = \sum_i o_i_jk c_i_alpha *)
in
let p' =
Bigarray.reshape_2 (Bigarray.genarray_of_array2 p) ao_num ao_mo_num
(* p_j_kalpha *)
in
let q =
gemm ~transa:`T ~c:q p' coef
(* q_kalpha_beta = \sum_j p_j_kalpha c_j_beta *)
in
let q' =
Bigarray.reshape_2 (Bigarray.genarray_of_array2 q) ao_num mo_num_2
(* q_k_alphabeta = \sum_j p_j_kalpha c_j_beta *)
in
ignore @@
gemm ~transa:`T ~beta:1. ~alpha:coef.{l,delta} ~c:u q' coef ;
(* u_alphabeta_gamma = \sum_k q_k_alphabeta c_k_gamma *)
end
) range_ao;
let u =
Bigarray.reshape
(Bigarray.genarray_of_array2 u)
[| mo_num ; mo_num ; mo_num |]
|> Bigarray.array3_of_genarray
in
let result = ref [] in
List.iter (fun gamma ->
List.iter (fun beta ->
List.iter (fun alpha ->
let x = u.{alpha,beta,gamma} in
if x <> 0. then
result := (alpha, beta, gamma, delta, x) :: !result;
) (list_range 1 beta)
) range_mo
) (list_range 1 delta);
Array.of_list !result
in
let n = ref 0 in
Stream.of_list range_mo
2019-04-05 09:46:23 +02:00
|> Farm.run ~f:task ~ordered:true ~comm:Parallel.Node.comm
2019-03-21 00:44:10 +01:00
|> Stream.iter (fun l ->
2019-03-26 10:38:50 +01:00
if Parallel.master then (incr n ; Printf.eprintf "\r%d / %d%!" !n mo_num);
2019-03-21 00:44:10 +01:00
Array.iter (fun (alpha, beta, gamma, delta, x) ->
set_chem destination alpha beta gamma delta x) l);
2019-03-29 17:38:19 +01:00
if Parallel.master then Printf.eprintf "\n%!";
broadcast destination
2019-03-21 00:44:10 +01:00