Added cuBLAS offloaded kernel for Woodbury KxK

This commit is contained in:
Francois Coppens 2022-07-21 12:21:51 +02:00
parent f35ad6a777
commit ebe38e79e3
5 changed files with 203 additions and 120 deletions

View File

@ -2,22 +2,26 @@
# CC = gcc
# FFLAGS=-O0 -finline -g -lm -Wall -pedantic
# CFLAGS=-std=c99 -O0 -finline -g -lm -Wall -pedantic
FC = ifort
CC = icc
FC = ifx
CC = icx
# FFLAGS=-O0 -warn all -g -pedantic
# CFLAGS=-std=c99 -O0 -Wall -g -pedantic
FFLAGS=-O3 -warn all -ip -finline -ftz -xCORE-AVX2 -g
CFLAGS=-std=c99 -O3 -Wall -ip -finline -ftz -xCORE-AVX2 -g
INCLUDE=-I/usr/include/hdf5/serial
LFLAGS=-L/usr/lib/x86_64-linux-gnu/hdf5/serial -lhdf5 -lhdf5_hl -qmkl=sequential
FFLAGS=-O3 -warn all -finline -xCORE-AVX2 -g -qopenmp -fopenmp-targets=spir64
CFLAGS=-std=c99 -O3 -Wall -finline -xCORE-AVX2 -g -qopenmp -fopenmp-targets=spir64
INCLUDE=-I/usr/include/hdf5/serial -I/usr/local/cuda/include
LFLAGS=-L/usr/lib/x86_64-linux-gnu/hdf5/serial -lhdf5 -lhdf5_hl -qmkl=sequential -L/usr/local/cuda-11.7/targets/x86_64-linux/lib -lcublas
#FC = verificarlo-f
#CC = verificarlo-c
#FFLAGS=-O3 -finline -g
#CFLAGS=-O3 -finline -g
## Link with icc
test: sm.o test.o detupdate21.o meuk.o
$(CC) $(LFLAGS) -o test sm.o detupdate21.o test.o meuk.o
# test: sm.o test.o detupdate21.o meuk.o
# $(CC) $(LFLAGS) -o test sm.o detupdate21.o test.o meuk.o
test: sm.o test.o meuk.o
$(CC) $(LFLAGS) -o test sm.o test.o meuk.o
## Link with ifort
# test: sm.o test.o detupdate21.o meuk.o

View File

@ -1,6 +1,14 @@
#include <mkl_lapacke.h>
#include <mkl.h>
#define HAVE_CUBLAS_OFFLOAD
#ifdef HAVE_CUBLAS_OFFLOAD
#include <stdio.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#endif
lapack_int inverse(double *A, uint64_t Dim, uint64_t LDS);
int min(int a, int b);
@ -44,6 +52,17 @@ uint32_t qmckl_woodbury_k(const uint64_t vLDS,
double *__restrict __attribute__((aligned(8))) Slater_inv,
double *__restrict determinant);
#ifdef HAVE_CUBLAS_OFFLOAD
uint32_t qmckl_woodbury_k_cublas_offload(const uint64_t vLDS,
const uint64_t vDim,
const uint64_t N_updates,
const double *__restrict __attribute__((aligned(8))) Updates,
const uint64_t *__restrict Updates_index,
const double breakdown,
double *__restrict __attribute__((aligned(8))) Slater_inv,
double *__restrict determinant);
#endif
uint32_t qmckl_woodbury_2(const uint64_t vLDS, const uint64_t vDim,
const double *__restrict __attribute__((aligned(8)))
Updates,

View File

@ -126,21 +126,22 @@ uint32_t test_kernel(char *version, const uint64_t LDS, const uint64_t Dim,
const uint64_t *Updates_index, const double breakdown, const double tolerance,
double *Slater, double *Slater_inv, double *determinant) {
uint32_t rc = 0;
if (version[0] == 'a') { // Anthony
const double *Upds;
const uint64_t *Ui;
for (int i = 0; i < LDS * Dim; i++) Slater_inv[i] *= *determinant;
for (int j = 0; j < N_updates; j++) {
Upds = &Updates[j * LDS];
Ui = &Updates_index[j];
detupd(Dim, LDS, Upds, Ui, Slater_inv, determinant);
if (determinant == 0) printf("TEST_KERNEL: det_update21 failed\n");
}
for (int i = 0; i < LDS * Dim; i++) Slater_inv[i] /= *determinant;
update_slater_matrix(LDS, Dim, N_updates, Updates, Updates_index, Slater);
rc = check_error(LDS, Dim, Slater_inv, Slater, tolerance);
if (rc != 0) printf("TEST_KERNEL: check_error failed\n");
} else if (version[0] == 'n') { // Naive
// if (version[0] == 'a') { // Anthony
// const double *Upds;
// const uint64_t *Ui;
// for (int i = 0; i < LDS * Dim; i++) Slater_inv[i] *= *determinant;
// for (int j = 0; j < N_updates; j++) {
// Upds = &Updates[j * LDS];
// Ui = &Updates_index[j];
// detupd(Dim, LDS, Upds, Ui, Slater_inv, determinant);
// if (determinant == 0) printf("TEST_KERNEL: det_update21 failed\n");
// }
// for (int i = 0; i < LDS * Dim; i++) Slater_inv[i] /= *determinant;
// update_slater_matrix(LDS, Dim, N_updates, Updates, Updates_index, Slater);
// rc = check_error(LDS, Dim, Slater_inv, Slater, tolerance);
// if (rc != 0) printf("TEST_KERNEL: check_error failed\n");
// } else if (version[0] == 'n') { // Naive
if (version[0] == 'n') { // Naive
rc = qmckl_sherman_morrison(LDS, Dim, N_updates, Updates, Updates_index,
breakdown, Slater_inv, determinant);
if (rc != 0) printf("TEST_KERNEL: qmckl_sherman_morrison failed\n");

View File

@ -1,10 +1,7 @@
#include <math.h>
#include <stdint.h>
#include <stdio.h>
#include <string.h>
#include "kernels.h"
extern uint64_t n_splits;
extern uint64_t block_fail;
extern uint64_t recursive_calls;
@ -107,17 +104,6 @@ uint32_t qmckl_woodbury_2(const uint64_t vLDS, const uint64_t vDim,
C[i * 2 + 1] += Slater_inv[i * LDS + k] * Updates[LDS + k];
}
}
// const double alpha = 1.0, beta = 0.0;
// const bool TransA = true, TransB = false;
// (void) cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
// Dim, 2, LDS, alpha, Slater_inv, LDS, Updates, LDS, beta,
// C, 2);
// (void) qmckl_dgemm(context, CblasNoTrans, CblasTrans,
// 2, Dim, LDS, alpha, Updates, LDS, Slater_inv, LDS, beta,
// C, 2);
// (void) qmckl_dgemm(context, TransA, TransB,
// 2, Dim, LDS, alpha, Updates, LDS, Slater_inv, LDS,
// beta, C, 2);
// Compute B = 1 + VC : 2 x 2
const double B0 = C[row1 * 2] + 1;
@ -204,10 +190,6 @@ uint32_t qmckl_woodbury_3(const uint64_t vLDS, const uint64_t vDim,
C[i * 3 + 2] += Slater_inv[i * LDS + k] * Updates[2 * LDS + k];
}
}
// double alpha = 1.0, beta = 0.0;
// cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
// Dim, 3, LDS, alpha, Slater_inv, LDS, Updates, LDS, beta,
// C, 3);
// Compute B = 1 + VC : 3 x 3
const double B0 = C[row1 * 3] + 1;
@ -322,7 +304,7 @@ uint32_t qmckl_woodbury_k(const uint64_t vLDS,
j += min(abs(ipiv[i] - i), 1);
det *= B[(N_updates + 1) * i];
}
if (j & 1 == 0) det = -det; // multiply det with -1 if j is even
if ((j & 1) == 0) det = -det; // multiply det with -1 if j is even
// Check if determinant of B is not too close to zero
if (fabs(det) < breakdown) {
@ -353,6 +335,104 @@ uint32_t qmckl_woodbury_k(const uint64_t vLDS,
return 0;
}
#ifdef HAVE_CUBLAS_OFFLOAD
uint32_t qmckl_woodbury_k_cublas_offload(const uint64_t vLDS,
const uint64_t vDim,
const uint64_t N_updates,
const double *__restrict __attribute__((aligned(8))) Updates,
const uint64_t *__restrict Updates_index,
const double breakdown,
double *__restrict __attribute__((aligned(8))) Slater_inv,
double *__restrict determinant) {
const uint32_t Dim = 21;
const uint32_t LDS = 24;
//cuBLAS initialization
cublasHandle_t handle;
if (cublasCreate(&handle) != CUBLAS_STATUS_SUCCESS) {
fprintf(stdout, "cuBLAS initialization failed!\n");
exit(EXIT_FAILURE);
}
// Compute C = S^{-1} U : Dim x K : standard dgemm
double C[Dim * N_updates];
double alpha = 1.0, beta = 0.0;
// #pragma omp target enter data map(to:een_rescaled_e[0:elec_num*elec_num*(cord_num+1)*walk_num], een_rescaled_n[0:M*N*walk_num], tmp_c[0:elec_num*nucl_num*(cord_num+1)*cord_num*walk_num])
// #pragma omp target data use_device_ptr(een_rescaled_e,een_rescaled_n,tmp_c)
// {
// for (int nw=0; nw < walk_num; ++nw) {
// int cublasError = cublasDgemmStridedBatched(handle, CUBLAS_OP_N, CUBLAS_OP_N, M, N, K, &alpha,
// &(een_rescaled_e[nw*(cord_num+1)]),
// LDA, af,
// &(een_rescaled_n[bf*nw]),
// LDB, 0,
// &beta,
// &(tmp_c[nw*cord_num]),
// LDC, cf, cord_num);
// }
// }
// #pragma omp target exit data map(from:tmp_c[0:elec_num*nucl_num*(cord_num+1)*cord_num*walk_num])
cublasDestroy(handle);
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
Dim, N_updates, LDS,
alpha, Slater_inv, LDS, Updates, LDS,
beta, C, N_updates);
// Construct B = 1 + V C : K x K : selecting and copying row from C into B. Can maybe be off-loaded to GPU by splitting in N_updates tiles of N_updates strides, using PARALLEL and SIMD
// Construct D = V S^{-1} : K x LDS
double B[N_updates * N_updates], D[N_updates * LDS];
for (uint32_t i = 0; i < N_updates; i++) {
const uint32_t row = Updates_index[i] - 1;
for (uint32_t j = 0; j < N_updates ; j++) B[i * N_updates + j] = C[row * N_updates + j] + (i == j);
for (uint32_t j = 0; j < LDS; j++) D[i * LDS + j] = Slater_inv[row * LDS + j];
}
// Compute determinant by LU decomposition
int ipiv[N_updates];
lapack_int ret;
ret = LAPACKE_dgetrf(LAPACK_ROW_MAJOR, N_updates, N_updates, B, N_updates, ipiv);
if (ret != 0) return ret;
double det = 1.0;
int j = 0;
for (uint32_t i = 0; i < N_updates; i++) {
j += min(abs(ipiv[i] - i), 1);
det *= B[(N_updates + 1) * i];
}
if ((j & 1) == 0) det = -det; // multiply det with -1 if j is even
// Check if determinant of B is not too close to zero
if (fabs(det) < breakdown) {
return 1;
}
// Update det(Slater) if passed
if (determinant) *determinant *= det;
// Compute B^{-1} with explicit formula for K x K inversion
ret = LAPACKE_dgetri(LAPACK_ROW_MAJOR, N_updates, B, N_updates, ipiv);
if (ret != 0) return ret;
// tmp = B^{-1} D : KxLDS = KxK X KxLDS : standard dgemm
double tmp[N_updates * LDS];
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
N_updates, LDS, N_updates,
alpha, B, N_updates, D, LDS,
beta, tmp, LDS);
// Compute S^{-1} - C * tmp : Dim x LDS : standard dgemm
alpha = -1.0, beta = 1.0;
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
Dim, LDS, N_updates,
alpha, C, N_updates, tmp, LDS,
beta, Slater_inv, LDS);
return 0;
}
#endif
uint32_t qmckl_slagel_splitting(
const uint64_t vLDS, const uint64_t vDim, uint64_t N_updates,
@ -442,19 +522,26 @@ uint32_t qmckl_sherman_morrison_splitting(
double __attribute__((aligned(8))) later_updates[LDS * N_updates];
uint64_t later_index[N_updates];
uint64_t later = 0;
uint32_t rc;
// uint32_t rc;
rc = qmckl_slagel_splitting(LDS, Dim, N_updates, Updates, Updates_index,
(void) qmckl_slagel_splitting(LDS, Dim, N_updates, Updates, Updates_index,
breakdown, Slater_inv, later_updates, later_index,
&later, determinant);
// rc = qmckl_slagel_splitting(LDS, Dim, N_updates, Updates, Updates_index,
// breakdown, Slater_inv, later_updates, later_index,
// &later, determinant);
// if (rc != 0) printf("Something when catastrophically wrong in QMCKL_SLAGEL_SPLITTING\n");
if (later > 0) {
recursive_calls++;
// printf("Later > 0\n");
rc = qmckl_sherman_morrison_splitting(LDS, Dim, later, later_updates,
(void) qmckl_sherman_morrison_splitting(LDS, Dim, later, later_updates,
later_index, breakdown, Slater_inv,
determinant);
// rc = qmckl_sherman_morrison_splitting(LDS, Dim, later, later_updates,
// later_index, breakdown, Slater_inv,
// determinant);
// if (rc != 0) printf("Something when catastrophically wrong in QMCKL_SHERMAN_MORRISON_SPLITTING\n");
}
@ -508,49 +595,6 @@ uint32_t qmckl_sherman_morrison_smw32s(
return 0;
}
// if (N_updates == 6) { // Special case for 6 rank-1 updates: 2+2+2
// rc = qmckl_woodbury_2(LDS, Dim, Updates, Updates_index,
// breakdown, Slater_inv, determinant);
// if (rc != 0) { // Send the entire block to slagel_splitting
// block_fail += 1;
// uint64_t l = 0;
// rc = qmckl_slagel_splitting(LDS, Dim, 2, Updates,
// Updates_index, breakdown, Slater_inv,
// later_updates + (LDS * later),
// later_index + later, &l, determinant);
// later += l;
// }
// rc = qmckl_woodbury_2(LDS, Dim, &Updates[2*LDS], &Updates_index[2],
// breakdown, Slater_inv, determinant);
// if (rc != 0) { // Send the entire block to slagel_splitting
// block_fail += 1;
// uint64_t l = 0;
// rc = qmckl_slagel_splitting(LDS, Dim, 2, &Updates[2*LDS],
// &Updates_index[2], breakdown, Slater_inv,
// later_updates + (LDS * later),
// later_index + later, &l, determinant);
// later += l;
// }
// rc = qmckl_woodbury_2(LDS, Dim, &Updates[4*LDS], &Updates_index[4],
// breakdown, Slater_inv, determinant);
// if (rc != 0) { // Send the entire block to slagel_splitting
// block_fail += 1;
// uint64_t l = 0;
// rc = qmckl_slagel_splitting(LDS, Dim, 2, &Updates[4*LDS],
// &Updates_index[4], breakdown, Slater_inv,
// later_updates + (LDS * later),
// later_index + later, &l, determinant);
// later += l;
// }
// if (later > 0) {
// recursive_calls++;
// rc = qmckl_sherman_morrison_splitting(LDS, Dim, later, later_updates,
// later_index, breakdown, Slater_inv,
// determinant);
// }
// return 0;
// }
// And for the other cases != 4, 6
// Apply first 3*n_of_3blocks updates in n_of_3blocks blocks of 3 updates with
// Woodbury 3x3 kernel

View File

@ -103,47 +103,47 @@ printf("#-----------------------------------------------------------------------
determinant_copy = determinant;
// ### CHOOSE A KERNEL:
if (version[0] == 'a') { // Anthony
const double *Upds;
const uint64_t *Ui;
double determinant_previous;
// if (version[0] == 'a') { // Anthony
// const double *Upds;
// const uint64_t *Ui;
// double determinant_previous;
err_break = 0;
// err_break = 0;
for (int i = 0; i < LDS * Dim; i++) Slater_invT_copy[i] *= determinant_copy; // Multiply inv(Slater-mat) by det(Slater-mat) to get adj(Slater_mat)
// for (int i = 0; i < LDS * Dim; i++) Slater_invT_copy[i] *= determinant_copy; // Multiply inv(Slater-mat) by det(Slater-mat) to get adj(Slater_mat)
for (int i = 0; i < N_updates; i++) {
Upds = &Updates[i * LDS];
Ui = &Updates_index[i];
determinant_previous = determinant_copy;
// for (int i = 0; i < N_updates; i++) {
// Upds = &Updates[i * LDS];
// Ui = &Updates_index[i];
// determinant_previous = determinant_copy;
// 1. FETCH START TIME
uint64_t before = rdtsc();
// // 1. FETCH START TIME
// uint64_t before = rdtsc();
// 2. EXECUTE KERNEL AND REMEMBER EXIT STATUS
detupd(Dim, LDS, Upds, Ui, Slater_invT_copy, &determinant_copy);
// // 2. EXECUTE KERNEL AND REMEMBER EXIT STATUS
// detupd(Dim, LDS, Upds, Ui, Slater_invT_copy, &determinant_copy);
// 3. FETCH FINISH TIME
uint64_t after = rdtsc();
// // 3. FETCH FINISH TIME
// uint64_t after = rdtsc();
// 4. ADD TIME DIFFERENCE TO TIME CUMMULATOR
accumulator += (double)(after - before);
// // 4. ADD TIME DIFFERENCE TO TIME CUMMULATOR
// accumulator += (double)(after - before);
// 5. STOP APPLYING UPDATES IF BREAKDOWN DETECTED
double lambda = determinant_copy / determinant_previous; // should be id. to lambda in detupd
if (fabs(lambda) < breakdown) {
err_break = 1;
break;
}
}
if (err_break == 1) { // Divide adj(Slater-mat) by OLD det(Slater-mat) to get inv(Slater_mat) again
for (int i = 0; i < LDS * Dim; i++) Slater_invT_copy[i] /= determinant_previous;
} else { // Divide adj(Slater-mat) by NEW det(Slater-mat) to get inv(Slater_mat) again
for (int i = 0; i < LDS * Dim; i++) Slater_invT_copy[i] /= determinant_copy;
}
} else if (version[0] == 'n') { // Naive
// // 5. STOP APPLYING UPDATES IF BREAKDOWN DETECTED
// double lambda = determinant_copy / determinant_previous; // should be id. to lambda in detupd
// if (fabs(lambda) < breakdown) {
// err_break = 1;
// break;
// }
// }
// if (err_break == 1) { // Divide adj(Slater-mat) by OLD det(Slater-mat) to get inv(Slater_mat) again
// for (int i = 0; i < LDS * Dim; i++) Slater_invT_copy[i] /= determinant_previous;
// } else { // Divide adj(Slater-mat) by NEW det(Slater-mat) to get inv(Slater_mat) again
// for (int i = 0; i < LDS * Dim; i++) Slater_invT_copy[i] /= determinant_copy;
// }
// } else if (version[0] == 'n') { // Naive
if (version[0] == 'n') { // Naive
// 1. FETCH START TIME
uint64_t before = rdtsc();
@ -215,6 +215,21 @@ printf("#-----------------------------------------------------------------------
// 4. ADD TIME DIFFERENCE TO TIME CUMMULATOR
accumulator += (double)(after - before);
} else if (version[0] == 'c') { // Woodbury K cuBLAS
// 1. FETCH START TIME
uint64_t before = rdtsc();
// 2. EXECUTE KERNEL AND REMEMBER EXIT STATUS
err_break = qmckl_woodbury_k_cublas_offload(LDS, Dim, N_updates, Updates,
Updates_index, breakdown, Slater_invT_copy, &determinant);
// 3. FETCH FINISH TIME
uint64_t after = rdtsc();
// 4. ADD TIME DIFFERENCE TO TIME CUMMULATOR
accumulator += (double)(after - before);
} else if (version[0] == 's') { // Splitting
// 1. FETCH START TIME