mirror of
https://github.com/TREX-CoE/Sherman-Morrison.git
synced 2024-12-26 14:23:47 +01:00
Optimised matrix multiplication with V.
This commit is contained in:
parent
e3dc3632a4
commit
f6f8746bef
@ -56,6 +56,7 @@ void SM2(double *Slater_inv, unsigned int Dim, unsigned int N_updates,
|
|||||||
#ifdef DEBUG1
|
#ifdef DEBUG1
|
||||||
std::cerr << "Called SM2 with " << N_updates << " updates" << std::endl;
|
std::cerr << "Called SM2 with " << N_updates << " updates" << std::endl;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
double C[Dim];
|
double C[Dim];
|
||||||
double D[Dim];
|
double D[Dim];
|
||||||
|
|
||||||
@ -118,7 +119,10 @@ void SM2(double *Slater_inv, unsigned int Dim, unsigned int N_updates,
|
|||||||
// Sherman Morrison, leaving zero denominators for later
|
// Sherman Morrison, leaving zero denominators for later
|
||||||
void SM3(double *Slater_inv, unsigned int Dim, unsigned int N_updates,
|
void SM3(double *Slater_inv, unsigned int Dim, unsigned int N_updates,
|
||||||
double *Updates, unsigned int *Updates_index) {
|
double *Updates, unsigned int *Updates_index) {
|
||||||
|
#ifdef DEBUG1
|
||||||
std::cerr << "Called SM3 with " << N_updates << " updates" << std::endl;
|
std::cerr << "Called SM3 with " << N_updates << " updates" << std::endl;
|
||||||
|
#endif
|
||||||
|
|
||||||
double C[Dim];
|
double C[Dim];
|
||||||
double D[Dim];
|
double D[Dim];
|
||||||
|
|
||||||
@ -185,7 +189,10 @@ void SM3(double *Slater_inv, unsigned int Dim, unsigned int N_updates,
|
|||||||
// (SM2)
|
// (SM2)
|
||||||
void SM4(double *Slater_inv, unsigned int Dim, unsigned int N_updates,
|
void SM4(double *Slater_inv, unsigned int Dim, unsigned int N_updates,
|
||||||
double *Updates, unsigned int *Updates_index) {
|
double *Updates, unsigned int *Updates_index) {
|
||||||
|
#ifdef DEBUG1
|
||||||
std::cerr << "Called SM4 with " << N_updates << " updates" << std::endl;
|
std::cerr << "Called SM4 with " << N_updates << " updates" << std::endl;
|
||||||
|
#endif
|
||||||
|
|
||||||
double C[Dim];
|
double C[Dim];
|
||||||
double D[Dim];
|
double D[Dim];
|
||||||
|
|
||||||
|
@ -22,11 +22,15 @@ bool WB2(double *Slater_inv, unsigned int Dim, double *Updates,
|
|||||||
std::cerr << "Called Woodbury 2x2 kernel" << std::endl;
|
std::cerr << "Called Woodbury 2x2 kernel" << std::endl;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Construct V from Updates_index
|
// Compute D = V.S^{-1}
|
||||||
unsigned int V[2 * Dim]; // 2 x Dim matrix stored in row-major order
|
double D[2 * Dim];
|
||||||
std::memset(V, 0, 2 * Dim * sizeof(unsigned int));
|
unsigned int row1, row2;
|
||||||
V[Updates_index[0] - 1] = 1;
|
row1 = Updates_index[0] - 1;
|
||||||
V[Dim + Updates_index[1] - 1] = 1;
|
row2 = Updates_index[1] - 1;
|
||||||
|
for (unsigned int i = 0; i < Dim; i++) {
|
||||||
|
D[i] = Slater_inv[row1 * Dim + i];
|
||||||
|
D[Dim + i] = Slater_inv[row2 * Dim + i];
|
||||||
|
}
|
||||||
|
|
||||||
// Compute C = S_inv * U !! NON-STANDARD MATRIX MULTIPLICATION BECAUSE
|
// Compute C = S_inv * U !! NON-STANDARD MATRIX MULTIPLICATION BECAUSE
|
||||||
// OF LAYOUT OF 'Updates' !!
|
// OF LAYOUT OF 'Updates' !!
|
||||||
@ -39,16 +43,15 @@ bool WB2(double *Slater_inv, unsigned int Dim, double *Updates,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// matMul2(Updates, Slater_inv, C, 2, Dim, Dim);
|
||||||
// Compute D = V * S^{-1}
|
|
||||||
double D[2 * Dim];
|
|
||||||
matMul2(V, Slater_inv, D, 2, Dim, Dim);
|
|
||||||
|
|
||||||
// Compute B = 1 + V * C
|
// Compute B = 1 + V * C
|
||||||
double B[4];
|
double B[4];
|
||||||
matMul2(V, C, B, 2, Dim, 2);
|
B[0] = C[row1 * 2];
|
||||||
B[0] += 1;
|
B[1] = C[row1 * 2 + 1];
|
||||||
B[3] += 1;
|
B[2] = C[row2 * 2];
|
||||||
|
B[3] = C[row2 * 2 + 1];
|
||||||
|
B[0] += 1, B[3] += 1;
|
||||||
|
|
||||||
// Compute B^{-1} with explicit formula for 2x2 inversion
|
// Compute B^{-1} with explicit formula for 2x2 inversion
|
||||||
double idet = 1.0 / (B[0] * B[3] - B[1] * B[2]);
|
double idet = 1.0 / (B[0] * B[3] - B[1] * B[2]);
|
||||||
@ -101,16 +104,17 @@ bool WB3(double *Slater_inv, unsigned int Dim, double *Updates,
|
|||||||
showMatrix2(Updates_index, 1, 3, "Updates_index");
|
showMatrix2(Updates_index, 1, 3, "Updates_index");
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Construct V from Updates_index
|
// Compute D = V * S^{-1}
|
||||||
unsigned int V[3 * Dim]; // 3 x Dim matrix stored in row-major order
|
double D[3 * Dim];
|
||||||
std::memset(V, 0, 3 * Dim * sizeof(unsigned int));
|
unsigned int row1, row2, row3;
|
||||||
V[Updates_index[0] - 1] = 1;
|
row1 = Updates_index[0] - 1;
|
||||||
V[Dim + Updates_index[1] - 1] = 1;
|
row2 = Updates_index[1] - 1;
|
||||||
V[2 * Dim + Updates_index[2] - 1] = 1;
|
row3 = Updates_index[2] - 1;
|
||||||
|
for (unsigned int i = 0; i < Dim; i++) {
|
||||||
#ifdef DEBUG2
|
D[i] = Slater_inv[row1 * Dim + i];
|
||||||
showMatrix2(V, 3, Dim, "V");
|
D[Dim + i] = Slater_inv[row2 * Dim + i];
|
||||||
#endif
|
D[2 * Dim + i] = Slater_inv[row3 * Dim + i];
|
||||||
|
}
|
||||||
|
|
||||||
// Compute C = S_inv * U !! NON-STANDARD MATRIX MULTIPLICATION BECAUSE
|
// Compute C = S_inv * U !! NON-STANDARD MATRIX MULTIPLICATION BECAUSE
|
||||||
// OF LAYOUT OF 'Updates' !!
|
// OF LAYOUT OF 'Updates' !!
|
||||||
@ -123,24 +127,28 @@ bool WB3(double *Slater_inv, unsigned int Dim, double *Updates,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// matMul2(Updates, Slater_inv, C, 2, Dim, Dim);
|
||||||
|
|
||||||
#ifdef DEBUG2
|
#ifdef DEBUG2
|
||||||
showMatrix2(C, Dim, 3, "C = S_inv * U");
|
showMatrix2(C, Dim, 3, "C = S_inv * U");
|
||||||
#endif
|
#endif
|
||||||
// Compute D = V * S^{-1}
|
|
||||||
double D[3 * Dim];
|
|
||||||
matMul2(V, Slater_inv, D, 3, Dim, Dim);
|
|
||||||
|
|
||||||
#ifdef DEBUG2
|
#ifdef DEBUG2
|
||||||
showMatrix2(D, 3, Dim, "D = V * S_inv");
|
showMatrix2(D, 3, Dim, "D = V * S_inv");
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Compute B = 1 + V * C
|
// Compute B = 1 + V.C
|
||||||
double B[9];
|
double B[9];
|
||||||
matMul2(V, C, B, 3, Dim, 3);
|
B[0] = C[row1 * 3];
|
||||||
B[0] += 1;
|
B[1] = C[row1 * 3 + 1];
|
||||||
B[4] += 1;
|
B[2] = C[row1 * 3 + 2];
|
||||||
B[8] += 1;
|
B[3] = C[row2 * 3];
|
||||||
|
B[4] = C[row2 * 3 + 1];
|
||||||
|
B[5] = C[row2 * 3 + 2];
|
||||||
|
B[6] = C[row3 * 3];
|
||||||
|
B[7] = C[row3 * 3 + 1];
|
||||||
|
B[8] = C[row3 * 3 + 2];
|
||||||
|
B[0] += 1, B[4] += 1, B[8] += 1;
|
||||||
|
|
||||||
#ifdef DEBUG2
|
#ifdef DEBUG2
|
||||||
showMatrix2(B, 3, 3, "B = 1 + V * C");
|
showMatrix2(B, 3, 3, "B = 1 + V * C");
|
||||||
|
@ -7,7 +7,7 @@
|
|||||||
#include "Woodbury.hpp"
|
#include "Woodbury.hpp"
|
||||||
#include "SMWB.hpp"
|
#include "SMWB.hpp"
|
||||||
|
|
||||||
// #define PERF
|
#define PERF
|
||||||
|
|
||||||
#ifdef PERF
|
#ifdef PERF
|
||||||
unsigned int repetition_number;
|
unsigned int repetition_number;
|
||||||
|
Loading…
Reference in New Issue
Block a user