Optimised matrix multiplication with V.

This commit is contained in:
François Coppens 2021-07-01 13:44:28 +02:00
parent e3dc3632a4
commit f6f8746bef
3 changed files with 46 additions and 31 deletions

View File

@ -56,6 +56,7 @@ void SM2(double *Slater_inv, unsigned int Dim, unsigned int N_updates,
#ifdef DEBUG1
std::cerr << "Called SM2 with " << N_updates << " updates" << std::endl;
#endif
double C[Dim];
double D[Dim];
@ -118,7 +119,10 @@ void SM2(double *Slater_inv, unsigned int Dim, unsigned int N_updates,
// Sherman Morrison, leaving zero denominators for later
void SM3(double *Slater_inv, unsigned int Dim, unsigned int N_updates,
double *Updates, unsigned int *Updates_index) {
#ifdef DEBUG1
std::cerr << "Called SM3 with " << N_updates << " updates" << std::endl;
#endif
double C[Dim];
double D[Dim];
@ -185,7 +189,10 @@ void SM3(double *Slater_inv, unsigned int Dim, unsigned int N_updates,
// (SM2)
void SM4(double *Slater_inv, unsigned int Dim, unsigned int N_updates,
double *Updates, unsigned int *Updates_index) {
#ifdef DEBUG1
std::cerr << "Called SM4 with " << N_updates << " updates" << std::endl;
#endif
double C[Dim];
double D[Dim];

View File

@ -22,11 +22,15 @@ bool WB2(double *Slater_inv, unsigned int Dim, double *Updates,
std::cerr << "Called Woodbury 2x2 kernel" << std::endl;
#endif
// Construct V from Updates_index
unsigned int V[2 * Dim]; // 2 x Dim matrix stored in row-major order
std::memset(V, 0, 2 * Dim * sizeof(unsigned int));
V[Updates_index[0] - 1] = 1;
V[Dim + Updates_index[1] - 1] = 1;
// Compute D = V.S^{-1}
double D[2 * Dim];
unsigned int row1, row2;
row1 = Updates_index[0] - 1;
row2 = Updates_index[1] - 1;
for (unsigned int i = 0; i < Dim; i++) {
D[i] = Slater_inv[row1 * Dim + i];
D[Dim + i] = Slater_inv[row2 * Dim + i];
}
// Compute C = S_inv * U !! NON-STANDARD MATRIX MULTIPLICATION BECAUSE
// OF LAYOUT OF 'Updates' !!
@ -39,16 +43,15 @@ bool WB2(double *Slater_inv, unsigned int Dim, double *Updates,
}
}
}
// Compute D = V * S^{-1}
double D[2 * Dim];
matMul2(V, Slater_inv, D, 2, Dim, Dim);
// matMul2(Updates, Slater_inv, C, 2, Dim, Dim);
// Compute B = 1 + V * C
double B[4];
matMul2(V, C, B, 2, Dim, 2);
B[0] += 1;
B[3] += 1;
B[0] = C[row1 * 2];
B[1] = C[row1 * 2 + 1];
B[2] = C[row2 * 2];
B[3] = C[row2 * 2 + 1];
B[0] += 1, B[3] += 1;
// Compute B^{-1} with explicit formula for 2x2 inversion
double idet = 1.0 / (B[0] * B[3] - B[1] * B[2]);
@ -101,16 +104,17 @@ bool WB3(double *Slater_inv, unsigned int Dim, double *Updates,
showMatrix2(Updates_index, 1, 3, "Updates_index");
#endif
// Construct V from Updates_index
unsigned int V[3 * Dim]; // 3 x Dim matrix stored in row-major order
std::memset(V, 0, 3 * Dim * sizeof(unsigned int));
V[Updates_index[0] - 1] = 1;
V[Dim + Updates_index[1] - 1] = 1;
V[2 * Dim + Updates_index[2] - 1] = 1;
#ifdef DEBUG2
showMatrix2(V, 3, Dim, "V");
#endif
// Compute D = V * S^{-1}
double D[3 * Dim];
unsigned int row1, row2, row3;
row1 = Updates_index[0] - 1;
row2 = Updates_index[1] - 1;
row3 = Updates_index[2] - 1;
for (unsigned int i = 0; i < Dim; i++) {
D[i] = Slater_inv[row1 * Dim + i];
D[Dim + i] = Slater_inv[row2 * Dim + i];
D[2 * Dim + i] = Slater_inv[row3 * Dim + i];
}
// Compute C = S_inv * U !! NON-STANDARD MATRIX MULTIPLICATION BECAUSE
// OF LAYOUT OF 'Updates' !!
@ -123,24 +127,28 @@ bool WB3(double *Slater_inv, unsigned int Dim, double *Updates,
}
}
}
// matMul2(Updates, Slater_inv, C, 2, Dim, Dim);
#ifdef DEBUG2
showMatrix2(C, Dim, 3, "C = S_inv * U");
#endif
// Compute D = V * S^{-1}
double D[3 * Dim];
matMul2(V, Slater_inv, D, 3, Dim, Dim);
#ifdef DEBUG2
showMatrix2(D, 3, Dim, "D = V * S_inv");
#endif
// Compute B = 1 + V * C
// Compute B = 1 + V.C
double B[9];
matMul2(V, C, B, 3, Dim, 3);
B[0] += 1;
B[4] += 1;
B[8] += 1;
B[0] = C[row1 * 3];
B[1] = C[row1 * 3 + 1];
B[2] = C[row1 * 3 + 2];
B[3] = C[row2 * 3];
B[4] = C[row2 * 3 + 1];
B[5] = C[row2 * 3 + 2];
B[6] = C[row3 * 3];
B[7] = C[row3 * 3 + 1];
B[8] = C[row3 * 3 + 2];
B[0] += 1, B[4] += 1, B[8] += 1;
#ifdef DEBUG2
showMatrix2(B, 3, 3, "B = 1 + V * C");

View File

@ -7,7 +7,7 @@
#include "Woodbury.hpp"
#include "SMWB.hpp"
// #define PERF
#define PERF
#ifdef PERF
unsigned int repetition_number;