- LAPACKE_dgetrf/ri replaced with cusolverDnDgetrf/rs.

- Solved sign bug in computation of determinant.

Most code is now executed on the device. Some openMP pragmas can be consolidated.
This commit is contained in:
Francois Coppens 2022-09-26 17:02:58 +02:00
parent 5a61ccc6b1
commit 4e7a334b78
2 changed files with 98 additions and 91 deletions

View File

@ -7,7 +7,6 @@
#ifdef HAVE_CUBLAS_OFFLOAD
#include <stdio.h>
#include <omp.h>
#include <cublas_v2.h>
#include <cusolverDn.h>
#include <cusolver_common.h>

View File

@ -1,5 +1,6 @@
#include <math.h>
#include <stdint.h>
#include <stdbool.h>
#include "kernels.h"
#include "debug.h"
@ -356,19 +357,20 @@ uint32_t qmckl_woodbury_k_cublas_offload(cublasHandle_t b_handle, cusolverDnHand
int* ipiv = calloc(1, sizeof *ipiv * N_updates);
double* C = calloc(1, sizeof *C * Dim * N_updates);
double* B = calloc(1, sizeof *B * N_updates * N_updates);
double* Binv = calloc(1, sizeof *Binv * N_updates * N_updates);
double* D = calloc(1, sizeof *D * N_updates * Lds);
double* T1 = calloc(1, sizeof *T1 * N_updates * Lds);
double* T2 = calloc(1, sizeof *T2 * Dim * Lds);
int lwork = 0, *info = NULL; double* d_work = NULL;
cusolverDnDgetrf_bufferSize(s_handle, N_updates, N_updates, B, N_updates, &lwork);
printf("Size of lwork = %d\n", lwork);
d_work = calloc(1, sizeof *d_work * lwork);
#pragma omp target enter data map(to: Updates[0:Lds*N_updates], \
Updates_index[0:N_updates], \
Slater_inv[0:Dim*Lds]) \
map(alloc: B[0:N_updates*N_updates], \
Binv[0:N_updates*N_updates], \
C[0:Dim*N_updates], \
D[0:N_updates*Lds], \
T1[0:N_updates*Lds], \
@ -408,35 +410,39 @@ uint32_t qmckl_woodbury_k_cublas_offload(cublasHandle_t b_handle, cusolverDnHand
(void) cusolverDnDgetrf(s_handle, N_updates, N_updates, B, N_updates, d_work, ipiv, info);
}
double det = 1.0f; uint32_t j = 0;
#pragma omp target teams distribute parallel for // compute det ON DEVICE
bool swap = false; uint32_t j = 0; double det = 1.0f;
#pragma omp target teams distribute parallel for reduction(+: j) reduction(*: det)
for (uint32_t i = 0; i < N_updates; i++)
{
int row = ipiv[i] - i;
j += min(row, 1);
det *= B[(N_updates + 1) * i]; // update determinant
swap = (bool)(ipiv[i] - (i + 1)); // swap = {0->false: no swap, >0->true: swap}
j += (uint32_t)swap; // count # of swaps
det *= B[i * (N_updates + 1)]; // prod. of diag elm. of B
}
if ((j & 1) == 0) det = -det; // multiply det with -1 if j is even
// Check if determinant of B is not too close to zero
if (fabs(det) < breakdown) return det;
// Update det(Slater) if passed
if (determinant) *determinant *= det;
if (fabs(det) < breakdown) return det; // check if determinant of B is too close to zero. If so, exit early.
if ((j & 1) != 0) det = -det; // multiply det with -1 if # of swaps is odd
if (determinant) *determinant *= det; // update det(Slater) if determinant!=NULL
// Compute B^{-1}
#pragma omp target update from(B[0:N_updates*N_updates], ipiv[0:N_updates])
(void) LAPACKE_dgetri(LAPACK_COL_MAJOR, N_updates, B, N_updates, ipiv); // compute B^-1 ON HOST
#pragma omp target update to(B[:N_updates*N_updates]) // Update B^-1 TO DEVICE
#pragma omp target teams distribute parallel for
for (int i = 0; i < N_updates; ++i) {
for (int j = 0; j < N_updates; ++j) {
Binv[i * N_updates + j] = (i == j);
}
}
#pragma omp target data use_device_ptr(B, ipiv, Binv) // correct result Binv, but in CM!
{
(void) cusolverDnDgetrs(s_handle, CUBLAS_OP_N, N_updates, N_updates, B, N_updates, ipiv, Binv, N_updates, info);
}
// T1 = B^{-1} D : KxLDS = KxK X KxLDS : standard dgemm
#pragma omp target data use_device_ptr(D, B, T1) // compute T1 ON DEVICE
#pragma omp target data use_device_ptr(D, Binv, T1) // compute T1 ON DEVICE
{
alpha = 1.0, beta = 0.0;
(void) cublasDgemm(b_handle,
CUBLAS_OP_N, CUBLAS_OP_T, // REMEMBER THIS IS TMP TRANSPOSED BECAUSE OF LAPACKE CALL ON l429 !!!
Lds, N_updates, N_updates,
&alpha, D, Lds, B, N_updates,
&alpha, D, Lds, Binv, N_updates,
&beta, T1, Lds);
}
@ -463,14 +469,16 @@ uint32_t qmckl_woodbury_k_cublas_offload(cublasHandle_t b_handle, cusolverDnHand
Updates_index[0:N_updates], \
Slater_inv[0:Dim*Lds], \
B[0:N_updates*N_updates], \
Binv[0:N_updates*N_updates], \
C[0:Dim*N_updates], \
D[0:N_updates*Lds], \
T1[0:N_updates*Lds], \
T2[0:Dim*Lds], \
ipiv[0:N_updates])
// free(ipiv);
free(ipiv);
free(B);
free(Binv);
free(C);
free(D);
free(T1);