1
0
mirror of https://github.com/TREX-CoE/qmckl.git synced 2024-12-22 12:23:56 +01:00

Add cublas batched

This commit is contained in:
hoffer 2022-04-07 18:44:59 +02:00
parent cba6477e4a
commit 69b9e0fb89

View File

@ -110,7 +110,7 @@ int main() {
#include <cuda_runtime_api.h>
#include "cublas_v2.h"
#include <cublas_v2.h>
@ -5032,10 +5032,10 @@ qmckl_exit_code qmckl_provide_tmp_c(qmckl_context context)
#else
const bool gpu_offload = false;
#endif
if (gpu_offload) {
if (gpu_offload) {
#ifdef HAVE_CUBLAS_OFFLOAD
rc = qmckl_compute_tmp_c_cublas_offload(context,
rc = qmckl_compute_tmp_c_cuBlas(context,
ctx->jastrow.cord_num,
ctx->electron.num,
ctx->nucleus.num,
@ -5074,7 +5074,7 @@ qmckl_exit_code qmckl_provide_tmp_c(qmckl_context context)
ctx->jastrow.een_rescaled_n,
ctx->jastrow.tmp_c);
}
ctx->jastrow.tmp_c_date = ctx->date;
}
@ -5121,10 +5121,10 @@ qmckl_exit_code qmckl_provide_dtmp_c(qmckl_context context)
#else
const bool gpu_offload = false;
#endif
if (gpu_offload) {
if (gpu_offload) {
#ifdef HAVE_CUBLAS_OFFLOAD
rc = qmckl_compute_dtmp_c_cublas_offload(context,
rc = qmckl_compute_dtmp_c_cuBlas(context,
ctx->jastrow.cord_num,
ctx->electron.num,
ctx->nucleus.num,
@ -5829,6 +5829,93 @@ qmckl_exit_code qmckl_compute_tmp_c_cuBlas (
const double* een_rescaled_n,
double* const tmp_c ) {
qmckl_exit_code info;
//Initialisation of cublas
cublasHandle_t handle;
if (cublasCreate(&handle) != CUBLAS_STATUS_SUCCESS)
{
fprintf(stdout, "CUBLAS initialization failed!\n");
exit(EXIT_FAILURE);
}
if (context == QMCKL_NULL_CONTEXT) {
return QMCKL_INVALID_CONTEXT;
}
if (cord_num <= 0) {
return QMCKL_INVALID_ARG_2;
}
if (elec_num <= 0) {
return QMCKL_INVALID_ARG_3;
}
if (nucl_num <= 0) {
return QMCKL_INVALID_ARG_4;
}
const double alpha = 1.0;
const double beta = 0.0;
const int64_t M = elec_num;
const int64_t N = nucl_num*(cord_num + 1);
const int64_t K = elec_num;
const int64_t LDA = elec_num;
const int64_t LDB = elec_num;
const int64_t LDC = elec_num;
const int64_t af = elec_num*elec_num;
const int64_t bf = elec_num*nucl_num*(cord_num+1);
const int64_t cf = bf;
#pragma omp target enter data map(to:een_rescaled_e[0:elec_num*elec_num*(cord_num+1)*walk_num],een_rescaled_n[0:M*N*walk_num],tmp_c[0:elec_num*nucl_num*(cord_num+1)*cord_num*walk_num])
#pragma omp target data use_device_ptr(een_rescaled_e,een_rescaled_n,tmp_c)
{
for (int nw=0; nw < walk_num; ++nw) {
int cublasError = cublasDgemmStridedBatched(handle, CUBLAS_OP_N, CUBLAS_OP_N, M, N, K, &alpha,
&(een_rescaled_e[nw*(cord_num+1)]), \
LDA, af, \
&(een_rescaled_n[bf*nw]), \
LDB, 0, \
&beta, \
&(tmp_c[nw*cord_num]), \
LDC, cf, cord_num);
//Manage cublas ERROR
if(cublasError != CUBLAS_STATUS_SUCCESS){
printf("CUBLAS ERROR %d", cublasError);
info = QMCKL_FAILURE;
}else{
info = QMCKL_SUCCESS;
}
}
}
cublasDestroy(handle);
#pragma omp target exit data map(from:tmp_c[0:elec_num*nucl_num*(cord_num+1)*cord_num*walk_num])
return info;
}
#+end_src
#+begin_src c :comments org :tangle (eval c) :noweb yes
qmckl_exit_code qmckl_compute_tmp_c_cuBlas_batched (
const qmckl_context context,
const int64_t cord_num,
const int64_t elec_num,
const int64_t nucl_num,
const int64_t walk_num,
const double* een_rescaled_e,
const double* een_rescaled_n,
double* const tmp_c ) {
qmckl_exit_code info;
//Initialisation of cublas
@ -6708,40 +6795,47 @@ qmckl_exit_code qmckl_compute_dtmp_c_omp_offload (
#+begin_src c :comments org :tangle (eval c) :noweb yes
#ifdef HAVE_CUBLAS_OFFLOAD
qmckl_exit_code qmckl_compute_dtmp_c_cublas_offload (
const qmckl_context context,
const int64_t cord_num,
const int64_t elec_num,
const int64_t nucl_num,
const int64_t walk_num,
const double* een_rescaled_e_deriv_e,
const double* een_rescaled_n,
double* const dtmp_c ) {
qmckl_exit_code qmckl_compute_dtmp_c_cuBlas (const qmckl_context context,
const int64_t cord_num,
const int64_t elec_num,
const int64_t nucl_num,
const int64_t walk_num,
const double* een_rescaled_e_deriv_e,
const double* een_rescaled_n,
double* const dtmp_c )
{
cublasHandle_t handle;
if (cublasCreate(&handle) != CUBLAS_STATUS_SUCCESS)
{
fprintf(stdout, "CUBLAS initialization failed!\n");
exit(EXIT_FAILURE);
}
if (context == QMCKL_NULL_CONTEXT) {
return QMCKL_INVALID_CONTEXT;
return QMCKL_INVALID_CONTEXT;
}
if (cord_num <= 0) {
return QMCKL_INVALID_ARG_2;
return QMCKL_INVALID_ARG_2;
}
if (elec_num <= 0) {
return QMCKL_INVALID_ARG_3;
return QMCKL_INVALID_ARG_3;
}
if (nucl_num <= 0) {
return QMCKL_INVALID_ARG_4;
return QMCKL_INVALID_ARG_4;
}
if (walk_num <= 0) {
return QMCKL_INVALID_ARG_5;
return QMCKL_INVALID_ARG_5;
}
qmckl_exit_code info = QMCKL_SUCCESS;
const char TransA = 'N';
const char TransB = 'N';
const double alpha = 1.0;
const double beta = 0.0;
@ -6757,19 +6851,48 @@ qmckl_exit_code qmckl_compute_dtmp_c_cublas_offload (
const int64_t bf = elec_num*nucl_num*(cord_num+1);
const int64_t cf = elec_num*4*nucl_num*(cord_num+1);
// TODO Replace with calls to cuBLAS
for (int64_t nw=0; nw < walk_num; ++nw) {
for (int64_t i=0; i < cord_num; ++i) {
info = qmckl_dgemm(context, TransA, TransB, M, N, K, alpha, \
&(een_rescaled_e_deriv_e[af*(i+nw*(cord_num+1))]), \
LDA, \
&(een_rescaled_n[bf*nw]), \
LDB, \
beta, \
&(dtmp_c[cf*(i+nw*cord_num)]), \
LDC);
#pragma omp target enter data map(to:een_rescaled_e_deriv_e[0:elec_num*4*elec_num*(cord_num+1)*walk_num], een_rescaled_n[0:elec_num*nucl_num*(cord_num+1)*walk_num], dtmp_c[0:elec_num*4*nucl_num*(cord_num+1)*cord_num*walk_num])
#pragma omp target data use_device_ptr(een_rescaled_e_deriv_e, een_rescaled_n, dtmp_c)
{
for (int64_t nw=0; nw < walk_num; ++nw) {
/*
for (int64_t i=0; i < cord_num; ++i) {
int cublasError = cublasDgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, M, N, K, &alpha, \
&(een_rescaled_e_deriv_e[af*(i+nw*(cord_num+1))]), \
LDA, \
&(een_rescaled_n[bf*nw]), \
LDB, \
&beta, \
&(dtmp_c[cf*(i+nw*cord_num)]), \
LDC);
,*/
//Manage CUBLAS ERRORS
int cublasError = cublasDgemmStridedBatched(handle, CUBLAS_OP_N, CUBLAS_OP_N, M, N, K, &alpha, \
&(een_rescaled_e_deriv_e[(nw*(cord_num+1))]), \
LDA, af, \
&(een_rescaled_n[bf*nw]), \
LDB, 0, \
&beta, \
&(dtmp_c[(nw*cord_num)]), \
LDC, cf, cord_num);
if(cublasError != CUBLAS_STATUS_SUCCESS){
printf("CUBLAS ERROR %d", cublasError);
info = QMCKL_FAILURE;
return info;
}else{
info = QMCKL_SUCCESS;
}
//}
}
}
cudaDeviceSynchronize();
cublasDestroy(handle);
#pragma omp target exit data map(from:dtmp_c[0:cf*cord_num*walk_num])
return info;
}
@ -6779,7 +6902,7 @@ qmckl_exit_code qmckl_compute_dtmp_c_cublas_offload (
#+RESULTS:
#+begin_src c :comments org :tangle (eval h_private_func) :noweb yes :exports none
#ifdef HAVE_CUBLAS_OFFLOAD
qmckl_exit_code qmckl_compute_dtmp_c_cublas_offload (
qmckl_exit_code qmckl_compute_dtmp_c_cuBlas (
const qmckl_context context,
const int64_t cord_num,
const int64_t elec_num,