From 69b9e0fb894b861b4cf9ae79013bac705f8bba65 Mon Sep 17 00:00:00 2001 From: hoffer Date: Thu, 7 Apr 2022 18:44:59 +0200 Subject: [PATCH] Add cublas batched --- org/qmckl_jastrow.org | 195 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 159 insertions(+), 36 deletions(-) diff --git a/org/qmckl_jastrow.org b/org/qmckl_jastrow.org index 269d3fd..e13498e 100644 --- a/org/qmckl_jastrow.org +++ b/org/qmckl_jastrow.org @@ -110,7 +110,7 @@ int main() { #include -#include "cublas_v2.h" +#include @@ -5032,10 +5032,10 @@ qmckl_exit_code qmckl_provide_tmp_c(qmckl_context context) #else const bool gpu_offload = false; #endif - - if (gpu_offload) { + + if (gpu_offload) { #ifdef HAVE_CUBLAS_OFFLOAD - rc = qmckl_compute_tmp_c_cublas_offload(context, + rc = qmckl_compute_tmp_c_cuBlas(context, ctx->jastrow.cord_num, ctx->electron.num, ctx->nucleus.num, @@ -5074,7 +5074,7 @@ qmckl_exit_code qmckl_provide_tmp_c(qmckl_context context) ctx->jastrow.een_rescaled_n, ctx->jastrow.tmp_c); } - + ctx->jastrow.tmp_c_date = ctx->date; } @@ -5121,10 +5121,10 @@ qmckl_exit_code qmckl_provide_dtmp_c(qmckl_context context) #else const bool gpu_offload = false; #endif - - if (gpu_offload) { + + if (gpu_offload) { #ifdef HAVE_CUBLAS_OFFLOAD - rc = qmckl_compute_dtmp_c_cublas_offload(context, + rc = qmckl_compute_dtmp_c_cuBlas(context, ctx->jastrow.cord_num, ctx->electron.num, ctx->nucleus.num, @@ -5829,6 +5829,93 @@ qmckl_exit_code qmckl_compute_tmp_c_cuBlas ( const double* een_rescaled_n, double* const tmp_c ) { + qmckl_exit_code info; + + //Initialisation of cublas + + cublasHandle_t handle; + if (cublasCreate(&handle) != CUBLAS_STATUS_SUCCESS) + { + fprintf(stdout, "CUBLAS initialization failed!\n"); + exit(EXIT_FAILURE); + } + + + + if (context == QMCKL_NULL_CONTEXT) { + return QMCKL_INVALID_CONTEXT; + } + + if (cord_num <= 0) { + return QMCKL_INVALID_ARG_2; + } + + if (elec_num <= 0) { + return QMCKL_INVALID_ARG_3; + } + + if (nucl_num <= 0) { + return QMCKL_INVALID_ARG_4; + } + + const double alpha = 1.0; + const double beta = 0.0; + + const int64_t M = elec_num; + const int64_t N = nucl_num*(cord_num + 1); + const int64_t K = elec_num; + + const int64_t LDA = elec_num; + const int64_t LDB = elec_num; + const int64_t LDC = elec_num; + + const int64_t af = elec_num*elec_num; + const int64_t bf = elec_num*nucl_num*(cord_num+1); + const int64_t cf = bf; + + #pragma omp target enter data map(to:een_rescaled_e[0:elec_num*elec_num*(cord_num+1)*walk_num],een_rescaled_n[0:M*N*walk_num],tmp_c[0:elec_num*nucl_num*(cord_num+1)*cord_num*walk_num]) + #pragma omp target data use_device_ptr(een_rescaled_e,een_rescaled_n,tmp_c) + { + for (int nw=0; nw < walk_num; ++nw) { + + int cublasError = cublasDgemmStridedBatched(handle, CUBLAS_OP_N, CUBLAS_OP_N, M, N, K, &alpha, + &(een_rescaled_e[nw*(cord_num+1)]), \ + LDA, af, \ + &(een_rescaled_n[bf*nw]), \ + LDB, 0, \ + &beta, \ + &(tmp_c[nw*cord_num]), \ + LDC, cf, cord_num); + + //Manage cublas ERROR + if(cublasError != CUBLAS_STATUS_SUCCESS){ + printf("CUBLAS ERROR %d", cublasError); + info = QMCKL_FAILURE; + }else{ + info = QMCKL_SUCCESS; + } + } + } + cublasDestroy(handle); + #pragma omp target exit data map(from:tmp_c[0:elec_num*nucl_num*(cord_num+1)*cord_num*walk_num]) + + return info; + } +#+end_src + + + +#+begin_src c :comments org :tangle (eval c) :noweb yes +qmckl_exit_code qmckl_compute_tmp_c_cuBlas_batched ( + const qmckl_context context, + const int64_t cord_num, + const int64_t elec_num, + const int64_t nucl_num, + const int64_t walk_num, + const double* een_rescaled_e, + const double* een_rescaled_n, + double* const tmp_c ) { + qmckl_exit_code info; //Initialisation of cublas @@ -6708,40 +6795,47 @@ qmckl_exit_code qmckl_compute_dtmp_c_omp_offload ( #+begin_src c :comments org :tangle (eval c) :noweb yes #ifdef HAVE_CUBLAS_OFFLOAD -qmckl_exit_code qmckl_compute_dtmp_c_cublas_offload ( - const qmckl_context context, - const int64_t cord_num, - const int64_t elec_num, - const int64_t nucl_num, - const int64_t walk_num, - const double* een_rescaled_e_deriv_e, - const double* een_rescaled_n, - double* const dtmp_c ) { +qmckl_exit_code qmckl_compute_dtmp_c_cuBlas (const qmckl_context context, + const int64_t cord_num, + const int64_t elec_num, + const int64_t nucl_num, + const int64_t walk_num, + const double* een_rescaled_e_deriv_e, + const double* een_rescaled_n, + double* const dtmp_c ) +{ + + cublasHandle_t handle; + if (cublasCreate(&handle) != CUBLAS_STATUS_SUCCESS) + { + fprintf(stdout, "CUBLAS initialization failed!\n"); + exit(EXIT_FAILURE); + } + + if (context == QMCKL_NULL_CONTEXT) { - return QMCKL_INVALID_CONTEXT; + return QMCKL_INVALID_CONTEXT; } if (cord_num <= 0) { - return QMCKL_INVALID_ARG_2; + return QMCKL_INVALID_ARG_2; } if (elec_num <= 0) { - return QMCKL_INVALID_ARG_3; + return QMCKL_INVALID_ARG_3; } if (nucl_num <= 0) { - return QMCKL_INVALID_ARG_4; + return QMCKL_INVALID_ARG_4; } if (walk_num <= 0) { - return QMCKL_INVALID_ARG_5; + return QMCKL_INVALID_ARG_5; } qmckl_exit_code info = QMCKL_SUCCESS; - const char TransA = 'N'; - const char TransB = 'N'; const double alpha = 1.0; const double beta = 0.0; @@ -6757,19 +6851,48 @@ qmckl_exit_code qmckl_compute_dtmp_c_cublas_offload ( const int64_t bf = elec_num*nucl_num*(cord_num+1); const int64_t cf = elec_num*4*nucl_num*(cord_num+1); - // TODO Replace with calls to cuBLAS - for (int64_t nw=0; nw < walk_num; ++nw) { - for (int64_t i=0; i < cord_num; ++i) { - info = qmckl_dgemm(context, TransA, TransB, M, N, K, alpha, \ - &(een_rescaled_e_deriv_e[af*(i+nw*(cord_num+1))]), \ - LDA, \ - &(een_rescaled_n[bf*nw]), \ - LDB, \ - beta, \ - &(dtmp_c[cf*(i+nw*cord_num)]), \ - LDC); +#pragma omp target enter data map(to:een_rescaled_e_deriv_e[0:elec_num*4*elec_num*(cord_num+1)*walk_num], een_rescaled_n[0:elec_num*nucl_num*(cord_num+1)*walk_num], dtmp_c[0:elec_num*4*nucl_num*(cord_num+1)*cord_num*walk_num]) +#pragma omp target data use_device_ptr(een_rescaled_e_deriv_e, een_rescaled_n, dtmp_c) + { + for (int64_t nw=0; nw < walk_num; ++nw) { + /* + for (int64_t i=0; i < cord_num; ++i) { + int cublasError = cublasDgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, M, N, K, &alpha, \ + &(een_rescaled_e_deriv_e[af*(i+nw*(cord_num+1))]), \ + LDA, \ + &(een_rescaled_n[bf*nw]), \ + LDB, \ + &beta, \ + &(dtmp_c[cf*(i+nw*cord_num)]), \ + LDC); + ,*/ + //Manage CUBLAS ERRORS + + int cublasError = cublasDgemmStridedBatched(handle, CUBLAS_OP_N, CUBLAS_OP_N, M, N, K, &alpha, \ + &(een_rescaled_e_deriv_e[(nw*(cord_num+1))]), \ + LDA, af, \ + &(een_rescaled_n[bf*nw]), \ + LDB, 0, \ + &beta, \ + &(dtmp_c[(nw*cord_num)]), \ + LDC, cf, cord_num); + + + if(cublasError != CUBLAS_STATUS_SUCCESS){ + printf("CUBLAS ERROR %d", cublasError); + info = QMCKL_FAILURE; + return info; + }else{ + info = QMCKL_SUCCESS; + } + + //} } } + cudaDeviceSynchronize(); + cublasDestroy(handle); + +#pragma omp target exit data map(from:dtmp_c[0:cf*cord_num*walk_num]) return info; } @@ -6779,7 +6902,7 @@ qmckl_exit_code qmckl_compute_dtmp_c_cublas_offload ( #+RESULTS: #+begin_src c :comments org :tangle (eval h_private_func) :noweb yes :exports none #ifdef HAVE_CUBLAS_OFFLOAD - qmckl_exit_code qmckl_compute_dtmp_c_cublas_offload ( + qmckl_exit_code qmckl_compute_dtmp_c_cuBlas ( const qmckl_context context, const int64_t cord_num, const int64_t elec_num,