1
0
mirror of https://github.com/TREX-CoE/qmckl.git synced 2025-01-03 10:06:09 +01:00

Merge branch 'gpu' of github.com:TREX-CoE/qmckl into gpu

This commit is contained in:
Anthony Scemama 2022-04-07 13:35:08 +02:00
commit a7fac59f04

View File

@ -108,6 +108,7 @@ int main() {
#include <assert.h> #include <assert.h>
#include <math.h> #include <math.h>
#include <stdio.h> #include <stdio.h>
#include "qmckl.h" #include "qmckl.h"
@ -116,6 +117,13 @@ int main() {
#include "qmckl_memory_private_func.h" #include "qmckl_memory_private_func.h"
#include "qmckl_jastrow_private_func.h" #include "qmckl_jastrow_private_func.h"
#include "qmckl_jastrow_private_type.h" #include "qmckl_jastrow_private_type.h"
#ifdef HAVE_CUBLAS_OFFLOAD
#include <cuda_runtime_api.h>
#include "cublas_v2.h"
#endif
#+end_src #+end_src
* Context * Context
@ -1117,7 +1125,7 @@ qmckl_exit_code qmckl_finalize_jastrow(qmckl_context context) {
#if defined(HAVE_HPC) && (defined(HAVE_CUBLAS_OFFLOAD) || defined(HAVE_OPENACC_OFFLOAD) || defined(HAVE_OPENMP_OFFLOAD)) #if defined(HAVE_HPC) && (defined(HAVE_CUBLAS_OFFLOAD) || defined(HAVE_OPENACC_OFFLOAD) || defined(HAVE_OPENMP_OFFLOAD))
ctx->jastrow.gpu_offload = true; // ctx->electron.num > 100; ctx->jastrow.gpu_offload = true; // ctx->electron.num > 100;
#endif #endif
qmckl_exit_code rc = QMCKL_SUCCESS; qmckl_exit_code rc = QMCKL_SUCCESS;
return rc; return rc;
@ -1511,7 +1519,7 @@ qmckl_exit_code qmckl_compute_asymp_jasb (
const int64_t bord_num, const int64_t bord_num,
const double* bord_vector, const double* bord_vector,
const double rescale_factor_kappa_ee, const double rescale_factor_kappa_ee,
double* const asymp_jasb ); double* const asymp_jasb );
#+end_src #+end_src
@ -1802,21 +1810,21 @@ qmckl_exit_code qmckl_compute_factor_ee (
int ipar; // can we use a smaller integer? int ipar; // can we use a smaller integer?
double x, x1, spin_fact, power_ser; double x, x1, spin_fact, power_ser;
if (context == QMCKL_NULL_CONTEXT) { if (context == QMCKL_NULL_CONTEXT) {
return QMCKL_INVALID_CONTEXT; return QMCKL_INVALID_CONTEXT;
} }
if (walk_num <= 0) { if (walk_num <= 0) {
return QMCKL_INVALID_ARG_2; return QMCKL_INVALID_ARG_2;
} }
if (elec_num <= 0) { if (elec_num <= 0) {
return QMCKL_INVALID_ARG_3; return QMCKL_INVALID_ARG_3;
} }
if (bord_num <= 0) { if (bord_num <= 0) {
return QMCKL_INVALID_ARG_4; return QMCKL_INVALID_ARG_4;
} }
for (int nw = 0; nw < walk_num; ++nw) { for (int nw = 0; nw < walk_num; ++nw) {
factor_ee[nw] = 0.0; // put init array here. factor_ee[nw] = 0.0; // put init array here.
@ -1827,9 +1835,9 @@ qmckl_exit_code qmckl_compute_factor_ee (
x1 = x; x1 = x;
power_ser = 0.0; power_ser = 0.0;
spin_fact = 1.0; spin_fact = 1.0;
ipar = 0; // index of asymp_jasb ipar = 0; // index of asymp_jasb
for (int p = 1; p < bord_num; ++p) { for (int p = 1; p < bord_num; ++p) {
x = x * x1; x = x * x1;
power_ser = power_ser + bord_vector[p + 1] * x; power_ser = power_ser + bord_vector[p + 1] * x;
} }
@ -1838,7 +1846,7 @@ qmckl_exit_code qmckl_compute_factor_ee (
spin_fact = 0.5; spin_fact = 0.5;
ipar = 1; ipar = 1;
} }
factor_ee[nw] = factor_ee[nw] + spin_fact * bord_vector[0] * \ factor_ee[nw] = factor_ee[nw] + spin_fact * bord_vector[0] * \
x1 / \ x1 / \
(1.0 + bord_vector[1] * \ (1.0 + bord_vector[1] * \
@ -1854,7 +1862,7 @@ qmckl_exit_code qmckl_compute_factor_ee (
#+end_src #+end_src
# #+CALL: generate_c_header(table=qmckl_factor_ee_args,rettyp=get_value("CRetType"),fname=get_value("Name")) # #+CALL: generate_c_header(table=qmckl_factor_ee_args,rettyp=get_value("CRetType"),fname=get_value("Name"))
#+begin_src c :comments org :tangle (eval h_private_func) :noweb yes :exports none #+begin_src c :comments org :tangle (eval h_private_func) :noweb yes :exports none
qmckl_exit_code qmckl_compute_factor_ee ( qmckl_exit_code qmckl_compute_factor_ee (
const qmckl_context context, const qmckl_context context,
@ -1865,7 +1873,7 @@ qmckl_exit_code qmckl_compute_factor_ee (
const double* bord_vector, const double* bord_vector,
const double* ee_distance_rescaled, const double* ee_distance_rescaled,
const double* asymp_jasb, const double* asymp_jasb,
double* const factor_ee ); double* const factor_ee );
#+end_src #+end_src
@ -2177,7 +2185,7 @@ end function qmckl_compute_factor_ee_deriv_e_f
const double* ee_distance_rescaled, const double* ee_distance_rescaled,
const double* ee_distance_rescaled_deriv_e, const double* ee_distance_rescaled_deriv_e,
const double* asymp_jasb, const double* asymp_jasb,
double* const factor_ee_deriv_e ); double* const factor_ee_deriv_e );
#+end_src #+end_src
@ -2451,7 +2459,7 @@ qmckl_exit_code qmckl_provide_factor_en(qmckl_context context)
if (rc != QMCKL_SUCCESS) { if (rc != QMCKL_SUCCESS) {
return rc; return rc;
} }
ctx->jastrow.factor_en_date = ctx->date; ctx->jastrow.factor_en_date = ctx->date;
} }
@ -2550,7 +2558,7 @@ integer function qmckl_compute_factor_en_f( &
end function qmckl_compute_factor_en_f end function qmckl_compute_factor_en_f
#+end_src #+end_src
#+begin_src c :comments org :tangle (eval c) :noweb yes #+begin_src c :comments org :tangle (eval c) :noweb yes
qmckl_exit_code qmckl_compute_factor_en ( qmckl_exit_code qmckl_compute_factor_en (
@ -2619,7 +2627,7 @@ qmckl_exit_code qmckl_compute_factor_en (
x1 = x; x1 = x;
power_ser = 0.0; power_ser = 0.0;
for (int p = 2; p < aord_num+1; ++p) { for (int p = 2; p < aord_num+1; ++p) {
x = x * x1; x = x * x1;
power_ser = power_ser + aord_vector[(p+1)-1 + (type_nucl_vector[a]-1) * aord_num] * x; power_ser = power_ser + aord_vector[(p+1)-1 + (type_nucl_vector[a]-1) * aord_num] * x;
} }
@ -2650,7 +2658,7 @@ qmckl_exit_code qmckl_compute_factor_en (
const int64_t aord_num, const int64_t aord_num,
const double* aord_vector, const double* aord_vector,
const double* en_distance_rescaled, const double* en_distance_rescaled,
double* const factor_en ); double* const factor_en );
#+end_src #+end_src
@ -2944,7 +2952,7 @@ end function qmckl_compute_factor_en_deriv_e_f
const double* aord_vector, const double* aord_vector,
const double* en_distance_rescaled, const double* en_distance_rescaled,
const double* en_distance_rescaled_deriv_e, const double* en_distance_rescaled_deriv_e,
double* const factor_en_deriv_e ); double* const factor_en_deriv_e );
#+end_src #+end_src
@ -3337,7 +3345,7 @@ end function qmckl_compute_een_rescaled_e_doc_f
const int64_t cord_num, const int64_t cord_num,
const double rescale_factor_kappa_ee, const double rescale_factor_kappa_ee,
const double* ee_distance, const double* ee_distance,
double* const een_rescaled_e ); double* const een_rescaled_e );
#+end_src #+end_src
#+CALL: generate_c_interface(table=qmckl_factor_een_rescaled_e_args,rettyp=get_value("CRetType"),fname="qmckl_compute_een_rescaled_e_doc") #+CALL: generate_c_interface(table=qmckl_factor_een_rescaled_e_args,rettyp=get_value("CRetType"),fname="qmckl_compute_een_rescaled_e_doc")
@ -3376,13 +3384,13 @@ qmckl_exit_code qmckl_compute_een_rescaled_e_hpc (
const double rescale_factor_kappa_ee, const double rescale_factor_kappa_ee,
const double* ee_distance, const double* ee_distance,
double* const een_rescaled_e ) { double* const een_rescaled_e ) {
double *een_rescaled_e_ij; double *een_rescaled_e_ij;
double x; double x;
const int64_t elec_pairs = (elec_num * (elec_num - 1)) / 2; const int64_t elec_pairs = (elec_num * (elec_num - 1)) / 2;
const int64_t len_een_ij = elec_pairs * (cord_num + 1); const int64_t len_een_ij = elec_pairs * (cord_num + 1);
int64_t k; int64_t k;
// number of element for the een_rescaled_e_ij[N_e*(N_e-1)/2][cord+1] // number of element for the een_rescaled_e_ij[N_e*(N_e-1)/2][cord+1]
// probably in C is better [cord+1, Ne*(Ne-1)/2] // probably in C is better [cord+1, Ne*(Ne-1)/2]
//elec_pairs = (elec_num * (elec_num - 1)) / 2; //elec_pairs = (elec_num * (elec_num - 1)) / 2;
@ -3391,7 +3399,7 @@ qmckl_exit_code qmckl_compute_een_rescaled_e_hpc (
if (context == QMCKL_NULL_CONTEXT) { if (context == QMCKL_NULL_CONTEXT) {
return QMCKL_INVALID_CONTEXT; return QMCKL_INVALID_CONTEXT;
} }
if (walk_num <= 0) { if (walk_num <= 0) {
return QMCKL_INVALID_ARG_2; return QMCKL_INVALID_ARG_2;
@ -3406,8 +3414,8 @@ qmckl_exit_code qmckl_compute_een_rescaled_e_hpc (
} }
// Prepare table of exponentiated distances raised to appropriate power // Prepare table of exponentiated distances raised to appropriate power
// init // init
for (int kk = 0; kk < walk_num*(cord_num+1)*elec_num*elec_num; ++kk) { for (int kk = 0; kk < walk_num*(cord_num+1)*elec_num*elec_num; ++kk) {
een_rescaled_e[kk]= 0.0; een_rescaled_e[kk]= 0.0;
} }
@ -3425,14 +3433,14 @@ qmckl_exit_code qmckl_compute_een_rescaled_e_hpc (
*/ */
for (int nw = 0; nw < walk_num; ++nw) { for (int nw = 0; nw < walk_num; ++nw) {
for (int kk = 0; kk < len_een_ij; ++kk) { for (int kk = 0; kk < len_een_ij; ++kk) {
// this array initialized at 0 except een_rescaled_e_ij(:, 1) = 1.0d0 // this array initialized at 0 except een_rescaled_e_ij(:, 1) = 1.0d0
// and the arrangement of indices is [cord_num+1, ne*(ne-1)/2] // and the arrangement of indices is [cord_num+1, ne*(ne-1)/2]
een_rescaled_e_ij[kk]= ( kk < (elec_pairs) ? 1.0 : 0.0 ); een_rescaled_e_ij[kk]= ( kk < (elec_pairs) ? 1.0 : 0.0 );
} }
k = 0; k = 0;
for (int i = 0; i < elec_num; ++i) { for (int i = 0; i < elec_num; ++i) {
for (int j = 0; j < i; ++j) { for (int j = 0; j < i; ++j) {
// een_rescaled_e_ij(k, 2) = dexp(-rescale_factor_kappa_ee * ee_distance(i, j, nw)); // een_rescaled_e_ij(k, 2) = dexp(-rescale_factor_kappa_ee * ee_distance(i, j, nw));
@ -3450,7 +3458,7 @@ qmckl_exit_code qmckl_compute_een_rescaled_e_hpc (
een_rescaled_e_ij[k + elec_pairs]; een_rescaled_e_ij[k + elec_pairs];
} }
} }
// prepare the actual een table // prepare the actual een table
for (int i = 0; i < elec_num; ++i){ for (int i = 0; i < elec_num; ++i){
@ -3458,7 +3466,7 @@ qmckl_exit_code qmckl_compute_een_rescaled_e_hpc (
een_rescaled_e[j + i*elec_num + 0 + nw*(cord_num+1)*elec_num*elec_num] = 1.0; een_rescaled_e[j + i*elec_num + 0 + nw*(cord_num+1)*elec_num*elec_num] = 1.0;
} }
} }
// Up to here it should work. // Up to here it should work.
for ( int l = 1; l < (cord_num+1); ++l) { for ( int l = 1; l < (cord_num+1); ++l) {
k = 0; k = 0;
@ -3481,7 +3489,7 @@ qmckl_exit_code qmckl_compute_een_rescaled_e_hpc (
} }
free(een_rescaled_e_ij); free(een_rescaled_e_ij);
return QMCKL_SUCCESS; return QMCKL_SUCCESS;
} }
#+end_src #+end_src
@ -3520,7 +3528,7 @@ qmckl_exit_code qmckl_compute_een_rescaled_e_hpc (
const double* ee_distance, const double* ee_distance,
double* const een_rescaled_e ); double* const een_rescaled_e );
#+end_src #+end_src
#+begin_src c :comments org :tangle (eval c) :noweb yes #+begin_src c :comments org :tangle (eval c) :noweb yes
qmckl_exit_code qmckl_compute_een_rescaled_e ( qmckl_exit_code qmckl_compute_een_rescaled_e (
const qmckl_context context, const qmckl_context context,
@ -3848,7 +3856,7 @@ end function qmckl_compute_factor_een_rescaled_e_deriv_e_f
const double* coord_new, const double* coord_new,
const double* ee_distance, const double* ee_distance,
const double* een_rescaled_e, const double* een_rescaled_e,
double* const een_rescaled_e_deriv_e ); double* const een_rescaled_e_deriv_e );
#+end_src #+end_src
@ -4207,7 +4215,7 @@ qmckl_exit_code qmckl_compute_een_rescaled_n (
if (context == QMCKL_NULL_CONTEXT) { if (context == QMCKL_NULL_CONTEXT) {
return QMCKL_INVALID_CONTEXT; return QMCKL_INVALID_CONTEXT;
} }
if (walk_num <= 0) { if (walk_num <= 0) {
return QMCKL_INVALID_ARG_2; return QMCKL_INVALID_ARG_2;
@ -4268,7 +4276,7 @@ qmckl_exit_code qmckl_compute_een_rescaled_n (
const int64_t cord_num, const int64_t cord_num,
const double rescale_factor_kappa_en, const double rescale_factor_kappa_en,
const double* en_distance, const double* en_distance,
double* const een_rescaled_n ); double* const een_rescaled_n );
#+end_src #+end_src
*** Test *** Test
@ -4577,7 +4585,7 @@ end function qmckl_compute_factor_een_rescaled_n_deriv_e_f
const double* coord, const double* coord,
const double* en_distance, const double* en_distance,
const double* een_rescaled_n, const double* een_rescaled_n,
double* const een_rescaled_n_deriv_e ); double* const een_rescaled_n_deriv_e );
#+end_src #+end_src
#+CALL: generate_c_interface(table=qmckl_compute_factor_een_rescaled_n_deriv_e_args,rettyp=get_value("CRetType"),fname=get_value("Name")) #+CALL: generate_c_interface(table=qmckl_compute_factor_een_rescaled_n_deriv_e_args,rettyp=get_value("CRetType"),fname=get_value("Name"))
@ -5019,14 +5027,15 @@ qmckl_exit_code qmckl_provide_tmp_c(qmckl_context context)
ctx->jastrow.tmp_c = tmp_c; ctx->jastrow.tmp_c = tmp_c;
} }
/* Choose the correct compute function (depending on offload type) */ /* Choose the correct compute function (depending on offload type) */
#ifdef HAVE_HPC #ifdef HAVE_HPC
const bool gpu_offload = ctx->jastrow.gpu_offload; const bool gpu_offload = ctx->jastrow.gpu_offload;
#else #else
const bool gpu_offload = false; const bool gpu_offload = false;
#endif #endif
if (gpu_offload) { if (gpu_offload) {
#ifdef HAVE_CUBLAS_OFFLOAD #ifdef HAVE_CUBLAS_OFFLOAD
rc = qmckl_compute_tmp_c_cublas_offload(context, rc = qmckl_compute_tmp_c_cublas_offload(context,
ctx->jastrow.cord_num, ctx->jastrow.cord_num,
@ -5067,7 +5076,8 @@ qmckl_exit_code qmckl_provide_tmp_c(qmckl_context context)
ctx->jastrow.een_rescaled_n, ctx->jastrow.een_rescaled_n,
ctx->jastrow.tmp_c); ctx->jastrow.tmp_c);
} }
ctx->jastrow.tmp_c_date = ctx->date; ctx->jastrow.tmp_c_date = ctx->date;
} }
@ -5107,13 +5117,14 @@ qmckl_exit_code qmckl_provide_dtmp_c(qmckl_context context)
ctx->jastrow.dtmp_c = dtmp_c; ctx->jastrow.dtmp_c = dtmp_c;
} }
#ifdef HAVE_HPC #ifdef HAVE_HPC
const bool gpu_offload = ctx->jastrow.gpu_offload; const bool gpu_offload = ctx->jastrow.gpu_offload;
#else #else
const bool gpu_offload = false; const bool gpu_offload = false;
#endif #endif
if (gpu_offload) { if (gpu_offload) {
#ifdef HAVE_CUBLAS_OFFLOAD #ifdef HAVE_CUBLAS_OFFLOAD
rc = qmckl_compute_dtmp_c_cublas_offload(context, rc = qmckl_compute_dtmp_c_cublas_offload(context,
ctx->jastrow.cord_num, ctx->jastrow.cord_num,
@ -5159,6 +5170,7 @@ qmckl_exit_code qmckl_provide_dtmp_c(qmckl_context context)
return rc; return rc;
} }
ctx->jastrow.dtmp_c_date = ctx->date; ctx->jastrow.dtmp_c_date = ctx->date;
} }
@ -5228,10 +5240,10 @@ qmckl_exit_code qmckl_compute_dim_cord_vect (
const qmckl_context context, const qmckl_context context,
const int64_t cord_num, const int64_t cord_num,
int64_t* const dim_cord_vect){ int64_t* const dim_cord_vect){
int lmax; int lmax;
if (context == QMCKL_NULL_CONTEXT) { if (context == QMCKL_NULL_CONTEXT) {
return QMCKL_INVALID_CONTEXT; return QMCKL_INVALID_CONTEXT;
} }
@ -5241,7 +5253,7 @@ qmckl_exit_code qmckl_compute_dim_cord_vect (
} }
*dim_cord_vect = 0; *dim_cord_vect = 0;
for (int p=2; p <= cord_num; ++p){ for (int p=2; p <= cord_num; ++p){
for (int k=p-1; k >= 0; --k) { for (int k=p-1; k >= 0; --k) {
if (k != 0) { if (k != 0) {
@ -5255,7 +5267,7 @@ qmckl_exit_code qmckl_compute_dim_cord_vect (
} }
} }
} }
return QMCKL_SUCCESS; return QMCKL_SUCCESS;
} }
#+end_src #+end_src
@ -5266,7 +5278,7 @@ qmckl_exit_code qmckl_compute_dim_cord_vect (
qmckl_exit_code qmckl_compute_dim_cord_vect ( qmckl_exit_code qmckl_compute_dim_cord_vect (
const qmckl_context context, const qmckl_context context,
const int64_t cord_num, const int64_t cord_num,
int64_t* const dim_cord_vect ); int64_t* const dim_cord_vect );
#+end_src #+end_src
@ -5531,15 +5543,15 @@ qmckl_exit_code qmckl_compute_lkpm_combined_index (
int kk, lmax, m; int kk, lmax, m;
if (context == QMCKL_NULL_CONTEXT) { if (context == QMCKL_NULL_CONTEXT) {
return QMCKL_INVALID_CONTEXT; return QMCKL_INVALID_CONTEXT;
} }
if (cord_num <= 0) { if (cord_num <= 0) {
return QMCKL_INVALID_ARG_2; return QMCKL_INVALID_ARG_2;
} }
if (dim_cord_vect <= 0) { if (dim_cord_vect <= 0) {
return QMCKL_INVALID_ARG_3; return QMCKL_INVALID_ARG_3;
} }
@ -5576,7 +5588,7 @@ qmckl_exit_code qmckl_compute_lkpm_combined_index (
const qmckl_context context, const qmckl_context context,
const int64_t cord_num, const int64_t cord_num,
const int64_t dim_cord_vect, const int64_t dim_cord_vect,
int64_t* const lkpm_combined_index ); int64_t* const lkpm_combined_index );
#+end_src #+end_src
@ -5617,7 +5629,7 @@ qmckl_exit_code qmckl_compute_tmp_c (const qmckl_context context,
#endif #endif
} }
#+end_src #+end_src
# #+CALL: generate_c_header(table=qmckl_factor_tmp_c_args,rettyp=get_value("CRetType"),fname="qmckl_compute_tmp_c") # #+CALL: generate_c_header(table=qmckl_factor_tmp_c_args,rettyp=get_value("CRetType"),fname="qmckl_compute_tmp_c")
#+begin_src c :comments org :tangle (eval h_private_func) :noweb yes :exports none #+begin_src c :comments org :tangle (eval h_private_func) :noweb yes :exports none
@ -5629,7 +5641,7 @@ qmckl_exit_code qmckl_compute_tmp_c (const qmckl_context context,
const int64_t walk_num, const int64_t walk_num,
const double* een_rescaled_e, const double* een_rescaled_e,
const double* een_rescaled_n, const double* een_rescaled_n,
double* const tmp_c ); double* const tmp_c );
#+end_src #+end_src
#+begin_src f90 :comments org :tangle (eval f) :noweb yes #+begin_src f90 :comments org :tangle (eval f) :noweb yes
@ -5709,11 +5721,11 @@ qmckl_exit_code qmckl_compute_tmp_c_doc (
const int64_t walk_num, const int64_t walk_num,
const double* een_rescaled_e, const double* een_rescaled_e,
const double* een_rescaled_n, const double* een_rescaled_n,
double* const tmp_c ); double* const tmp_c );
#+end_src #+end_src
#+CALL: generate_c_interface(table=qmckl_factor_tmp_c_args,rettyp=get_value("FRetType"),fname="qmckl_compute_tmp_c_doc") #+CALL: generate_c_interface(table=qmckl_factor_tmp_c_args,rettyp=get_value("FRetType"),fname="qmckl_compute_tmp_c_doc")
#+RESULTS: #+RESULTS:
#+begin_src f90 :tangle (eval f) :comments org :exports none #+begin_src f90 :tangle (eval f) :comments org :exports none
integer(c_int32_t) function qmckl_compute_tmp_c_doc & integer(c_int32_t) function qmckl_compute_tmp_c_doc &
@ -5758,19 +5770,19 @@ qmckl_exit_code qmckl_compute_tmp_c_hpc (
if (cord_num <= 0) { if (cord_num <= 0) {
return QMCKL_INVALID_ARG_2; return QMCKL_INVALID_ARG_2;
} }
if (elec_num <= 0) { if (elec_num <= 0) {
return QMCKL_INVALID_ARG_3; return QMCKL_INVALID_ARG_3;
} }
if (nucl_num <= 0) { if (nucl_num <= 0) {
return QMCKL_INVALID_ARG_4; return QMCKL_INVALID_ARG_4;
} }
if (walk_num <= 0) { if (walk_num <= 0) {
return QMCKL_INVALID_ARG_5; return QMCKL_INVALID_ARG_5;
} }
qmckl_exit_code info = QMCKL_SUCCESS; qmckl_exit_code info = QMCKL_SUCCESS;
@ -5807,6 +5819,42 @@ qmckl_exit_code qmckl_compute_tmp_c_hpc (
} }
#+end_src #+end_src
#+CALL: generate_c_header(table=qmckl_factor_tmp_c_args,rettyp=get_value("CRetType"),fname="qmckl_compute_tmp_c")
#+RESULTS:
#+begin_src c :tangle (eval h_func) :comments org
qmckl_exit_code qmckl_compute_tmp_c (
const qmckl_context context,
const int64_t cord_num,
const int64_t elec_num,
const int64_t nucl_num,
const int64_t walk_num,
const double* een_rescaled_e,
const double* een_rescaled_n,
double* const tmp_c );
#+end_src
# #+CALL: generate_c_header(table=qmckl_factor_tmp_c_args,rettyp=get_value("CRetType"),fname="qmckl_compute_tmp_c_doc")
#+RESULTS:
#+begin_src c :tangle (eval h_private_func) :comments org
qmckl_exit_code qmckl_compute_tmp_c_doc (
const qmckl_context context,
const int64_t cord_num,
const int64_t elec_num,
const int64_t nucl_num,
const int64_t walk_num,
const double* een_rescaled_e,
const double* een_rescaled_n,
double* const tmp_c );
#+end_src
# #+CALL: generate_c_header(table=qmckl_factor_tmp_c_args,rettyp=get_value("CRetType"),fname="qmckl_compute_tmp_c_hpc")
#+RESULTS:
#+begin_src c :tangle (eval h_private_func) :comments org #+begin_src c :tangle (eval h_private_func) :comments org
qmckl_exit_code qmckl_compute_tmp_c_hpc (const qmckl_context context, qmckl_exit_code qmckl_compute_tmp_c_hpc (const qmckl_context context,
const int64_t cord_num, const int64_t cord_num,
@ -5815,7 +5863,7 @@ qmckl_exit_code qmckl_compute_tmp_c_hpc (const qmckl_context context,
const int64_t walk_num, const int64_t walk_num,
const double* een_rescaled_e, const double* een_rescaled_e,
const double* een_rescaled_n, const double* een_rescaled_n,
double* const tmp_c ); double* const tmp_c );
#+end_src #+end_src
**** OpenACC offload :noexport: **** OpenACC offload :noexport:
@ -5865,7 +5913,7 @@ qmckl_exit_code qmckl_compute_tmp_c_acc_offload (const qmckl_context context,
const int64_t size_tmp_c = elec_num*nucl_num*(cord_num+1)*cord_num*walk_num; const int64_t size_tmp_c = elec_num*nucl_num*(cord_num+1)*cord_num*walk_num;
const int64_t size_e = walk_num*(cord_num+1)*elec_num*elec_num; const int64_t size_e = walk_num*(cord_num+1)*elec_num*elec_num;
const int64_t size_n = walk_num*(cord_num+1)*nucl_num*elec_num; const int64_t size_n = walk_num*(cord_num+1)*nucl_num*elec_num;
#pragma acc parallel copyout(tmp_c [0:size_tmp_c]) copyin(een_rescaled_e[0:size_e], een_rescaled_n[0:size_n]) #pragma acc parallel copyout(tmp_c [0:size_tmp_c]) copyin(een_rescaled_e[0:size_e], een_rescaled_n[0:size_n])
{ {
@ -5877,7 +5925,7 @@ qmckl_exit_code qmckl_compute_tmp_c_acc_offload (const qmckl_context context,
for (int j=0; j<cord_num+1; j++) { for (int j=0; j<cord_num+1; j++) {
for (int k=0; k<nucl_num; k++) { for (int k=0; k<nucl_num; k++) {
for (int l=0; l<elec_num; l++) { for (int l=0; l<elec_num; l++) {
// Single reduction // Single reduction
tmp_c[l + k*stride_k_c + j*stride_j_c + i*stride_i_c + nw*stride_nw_c] = 0.; tmp_c[l + k*stride_k_c + j*stride_j_c + i*stride_i_c + nw*stride_nw_c] = 0.;
for (int m=0; m<elec_num; m++) { for (int m=0; m<elec_num; m++) {
@ -5886,7 +5934,7 @@ qmckl_exit_code qmckl_compute_tmp_c_acc_offload (const qmckl_context context,
een_rescaled_e[l + m*stride_m_e + i*stride_i_e + nw*stride_nw_e] * een_rescaled_e[l + m*stride_m_e + i*stride_i_e + nw*stride_nw_e] *
een_rescaled_n[m + k*stride_k_n + j*stride_j_n + nw*stride_nw_n]; een_rescaled_n[m + k*stride_k_n + j*stride_j_n + nw*stride_nw_n];
} }
} }
} }
} }
@ -5961,7 +6009,7 @@ qmckl_compute_tmp_c_omp_offload (const qmckl_context context,
const int64_t size_tmp_c = elec_num*nucl_num*(cord_num+1)*cord_num*walk_num; const int64_t size_tmp_c = elec_num*nucl_num*(cord_num+1)*cord_num*walk_num;
const int64_t size_e = walk_num*(cord_num+1)*elec_num*elec_num; const int64_t size_e = walk_num*(cord_num+1)*elec_num*elec_num;
const int64_t size_n = walk_num*(cord_num+1)*nucl_num*elec_num; const int64_t size_n = walk_num*(cord_num+1)*nucl_num*elec_num;
#pragma omp target teams distribute parallel for collapse(5) \ #pragma omp target teams distribute parallel for collapse(5) \
@ -5975,7 +6023,7 @@ qmckl_compute_tmp_c_omp_offload (const qmckl_context context,
for (int j=0; j<cord_num+1; j++) { for (int j=0; j<cord_num+1; j++) {
for (int k=0; k<nucl_num; k++) { for (int k=0; k<nucl_num; k++) {
for (int l=0; l<elec_num; l++) { for (int l=0; l<elec_num; l++) {
// Single reduction // Single reduction
tmp_c[l + k*stride_k_c + j*stride_j_c + i*stride_i_c + nw*stride_nw_c] = 0.; tmp_c[l + k*stride_k_c + j*stride_j_c + i*stride_i_c + nw*stride_nw_c] = 0.;
for (int m=0; m<elec_num; m++) { for (int m=0; m<elec_num; m++) {
@ -5984,7 +6032,7 @@ qmckl_compute_tmp_c_omp_offload (const qmckl_context context,
een_rescaled_e[l + m*stride_m_e + i*stride_i_e + nw*stride_nw_e] * een_rescaled_e[l + m*stride_m_e + i*stride_i_e + nw*stride_nw_e] *
een_rescaled_n[m + k*stride_k_n + j*stride_j_n + nw*stride_nw_n]; een_rescaled_n[m + k*stride_k_n + j*stride_j_n + nw*stride_nw_n];
} }
} }
} }
} }
@ -6012,9 +6060,8 @@ qmckl_compute_tmp_c_omp_offload (const qmckl_context context,
**** cuBLAS offload :noexport: **** cuBLAS offload :noexport:
#+begin_src c :comments org :tangle (eval c) :noweb yes #+begin_src c :comments org :tangle (eval c) :noweb yes
#ifdef HAVE_CUBLAS_OFFLOAD #ifdef HAVE_CUBLAS_OFFLOAD
qmckl_exit_code
qmckl_compute_tmp_c_cublas_offload (const qmckl_context context, qmckl_compute_tmp_c_cublas_offload (const qmckl_context context,
const int64_t cord_num, const int64_t cord_num,
const int64_t elec_num, const int64_t elec_num,
@ -6025,6 +6072,19 @@ qmckl_compute_tmp_c_cublas_offload (const qmckl_context context,
double* const tmp_c ) double* const tmp_c )
{ {
qmckl_exit_code info;
//Initialisation of cublas
cublasHandle_t handle;
if (cublasCreate(&handle) != CUBLAS_STATUS_SUCCESS)
{
fprintf(stdout, "CUBLAS initialization failed!\n");
exit(EXIT_FAILURE);
}
if (context == QMCKL_NULL_CONTEXT) { if (context == QMCKL_NULL_CONTEXT) {
return QMCKL_INVALID_CONTEXT; return QMCKL_INVALID_CONTEXT;
} }
@ -6041,14 +6101,6 @@ qmckl_compute_tmp_c_cublas_offload (const qmckl_context context,
return QMCKL_INVALID_ARG_4; return QMCKL_INVALID_ARG_4;
} }
if (walk_num <= 0) {
return QMCKL_INVALID_ARG_5;
}
qmckl_exit_code info = QMCKL_SUCCESS;
const char TransA = 'N';
const char TransB = 'N';
const double alpha = 1.0; const double alpha = 1.0;
const double beta = 0.0; const double beta = 0.0;
@ -6064,19 +6116,49 @@ qmckl_compute_tmp_c_cublas_offload (const qmckl_context context,
const int64_t bf = elec_num*nucl_num*(cord_num+1); const int64_t bf = elec_num*nucl_num*(cord_num+1);
const int64_t cf = bf; const int64_t cf = bf;
for (int64_t nw=0; nw < walk_num; ++nw) { #pragma omp target enter data map(to:een_rescaled_e[0:elec_num*elec_num*(cord_num+1)*walk_num],een_rescaled_n[0:M*N*walk_num],tmp_c[0:elec_num*nucl_num*(cord_num+1)*cord_num*walk_num])
for (int64_t i=0; i<cord_num; ++i){ #pragma omp target data use_device_ptr(een_rescaled_e,een_rescaled_n,tmp_c)
info = qmckl_dgemm(context, TransA, TransB, M, N, K, alpha, {
&(een_rescaled_e[af*(i+nw*(cord_num+1))]), LDA,
&(een_rescaled_n[bf*nw]), LDB, beta,
&(tmp_c[cf*(i+nw*cord_num)]), LDC); for (int nw=0; nw < walk_num; ++nw) {
for (int i=0; i<cord_num; ++i){
//CuBlas implementation
int cublasError = cublasDgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, M, N, K, &alpha,
&(een_rescaled_e[af*(i+nw*(cord_num+1))]), \
LDA, \
&(een_rescaled_n[bf*nw]), \
LDB, \
&beta, \
&(tmp_c[cf*(i+nw*cord_num)]), \
LDC);
//Manage cublas ERROR
if(cublasError != CUBLAS_STATUS_SUCCESS){
printf("CUBLAS ERROR %d", cublasError);
info = QMCKL_FAILURE;
return info;
}else{
info = QMCKL_SUCCESS;
}
} }
} }
}
cudaDeviceSynchronize();
cublasDestroy(handle);
#pragma omp target exit data map(from:tmp_c[0:elec_num*nucl_num*(cord_num+1)*cord_num*walk_num])
return info; return info;
} }
#endif #endif
#+end_src #+end_src
#+begin_src c :comments org :tangle (eval h_private_func) :noweb yes :exports none #+begin_src c :comments org :tangle (eval h_private_func) :noweb yes :exports none
#ifdef HAVE_CUBLAS_OFFLOAD #ifdef HAVE_CUBLAS_OFFLOAD
@ -6209,7 +6291,7 @@ integer function qmckl_compute_dtmp_c_doc_f( &
dtmp_c(1,1,1,0,i,nw),LDC) dtmp_c(1,1,1,0,i,nw),LDC)
end do end do
end do end do
end function qmckl_compute_dtmp_c_doc_f end function qmckl_compute_dtmp_c_doc_f
#+end_src #+end_src
@ -6253,7 +6335,7 @@ qmckl_exit_code qmckl_compute_dtmp_c_doc (
#+end_src #+end_src
**** CPU :noexport: **** CPU :noexport:
#+begin_src c :comments org :tangle (eval c) :noweb yes #+begin_src c :comments org :tangle (eval c) :noweb yes
qmckl_exit_code qmckl_exit_code
qmckl_compute_dtmp_c_hpc (const qmckl_context context, qmckl_compute_dtmp_c_hpc (const qmckl_context context,
@ -6268,7 +6350,7 @@ qmckl_compute_dtmp_c_hpc (const qmckl_context context,
if (context == QMCKL_NULL_CONTEXT) { if (context == QMCKL_NULL_CONTEXT) {
return QMCKL_INVALID_CONTEXT; return QMCKL_INVALID_CONTEXT;
} }
if (cord_num <= 0) { if (cord_num <= 0) {
return QMCKL_INVALID_ARG_2; return QMCKL_INVALID_ARG_2;
@ -6280,11 +6362,11 @@ qmckl_compute_dtmp_c_hpc (const qmckl_context context,
if (nucl_num <= 0) { if (nucl_num <= 0) {
return QMCKL_INVALID_ARG_4; return QMCKL_INVALID_ARG_4;
} }
if (walk_num <= 0) { if (walk_num <= 0) {
return QMCKL_INVALID_ARG_5; return QMCKL_INVALID_ARG_5;
} }
qmckl_exit_code info = QMCKL_SUCCESS; qmckl_exit_code info = QMCKL_SUCCESS;
@ -6332,7 +6414,7 @@ qmckl_exit_code qmckl_compute_dtmp_c_hpc (
const double* een_rescaled_n, const double* een_rescaled_n,
double* const dtmp_c ); double* const dtmp_c );
#+end_src #+end_src
**** OpenACC offload :noexport: **** OpenACC offload :noexport:
#+begin_src c :comments org :tangle (eval c) :noweb yes #+begin_src c :comments org :tangle (eval c) :noweb yes
@ -6382,7 +6464,7 @@ qmckl_exit_code qmckl_compute_dtmp_c_acc_offload (
const int64_t size_dtmp_c = walk_num*cord_num*(cord_num+1)*nucl_num*4*elec_num; const int64_t size_dtmp_c = walk_num*cord_num*(cord_num+1)*nucl_num*4*elec_num;
const int64_t size_n = walk_num*(cord_num+1)*nucl_num*elec_num; const int64_t size_n = walk_num*(cord_num+1)*nucl_num*elec_num;
const int64_t size_e = walk_num*(cord_num+1)*elec_num*4*elec_num; const int64_t size_e = walk_num*(cord_num+1)*elec_num*4*elec_num;
#pragma acc parallel copyout(dtmp_c [0:size_dtmp_c]) copyin(een_rescaled_e_deriv_e[0:size_e], een_rescaled_n[0:size_n]) #pragma acc parallel copyout(dtmp_c [0:size_dtmp_c]) copyin(een_rescaled_e_deriv_e[0:size_e], een_rescaled_n[0:size_n])
@ -6396,7 +6478,7 @@ qmckl_exit_code qmckl_compute_dtmp_c_acc_offload (
for(int k=0; k<nucl_num; k++) { for(int k=0; k<nucl_num; k++) {
for(int l=0; l<4; l++) { for(int l=0; l<4; l++) {
for(int m=0; m<elec_num; m++) { for(int m=0; m<elec_num; m++) {
// Single reduction // Single reduction
dtmp_c[m + l * stride_l_d + k * stride_k_d + j * stride_j_d + i * stride_i_d + nw * stride_nw_d] = 0.; dtmp_c[m + l * stride_l_d + k * stride_k_d + j * stride_j_d + i * stride_i_d + nw * stride_nw_d] = 0.;
for(int n=0; n<elec_num; n++){ for(int n=0; n<elec_num; n++){
@ -6481,7 +6563,7 @@ qmckl_exit_code qmckl_compute_dtmp_c_omp_offload (
const int64_t size_dtmp_c = walk_num*cord_num*(cord_num+1)*nucl_num*4*elec_num; const int64_t size_dtmp_c = walk_num*cord_num*(cord_num+1)*nucl_num*4*elec_num;
const int64_t size_n = walk_num*(cord_num+1)*nucl_num*elec_num; const int64_t size_n = walk_num*(cord_num+1)*nucl_num*elec_num;
const int64_t size_e = walk_num*(cord_num+1)*elec_num*4*elec_num; const int64_t size_e = walk_num*(cord_num+1)*elec_num*4*elec_num;
@ -6497,7 +6579,7 @@ qmckl_exit_code qmckl_compute_dtmp_c_omp_offload (
for(int k=0; k<nucl_num; k++) { for(int k=0; k<nucl_num; k++) {
for(int l=0; l<4; l++) { for(int l=0; l<4; l++) {
for(int m=0; m<elec_num; m++) { for(int m=0; m<elec_num; m++) {
// Single reduction // Single reduction
dtmp_c[m + l * stride_l_d + k * stride_k_d + j * stride_j_d + i * stride_i_d + nw * stride_nw_d] = 0.; dtmp_c[m + l * stride_l_d + k * stride_k_d + j * stride_j_d + i * stride_i_d + nw * stride_nw_d] = 0.;
for(int n=0; n<elec_num; n++){ for(int n=0; n<elec_num; n++){
@ -6701,7 +6783,7 @@ rc = qmckl_get_jastrow_dtmp_c(context, &(dtmp_c[0][0][0][0][0][0]));
printf("%e\n%e\n", tmp_c[0][0][1][0][0], 2.7083473948352403); printf("%e\n%e\n", tmp_c[0][0][1][0][0], 2.7083473948352403);
assert(fabs(tmp_c[0][0][1][0][0] - 2.7083473948352403) < 1e-12); assert(fabs(tmp_c[0][0][1][0][0] - 2.7083473948352403) < 1e-12);
printf("%e\n%e\n", tmp_c[0][1][0][0][0],0.237440520852232); printf("%e\n%e\n", tmp_c[0][1][0][0][0],0.237440520852232);
assert(fabs(dtmp_c[0][1][0][0][0][0] - 0.237440520852232) < 1e-12); assert(fabs(dtmp_c[0][1][0][0][0][0] - 0.237440520852232) < 1e-12);
return QMCKL_SUCCESS; return QMCKL_SUCCESS;
#+end_src #+end_src