From 5e399dac4434a43b6cc1159c6d18efd9cad04b56 Mon Sep 17 00:00:00 2001 From: Anthony Scemama Date: Wed, 3 Apr 2019 22:17:20 +0200 Subject: [PATCH] Accelerated direct FCI --- CI/CI.ml | 62 +++++++++++++++++++++++++++--------------- CI/Spindeterminant.ml | 5 ++-- Makefile.include | 2 +- Utils/Bitstring.ml | 3 +- Utils/Matrix.ml | 42 +++++++++++++++++++++++++--- Utils/math_functions.c | 1 + 6 files changed, 83 insertions(+), 32 deletions(-) diff --git a/CI/CI.ml b/CI/CI.ml index 0bd935f..efd26bf 100644 --- a/CI/CI.ml +++ b/CI/CI.ml @@ -349,21 +349,34 @@ let create_matrix_spin_computed f det_space = | _ -> assert false in let n_beta = Array.length b in + let n_alfa = Array.length a in - let h i_alfa = - let deg_a = Spindeterminant.degree i_alfa in - fun j_alfa -> - match deg_a j_alfa with - | 0 | 1 | 2 -> - (fun i_beta -> - let deg_b = Spindeterminant.degree i_beta in - let ki = Determinant.of_spindeterminants i_alfa i_beta in - fun j_beta -> - match deg_b j_beta with - | 0 | 1 | 2 -> ( - let kj = Determinant.of_spindeterminants j_alfa j_beta in - f ki kj) - | _ -> 0. + let h i_alfa j_alfa = + match Spindeterminant.degree a.(i_alfa) a.(j_alfa) with + | 2 -> + let ai, aj = a.(i_alfa), a.(j_alfa) in + (fun i_beta j_beta -> + if i_beta <> j_beta then 0. else + let ki = Determinant.of_spindeterminants ai b.(i_beta) in + let kj = Determinant.of_spindeterminants aj b.(j_beta) in + f ki kj + ) + | 1 -> + let ai, aj = a.(i_alfa), a.(j_alfa) in + (fun i_beta j_beta -> + match Spindeterminant.degree b.(i_beta) b.(j_beta) with + | 0 | 1 -> + let ki = Determinant.of_spindeterminants ai b.(i_beta) in + let kj = Determinant.of_spindeterminants aj b.(j_beta) in + f ki kj + | _ -> 0. + ) + | 0 -> + let ai, aj = a.(i_alfa), a.(j_alfa) in + (fun i_beta j_beta -> + let ki = Determinant.of_spindeterminants ai b.(i_beta) in + let kj = Determinant.of_spindeterminants aj b.(j_beta) in + f ki kj ) | _ -> (fun _ _ -> 0.) in @@ -379,25 +392,30 @@ let create_matrix_spin_computed f det_space = let i_a = (i-1)/n_beta in let i_alfa = i_a + 1 in let h1 = - h a.(i_alfa-1) + h (i_alfa-1) in let i_beta = i - i_a*n_beta in - let bi = b.(i_beta-1) in + let bi = (i_beta-1) in let h123_prev = ref (fun _ -> 0.) in + let j_a = ref (-n_alfa) in let j_alfa_prev = ref (-10) in result := fun j -> - let j_a = (j-1)/n_beta in - let j_alfa = j_a + 1 in + let j0 = !j_a * n_beta in + if j > j0 + n_beta + || j < j0 + then + j_a := (j-1)/n_beta; + let j_alfa = !j_a + 1 in let h123 = if j_alfa <> !j_alfa_prev then begin j_alfa_prev := j_alfa ; - h123_prev := (h1 a.(j_alfa-1) bi) + h123_prev := (h1 (j_alfa-1) bi) end; !h123_prev in - let j_beta = j - j_a*n_beta in - h123 b.(j_beta-1) + let j_beta = j - !j_a*n_beta in + h123 (j_beta-1) end; !result in @@ -490,7 +508,7 @@ let make ?(n_states=1) ?(algo=`Direct) det_space = in let matrix_prod psi = let result = - Matrix.parallel_mm ~transa:`T ~transb:`T psi m_H + Matrix.parallel_mm ~transa:`T ~transb:`N psi m_H |> Matrix.transpose in Parallel.broadcast (lazy result) diff --git a/CI/Spindeterminant.ml b/CI/Spindeterminant.ml index 1f6aff2..2a254b9 100644 --- a/CI/Spindeterminant.ml +++ b/CI/Spindeterminant.ml @@ -77,9 +77,8 @@ let double_excitation h' p' h p = double_excitation_reference h' p' h p -let degree t = - let bt = bitstring t in - fun t' -> Bitstring.hamdist bt (bitstring t') / 2 +let degree t t' = + Bitstring.hamdist (bitstring t) (bitstring t') / 2 let holes_of t t' = Bitstring.logand (bitstring t) (Bitstring.logxor (bitstring t) (bitstring t')) diff --git a/Makefile.include b/Makefile.include index ddb8f49..0e950ea 100644 --- a/Makefile.include +++ b/Makefile.include @@ -3,7 +3,7 @@ INCLUDE_DIRS=Parallel,Nuclei,Utils,Basis,SCF,MOBasis,CI,F12,Perturbation LIBS= PKGS= -OCAMLBUILD=ocamlbuild -j 0 -cflags $(ocamlcflags) -lflags $(ocamllflags) $(ocamldocflags) -Is $(INCLUDE_DIRS) -ocamlopt $(ocamloptflags) $(mpi) +OCAMLBUILD=ocamlbuild -j 0 -cflags $(ocamlcflags) -lflags $(ocamllflags) $(ocamldocflags) -Is $(INCLUDE_DIRS) -ocamlopt $(ocamloptflags) $(mpi) MLLFILES=$(filter-out $(wildcard _build/*), $(wildcard */*.mll) $(wildcard *.mll)) Utils/math_functions.c MLYFILES=$(filter-out $(wildcard _build/*), $(wildcard */*.mly) $(wildcard *.mly)) diff --git a/Utils/Bitstring.ml b/Utils/Bitstring.ml index ff86504..1497136 100644 --- a/Utils/Bitstring.ml +++ b/Utils/Bitstring.ml @@ -30,8 +30,7 @@ module One = struct let hamdist a b = a lxor b - |> Int64.of_int - |> Util.popcnt + |> popcount let pp ppf s = diff --git a/Utils/Matrix.ml b/Utils/Matrix.ml index 0dbfad4..0fbb228 100644 --- a/Utils/Matrix.ml +++ b/Utils/Matrix.ml @@ -331,16 +331,50 @@ let rec mm ?(transa=`N) ?(transb=`N) ?(threshold=epsilon) a b = Computed {m=m' ; n=n ; f=g} in + let mmccde transa transb a b = + let m', n', f' = + if transa = `T then + match transpose (Computed a) with + | Computed {m ; n ; f} -> m, n, f + | _ -> assert false + else + let {m ; n ; f} = a in + m, n, f + in + let m, n = + match transb with + | `N -> Mat.dim1 b , Mat.dim2 b + | `T -> Mat.dim2 b , Mat.dim1 b + in + if n' <> m then + invalid_arg "Inconsistent dimensions"; + + let matrix = + Array.init n (fun j -> + let bj = + if transb = `T then + (Mat.copy_row b (j+1)) + else + (Mat.to_col_vecs b).(j) + in + let accu = Vec.make0 m' in + Vec.iteri (fun k a -> + if a <> 0. then + Vec.iteri (fun i vi -> accu.{i} <- vi +. (f' i k) *. a) accu + ) bj; + accu + ) + |> Mat.of_col_vecs + in + Dense matrix + in match a, b with | (Dense a), (Dense b) -> Dense (gemm ~transa ~transb a b) | (Sparse a), (Dense b) -> spmm transa transb a b | (Dense a), (Sparse b) -> mmsp transa transb a b | (Sparse a), (Sparse b) -> mmspmm transa transb a b | (Computed a), (Computed b) -> mmcc transa transb a b - | (Computed a), (Dense _) -> - let b = { m = dim1 b ; n = dim2 b ; f = get b } in - mmcc transa transb a b - |> dense_of_computed + | (Computed a), (Dense b) -> mmccde transa transb a b | (Computed a), (Sparse _) -> let b = { m = dim1 b ; n = dim2 b ; f = get b } in mmcc transa transb a b diff --git a/Utils/math_functions.c b/Utils/math_functions.c index 10a4928..6b433ef 100644 --- a/Utils/math_functions.c +++ b/Utils/math_functions.c @@ -37,6 +37,7 @@ CAMLprim double gamma_float(double x) } +#include CAMLprim int32_t popcnt(int64_t i) { return __builtin_popcountll (i);