mirror of
https://github.com/TREX-CoE/qmckl.git
synced 2025-01-10 13:08:29 +01:00
commit
54f60480fa
@ -2351,6 +2351,7 @@ integer(c_int32_t) function qmckl_compute_factor_ee_deriv_e_doc &
|
|||||||
const double* ee_distance_rescaled,
|
const double* ee_distance_rescaled,
|
||||||
const double* ee_distance_rescaled_deriv_e,
|
const double* ee_distance_rescaled_deriv_e,
|
||||||
double* const factor_ee_deriv_e );
|
double* const factor_ee_deriv_e );
|
||||||
|
|
||||||
#+end_src
|
#+end_src
|
||||||
|
|
||||||
#+begin_src c :tangle (eval h_private_func) :comments org
|
#+begin_src c :tangle (eval h_private_func) :comments org
|
||||||
@ -2366,6 +2367,7 @@ integer(c_int32_t) function qmckl_compute_factor_ee_deriv_e_doc &
|
|||||||
double* const factor_ee_deriv_e );
|
double* const factor_ee_deriv_e );
|
||||||
#+end_src
|
#+end_src
|
||||||
|
|
||||||
|
|
||||||
#+begin_src c :comments org :tangle (eval c) :noweb yes
|
#+begin_src c :comments org :tangle (eval c) :noweb yes
|
||||||
qmckl_exit_code qmckl_compute_factor_ee_deriv_e (
|
qmckl_exit_code qmckl_compute_factor_ee_deriv_e (
|
||||||
const qmckl_context context,
|
const qmckl_context context,
|
||||||
@ -6225,6 +6227,18 @@ qmckl_compute_tmp_c_cublas_offload (const qmckl_context context,
|
|||||||
const double* een_rescaled_n,
|
const double* een_rescaled_n,
|
||||||
double* const tmp_c )
|
double* const tmp_c )
|
||||||
{
|
{
|
||||||
|
qmckl_exit_code info;
|
||||||
|
|
||||||
|
//Initialisation of cublas
|
||||||
|
|
||||||
|
cublasHandle_t handle;
|
||||||
|
if (cublasCreate(&handle) != CUBLAS_STATUS_SUCCESS)
|
||||||
|
{
|
||||||
|
fprintf(stdout, "CUBLAS initialization failed!\n");
|
||||||
|
exit(EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
qmckl_exit_code info;
|
qmckl_exit_code info;
|
||||||
|
|
||||||
@ -6270,40 +6284,35 @@ qmckl_compute_tmp_c_cublas_offload (const qmckl_context context,
|
|||||||
const int64_t bf = elec_num*nucl_num*(cord_num+1);
|
const int64_t bf = elec_num*nucl_num*(cord_num+1);
|
||||||
const int64_t cf = bf;
|
const int64_t cf = bf;
|
||||||
|
|
||||||
info = QMCKL_SUCCESS;
|
|
||||||
|
|
||||||
|
|
||||||
#pragma omp target enter data map(to:een_rescaled_e[0:elec_num*elec_num*(cord_num+1)*walk_num],een_rescaled_n[0:M*N*walk_num],tmp_c[0:elec_num*nucl_num*(cord_num+1)*cord_num*walk_num])
|
#pragma omp target enter data map(to:een_rescaled_e[0:elec_num*elec_num*(cord_num+1)*walk_num],een_rescaled_n[0:M*N*walk_num],tmp_c[0:elec_num*nucl_num*(cord_num+1)*cord_num*walk_num])
|
||||||
#pragma omp target data use_device_ptr(een_rescaled_e,een_rescaled_n,tmp_c)
|
#pragma omp target data use_device_ptr(een_rescaled_e,een_rescaled_n,tmp_c)
|
||||||
{
|
{
|
||||||
|
|
||||||
#pragma omp target teams distribute parallel for collapse(2)
|
|
||||||
for (int nw=0; nw < walk_num; ++nw) {
|
for (int nw=0; nw < walk_num; ++nw) {
|
||||||
for (int i=0; i<cord_num; ++i){
|
|
||||||
|
|
||||||
cublasStatus_t cublasError =
|
int cublasError = cublasDgemmStridedBatched(handle, CUBLAS_OP_N, CUBLAS_OP_N, M, N, K, &alpha,
|
||||||
cublasDgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, M, N, K, &alpha,
|
&(een_rescaled_e[nw*(cord_num+1)]), \
|
||||||
&(een_rescaled_e[af*(i+nw*(cord_num+1))]), \
|
LDA, af, \
|
||||||
LDA, \
|
|
||||||
&(een_rescaled_n[bf*nw]), \
|
&(een_rescaled_n[bf*nw]), \
|
||||||
LDB, \
|
LDB, 0, \
|
||||||
&beta, \
|
&beta, \
|
||||||
&(tmp_c[cf*(i+nw*cord_num)]), \
|
&(tmp_c[nw*cord_num]), \
|
||||||
LDC);
|
LDC, cf, cord_num);
|
||||||
|
|
||||||
|
|
||||||
/*
|
|
||||||
//Manage cublas ERROR
|
//Manage cublas ERROR
|
||||||
if(cublasError != CUBLAS_STATUS_SUCCESS){
|
if(cublasError != CUBLAS_STATUS_SUCCESS){
|
||||||
printf("CUBLAS ERROR %d", cublasError);
|
printf("CUBLAS ERROR %d", cublasError);
|
||||||
info = QMCKL_FAILURE;
|
info = QMCKL_FAILURE;
|
||||||
|
|
||||||
return info;
|
return info;
|
||||||
}else{
|
}else{
|
||||||
info = QMCKL_SUCCESS;
|
info = QMCKL_SUCCESS;
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cudaDeviceSynchronize();
|
cudaDeviceSynchronize();
|
||||||
@ -6313,9 +6322,11 @@ qmckl_compute_tmp_c_cublas_offload (const qmckl_context context,
|
|||||||
#pragma omp target exit data map(from:tmp_c[0:elec_num*nucl_num*(cord_num+1)*cord_num*walk_num])
|
#pragma omp target exit data map(from:tmp_c[0:elec_num*nucl_num*(cord_num+1)*cord_num*walk_num])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return info;
|
return info;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#+end_src
|
#+end_src
|
||||||
|
|
||||||
|
|
||||||
@ -6790,6 +6801,15 @@ qmckl_compute_dtmp_c_cublas_offload (
|
|||||||
const double* een_rescaled_n,
|
const double* een_rescaled_n,
|
||||||
double* const dtmp_c ) {
|
double* const dtmp_c ) {
|
||||||
|
|
||||||
|
cublasHandle_t handle;
|
||||||
|
if (cublasCreate(&handle) != CUBLAS_STATUS_SUCCESS)
|
||||||
|
{
|
||||||
|
fprintf(stdout, "CUBLAS initialization failed!\n");
|
||||||
|
exit(EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if (context == QMCKL_NULL_CONTEXT) {
|
if (context == QMCKL_NULL_CONTEXT) {
|
||||||
return QMCKL_INVALID_CONTEXT;
|
return QMCKL_INVALID_CONTEXT;
|
||||||
}
|
}
|
||||||
@ -6812,8 +6832,6 @@ qmckl_compute_dtmp_c_cublas_offload (
|
|||||||
|
|
||||||
qmckl_exit_code info = QMCKL_SUCCESS;
|
qmckl_exit_code info = QMCKL_SUCCESS;
|
||||||
|
|
||||||
const char TransA = 'N';
|
|
||||||
const char TransB = 'N';
|
|
||||||
const double alpha = 1.0;
|
const double alpha = 1.0;
|
||||||
const double beta = 0.0;
|
const double beta = 0.0;
|
||||||
|
|
||||||
@ -6829,19 +6847,37 @@ qmckl_compute_dtmp_c_cublas_offload (
|
|||||||
const int64_t bf = elec_num*nucl_num*(cord_num+1);
|
const int64_t bf = elec_num*nucl_num*(cord_num+1);
|
||||||
const int64_t cf = elec_num*4*nucl_num*(cord_num+1);
|
const int64_t cf = elec_num*4*nucl_num*(cord_num+1);
|
||||||
|
|
||||||
// TODO Replace with calls to cuBLAS
|
#pragma omp target enter data map(to:een_rescaled_e_deriv_e[0:elec_num*4*elec_num*(cord_num+1)*walk_num], een_rescaled_n[0:elec_num*nucl_num*(cord_num+1)*walk_num], dtmp_c[0:elec_num*4*nucl_num*(cord_num+1)*cord_num*walk_num])
|
||||||
|
#pragma omp target data use_device_ptr(een_rescaled_e_deriv_e, een_rescaled_n, dtmp_c)
|
||||||
|
{
|
||||||
for (int64_t nw=0; nw < walk_num; ++nw) {
|
for (int64_t nw=0; nw < walk_num; ++nw) {
|
||||||
for (int64_t i=0; i < cord_num; ++i) {
|
//Manage CUBLAS ERRORS
|
||||||
info = qmckl_dgemm(context, TransA, TransB, M, N, K, alpha, \
|
|
||||||
&(een_rescaled_e_deriv_e[af*(i+nw*(cord_num+1))]), \
|
int cublasError = cublasDgemmStridedBatched(handle, CUBLAS_OP_N, CUBLAS_OP_N, M, N, K, &alpha, \
|
||||||
LDA, \
|
&(een_rescaled_e_deriv_e[(nw*(cord_num+1))]), \
|
||||||
|
LDA, af, \
|
||||||
&(een_rescaled_n[bf*nw]), \
|
&(een_rescaled_n[bf*nw]), \
|
||||||
LDB, \
|
LDB, 0, \
|
||||||
beta, \
|
&beta, \
|
||||||
&(dtmp_c[cf*(i+nw*cord_num)]), \
|
&(dtmp_c[(nw*cord_num)]), \
|
||||||
LDC);
|
LDC, cf, cord_num);
|
||||||
|
|
||||||
|
|
||||||
|
if(cublasError != CUBLAS_STATUS_SUCCESS){
|
||||||
|
printf("CUBLAS ERROR %d", cublasError);
|
||||||
|
info = QMCKL_FAILURE;
|
||||||
|
return info;
|
||||||
|
}else{
|
||||||
|
info = QMCKL_SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
//}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
cudaDeviceSynchronize();
|
||||||
|
cublasDestroy(handle);
|
||||||
|
|
||||||
|
#pragma omp target exit data map(from:dtmp_c[0:cf*cord_num*walk_num])
|
||||||
|
|
||||||
return info;
|
return info;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user