diff --git a/org/qmckl_jastrow.org b/org/qmckl_jastrow.org index c4f2e28..269d3fd 100644 --- a/org/qmckl_jastrow.org +++ b/org/qmckl_jastrow.org @@ -108,6 +108,12 @@ int main() { #include #include + +#include +#include "cublas_v2.h" + + + #include #include "qmckl.h" @@ -5019,6 +5025,7 @@ qmckl_exit_code qmckl_provide_tmp_c(qmckl_context context) ctx->jastrow.tmp_c = tmp_c; } + /* Choose the correct compute function (depending on offload type) */ #ifdef HAVE_HPC const bool gpu_offload = ctx->jastrow.gpu_offload; @@ -5068,6 +5075,7 @@ qmckl_exit_code qmckl_provide_tmp_c(qmckl_context context) ctx->jastrow.tmp_c); } + ctx->jastrow.tmp_c_date = ctx->date; } @@ -5107,6 +5115,7 @@ qmckl_exit_code qmckl_provide_dtmp_c(qmckl_context context) ctx->jastrow.dtmp_c = dtmp_c; } + #ifdef HAVE_HPC const bool gpu_offload = ctx->jastrow.gpu_offload; #else @@ -5159,6 +5168,7 @@ qmckl_exit_code qmckl_provide_dtmp_c(qmckl_context context) return rc; } + ctx->jastrow.dtmp_c_date = ctx->date; } @@ -5807,6 +5817,152 @@ qmckl_exit_code qmckl_compute_tmp_c_hpc ( } #+end_src + +#+begin_src c :comments org :tangle (eval c) :noweb yes +qmckl_exit_code qmckl_compute_tmp_c_cuBlas ( + const qmckl_context context, + const int64_t cord_num, + const int64_t elec_num, + const int64_t nucl_num, + const int64_t walk_num, + const double* een_rescaled_e, + const double* een_rescaled_n, + double* const tmp_c ) { + + qmckl_exit_code info; + + //Initialisation of cublas + + cublasHandle_t handle; + if (cublasCreate(&handle) != CUBLAS_STATUS_SUCCESS) + { + fprintf(stdout, "CUBLAS initialization failed!\n"); + exit(EXIT_FAILURE); + } + + + + if (context == QMCKL_NULL_CONTEXT) { + return QMCKL_INVALID_CONTEXT; + } + + if (cord_num <= 0) { + return QMCKL_INVALID_ARG_2; + } + + if (elec_num <= 0) { + return QMCKL_INVALID_ARG_3; + } + + if (nucl_num <= 0) { + return QMCKL_INVALID_ARG_4; + } + + const double alpha = 1.0; + const double beta = 0.0; + + const int64_t M = elec_num; + const int64_t N = nucl_num*(cord_num + 1); + const int64_t K = elec_num; + + const int64_t LDA = elec_num; + const int64_t LDB = elec_num; + const int64_t LDC = elec_num; + + const int64_t af = elec_num*elec_num; + const int64_t bf = elec_num*nucl_num*(cord_num+1); + const int64_t cf = bf; + + #pragma omp target enter data map(to:een_rescaled_e[0:elec_num*elec_num*(cord_num+1)*walk_num],een_rescaled_n[0:M*N*walk_num],tmp_c[0:elec_num*nucl_num*(cord_num+1)*cord_num*walk_num]) + #pragma omp target data use_device_ptr(een_rescaled_e,een_rescaled_n,tmp_c) + { + + + for (int nw=0; nw < walk_num; ++nw) { + for (int i=0; i