mirror of
https://github.com/TREX-CoE/Sherman-Morrison.git
synced 2025-01-14 06:46:12 +01:00
Cleanup: consolidated some pragmas.
This commit is contained in:
parent
4e7a334b78
commit
a63b1289d4
@ -283,14 +283,12 @@ uint32_t qmckl_woodbury_k(const uint64_t vLDS,
|
|||||||
// Compute C = S^{-1} U : Dim x K : standard dgemm
|
// Compute C = S^{-1} U : Dim x K : standard dgemm
|
||||||
double *C = calloc(1, DIM * N_updates * sizeof(double));
|
double *C = calloc(1, DIM * N_updates * sizeof(double));
|
||||||
double alpha = 1.0, beta = 0.0;
|
double alpha = 1.0, beta = 0.0;
|
||||||
|
|
||||||
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
||||||
Dim, N_updates, Lds,
|
Dim, N_updates, Lds,
|
||||||
alpha, Slater_inv, Lds, Updates, Lds,
|
alpha, Slater_inv, Lds, Updates, Lds,
|
||||||
beta, C, N_updates);
|
beta, C, N_updates);
|
||||||
|
|
||||||
// Construct B = 1 + V C : K x K : selecting and copying row from C into B. Can maybe be off-loaded to GPU by splitting in N_updates tiles of N_updates strides, using PARALLEL and SIMD
|
// Construct B = 1 + V C : K x K, construct D = V S^{-1} : K x LDS
|
||||||
// Construct D = V S^{-1} : K x LDS
|
|
||||||
double B[N_updates * N_updates], D[N_updates * LDS];
|
double B[N_updates * N_updates], D[N_updates * LDS];
|
||||||
for (uint32_t i = 0; i < N_updates; i++) {
|
for (uint32_t i = 0; i < N_updates; i++) {
|
||||||
const uint32_t row = Updates_index[i] - 1;
|
const uint32_t row = Updates_index[i] - 1;
|
||||||
@ -299,27 +297,23 @@ uint32_t qmckl_woodbury_k(const uint64_t vLDS,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Compute determinant by LU decomposition
|
// Compute determinant by LU decomposition
|
||||||
int* ipiv = calloc(1, sizeof *ipiv * N_updates);
|
int* pivot = calloc(1, sizeof *pivot * N_updates);
|
||||||
(void) LAPACKE_dgetrf(LAPACK_ROW_MAJOR, N_updates, N_updates, B, N_updates, ipiv);
|
(void) LAPACKE_dgetrf(LAPACK_ROW_MAJOR, N_updates, N_updates, B, N_updates, pivot);
|
||||||
|
|
||||||
double det = 1.0;
|
bool swap = false; uint32_t j = 0; double det = 1.0f;
|
||||||
int j = 0;
|
|
||||||
for (uint32_t i = 0; i < N_updates; i++) {
|
for (uint32_t i = 0; i < N_updates; i++) {
|
||||||
j += min(ipiv[i] - i, 1);
|
swap = (bool)(pivot[i] - (i + 1)); // swap = {0->false: no swap, >0->true: swap}
|
||||||
det *= B[(N_updates + 1) * i];
|
j += (uint32_t)swap; // count # of swaps
|
||||||
|
det *= B[i * (N_updates + 1)]; // prod. of diag elm. of B
|
||||||
}
|
}
|
||||||
if ((j & 1) == 0) det = -det; // multiply det with -1 if j is even
|
if (fabs(det) < breakdown) return 1; // check if determinant of B is too close to zero. If so, exit early.
|
||||||
|
if (determinant) { // update det(Slater) if determinant != NULL
|
||||||
// Check if determinant of B is not too close to zero
|
if ((j & 1) != 0) det = -det; // multiply det with -1 if # of swaps is odd
|
||||||
if (fabs(det) < breakdown) {
|
*determinant *= det;
|
||||||
return 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update det(Slater) if passed
|
|
||||||
if (determinant) *determinant *= det;
|
|
||||||
|
|
||||||
// Compute B^{-1} with explicit formula for K x K inversion
|
// Compute B^{-1} with explicit formula for K x K inversion
|
||||||
(void) LAPACKE_dgetri(LAPACK_ROW_MAJOR, N_updates, B, N_updates, ipiv);
|
(void) LAPACKE_dgetri(LAPACK_ROW_MAJOR, N_updates, B, N_updates, pivot);
|
||||||
|
|
||||||
// tmp1 = B^{-1} D : KxLDS = KxK X KxLDS : standard dgemm
|
// tmp1 = B^{-1} D : KxLDS = KxK X KxLDS : standard dgemm
|
||||||
double tmp1[N_updates * LDS];
|
double tmp1[N_updates * LDS];
|
||||||
@ -335,7 +329,7 @@ uint32_t qmckl_woodbury_k(const uint64_t vLDS,
|
|||||||
alpha, C, N_updates, tmp1, LDS,
|
alpha, C, N_updates, tmp1, LDS,
|
||||||
beta, Slater_inv, LDS);
|
beta, Slater_inv, LDS);
|
||||||
|
|
||||||
free(ipiv);
|
free(pivot);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -354,7 +348,7 @@ uint32_t qmckl_woodbury_k_cublas_offload(cublasHandle_t b_handle, cusolverDnHand
|
|||||||
const uint32_t Lds = LDS;
|
const uint32_t Lds = LDS;
|
||||||
|
|
||||||
double alpha, beta;
|
double alpha, beta;
|
||||||
int* ipiv = calloc(1, sizeof *ipiv * N_updates);
|
int* pivot = calloc(1, sizeof *pivot * N_updates);
|
||||||
double* C = calloc(1, sizeof *C * Dim * N_updates);
|
double* C = calloc(1, sizeof *C * Dim * N_updates);
|
||||||
double* B = calloc(1, sizeof *B * N_updates * N_updates);
|
double* B = calloc(1, sizeof *B * N_updates * N_updates);
|
||||||
double* Binv = calloc(1, sizeof *Binv * N_updates * N_updates);
|
double* Binv = calloc(1, sizeof *Binv * N_updates * N_updates);
|
||||||
@ -362,9 +356,9 @@ uint32_t qmckl_woodbury_k_cublas_offload(cublasHandle_t b_handle, cusolverDnHand
|
|||||||
double* T1 = calloc(1, sizeof *T1 * N_updates * Lds);
|
double* T1 = calloc(1, sizeof *T1 * N_updates * Lds);
|
||||||
double* T2 = calloc(1, sizeof *T2 * Dim * Lds);
|
double* T2 = calloc(1, sizeof *T2 * Dim * Lds);
|
||||||
|
|
||||||
int lwork = 0, *info = NULL; double* d_work = NULL;
|
int workspace_size = 0, *info = NULL; double* workspace = NULL;
|
||||||
cusolverDnDgetrf_bufferSize(s_handle, N_updates, N_updates, B, N_updates, &lwork);
|
cusolverDnDgetrf_bufferSize(s_handle, N_updates, N_updates, B, N_updates, &workspace_size);
|
||||||
d_work = calloc(1, sizeof *d_work * lwork);
|
workspace = calloc(1, sizeof *workspace * workspace_size);
|
||||||
|
|
||||||
#pragma omp target enter data map(to: Updates[0:Lds*N_updates], \
|
#pragma omp target enter data map(to: Updates[0:Lds*N_updates], \
|
||||||
Updates_index[0:N_updates], \
|
Updates_index[0:N_updates], \
|
||||||
@ -375,96 +369,78 @@ uint32_t qmckl_woodbury_k_cublas_offload(cublasHandle_t b_handle, cusolverDnHand
|
|||||||
D[0:N_updates*Lds], \
|
D[0:N_updates*Lds], \
|
||||||
T1[0:N_updates*Lds], \
|
T1[0:N_updates*Lds], \
|
||||||
T2[0:Dim*Lds], \
|
T2[0:Dim*Lds], \
|
||||||
ipiv[0:N_updates], \
|
pivot[0:N_updates], \
|
||||||
d_work[0:lwork])
|
workspace[0:workspace_size])
|
||||||
|
|
||||||
#pragma omp target data use_device_ptr(Slater_inv, Updates, C) // compute C ON DEVICE
|
#pragma omp target data use_device_ptr(Slater_inv, Updates, C, B, workspace, pivot, Binv, D, T1, T2)
|
||||||
{
|
{
|
||||||
|
// Compute C <- S^{-1} U : Dim x K : standard dgemm
|
||||||
alpha = 1.0f, beta = 0.0f;
|
alpha = 1.0f, beta = 0.0f;
|
||||||
(void) cublasDgemm(b_handle,
|
(void) cublasDgemm(b_handle,
|
||||||
CUBLAS_OP_T, CUBLAS_OP_N,
|
CUBLAS_OP_T, CUBLAS_OP_N,
|
||||||
N_updates, Dim, Lds,
|
N_updates, Dim, Lds,
|
||||||
&alpha, Updates, Lds, Slater_inv, Lds,
|
&alpha, Updates, Lds, Slater_inv, Lds,
|
||||||
&beta, C, N_updates);
|
&beta, C, N_updates);
|
||||||
}
|
|
||||||
|
|
||||||
// Construct B = 1 + V C : K x K
|
// Construct B <- 1 + V C : K x K, construct D = V S^{-1} : K x LDS
|
||||||
// Construct D = V S^{-1} : K x LDS
|
|
||||||
#pragma omp target teams distribute parallel for // compute B, D ON DEVICE
|
#pragma omp target teams distribute parallel for // compute B, D ON DEVICE
|
||||||
for (uint32_t i = 0; i < N_updates; i++)
|
for (uint32_t i = 0; i < N_updates; i++) {
|
||||||
{
|
|
||||||
const uint32_t row = Updates_index[i] - 1;
|
const uint32_t row = Updates_index[i] - 1;
|
||||||
for (uint32_t j = 0; j < N_updates ; j++)
|
for (uint32_t j = 0; j < N_updates ; j++) {
|
||||||
{
|
|
||||||
B[j * N_updates + i] = C[row * N_updates + j] + (i == j); // B NEEDS TO BE IN COL-MAJ FOR cusolverDnDgetrf !
|
B[j * N_updates + i] = C[row * N_updates + j] + (i == j); // B NEEDS TO BE IN COL-MAJ FOR cusolverDnDgetrf !
|
||||||
}
|
}
|
||||||
for (uint32_t j = 0; j < Lds; j++)
|
for (uint32_t j = 0; j < Lds; j++) {
|
||||||
{
|
|
||||||
D[i * Lds + j] = Slater_inv[row * Lds + j];
|
D[i * Lds + j] = Slater_inv[row * Lds + j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute determinant by LU decomposition
|
// Compute determinant by LU decomposition
|
||||||
#pragma omp target data use_device_ptr(B, d_work, ipiv) // compute C ON DEVICE
|
(void) cusolverDnDgetrf(s_handle, N_updates, N_updates, B, N_updates, workspace, pivot, info);
|
||||||
{
|
|
||||||
(void) cusolverDnDgetrf(s_handle, N_updates, N_updates, B, N_updates, d_work, ipiv, info);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool swap = false; uint32_t j = 0; double det = 1.0f;
|
bool swap = false; uint32_t j = 0; double det = 1.0f;
|
||||||
#pragma omp target teams distribute parallel for reduction(+: j) reduction(*: det)
|
#pragma omp target teams distribute parallel for reduction(+: j) reduction(*: det)
|
||||||
for (uint32_t i = 0; i < N_updates; i++)
|
for (uint32_t i = 0; i < N_updates; i++) {
|
||||||
{
|
swap = (bool)(pivot[i] - (i + 1)); // swap = {0->false: no swap, >0->true: swap}
|
||||||
swap = (bool)(ipiv[i] - (i + 1)); // swap = {0->false: no swap, >0->true: swap}
|
|
||||||
j += (uint32_t)swap; // count # of swaps
|
j += (uint32_t)swap; // count # of swaps
|
||||||
det *= B[i * (N_updates + 1)]; // prod. of diag elm. of B
|
det *= B[i * (N_updates + 1)]; // prod. of diag elm. of B
|
||||||
}
|
}
|
||||||
if (fabs(det) < breakdown) return det; // check if determinant of B is too close to zero. If so, exit early.
|
if (fabs(det) < breakdown) return 1; // check if determinant of B is too close to zero. If so, exit early.
|
||||||
|
if (determinant) { // update det(Slater) if determinant != NULL
|
||||||
if ((j & 1) != 0) det = -det; // multiply det with -1 if # of swaps is odd
|
if ((j & 1) != 0) det = -det; // multiply det with -1 if # of swaps is odd
|
||||||
|
*determinant *= det;
|
||||||
|
}
|
||||||
|
|
||||||
if (determinant) *determinant *= det; // update det(Slater) if determinant!=NULL
|
// Compute B^{-1} : initialise as I for solving BX=I
|
||||||
// Compute B^{-1}
|
|
||||||
#pragma omp target teams distribute parallel for
|
#pragma omp target teams distribute parallel for
|
||||||
for (int i = 0; i < N_updates; ++i) {
|
for (int i = 0; i < N_updates; ++i) {
|
||||||
for (int j = 0; j < N_updates; ++j) {
|
for (int j = 0; j < N_updates; ++j) {
|
||||||
Binv[i * N_updates + j] = (i == j);
|
Binv[i * N_updates + j] = (i == j);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
(void) cusolverDnDgetrs(s_handle, CUBLAS_OP_N, N_updates, N_updates, B, N_updates, pivot, Binv, N_updates, info);
|
||||||
|
|
||||||
#pragma omp target data use_device_ptr(B, ipiv, Binv) // correct result Binv, but in CM!
|
// T1 <- B^{-1} D : KxLDS : standard dgemm
|
||||||
{
|
|
||||||
(void) cusolverDnDgetrs(s_handle, CUBLAS_OP_N, N_updates, N_updates, B, N_updates, ipiv, Binv, N_updates, info);
|
|
||||||
}
|
|
||||||
|
|
||||||
// T1 = B^{-1} D : KxLDS = KxK X KxLDS : standard dgemm
|
|
||||||
#pragma omp target data use_device_ptr(D, Binv, T1) // compute T1 ON DEVICE
|
|
||||||
{
|
|
||||||
alpha = 1.0, beta = 0.0;
|
alpha = 1.0, beta = 0.0;
|
||||||
(void) cublasDgemm(b_handle,
|
(void) cublasDgemm(b_handle,
|
||||||
CUBLAS_OP_N, CUBLAS_OP_T, // REMEMBER THIS IS TMP TRANSPOSED BECAUSE OF LAPACKE CALL ON l429 !!!
|
CUBLAS_OP_N, CUBLAS_OP_T, // REMEMBER THIS IS Binv TRANSPOSED BECAUSE OF cusolverDnDgetrs CALL ON l.434 !!!
|
||||||
Lds, N_updates, N_updates,
|
Lds, N_updates, N_updates,
|
||||||
&alpha, D, Lds, Binv, N_updates,
|
&alpha, D, Lds, Binv, N_updates,
|
||||||
&beta, T1, Lds);
|
&beta, T1, Lds);
|
||||||
}
|
|
||||||
|
|
||||||
// Compute T2 <- C * T1 : Dim x LDS : standard dgemm
|
// Compute T2 <- C * T1 : Dim x LDS : standard dgemm
|
||||||
#pragma omp target data use_device_ptr(T1, C, T2) // compute T2 ONM DEVICE
|
|
||||||
{
|
|
||||||
alpha = 1.0f, beta = 0.0f;
|
alpha = 1.0f, beta = 0.0f;
|
||||||
(void) cublasDgemm(b_handle,
|
(void) cublasDgemm(b_handle,
|
||||||
CUBLAS_OP_N, CUBLAS_OP_N,
|
CUBLAS_OP_N, CUBLAS_OP_N,
|
||||||
Dim, Lds, N_updates,
|
Dim, Lds, N_updates,
|
||||||
&alpha, T1, Lds, C, N_updates,
|
&alpha, T1, Lds, C, N_updates,
|
||||||
&beta, T2, Lds);
|
&beta, T2, Lds);
|
||||||
}
|
|
||||||
|
|
||||||
// Compute S^{-1} <- S^{-1} - T2 : Dim x LDS : standard dgemm
|
// Compute S^{-1} <- S^{-1} - T2 : Dim x LDS : standard dgemm
|
||||||
#pragma omp target teams distribute parallel for // compute S^-1 ON DEVICE
|
#pragma omp target teams distribute parallel for // compute S^-1 ON DEVICE
|
||||||
for (uint32_t i = 0; i < Dim * Lds; i++)
|
for (uint32_t i = 0; i < Dim * Lds; i++) {
|
||||||
{
|
|
||||||
Slater_inv[i] = Slater_inv[i] - T2[i];
|
Slater_inv[i] = Slater_inv[i] - T2[i];
|
||||||
}
|
}
|
||||||
|
}
|
||||||
#pragma omp target update from(Slater_inv[0:Dim*Lds]) // update S^-1 ON HOST
|
#pragma omp target update from(Slater_inv[0:Dim*Lds]) // update S^-1 ON HOST
|
||||||
|
|
||||||
#pragma omp target exit data map(delete: Updates[0:Lds*N_updates], \
|
#pragma omp target exit data map(delete: Updates[0:Lds*N_updates], \
|
||||||
Updates_index[0:N_updates], \
|
Updates_index[0:N_updates], \
|
||||||
Slater_inv[0:Dim*Lds], \
|
Slater_inv[0:Dim*Lds], \
|
||||||
@ -474,16 +450,14 @@ uint32_t qmckl_woodbury_k_cublas_offload(cublasHandle_t b_handle, cusolverDnHand
|
|||||||
D[0:N_updates*Lds], \
|
D[0:N_updates*Lds], \
|
||||||
T1[0:N_updates*Lds], \
|
T1[0:N_updates*Lds], \
|
||||||
T2[0:Dim*Lds], \
|
T2[0:Dim*Lds], \
|
||||||
ipiv[0:N_updates])
|
pivot[0:N_updates])
|
||||||
|
free(pivot);
|
||||||
free(ipiv);
|
|
||||||
free(B);
|
free(B);
|
||||||
free(Binv);
|
free(Binv);
|
||||||
free(C);
|
free(C);
|
||||||
free(D);
|
free(D);
|
||||||
free(T1);
|
free(T1);
|
||||||
free(T2);
|
free(T2);
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@ -809,10 +783,10 @@ uint32_t qmckl_sherman_morrison_later(
|
|||||||
// ret < 0 illegal argument value
|
// ret < 0 illegal argument value
|
||||||
// ret > 0 singular matrix
|
// ret > 0 singular matrix
|
||||||
lapack_int inverse(double *a, uint64_t m, uint64_t n) {
|
lapack_int inverse(double *a, uint64_t m, uint64_t n) {
|
||||||
int ipiv[m + 1];
|
int pivot[m + 1];
|
||||||
lapack_int ret;
|
lapack_int ret;
|
||||||
ret = LAPACKE_dgetrf(LAPACK_ROW_MAJOR, m, n, a, n, ipiv);
|
ret = LAPACKE_dgetrf(LAPACK_ROW_MAJOR, m, n, a, n, pivot);
|
||||||
if (ret != 0) return ret;
|
if (ret != 0) return ret;
|
||||||
ret = LAPACKE_dgetri(LAPACK_ROW_MAJOR, n, a, n, ipiv);
|
ret = LAPACKE_dgetri(LAPACK_ROW_MAJOR, n, a, n, pivot);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user