F12CI with parallel dressing

This commit is contained in:
Anthony Scemama 2019-03-28 10:32:54 +01:00
parent f77b7f07d5
commit 0321b53ec3
1 changed files with 39 additions and 22 deletions

View File

@ -158,15 +158,13 @@ let dressing_vector ~frozen_core aux_basis f12_amplitudes ci =
)
in
let make_h_and_f n =
let make_h_and_f alpha_list =
let rec col_vecs_list accu_H accu_F = function
| 0 ->
| [] ->
List.rev accu_H,
List.rev accu_F
| n ->
try
let ki = Stream.next out_dets_stream in
| ki :: rest ->
let h, f =
List.map (fun kj ->
match hf_ij aux_basis ki kj with
@ -179,12 +177,10 @@ let dressing_vector ~frozen_core aux_basis f12_amplitudes ci =
and f =
Vec.of_list f
in
col_vecs_list (h::accu_H) (f::accu_F) (n-1)
with
| Stream.Failure -> col_vecs_list accu_H accu_F 0
col_vecs_list (h::accu_H) (f::accu_F) rest
in
let h, f =
col_vecs_list [] [] n
col_vecs_list [] [] alpha_list
in
Mat.of_col_vecs_list h,
Mat.of_col_vecs_list f
@ -193,20 +189,42 @@ let dressing_vector ~frozen_core aux_basis f12_amplitudes ci =
Printf.printf "Matrix product\n%!";
let m_HF =
let batch_size = 10_000_000 / (Mat.dim1 f12_amplitudes) in
let result =
let m_H_aux, m_F_aux = make_h_and_f batch_size in
gemm m_H_aux m_F_aux ~transb:`T
let batch_size = 1 + 10_000_000 / (Mat.dim1 f12_amplitudes) in
let input_stream =
Stream.from (fun i ->
let rec make_batch accu = function
| 0 -> accu
| n -> try
let alpha = Stream.next out_dets_stream in
let accu = alpha :: accu in
make_batch accu (n-1)
with Stream.Failure -> accu
in
let result = make_batch [] batch_size in
if result = [] then None else Some result
)
in
while (Stream.peek out_dets_stream <> None)
do
Printf.printf "gemm\n%!";
let m_H_aux, m_F_aux = make_h_and_f batch_size in
let hf =
let result =
let m_H_aux, m_F_aux = make_h_and_f [(Stream.next out_dets_stream)] in
let m_HF =
gemm m_H_aux m_F_aux ~transb:`T
in
ignore @@ Mat.add result hf ~c:result
done;
gemm m_HF f12_amplitudes
in
let iteration input =
Printf.printf "gemm\n%!";
let m_H_aux, m_F_aux = make_h_and_f input in
let m_HF =
gemm m_H_aux m_F_aux ~transb:`T
in
gemm m_HF f12_amplitudes
in
input_stream
|> Farm.run ~ordered:false ~f:iteration
|> Stream.iter (fun hf ->
ignore @@ Mat.add result hf ~c:result );
result
in
@ -246,8 +264,7 @@ let dressing_vector ~frozen_core aux_basis f12_amplitudes ci =
*)
Printf.printf "Done\n%!";
gemm m_HF f12_amplitudes
|> Matrix.dense_of_mat
Matrix.dense_of_mat m_HF