1
0
mirror of https://github.com/TREX-CoE/qmckl.git synced 2024-09-27 03:51:09 +02:00

Merge branch 'gpu' of github.com:TREX-CoE/qmckl into gpu

This commit is contained in:
Anthony Scemama 2022-04-07 13:35:08 +02:00
commit a7fac59f04

View File

@ -108,6 +108,7 @@ int main() {
#include <assert.h> #include <assert.h>
#include <math.h> #include <math.h>
#include <stdio.h> #include <stdio.h>
#include "qmckl.h" #include "qmckl.h"
@ -116,6 +117,13 @@ int main() {
#include "qmckl_memory_private_func.h" #include "qmckl_memory_private_func.h"
#include "qmckl_jastrow_private_func.h" #include "qmckl_jastrow_private_func.h"
#include "qmckl_jastrow_private_type.h" #include "qmckl_jastrow_private_type.h"
#ifdef HAVE_CUBLAS_OFFLOAD
#include <cuda_runtime_api.h>
#include "cublas_v2.h"
#endif
#+end_src #+end_src
* Context * Context
@ -5019,6 +5027,7 @@ qmckl_exit_code qmckl_provide_tmp_c(qmckl_context context)
ctx->jastrow.tmp_c = tmp_c; ctx->jastrow.tmp_c = tmp_c;
} }
/* Choose the correct compute function (depending on offload type) */ /* Choose the correct compute function (depending on offload type) */
#ifdef HAVE_HPC #ifdef HAVE_HPC
const bool gpu_offload = ctx->jastrow.gpu_offload; const bool gpu_offload = ctx->jastrow.gpu_offload;
@ -5068,6 +5077,7 @@ qmckl_exit_code qmckl_provide_tmp_c(qmckl_context context)
ctx->jastrow.tmp_c); ctx->jastrow.tmp_c);
} }
ctx->jastrow.tmp_c_date = ctx->date; ctx->jastrow.tmp_c_date = ctx->date;
} }
@ -5107,6 +5117,7 @@ qmckl_exit_code qmckl_provide_dtmp_c(qmckl_context context)
ctx->jastrow.dtmp_c = dtmp_c; ctx->jastrow.dtmp_c = dtmp_c;
} }
#ifdef HAVE_HPC #ifdef HAVE_HPC
const bool gpu_offload = ctx->jastrow.gpu_offload; const bool gpu_offload = ctx->jastrow.gpu_offload;
#else #else
@ -5159,6 +5170,7 @@ qmckl_exit_code qmckl_provide_dtmp_c(qmckl_context context)
return rc; return rc;
} }
ctx->jastrow.dtmp_c_date = ctx->date; ctx->jastrow.dtmp_c_date = ctx->date;
} }
@ -5807,6 +5819,42 @@ qmckl_exit_code qmckl_compute_tmp_c_hpc (
} }
#+end_src #+end_src
#+CALL: generate_c_header(table=qmckl_factor_tmp_c_args,rettyp=get_value("CRetType"),fname="qmckl_compute_tmp_c")
#+RESULTS:
#+begin_src c :tangle (eval h_func) :comments org
qmckl_exit_code qmckl_compute_tmp_c (
const qmckl_context context,
const int64_t cord_num,
const int64_t elec_num,
const int64_t nucl_num,
const int64_t walk_num,
const double* een_rescaled_e,
const double* een_rescaled_n,
double* const tmp_c );
#+end_src
# #+CALL: generate_c_header(table=qmckl_factor_tmp_c_args,rettyp=get_value("CRetType"),fname="qmckl_compute_tmp_c_doc")
#+RESULTS:
#+begin_src c :tangle (eval h_private_func) :comments org
qmckl_exit_code qmckl_compute_tmp_c_doc (
const qmckl_context context,
const int64_t cord_num,
const int64_t elec_num,
const int64_t nucl_num,
const int64_t walk_num,
const double* een_rescaled_e,
const double* een_rescaled_n,
double* const tmp_c );
#+end_src
# #+CALL: generate_c_header(table=qmckl_factor_tmp_c_args,rettyp=get_value("CRetType"),fname="qmckl_compute_tmp_c_hpc")
#+RESULTS:
#+begin_src c :tangle (eval h_private_func) :comments org #+begin_src c :tangle (eval h_private_func) :comments org
qmckl_exit_code qmckl_compute_tmp_c_hpc (const qmckl_context context, qmckl_exit_code qmckl_compute_tmp_c_hpc (const qmckl_context context,
const int64_t cord_num, const int64_t cord_num,
@ -6014,7 +6062,6 @@ qmckl_compute_tmp_c_omp_offload (const qmckl_context context,
#+begin_src c :comments org :tangle (eval c) :noweb yes #+begin_src c :comments org :tangle (eval c) :noweb yes
#ifdef HAVE_CUBLAS_OFFLOAD #ifdef HAVE_CUBLAS_OFFLOAD
qmckl_exit_code
qmckl_compute_tmp_c_cublas_offload (const qmckl_context context, qmckl_compute_tmp_c_cublas_offload (const qmckl_context context,
const int64_t cord_num, const int64_t cord_num,
const int64_t elec_num, const int64_t elec_num,
@ -6025,6 +6072,19 @@ qmckl_compute_tmp_c_cublas_offload (const qmckl_context context,
double* const tmp_c ) double* const tmp_c )
{ {
qmckl_exit_code info;
//Initialisation of cublas
cublasHandle_t handle;
if (cublasCreate(&handle) != CUBLAS_STATUS_SUCCESS)
{
fprintf(stdout, "CUBLAS initialization failed!\n");
exit(EXIT_FAILURE);
}
if (context == QMCKL_NULL_CONTEXT) { if (context == QMCKL_NULL_CONTEXT) {
return QMCKL_INVALID_CONTEXT; return QMCKL_INVALID_CONTEXT;
} }
@ -6041,14 +6101,6 @@ qmckl_compute_tmp_c_cublas_offload (const qmckl_context context,
return QMCKL_INVALID_ARG_4; return QMCKL_INVALID_ARG_4;
} }
if (walk_num <= 0) {
return QMCKL_INVALID_ARG_5;
}
qmckl_exit_code info = QMCKL_SUCCESS;
const char TransA = 'N';
const char TransB = 'N';
const double alpha = 1.0; const double alpha = 1.0;
const double beta = 0.0; const double beta = 0.0;
@ -6064,20 +6116,50 @@ qmckl_compute_tmp_c_cublas_offload (const qmckl_context context,
const int64_t bf = elec_num*nucl_num*(cord_num+1); const int64_t bf = elec_num*nucl_num*(cord_num+1);
const int64_t cf = bf; const int64_t cf = bf;
for (int64_t nw=0; nw < walk_num; ++nw) { #pragma omp target enter data map(to:een_rescaled_e[0:elec_num*elec_num*(cord_num+1)*walk_num],een_rescaled_n[0:M*N*walk_num],tmp_c[0:elec_num*nucl_num*(cord_num+1)*cord_num*walk_num])
for (int64_t i=0; i<cord_num; ++i){ #pragma omp target data use_device_ptr(een_rescaled_e,een_rescaled_n,tmp_c)
info = qmckl_dgemm(context, TransA, TransB, M, N, K, alpha, {
&(een_rescaled_e[af*(i+nw*(cord_num+1))]), LDA,
&(een_rescaled_n[bf*nw]), LDB, beta,
&(tmp_c[cf*(i+nw*cord_num)]), LDC); for (int nw=0; nw < walk_num; ++nw) {
for (int i=0; i<cord_num; ++i){
//CuBlas implementation
int cublasError = cublasDgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, M, N, K, &alpha,
&(een_rescaled_e[af*(i+nw*(cord_num+1))]), \
LDA, \
&(een_rescaled_n[bf*nw]), \
LDB, \
&beta, \
&(tmp_c[cf*(i+nw*cord_num)]), \
LDC);
//Manage cublas ERROR
if(cublasError != CUBLAS_STATUS_SUCCESS){
printf("CUBLAS ERROR %d", cublasError);
info = QMCKL_FAILURE;
return info;
}else{
info = QMCKL_SUCCESS;
}
} }
} }
}
cudaDeviceSynchronize();
cublasDestroy(handle);
#pragma omp target exit data map(from:tmp_c[0:elec_num*nucl_num*(cord_num+1)*cord_num*walk_num])
return info; return info;
} }
#endif #endif
#+end_src #+end_src
#+begin_src c :comments org :tangle (eval h_private_func) :noweb yes :exports none #+begin_src c :comments org :tangle (eval h_private_func) :noweb yes :exports none
#ifdef HAVE_CUBLAS_OFFLOAD #ifdef HAVE_CUBLAS_OFFLOAD
qmckl_exit_code qmckl_compute_tmp_c_cublas_offload ( qmckl_exit_code qmckl_compute_tmp_c_cublas_offload (