1
0
mirror of https://github.com/TREX-CoE/qmckl.git synced 2024-12-22 20:36:01 +01:00

Merge pull request #77 from justemax/gpu

Gpu
This commit is contained in:
Aurélien Delval 2022-04-08 10:49:09 +02:00 committed by GitHub
commit 54f60480fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2351,6 +2351,7 @@ integer(c_int32_t) function qmckl_compute_factor_ee_deriv_e_doc &
const double* ee_distance_rescaled, const double* ee_distance_rescaled,
const double* ee_distance_rescaled_deriv_e, const double* ee_distance_rescaled_deriv_e,
double* const factor_ee_deriv_e ); double* const factor_ee_deriv_e );
#+end_src #+end_src
#+begin_src c :tangle (eval h_private_func) :comments org #+begin_src c :tangle (eval h_private_func) :comments org
@ -2366,6 +2367,7 @@ integer(c_int32_t) function qmckl_compute_factor_ee_deriv_e_doc &
double* const factor_ee_deriv_e ); double* const factor_ee_deriv_e );
#+end_src #+end_src
#+begin_src c :comments org :tangle (eval c) :noweb yes #+begin_src c :comments org :tangle (eval c) :noweb yes
qmckl_exit_code qmckl_compute_factor_ee_deriv_e ( qmckl_exit_code qmckl_compute_factor_ee_deriv_e (
const qmckl_context context, const qmckl_context context,
@ -6225,6 +6227,18 @@ qmckl_compute_tmp_c_cublas_offload (const qmckl_context context,
const double* een_rescaled_n, const double* een_rescaled_n,
double* const tmp_c ) double* const tmp_c )
{ {
qmckl_exit_code info;
//Initialisation of cublas
cublasHandle_t handle;
if (cublasCreate(&handle) != CUBLAS_STATUS_SUCCESS)
{
fprintf(stdout, "CUBLAS initialization failed!\n");
exit(EXIT_FAILURE);
}
qmckl_exit_code info; qmckl_exit_code info;
@ -6270,40 +6284,35 @@ qmckl_compute_tmp_c_cublas_offload (const qmckl_context context,
const int64_t bf = elec_num*nucl_num*(cord_num+1); const int64_t bf = elec_num*nucl_num*(cord_num+1);
const int64_t cf = bf; const int64_t cf = bf;
info = QMCKL_SUCCESS;
#pragma omp target enter data map(to:een_rescaled_e[0:elec_num*elec_num*(cord_num+1)*walk_num],een_rescaled_n[0:M*N*walk_num],tmp_c[0:elec_num*nucl_num*(cord_num+1)*cord_num*walk_num]) #pragma omp target enter data map(to:een_rescaled_e[0:elec_num*elec_num*(cord_num+1)*walk_num],een_rescaled_n[0:M*N*walk_num],tmp_c[0:elec_num*nucl_num*(cord_num+1)*cord_num*walk_num])
#pragma omp target data use_device_ptr(een_rescaled_e,een_rescaled_n,tmp_c) #pragma omp target data use_device_ptr(een_rescaled_e,een_rescaled_n,tmp_c)
{ {
#pragma omp target teams distribute parallel for collapse(2)
for (int nw=0; nw < walk_num; ++nw) { for (int nw=0; nw < walk_num; ++nw) {
for (int i=0; i<cord_num; ++i){
cublasStatus_t cublasError = int cublasError = cublasDgemmStridedBatched(handle, CUBLAS_OP_N, CUBLAS_OP_N, M, N, K, &alpha,
cublasDgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, M, N, K, &alpha, &(een_rescaled_e[nw*(cord_num+1)]), \
&(een_rescaled_e[af*(i+nw*(cord_num+1))]), \ LDA, af, \
LDA, \
&(een_rescaled_n[bf*nw]), \ &(een_rescaled_n[bf*nw]), \
LDB, \ LDB, 0, \
&beta, \ &beta, \
&(tmp_c[cf*(i+nw*cord_num)]), \ &(tmp_c[nw*cord_num]), \
LDC); LDC, cf, cord_num);
/*
//Manage cublas ERROR //Manage cublas ERROR
if(cublasError != CUBLAS_STATUS_SUCCESS){ if(cublasError != CUBLAS_STATUS_SUCCESS){
printf("CUBLAS ERROR %d", cublasError); printf("CUBLAS ERROR %d", cublasError);
info = QMCKL_FAILURE; info = QMCKL_FAILURE;
return info; return info;
}else{ }else{
info = QMCKL_SUCCESS; info = QMCKL_SUCCESS;
} }
*/
}
} }
} }
cudaDeviceSynchronize(); cudaDeviceSynchronize();
@ -6313,9 +6322,11 @@ qmckl_compute_tmp_c_cublas_offload (const qmckl_context context,
#pragma omp target exit data map(from:tmp_c[0:elec_num*nucl_num*(cord_num+1)*cord_num*walk_num]) #pragma omp target exit data map(from:tmp_c[0:elec_num*nucl_num*(cord_num+1)*cord_num*walk_num])
return info; return info;
} }
#endif #endif
#+end_src #+end_src
@ -6790,30 +6801,37 @@ qmckl_compute_dtmp_c_cublas_offload (
const double* een_rescaled_n, const double* een_rescaled_n,
double* const dtmp_c ) { double* const dtmp_c ) {
cublasHandle_t handle;
if (cublasCreate(&handle) != CUBLAS_STATUS_SUCCESS)
{
fprintf(stdout, "CUBLAS initialization failed!\n");
exit(EXIT_FAILURE);
}
if (context == QMCKL_NULL_CONTEXT) { if (context == QMCKL_NULL_CONTEXT) {
return QMCKL_INVALID_CONTEXT; return QMCKL_INVALID_CONTEXT;
} }
if (cord_num <= 0) { if (cord_num <= 0) {
return QMCKL_INVALID_ARG_2; return QMCKL_INVALID_ARG_2;
} }
if (elec_num <= 0) { if (elec_num <= 0) {
return QMCKL_INVALID_ARG_3; return QMCKL_INVALID_ARG_3;
} }
if (nucl_num <= 0) { if (nucl_num <= 0) {
return QMCKL_INVALID_ARG_4; return QMCKL_INVALID_ARG_4;
} }
if (walk_num <= 0) { if (walk_num <= 0) {
return QMCKL_INVALID_ARG_5; return QMCKL_INVALID_ARG_5;
} }
qmckl_exit_code info = QMCKL_SUCCESS; qmckl_exit_code info = QMCKL_SUCCESS;
const char TransA = 'N';
const char TransB = 'N';
const double alpha = 1.0; const double alpha = 1.0;
const double beta = 0.0; const double beta = 0.0;
@ -6829,19 +6847,37 @@ qmckl_compute_dtmp_c_cublas_offload (
const int64_t bf = elec_num*nucl_num*(cord_num+1); const int64_t bf = elec_num*nucl_num*(cord_num+1);
const int64_t cf = elec_num*4*nucl_num*(cord_num+1); const int64_t cf = elec_num*4*nucl_num*(cord_num+1);
// TODO Replace with calls to cuBLAS #pragma omp target enter data map(to:een_rescaled_e_deriv_e[0:elec_num*4*elec_num*(cord_num+1)*walk_num], een_rescaled_n[0:elec_num*nucl_num*(cord_num+1)*walk_num], dtmp_c[0:elec_num*4*nucl_num*(cord_num+1)*cord_num*walk_num])
for (int64_t nw=0; nw < walk_num; ++nw) { #pragma omp target data use_device_ptr(een_rescaled_e_deriv_e, een_rescaled_n, dtmp_c)
for (int64_t i=0; i < cord_num; ++i) { {
info = qmckl_dgemm(context, TransA, TransB, M, N, K, alpha, \ for (int64_t nw=0; nw < walk_num; ++nw) {
&(een_rescaled_e_deriv_e[af*(i+nw*(cord_num+1))]), \ //Manage CUBLAS ERRORS
LDA, \
&(een_rescaled_n[bf*nw]), \ int cublasError = cublasDgemmStridedBatched(handle, CUBLAS_OP_N, CUBLAS_OP_N, M, N, K, &alpha, \
LDB, \ &(een_rescaled_e_deriv_e[(nw*(cord_num+1))]), \
beta, \ LDA, af, \
&(dtmp_c[cf*(i+nw*cord_num)]), \ &(een_rescaled_n[bf*nw]), \
LDC); LDB, 0, \
&beta, \
&(dtmp_c[(nw*cord_num)]), \
LDC, cf, cord_num);
if(cublasError != CUBLAS_STATUS_SUCCESS){
printf("CUBLAS ERROR %d", cublasError);
info = QMCKL_FAILURE;
return info;
}else{
info = QMCKL_SUCCESS;
}
//}
} }
} }
cudaDeviceSynchronize();
cublasDestroy(handle);
#pragma omp target exit data map(from:dtmp_c[0:cf*cord_num*walk_num])
return info; return info;
} }