1
0
mirror of https://github.com/TREX-CoE/qmckl.git synced 2025-01-05 11:00:36 +01:00

Added Woodbury 3x3 kernel template generator.

This commit is contained in:
Francois Coppens 2023-01-27 17:41:32 +01:00
parent 549413abca
commit 6c0430a509

View File

@ -921,14 +921,14 @@ assert(rc == QMCKL_SUCCESS);
*** C source *** C source
#+begin_src c :tangle (eval c) :comments org #+begin_src c :tangle (eval c) :comments org
qmckl_exit_code qmckl_woodbury_3x3(const qmckl_context context, qmckl_exit_code qmckl_woodbury_3x3_hpc(const qmckl_context context,
const uint64_t LDS, const uint64_t LDS,
const uint64_t Dim, const uint64_t Dim,
const double* Updates, const double* __restrict Updates,
const uint64_t* Updates_index, const uint64_t* __restrict Updates_index,
const double breakdown, const double breakdown,
double* Slater_inv, double* __restrict Slater_inv,
double* determinant) { double* __restrict determinant) {
/* /*
C := S^{-1} * U, dim x 3 C := S^{-1} * U, dim x 3
B := 1 + V * C, 3 x 3 B := 1 + V * C, 3 x 3
@ -938,7 +938,7 @@ qmckl_exit_code qmckl_woodbury_3x3(const qmckl_context context,
if (qmckl_context_check(context) == QMCKL_NULL_CONTEXT) { if (qmckl_context_check(context) == QMCKL_NULL_CONTEXT) {
return qmckl_failwith(context, return qmckl_failwith(context,
QMCKL_NULL_CONTEXT, QMCKL_NULL_CONTEXT,
"qmckl_woodbury_3x3", "qmckl_woodbury_3x3_hpc",
NULL); NULL);
} }
@ -1026,6 +1026,190 @@ qmckl_exit_code qmckl_woodbury_3x3(const qmckl_context context,
} }
#+end_src #+end_src
#+NAME:woodbury_3x3_kernel_template
#+begin_src c
qmckl_exit_code qmckl_woodbury_3x3_{Dim}(
const qmckl_context context,
const double* __restrict Updates,
const uint64_t* __restrict Updates_index,
const double breakdown,
double* __restrict Slater_inv,
double* __restrict determinant) {
/*
C := S^{-1} * U, dim x 3
B := 1 + V * C, 3 x 3
D := V * S^{-1}, 3 x dim
,*/
if (qmckl_context_check(context) == QMCKL_NULL_CONTEXT) {
return qmckl_failwith(context,
QMCKL_NULL_CONTEXT,
"qmckl_woodbury_3x3_{Dim}",
NULL);
}
const uint64_t row1 = (Updates_index[0] - 1);
const uint64_t row2 = (Updates_index[1] - 1);
const uint64_t row3 = (Updates_index[2] - 1);
// Compute C = (S^T)^{-1}U : {Dim} x 3
double __attribute__((aligned(8))) C[3 * {Dim}];
for (uint64_t i = 0; i < {Dim}; i++) {
C[i * 3] = 0;
C[i * 3 + 1] = 0;
C[i * 3 + 2] = 0;
IVDEP
ALIGNED
for (uint64_t k = 0; k < D{Dim}_P; k++) {
C[i * 3] += Slater_inv[i * D{Dim}_P + k] * Updates[k];
C[i * 3 + 1] += Slater_inv[i * D{Dim}_P + k] * Updates[D{Dim}_P + k];
C[i * 3 + 2] += Slater_inv[i * D{Dim}_P + k] * Updates[2 * D{Dim}_P + k];
}
}
// Compute B = 1 + VC : 3 x 3
const double B0 = C[row1 * 3] + 1;
const double B1 = C[row1 * 3 + 1];
const double B2 = C[row1 * 3 + 2];
const double B3 = C[row2 * 3];
const double B4 = C[row2 * 3 + 1] + 1;
const double B5 = C[row2 * 3 + 2];
const double B6 = C[row3 * 3];
const double B7 = C[row3 * 3 + 1];
const double B8 = C[row3 * 3 + 2] + 1;
// Check if determinant of B is not too close to zero
double det;
det = B0 * (B4 * B8 - B5 * B7) - B1 * (B3 * B8 - B5 * B6) +
B2 * (B3 * B7 - B4 * B6);
if (fabs(det) < breakdown) {
return QMCKL_FAILURE;
}
// Update det(Slater) if passed
if (determinant)
*determinant *= det;
// Compute B^{-1} with explicit formula for 3 x 3 inversion
double __attribute__((aligned(8))) Binv[9], idet = 1.0 / det;
Binv[0] = (B4 * B8 - B7 * B5) * idet;
Binv[1] = -(B1 * B8 - B7 * B2) * idet;
Binv[2] = (B1 * B5 - B4 * B2) * idet;
Binv[3] = -(B3 * B8 - B6 * B5) * idet;
Binv[4] = (B0 * B8 - B6 * B2) * idet;
Binv[5] = -(B0 * B5 - B3 * B2) * idet;
Binv[6] = (B3 * B7 - B6 * B4) * idet;
Binv[7] = -(B0 * B7 - B6 * B1) * idet;
Binv[8] = (B0 * B4 - B3 * B1) * idet;
// tmp = B^{-1}D : 3 x D{Dim}_P
double __attribute__((aligned(8))) tmp[3 * D{Dim}_P];
double* r1dim = &(Slater_inv[row1 * D{Dim}_P]);
double* r2dim = &(Slater_inv[row2 * D{Dim}_P]);
double* r3dim = &(Slater_inv[row3 * D{Dim}_P]);
IVDEP
ALIGNED
for (uint64_t j = 0; j < D{Dim}_P; j++) {
tmp[j] = Binv[0] * r1dim[j] + Binv[1] * r2dim[j] + Binv[2] * r3dim[j];
tmp[D{Dim}_P + j] =
Binv[3] * r1dim[j] + Binv[4] * r2dim[j] + Binv[5] * r3dim[j];
tmp[2 * D{Dim}_P + j] =
Binv[6] * r1dim[j] + Binv[7] * r2dim[j] + Binv[8] * r3dim[j];
}
// Compute (S^T)^{-1} - C * tmp : {Dim} x D{Dim}_P
for (uint64_t i = 0; i < {Dim}; i++) {
IVDEP
ALIGNED
for (uint64_t j = 0; j < D{Dim}_P; j++) {
Slater_inv[i * D{Dim}_P + j] -= C[i * 3] * tmp[j];
Slater_inv[i * D{Dim}_P + j] -= C[i * 3 + 1] * tmp[D{Dim}_P + j];
Slater_inv[i * D{Dim}_P + j] -= C[i * 3 + 2] * tmp[2 * D{Dim}_P + j];
}
}
return QMCKL_SUCCESS;
}
#+end_src
#+NAME:woodbury_3x3_kernel_generator
#+begin_src python :noweb yes :exports none
text="""
<<woodbury_3x3_kernel_template>>
"""
result = []
for Dim in <<kernel_generator_range>>:
Dim=str(Dim)
result.append(text.replace("{Dim}",Dim) )
return '\n'.join(result)
#+end_src
#+NAME:woodbury_3x3_switch-case_generator
#+begin_src python :noweb yes :exports none
text="""
case {Dim}:
return qmckl_woodbury_3x3_{Dim}(context,
Updates,
Updates_index,
breakdown,
Slater_inv,
determinant);
"""
result = []
for Dim in <<kernel_generator_range>>:
Dim=str(Dim)
result.append(text.replace("{Dim}",Dim) )
return '\n'.join(result)
#+end_src
#+begin_src c :tangle (eval c) :comments org :noweb yes
<<woodbury_3x3_kernel_generator()>>
qmckl_exit_code qmckl_woodbury_3x3(const qmckl_context context,
const uint64_t LDS,
const uint64_t Dim,
const double* Updates,
const uint64_t* Updates_index,
const double breakdown,
double* Slater_inv,
double* determinant) {
if (qmckl_context_check(context) == QMCKL_NULL_CONTEXT) {
return qmckl_failwith(context,
QMCKL_NULL_CONTEXT,
"qmckl_woodbury_3x3",
NULL);
}
if (LDS == (1+(Dim-1)/SIMD_LENGTH)*SIMD_LENGTH) { // Most cases
switch (Dim) {
<<woodbury_3x3_switch-case_generator()>>
}
}
else { // When SIMD_LENGTH > 1, called with LDS == Dim AND Dim != (1+(Dim-1)/SIMD_LENGTH)*SIMD_LENGTH)
return qmckl_woodbury_3x3_hpc(context,
LDS,
Dim,
Updates,
Updates_index,
breakdown,
Slater_inv,
determinant);
}
return QMCKL_FAILURE;
}
#+end_src
*** Performance... *** Performance...
This function is most efficient when used in cases where there are only 3 rank-1 updates and This function is most efficient when used in cases where there are only 3 rank-1 updates and