Optimised matrix multiplication with V.

This commit is contained in:
François Coppens 2021-07-01 13:44:28 +02:00
parent e3dc3632a4
commit f6f8746bef
3 changed files with 46 additions and 31 deletions

View File

@ -56,6 +56,7 @@ void SM2(double *Slater_inv, unsigned int Dim, unsigned int N_updates,
#ifdef DEBUG1 #ifdef DEBUG1
std::cerr << "Called SM2 with " << N_updates << " updates" << std::endl; std::cerr << "Called SM2 with " << N_updates << " updates" << std::endl;
#endif #endif
double C[Dim]; double C[Dim];
double D[Dim]; double D[Dim];
@ -118,7 +119,10 @@ void SM2(double *Slater_inv, unsigned int Dim, unsigned int N_updates,
// Sherman Morrison, leaving zero denominators for later // Sherman Morrison, leaving zero denominators for later
void SM3(double *Slater_inv, unsigned int Dim, unsigned int N_updates, void SM3(double *Slater_inv, unsigned int Dim, unsigned int N_updates,
double *Updates, unsigned int *Updates_index) { double *Updates, unsigned int *Updates_index) {
#ifdef DEBUG1
std::cerr << "Called SM3 with " << N_updates << " updates" << std::endl; std::cerr << "Called SM3 with " << N_updates << " updates" << std::endl;
#endif
double C[Dim]; double C[Dim];
double D[Dim]; double D[Dim];
@ -185,7 +189,10 @@ void SM3(double *Slater_inv, unsigned int Dim, unsigned int N_updates,
// (SM2) // (SM2)
void SM4(double *Slater_inv, unsigned int Dim, unsigned int N_updates, void SM4(double *Slater_inv, unsigned int Dim, unsigned int N_updates,
double *Updates, unsigned int *Updates_index) { double *Updates, unsigned int *Updates_index) {
#ifdef DEBUG1
std::cerr << "Called SM4 with " << N_updates << " updates" << std::endl; std::cerr << "Called SM4 with " << N_updates << " updates" << std::endl;
#endif
double C[Dim]; double C[Dim];
double D[Dim]; double D[Dim];

View File

@ -22,11 +22,15 @@ bool WB2(double *Slater_inv, unsigned int Dim, double *Updates,
std::cerr << "Called Woodbury 2x2 kernel" << std::endl; std::cerr << "Called Woodbury 2x2 kernel" << std::endl;
#endif #endif
// Construct V from Updates_index // Compute D = V.S^{-1}
unsigned int V[2 * Dim]; // 2 x Dim matrix stored in row-major order double D[2 * Dim];
std::memset(V, 0, 2 * Dim * sizeof(unsigned int)); unsigned int row1, row2;
V[Updates_index[0] - 1] = 1; row1 = Updates_index[0] - 1;
V[Dim + Updates_index[1] - 1] = 1; row2 = Updates_index[1] - 1;
for (unsigned int i = 0; i < Dim; i++) {
D[i] = Slater_inv[row1 * Dim + i];
D[Dim + i] = Slater_inv[row2 * Dim + i];
}
// Compute C = S_inv * U !! NON-STANDARD MATRIX MULTIPLICATION BECAUSE // Compute C = S_inv * U !! NON-STANDARD MATRIX MULTIPLICATION BECAUSE
// OF LAYOUT OF 'Updates' !! // OF LAYOUT OF 'Updates' !!
@ -39,16 +43,15 @@ bool WB2(double *Slater_inv, unsigned int Dim, double *Updates,
} }
} }
} }
// matMul2(Updates, Slater_inv, C, 2, Dim, Dim);
// Compute D = V * S^{-1}
double D[2 * Dim];
matMul2(V, Slater_inv, D, 2, Dim, Dim);
// Compute B = 1 + V * C // Compute B = 1 + V * C
double B[4]; double B[4];
matMul2(V, C, B, 2, Dim, 2); B[0] = C[row1 * 2];
B[0] += 1; B[1] = C[row1 * 2 + 1];
B[3] += 1; B[2] = C[row2 * 2];
B[3] = C[row2 * 2 + 1];
B[0] += 1, B[3] += 1;
// Compute B^{-1} with explicit formula for 2x2 inversion // Compute B^{-1} with explicit formula for 2x2 inversion
double idet = 1.0 / (B[0] * B[3] - B[1] * B[2]); double idet = 1.0 / (B[0] * B[3] - B[1] * B[2]);
@ -101,16 +104,17 @@ bool WB3(double *Slater_inv, unsigned int Dim, double *Updates,
showMatrix2(Updates_index, 1, 3, "Updates_index"); showMatrix2(Updates_index, 1, 3, "Updates_index");
#endif #endif
// Construct V from Updates_index // Compute D = V * S^{-1}
unsigned int V[3 * Dim]; // 3 x Dim matrix stored in row-major order double D[3 * Dim];
std::memset(V, 0, 3 * Dim * sizeof(unsigned int)); unsigned int row1, row2, row3;
V[Updates_index[0] - 1] = 1; row1 = Updates_index[0] - 1;
V[Dim + Updates_index[1] - 1] = 1; row2 = Updates_index[1] - 1;
V[2 * Dim + Updates_index[2] - 1] = 1; row3 = Updates_index[2] - 1;
for (unsigned int i = 0; i < Dim; i++) {
#ifdef DEBUG2 D[i] = Slater_inv[row1 * Dim + i];
showMatrix2(V, 3, Dim, "V"); D[Dim + i] = Slater_inv[row2 * Dim + i];
#endif D[2 * Dim + i] = Slater_inv[row3 * Dim + i];
}
// Compute C = S_inv * U !! NON-STANDARD MATRIX MULTIPLICATION BECAUSE // Compute C = S_inv * U !! NON-STANDARD MATRIX MULTIPLICATION BECAUSE
// OF LAYOUT OF 'Updates' !! // OF LAYOUT OF 'Updates' !!
@ -123,24 +127,28 @@ bool WB3(double *Slater_inv, unsigned int Dim, double *Updates,
} }
} }
} }
// matMul2(Updates, Slater_inv, C, 2, Dim, Dim);
#ifdef DEBUG2 #ifdef DEBUG2
showMatrix2(C, Dim, 3, "C = S_inv * U"); showMatrix2(C, Dim, 3, "C = S_inv * U");
#endif #endif
// Compute D = V * S^{-1}
double D[3 * Dim];
matMul2(V, Slater_inv, D, 3, Dim, Dim);
#ifdef DEBUG2 #ifdef DEBUG2
showMatrix2(D, 3, Dim, "D = V * S_inv"); showMatrix2(D, 3, Dim, "D = V * S_inv");
#endif #endif
// Compute B = 1 + V * C // Compute B = 1 + V.C
double B[9]; double B[9];
matMul2(V, C, B, 3, Dim, 3); B[0] = C[row1 * 3];
B[0] += 1; B[1] = C[row1 * 3 + 1];
B[4] += 1; B[2] = C[row1 * 3 + 2];
B[8] += 1; B[3] = C[row2 * 3];
B[4] = C[row2 * 3 + 1];
B[5] = C[row2 * 3 + 2];
B[6] = C[row3 * 3];
B[7] = C[row3 * 3 + 1];
B[8] = C[row3 * 3 + 2];
B[0] += 1, B[4] += 1, B[8] += 1;
#ifdef DEBUG2 #ifdef DEBUG2
showMatrix2(B, 3, 3, "B = 1 + V * C"); showMatrix2(B, 3, 3, "B = 1 + V * C");

View File

@ -7,7 +7,7 @@
#include "Woodbury.hpp" #include "Woodbury.hpp"
#include "SMWB.hpp" #include "SMWB.hpp"
// #define PERF #define PERF
#ifdef PERF #ifdef PERF
unsigned int repetition_number; unsigned int repetition_number;