1
0
mirror of https://github.com/TREX-CoE/qmckl.git synced 2025-01-08 20:33:40 +01:00

Add cublas batch Dgemm

This commit is contained in:
hoffer 2022-04-08 10:44:48 +02:00
parent 69b9e0fb89
commit d4f0ccee3b

View File

@ -109,11 +109,6 @@ int main() {
#include <math.h> #include <math.h>
#include <cuda_runtime_api.h>
#include <cublas_v2.h>
#include <stdio.h> #include <stdio.h>
#include "qmckl.h" #include "qmckl.h"
@ -122,6 +117,13 @@ int main() {
#include "qmckl_memory_private_func.h" #include "qmckl_memory_private_func.h"
#include "qmckl_jastrow_private_func.h" #include "qmckl_jastrow_private_func.h"
#include "qmckl_jastrow_private_type.h" #include "qmckl_jastrow_private_type.h"
#ifdef HAVE_CUBLAS_OFFLOAD
#include <cuda_runtime_api.h>
#include "cublas_v2.h"
#endif
#+end_src #+end_src
* Context * Context
@ -2030,7 +2032,6 @@ qmckl_exit_code qmckl_provide_factor_ee_deriv_e(qmckl_context context)
ctx->jastrow.bord_vector, ctx->jastrow.bord_vector,
ctx->electron.ee_distance_rescaled, ctx->electron.ee_distance_rescaled,
ctx->electron.ee_distance_rescaled_deriv_e, ctx->electron.ee_distance_rescaled_deriv_e,
ctx->jastrow.asymp_jasb,
ctx->jastrow.factor_ee_deriv_e); ctx->jastrow.factor_ee_deriv_e);
if (rc != QMCKL_SUCCESS) { if (rc != QMCKL_SUCCESS) {
return rc; return rc;
@ -2061,14 +2062,13 @@ qmckl_exit_code qmckl_provide_factor_ee_deriv_e(qmckl_context context)
| ~bord_vector~ | ~double[bord_num+1]~ | in | List of coefficients | | ~bord_vector~ | ~double[bord_num+1]~ | in | List of coefficients |
| ~ee_distance_rescaled~ | ~double[walk_num][elec_num][elec_num]~ | in | Electron-electron distances | | ~ee_distance_rescaled~ | ~double[walk_num][elec_num][elec_num]~ | in | Electron-electron distances |
| ~ee_distance_rescaled_deriv_e~ | ~double[walk_num][4][elec_num][elec_num]~ | in | Electron-electron distances | | ~ee_distance_rescaled_deriv_e~ | ~double[walk_num][4][elec_num][elec_num]~ | in | Electron-electron distances |
| ~asymp_jasb~ | ~double[2]~ | in | Electron-electron distances |
| ~factor_ee_deriv_e~ | ~double[walk_num][4][elec_num]~ | out | Electron-electron distances | | ~factor_ee_deriv_e~ | ~double[walk_num][4][elec_num]~ | out | Electron-electron distances |
#+begin_src f90 :comments org :tangle (eval f) :noweb yes #+begin_src f90 :comments org :tangle (eval f) :noweb yes
integer function qmckl_compute_factor_ee_deriv_e_f( & integer function qmckl_compute_factor_ee_deriv_e_doc_f( &
context, walk_num, elec_num, up_num, bord_num, & context, walk_num, elec_num, up_num, bord_num, &
bord_vector, ee_distance_rescaled, ee_distance_rescaled_deriv_e, & bord_vector, ee_distance_rescaled, ee_distance_rescaled_deriv_e, &
asymp_jasb, factor_ee_deriv_e) & factor_ee_deriv_e) &
result(info) result(info)
use qmckl use qmckl
implicit none implicit none
@ -2077,10 +2077,9 @@ integer function qmckl_compute_factor_ee_deriv_e_f( &
double precision , intent(in) :: bord_vector(bord_num + 1) double precision , intent(in) :: bord_vector(bord_num + 1)
double precision , intent(in) :: ee_distance_rescaled(elec_num, elec_num,walk_num) double precision , intent(in) :: ee_distance_rescaled(elec_num, elec_num,walk_num)
double precision , intent(in) :: ee_distance_rescaled_deriv_e(4,elec_num, elec_num,walk_num) !TODO double precision , intent(in) :: ee_distance_rescaled_deriv_e(4,elec_num, elec_num,walk_num) !TODO
double precision , intent(in) :: asymp_jasb(2)
double precision , intent(out) :: factor_ee_deriv_e(elec_num,4,walk_num) double precision , intent(out) :: factor_ee_deriv_e(elec_num,4,walk_num)
integer*8 :: i, j, p, ipar, nw, ii integer*8 :: i, j, p, nw, ii
double precision :: x, spin_fact, y double precision :: x, spin_fact, y
double precision :: den, invden, invden2, invden3, xinv double precision :: den, invden, invden2, invden3, xinv
double precision :: lap1, lap2, lap3, third double precision :: lap1, lap2, lap3, third
@ -2124,7 +2123,6 @@ integer function qmckl_compute_factor_ee_deriv_e_f( &
invden2 = invden * invden invden2 = invden * invden
invden3 = invden2 * invden invden3 = invden2 * invden
xinv = 1.0d0 / (x + 1.0d-18) xinv = 1.0d0 / (x + 1.0d-18)
ipar = 1
dx(1) = ee_distance_rescaled_deriv_e(1, i, j, nw) dx(1) = ee_distance_rescaled_deriv_e(1, i, j, nw)
dx(2) = ee_distance_rescaled_deriv_e(2, i, j, nw) dx(2) = ee_distance_rescaled_deriv_e(2, i, j, nw)
@ -2166,7 +2164,120 @@ integer function qmckl_compute_factor_ee_deriv_e_f( &
end do end do
end do end do
end function qmckl_compute_factor_ee_deriv_e_f end function qmckl_compute_factor_ee_deriv_e_doc_f
#+end_src
#+begin_src c :comments org :tangle (eval c) :noweb yes
qmckl_exit_code qmckl_compute_factor_ee_deriv_e_hpc(
const qmckl_context context,
const int64_t walk_num,
const int64_t elec_num,
const int64_t up_num,
const int64_t bord_num,
const double* bord_vector,
const double* ee_distance_rescaled,
const double* ee_distance_rescaled_deriv_e,
double* const factor_ee_deriv_e ) {
int64_t ii;
double pow_ser_g[3];
double dx[4];
double x, spin_fact, y;
double den, invden, invden2, invden3, xinv;
double lap1, lap2, lap3, third;
if (context == QMCKL_NULL_CONTEXT) {
return QMCKL_INVALID_CONTEXT;
}
if (walk_num <= 0) {
return QMCKL_INVALID_ARG_2;
}
if (elec_num <= 0) {
return QMCKL_INVALID_ARG_3;
}
if (bord_num <= 0) {
return QMCKL_INVALID_ARG_4;
}
for (int nw = 0; nw < walk_num; ++nw) {
for (int ii = 0; ii < 4; ++ii) {
for (int j = 0; j < elec_num; ++j) {
factor_ee_deriv_e[j + ii * elec_num + nw * elec_num * 4] = 0.0;
}
}
}
third = 1.0 / 3.0;
for (int nw = 0; nw < walk_num; ++nw) {
for (int i = 0; i < elec_num; ++i) {
for (int j = 0; j < elec_num; ++j) {
x = ee_distance_rescaled[j + i * elec_num + nw * elec_num * elec_num];
if (fabs(x) < 1.0e-18) continue;
for (int ii = 0; ii < 3; ++ii){
pow_ser_g[ii] = 0.0;
}
spin_fact = 1.0;
den = 1.0 + bord_vector[1] * x;
invden = 1.0 / den;
invden2 = invden * invden;
invden3 = invden2 * invden;
xinv = 1.0 / (x + 1.0e-18);
dx[0] = ee_distance_rescaled_deriv_e[0 \
+ j * 4 + i * 4 * elec_num \
+ nw * 4 * elec_num * elec_num];
dx[1] = ee_distance_rescaled_deriv_e[1 \
+ j * 4 + i * 4 * elec_num \
+ nw * 4 * elec_num * elec_num];
dx[2] = ee_distance_rescaled_deriv_e[2 \
+ j * 4 + i * 4 * elec_num \
+ nw * 4 * elec_num * elec_num];
dx[3] = ee_distance_rescaled_deriv_e[3 \
+ j * 4 + i * 4 * elec_num \
+ nw * 4 * elec_num * elec_num];
if((i <= (up_num-1) && j <= (up_num-1) ) || (i > (up_num-1) && j > (up_num-1))) {
spin_fact = 0.5;
}
lap1 = 0.0;
lap2 = 0.0;
lap3 = 0.0;
for (int ii = 0; ii < 3; ++ii) {
x = ee_distance_rescaled[j + i * elec_num + nw * elec_num * elec_num];
if (fabs(x) < 1.0e-18) continue;
for (int p = 2; p < bord_num+1; ++p) {
y = p * bord_vector[(p-1) + 1] * x;
pow_ser_g[ii] = pow_ser_g[ii] + y * dx[ii];
lap1 = lap1 + (p - 1) * y * xinv * dx[ii] * dx[ii];
lap2 = lap2 + y;
x = x * ee_distance_rescaled[j + i * elec_num + nw * elec_num * elec_num];
}
lap3 = lap3 - 2.0 * bord_vector[1] * dx[ii] * dx[ii];
factor_ee_deriv_e[i + ii * elec_num + nw * elec_num * 4 ] += \
+ spin_fact * bord_vector[0] * dx[ii] * invden2 \
+ pow_ser_g[ii] ;
}
ii = 3;
lap2 = lap2 * dx[ii] * third;
lap3 = lap3 + den * dx[ii];
lap3 = lap3 * (spin_fact * bord_vector[0] * invden3);
factor_ee_deriv_e[i + ii*elec_num + nw * elec_num * 4] += lap1 + lap2 + lap3;
}
}
}
return QMCKL_SUCCESS;
}
#+end_src #+end_src
# #+CALL: generate_c_header(table=qmckl_factor_ee_deriv_e_args,rettyp=get_value("CRetType"),fname=get_value("Name")) # #+CALL: generate_c_header(table=qmckl_factor_ee_deriv_e_args,rettyp=get_value("CRetType"),fname=get_value("Name"))
@ -2182,17 +2293,16 @@ end function qmckl_compute_factor_ee_deriv_e_f
const double* bord_vector, const double* bord_vector,
const double* ee_distance_rescaled, const double* ee_distance_rescaled,
const double* ee_distance_rescaled_deriv_e, const double* ee_distance_rescaled_deriv_e,
const double* asymp_jasb,
double* const factor_ee_deriv_e ); double* const factor_ee_deriv_e );
#+end_src #+end_src
#+CALL: generate_c_interface(table=qmckl_factor_ee_deriv_e_args,rettyp=get_value("CRetType"),fname=get_value("Name")) #+CALL: generate_c_interface(table=qmckl_factor_ee_deriv_e_args,rettyp=get_value("CRetType"),fname="qmckl_compute_factor_ee_deriv_e_doc")
#+RESULTS: #+RESULTS:
#+begin_src f90 :tangle (eval f) :comments org :exports none #+begin_src f90 :tangle (eval f) :comments org :exports none
integer(c_int32_t) function qmckl_compute_factor_ee_deriv_e & integer(c_int32_t) function qmckl_compute_factor_ee_deriv_e_doc &
(context, & (context, &
walk_num, & walk_num, &
elec_num, & elec_num, &
up_num, & up_num, &
@ -2200,7 +2310,6 @@ end function qmckl_compute_factor_ee_deriv_e_f
bord_vector, & bord_vector, &
ee_distance_rescaled, & ee_distance_rescaled, &
ee_distance_rescaled_deriv_e, & ee_distance_rescaled_deriv_e, &
asymp_jasb, &
factor_ee_deriv_e) & factor_ee_deriv_e) &
bind(C) result(info) bind(C) result(info)
@ -2215,12 +2324,11 @@ end function qmckl_compute_factor_ee_deriv_e_f
real (c_double ) , intent(in) :: bord_vector(bord_num+1) real (c_double ) , intent(in) :: bord_vector(bord_num+1)
real (c_double ) , intent(in) :: ee_distance_rescaled(elec_num,elec_num,walk_num) real (c_double ) , intent(in) :: ee_distance_rescaled(elec_num,elec_num,walk_num)
real (c_double ) , intent(in) :: ee_distance_rescaled_deriv_e(elec_num,elec_num,4,walk_num) real (c_double ) , intent(in) :: ee_distance_rescaled_deriv_e(elec_num,elec_num,4,walk_num)
real (c_double ) , intent(in) :: asymp_jasb(2)
real (c_double ) , intent(out) :: factor_ee_deriv_e(elec_num,4,walk_num) real (c_double ) , intent(out) :: factor_ee_deriv_e(elec_num,4,walk_num)
integer(c_int32_t), external :: qmckl_compute_factor_ee_deriv_e_f integer(c_int32_t), external :: qmckl_compute_factor_ee_deriv_e_doc_f
info = qmckl_compute_factor_ee_deriv_e_f & info = qmckl_compute_factor_ee_deriv_e_doc_f &
(context, & (context, &
walk_num, & walk_num, &
elec_num, & elec_num, &
up_num, & up_num, &
@ -2228,12 +2336,61 @@ end function qmckl_compute_factor_ee_deriv_e_f
bord_vector, & bord_vector, &
ee_distance_rescaled, & ee_distance_rescaled, &
ee_distance_rescaled_deriv_e, & ee_distance_rescaled_deriv_e, &
asymp_jasb, &
factor_ee_deriv_e) factor_ee_deriv_e)
end function qmckl_compute_factor_ee_deriv_e end function qmckl_compute_factor_ee_deriv_e_doc
#+end_src #+end_src
#+begin_src c :tangle (eval h_private_func) :comments org
qmckl_exit_code qmckl_compute_factor_ee_deriv_e_hpc (
const qmckl_context context,
const int64_t walk_num,
const int64_t elec_num,
const int64_t up_num,
const int64_t bord_num,
const double* bord_vector,
const double* ee_distance_rescaled,
const double* ee_distance_rescaled_deriv_e,
double* const factor_ee_deriv_e );
#+end_src
#+begin_src c :tangle (eval h_private_func) :comments org
qmckl_exit_code qmckl_compute_factor_ee_deriv_e_doc (
const qmckl_context context,
const int64_t walk_num,
const int64_t elec_num,
const int64_t up_num,
const int64_t bord_num,
const double* bord_vector,
const double* ee_distance_rescaled,
const double* ee_distance_rescaled_deriv_e,
double* const factor_ee_deriv_e );
#+end_src
#+begin_src c :comments org :tangle (eval c) :noweb yes
qmckl_exit_code qmckl_compute_factor_ee_deriv_e (
const qmckl_context context,
const int64_t walk_num,
const int64_t elec_num,
const int64_t up_num,
const int64_t bord_num,
const double* bord_vector,
const double* ee_distance_rescaled,
const double* ee_distance_rescaled_deriv_e,
double* const factor_ee_deriv_e ) {
#ifdef HAVE_HPC
return qmckl_compute_factor_ee_deriv_e_hpc(context, walk_num, elec_num, up_num, bord_num, bord_vector, ee_distance_rescaled, ee_distance_rescaled_deriv_e, factor_ee_deriv_e );
#else
return qmckl_compute_factor_ee_deriv_e_doc(context, walk_num, elec_num, up_num, bord_num, bord_vector, ee_distance_rescaled, ee_distance_rescaled_deriv_e, factor_ee_deriv_e );
#endif
}
#+end_src
*** Test *** Test
#+begin_src python :results output :exports none :noweb yes #+begin_src python :results output :exports none :noweb yes
import numpy as np import numpy as np
@ -2351,7 +2508,6 @@ assert(fabs(factor_ee_deriv_e[0][0][0]-0.16364894652107934) < 1.e-12);
assert(fabs(factor_ee_deriv_e[0][1][0]+0.6927548119830084 ) < 1.e-12); assert(fabs(factor_ee_deriv_e[0][1][0]+0.6927548119830084 ) < 1.e-12);
assert(fabs(factor_ee_deriv_e[0][2][0]-0.073267755223968 ) < 1.e-12); assert(fabs(factor_ee_deriv_e[0][2][0]-0.073267755223968 ) < 1.e-12);
assert(fabs(factor_ee_deriv_e[0][3][0]-1.5111672803213185 ) < 1.e-12); assert(fabs(factor_ee_deriv_e[0][3][0]-1.5111672803213185 ) < 1.e-12);
#+end_src #+end_src
** Electron-nucleus component \(f_{en}\) ** Electron-nucleus component \(f_{en}\)
@ -5035,7 +5191,7 @@ qmckl_exit_code qmckl_provide_tmp_c(qmckl_context context)
if (gpu_offload) { if (gpu_offload) {
#ifdef HAVE_CUBLAS_OFFLOAD #ifdef HAVE_CUBLAS_OFFLOAD
rc = qmckl_compute_tmp_c_cuBlas(context, rc = qmckl_compute_tmp_c_cublas_offload(context,
ctx->jastrow.cord_num, ctx->jastrow.cord_num,
ctx->electron.num, ctx->electron.num,
ctx->nucleus.num, ctx->nucleus.num,
@ -5124,7 +5280,7 @@ qmckl_exit_code qmckl_provide_dtmp_c(qmckl_context context)
if (gpu_offload) { if (gpu_offload) {
#ifdef HAVE_CUBLAS_OFFLOAD #ifdef HAVE_CUBLAS_OFFLOAD
rc = qmckl_compute_dtmp_c_cuBlas(context, rc = qmckl_compute_dtmp_c_cublas_offload(context,
ctx->jastrow.cord_num, ctx->jastrow.cord_num,
ctx->electron.num, ctx->electron.num,
ctx->nucleus.num, ctx->nucleus.num,
@ -5818,191 +5974,6 @@ qmckl_exit_code qmckl_compute_tmp_c_hpc (
#+end_src #+end_src
#+begin_src c :comments org :tangle (eval c) :noweb yes
qmckl_exit_code qmckl_compute_tmp_c_cuBlas (
const qmckl_context context,
const int64_t cord_num,
const int64_t elec_num,
const int64_t nucl_num,
const int64_t walk_num,
const double* een_rescaled_e,
const double* een_rescaled_n,
double* const tmp_c ) {
qmckl_exit_code info;
//Initialisation of cublas
cublasHandle_t handle;
if (cublasCreate(&handle) != CUBLAS_STATUS_SUCCESS)
{
fprintf(stdout, "CUBLAS initialization failed!\n");
exit(EXIT_FAILURE);
}
if (context == QMCKL_NULL_CONTEXT) {
return QMCKL_INVALID_CONTEXT;
}
if (cord_num <= 0) {
return QMCKL_INVALID_ARG_2;
}
if (elec_num <= 0) {
return QMCKL_INVALID_ARG_3;
}
if (nucl_num <= 0) {
return QMCKL_INVALID_ARG_4;
}
const double alpha = 1.0;
const double beta = 0.0;
const int64_t M = elec_num;
const int64_t N = nucl_num*(cord_num + 1);
const int64_t K = elec_num;
const int64_t LDA = elec_num;
const int64_t LDB = elec_num;
const int64_t LDC = elec_num;
const int64_t af = elec_num*elec_num;
const int64_t bf = elec_num*nucl_num*(cord_num+1);
const int64_t cf = bf;
#pragma omp target enter data map(to:een_rescaled_e[0:elec_num*elec_num*(cord_num+1)*walk_num],een_rescaled_n[0:M*N*walk_num],tmp_c[0:elec_num*nucl_num*(cord_num+1)*cord_num*walk_num])
#pragma omp target data use_device_ptr(een_rescaled_e,een_rescaled_n,tmp_c)
{
for (int nw=0; nw < walk_num; ++nw) {
int cublasError = cublasDgemmStridedBatched(handle, CUBLAS_OP_N, CUBLAS_OP_N, M, N, K, &alpha,
&(een_rescaled_e[nw*(cord_num+1)]), \
LDA, af, \
&(een_rescaled_n[bf*nw]), \
LDB, 0, \
&beta, \
&(tmp_c[nw*cord_num]), \
LDC, cf, cord_num);
//Manage cublas ERROR
if(cublasError != CUBLAS_STATUS_SUCCESS){
printf("CUBLAS ERROR %d", cublasError);
info = QMCKL_FAILURE;
}else{
info = QMCKL_SUCCESS;
}
}
}
cublasDestroy(handle);
#pragma omp target exit data map(from:tmp_c[0:elec_num*nucl_num*(cord_num+1)*cord_num*walk_num])
return info;
}
#+end_src
#+begin_src c :comments org :tangle (eval c) :noweb yes
qmckl_exit_code qmckl_compute_tmp_c_cuBlas_batched (
const qmckl_context context,
const int64_t cord_num,
const int64_t elec_num,
const int64_t nucl_num,
const int64_t walk_num,
const double* een_rescaled_e,
const double* een_rescaled_n,
double* const tmp_c ) {
qmckl_exit_code info;
//Initialisation of cublas
cublasHandle_t handle;
if (cublasCreate(&handle) != CUBLAS_STATUS_SUCCESS)
{
fprintf(stdout, "CUBLAS initialization failed!\n");
exit(EXIT_FAILURE);
}
if (context == QMCKL_NULL_CONTEXT) {
return QMCKL_INVALID_CONTEXT;
}
if (cord_num <= 0) {
return QMCKL_INVALID_ARG_2;
}
if (elec_num <= 0) {
return QMCKL_INVALID_ARG_3;
}
if (nucl_num <= 0) {
return QMCKL_INVALID_ARG_4;
}
const double alpha = 1.0;
const double beta = 0.0;
const int64_t M = elec_num;
const int64_t N = nucl_num*(cord_num + 1);
const int64_t K = elec_num;
const int64_t LDA = elec_num;
const int64_t LDB = elec_num;
const int64_t LDC = elec_num;
const int64_t af = elec_num*elec_num;
const int64_t bf = elec_num*nucl_num*(cord_num+1);
const int64_t cf = bf;
#pragma omp target enter data map(to:een_rescaled_e[0:elec_num*elec_num*(cord_num+1)*walk_num],een_rescaled_n[0:M*N*walk_num],tmp_c[0:elec_num*nucl_num*(cord_num+1)*cord_num*walk_num])
#pragma omp target data use_device_ptr(een_rescaled_e,een_rescaled_n,tmp_c)
{
for (int nw=0; nw < walk_num; ++nw) {
for (int i=0; i<cord_num; ++i){
//CuBlas implementation
int cublasError = cublasDgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, M, N, K, &alpha,
&(een_rescaled_e[af*(i+nw*(cord_num+1))]), \
LDA, \
&(een_rescaled_n[bf*nw]), \
LDB, \
&beta, \
&(tmp_c[cf*(i+nw*cord_num)]), \
LDC);
//Manage cublas ERROR
if(cublasError != CUBLAS_STATUS_SUCCESS){
printf("CUBLAS ERROR %d", cublasError);
info = QMCKL_FAILURE;
return info;
}else{
info = QMCKL_SUCCESS;
}
}
}
}
cudaDeviceSynchronize();
cublasDestroy(handle);
#pragma omp target exit data map(from:tmp_c[0:elec_num*nucl_num*(cord_num+1)*cord_num*walk_num])
return info;
}
#+end_src
#+CALL: generate_c_header(table=qmckl_factor_tmp_c_args,rettyp=get_value("CRetType"),fname="qmckl_compute_tmp_c") #+CALL: generate_c_header(table=qmckl_factor_tmp_c_args,rettyp=get_value("CRetType"),fname="qmckl_compute_tmp_c")
@ -6019,18 +5990,6 @@ qmckl_exit_code qmckl_compute_tmp_c (
double* const tmp_c ); double* const tmp_c );
#+end_src #+end_src
#+begin_src c :tangle (eval h_func) :comments org
qmckl_exit_code qmckl_compute_tmp_c_cuBlas (
const qmckl_context context,
const int64_t cord_num,
const int64_t elec_num,
const int64_t nucl_num,
const int64_t walk_num,
const double* een_rescaled_e,
const double* een_rescaled_n,
double* const tmp_c );
#+end_src
# #+CALL: generate_c_header(table=qmckl_factor_tmp_c_args,rettyp=get_value("CRetType"),fname="qmckl_compute_tmp_c_doc") # #+CALL: generate_c_header(table=qmckl_factor_tmp_c_args,rettyp=get_value("CRetType"),fname="qmckl_compute_tmp_c_doc")
#+RESULTS: #+RESULTS:
@ -6065,14 +6024,15 @@ qmckl_exit_code qmckl_compute_tmp_c_hpc (const qmckl_context context,
#+begin_src c :comments org :tangle (eval c) :noweb yes #+begin_src c :comments org :tangle (eval c) :noweb yes
#ifdef HAVE_OPENACC_OFFLOAD #ifdef HAVE_OPENACC_OFFLOAD
qmckl_exit_code qmckl_compute_tmp_c_acc_offload (const qmckl_context context, qmckl_exit_code
const int64_t cord_num, qmckl_compute_tmp_c_acc_offload (const qmckl_context context,
const int64_t elec_num, const int64_t cord_num,
const int64_t nucl_num, const int64_t elec_num,
const int64_t walk_num, const int64_t nucl_num,
const double* een_rescaled_e, const int64_t walk_num,
const double* een_rescaled_n, const double* een_rescaled_e,
double* const tmp_c ) const double* een_rescaled_n,
double* const tmp_c )
{ {
if (context == QMCKL_NULL_CONTEXT) { if (context == QMCKL_NULL_CONTEXT) {
@ -6110,12 +6070,8 @@ qmckl_exit_code qmckl_compute_tmp_c_acc_offload (const qmckl_context context,
const int64_t size_e = walk_num*(cord_num+1)*elec_num*elec_num; const int64_t size_e = walk_num*(cord_num+1)*elec_num*elec_num;
const int64_t size_n = walk_num*(cord_num+1)*nucl_num*elec_num; const int64_t size_n = walk_num*(cord_num+1)*nucl_num*elec_num;
#pragma acc parallel copyout(tmp_c [0:size_tmp_c]) copyin(een_rescaled_e[0:size_e], een_rescaled_n[0:size_n]) #pragma acc parallel copyout(tmp_c [0:size_tmp_c]) copyin(een_rescaled_e[0:size_e], een_rescaled_n[0:size_n])
{ {
#pragma acc loop independent gang worker vector
for (int64_t i=0 ; i<size_tmp_c ; ++i)
tmp_c[i] = 0.;
#pragma acc loop independent gang worker vector collapse(5) #pragma acc loop independent gang worker vector collapse(5)
for (int nw=0; nw < walk_num; ++nw) { for (int nw=0; nw < walk_num; ++nw) {
for (int i=0; i<cord_num; ++i){ for (int i=0; i<cord_num; ++i){
@ -6126,20 +6082,19 @@ qmckl_exit_code qmckl_compute_tmp_c_acc_offload (const qmckl_context context,
for (int l=0; l<elec_num; l++) { for (int l=0; l<elec_num; l++) {
// Single reduction // Single reduction
tmp_c[l + k*stride_k_c + j*stride_j_c + i*stride_i_c + nw*stride_nw_c] = 0; tmp_c[l + k*stride_k_c + j*stride_j_c + i*stride_i_c + nw*stride_nw_c] = 0.;
for (int m=0; m<elec_num; m++) { for (int m=0; m<elec_num; m++) {
tmp_c[l + k*stride_k_c + j*stride_j_c + i*stride_i_c + nw*stride_nw_c] = tmp_c[l + k*stride_k_c + j*stride_j_c + i*stride_i_c + nw*stride_nw_c] =
tmp_c[l + k*stride_k_c + j*stride_j_c + i*stride_i_c + nw*stride_nw_c] + tmp_c[l + k*stride_k_c + j*stride_j_c + i*stride_i_c + nw*stride_nw_c] +
een_rescaled_e[l + m*stride_m_e + i*stride_i_e + nw*stride_nw_e] * een_rescaled_e[l + m*stride_m_e + i*stride_i_e + nw*stride_nw_e] *
een_rescaled_n[m + k*stride_k_n + j*stride_j_n + nw*stride_nw_n]; een_rescaled_n[m + k*stride_k_n + j*stride_j_n + nw*stride_nw_n];
}
} }
} }
} }
} }
} }
} }
}
return QMCKL_SUCCESS; return QMCKL_SUCCESS;
} }
@ -6210,13 +6165,11 @@ qmckl_compute_tmp_c_omp_offload (const qmckl_context context,
const int64_t size_e = walk_num*(cord_num+1)*elec_num*elec_num; const int64_t size_e = walk_num*(cord_num+1)*elec_num*elec_num;
const int64_t size_n = walk_num*(cord_num+1)*nucl_num*elec_num; const int64_t size_n = walk_num*(cord_num+1)*nucl_num*elec_num;
#pragma omp parallel copyout(tmp_c [0:size_tmp_c]) copyin(een_rescaled_e[0:size_e], een_rescaled_n[0:size_n])
{
#pragma omp loop independent gang worker vector
for (int64_t i=0 ; i<size_tmp_c ; ++i)
tmp_c[i] = 0.;
#pragma omp loop independent gang worker vector collapse(5) // WARNING This implementation seems unomptimized
#pragma omp target map(from:tmp_c[0:size_tmp_c]) map(to:een_rescaled_e[0:size_e], een_rescaled_n[0:size_n])
{
#pragma omp teams distribute parallel for collapse(5)
for (int nw=0; nw < walk_num; ++nw) { for (int nw=0; nw < walk_num; ++nw) {
for (int i=0; i<cord_num; ++i){ for (int i=0; i<cord_num; ++i){
@ -6226,14 +6179,13 @@ qmckl_compute_tmp_c_omp_offload (const qmckl_context context,
for (int l=0; l<elec_num; l++) { for (int l=0; l<elec_num; l++) {
// Single reduction // Single reduction
tmp_c[l + k*stride_k_c + j*stride_j_c + i*stride_i_c + nw*stride_nw_c] = 0; tmp_c[l + k*stride_k_c + j*stride_j_c + i*stride_i_c + nw*stride_nw_c] = 0.;
for (int m=0; m<elec_num; m++) { for (int m=0; m<elec_num; m++) {
tmp_c[l + k*stride_k_c + j*stride_j_c + i*stride_i_c + nw*stride_nw_c] = tmp_c[l + k*stride_k_c + j*stride_j_c + i*stride_i_c + nw*stride_nw_c] =
tmp_c[l + k*stride_k_c + j*stride_j_c + i*stride_i_c + nw*stride_nw_c] + tmp_c[l + k*stride_k_c + j*stride_j_c + i*stride_i_c + nw*stride_nw_c] +
een_rescaled_e[l + m*stride_m_e + i*stride_i_e + nw*stride_nw_e] * een_rescaled_e[l + m*stride_m_e + i*stride_i_e + nw*stride_nw_e] *
een_rescaled_n[m + k*stride_k_n + j*stride_j_n + nw*stride_nw_n]; een_rescaled_n[m + k*stride_k_n + j*stride_j_n + nw*stride_nw_n];
} }
} }
} }
} }
@ -6262,7 +6214,7 @@ qmckl_compute_tmp_c_omp_offload (const qmckl_context context,
**** cuBLAS offload :noexport: **** cuBLAS offload :noexport:
#+begin_src c :comments org :tangle (eval c) :noweb yes #+begin_src c :comments org :tangle (eval c) :noweb yes
#ifdef HAVE_CUBLAS_OFFLOAD #ifdef HAVE_CUBLAS_OFFLOAD
qmckl_exit_code qmckl_exit_code
qmckl_compute_tmp_c_cublas_offload (const qmckl_context context, qmckl_compute_tmp_c_cublas_offload (const qmckl_context context,
@ -6274,6 +6226,18 @@ qmckl_compute_tmp_c_cublas_offload (const qmckl_context context,
const double* een_rescaled_n, const double* een_rescaled_n,
double* const tmp_c ) double* const tmp_c )
{ {
qmckl_exit_code info;
//Initialisation of cublas
cublasHandle_t handle;
if (cublasCreate(&handle) != CUBLAS_STATUS_SUCCESS)
{
fprintf(stdout, "CUBLAS initialization failed!\n");
exit(EXIT_FAILURE);
}
if (context == QMCKL_NULL_CONTEXT) { if (context == QMCKL_NULL_CONTEXT) {
return QMCKL_INVALID_CONTEXT; return QMCKL_INVALID_CONTEXT;
@ -6291,14 +6255,6 @@ qmckl_compute_tmp_c_cublas_offload (const qmckl_context context,
return QMCKL_INVALID_ARG_4; return QMCKL_INVALID_ARG_4;
} }
if (walk_num <= 0) {
return QMCKL_INVALID_ARG_5;
}
qmckl_exit_code info = QMCKL_SUCCESS;
const char TransA = 'N';
const char TransB = 'N';
const double alpha = 1.0; const double alpha = 1.0;
const double beta = 0.0; const double beta = 0.0;
@ -6314,23 +6270,45 @@ qmckl_compute_tmp_c_cublas_offload (const qmckl_context context,
const int64_t bf = elec_num*nucl_num*(cord_num+1); const int64_t bf = elec_num*nucl_num*(cord_num+1);
const int64_t cf = bf; const int64_t cf = bf;
for (int64_t nw=0; nw < walk_num; ++nw) { #pragma omp target enter data map(to:een_rescaled_e[0:elec_num*elec_num*(cord_num+1)*walk_num],een_rescaled_n[0:M*N*walk_num],tmp_c[0:elec_num*nucl_num*(cord_num+1)*cord_num*walk_num])
for (int64_t i=0; i<cord_num; ++i){ #pragma omp target data use_device_ptr(een_rescaled_e,een_rescaled_n,tmp_c)
info = qmckl_dgemm(context, TransA, TransB, M, N, K, alpha, {
&(een_rescaled_e[af*(i+nw*(cord_num+1))]), LDA, for (int nw=0; nw < walk_num; ++nw) {
&(een_rescaled_n[bf*nw]), LDB, beta,
&(tmp_c[cf*(i+nw*cord_num)]), LDC); int cublasError = cublasDgemmStridedBatched(handle, CUBLAS_OP_N, CUBLAS_OP_N, M, N, K, &alpha,
} &(een_rescaled_e[nw*(cord_num+1)]), \
LDA, af, \
&(een_rescaled_n[bf*nw]), \
LDB, 0, \
&beta, \
&(tmp_c[nw*cord_num]), \
LDC, cf, cord_num);
//Manage cublas ERROR
if(cublasError != CUBLAS_STATUS_SUCCESS){
printf("CUBLAS ERROR %d", cublasError);
info = QMCKL_FAILURE;
}else{
info = QMCKL_SUCCESS;
}
} }
}
cublasDestroy(handle);
#pragma omp target exit data map(from:tmp_c[0:elec_num*nucl_num*(cord_num+1)*cord_num*walk_num])
return info; return info;
} }
#endif #endif
#+end_src
#+end_src
#+begin_src c :comments org :tangle (eval h_private_func) :noweb yes :exports none #+begin_src c :comments org :tangle (eval h_private_func) :noweb yes :exports none
#ifdef HAVE_CUBLAS_OFFLOAD #ifdef HAVE_CUBLAS_OFFLOAD
qmckl_exit_code qmckl_compute_tmp_c_cublas_offload ( qmckl_exit_code
qmckl_compute_tmp_c_cublas_offload (
const qmckl_context context, const qmckl_context context,
const int64_t cord_num, const int64_t cord_num,
const int64_t elec_num, const int64_t elec_num,
@ -6555,8 +6533,6 @@ qmckl_compute_dtmp_c_hpc (const qmckl_context context,
const int64_t bf = elec_num*nucl_num*(cord_num+1); const int64_t bf = elec_num*nucl_num*(cord_num+1);
const int64_t cf = elec_num*4*nucl_num*(cord_num+1); const int64_t cf = elec_num*4*nucl_num*(cord_num+1);
printf("COUCOU\n");
#ifdef HAVE_OPENMP #ifdef HAVE_OPENMP
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
#endif #endif
@ -6589,7 +6565,8 @@ qmckl_exit_code qmckl_compute_dtmp_c_hpc (
#+begin_src c :comments org :tangle (eval c) :noweb yes #+begin_src c :comments org :tangle (eval c) :noweb yes
#ifdef HAVE_OPENACC_OFFLOAD #ifdef HAVE_OPENACC_OFFLOAD
qmckl_exit_code qmckl_compute_dtmp_c_acc_offload ( qmckl_exit_code
qmckl_compute_dtmp_c_acc_offload (
const qmckl_context context, const qmckl_context context,
const int64_t cord_num, const int64_t cord_num,
const int64_t elec_num, const int64_t elec_num,
@ -6632,41 +6609,36 @@ qmckl_exit_code qmckl_compute_dtmp_c_acc_offload (
const int64_t stride_j_n = stride_k_n * nucl_num; const int64_t stride_j_n = stride_k_n * nucl_num;
const int64_t stride_nw_n = stride_j_n * (cord_num+1); const int64_t stride_nw_n = stride_j_n * (cord_num+1);
const int64_t size_dtmp_c = walk_num*cord_num*(cord_num+1)*nucl_num*4*elec_num; const int64_t size_dtmp_c = walk_num*cord_num*(cord_num+1)*nucl_num*4*elec_num;
const int64_t size_n = walk_num*(cord_num+1)*nucl_num*elec_num; const int64_t size_n = walk_num*(cord_num+1)*nucl_num*elec_num;
const int64_t size_e = walk_num*(cord_num+1)*elec_num*4*elec_num; const int64_t size_e = walk_num*(cord_num+1)*elec_num*4*elec_num;
#pragma acc parallel copyout(dtmp_c [0:size_dtmp_c]) copyin(een_rescaled_e_deriv_e[0:size_e], een_rescaled_n[0:size_n]) #pragma acc parallel copyout(dtmp_c [0:size_dtmp_c]) copyin(een_rescaled_e_deriv_e[0:size_e], een_rescaled_n[0:size_n])
{ {
#pragma acc loop independent gang worker vector #pragma acc loop independent gang worker vector collapse(6)
for (int64_t i=0 ; i<size_dtmp_c ; ++i) for (int nw=0; nw < walk_num; nw++) {
dtmp_c[i] = 0.; for (int i=0; i < cord_num; i++) {
#pragma loop independent gang worker vector collapse(6) // Single DGEMM
for (int nw=0; nw < walk_num; nw++) { for(int j=0; j<cord_num+1; j++) {
for (int i=0; i < cord_num; i++) { for(int k=0; k<nucl_num; k++) {
for(int l=0; l<4; l++) {
for(int m=0; m<elec_num; m++) {
// Single DGEMM // Single reduction
for(int j=0; j<cord_num+1; j++) { dtmp_c[m + l * stride_l_d + k * stride_k_d + j * stride_j_d + i * stride_i_d + nw * stride_nw_d] = 0.;
for(int k=0; k<nucl_num; k++) { for(int n=0; n<elec_num; n++){
for(int l=0; l<4; l++) {
for(int m=0; m<elec_num; m++) {
// Single reduction
dtmp_c[m + l * stride_l_d + k * stride_k_d + j * stride_j_d + i * stride_i_d + nw * stride_nw_d] = 0;
for(int n=0; n<elec_num; n++){
dtmp_c[m + l * stride_l_d + k * stride_k_d + j * stride_j_d + i * stride_i_d + nw * stride_nw_d] = dtmp_c[m + l * stride_l_d + k * stride_k_d + j * stride_j_d + i * stride_i_d + nw * stride_nw_d] =
dtmp_c[m + l * stride_l_d + k * stride_k_d + j * stride_j_d + i * stride_i_d + nw * stride_nw_d] + dtmp_c[m + l * stride_l_d + k * stride_k_d + j * stride_j_d + i * stride_i_d + nw * stride_nw_d] +
een_rescaled_e_deriv_e[m + l * stride_l_e + n * stride_n_e + i * stride_i_e + nw * stride_nw_e] * een_rescaled_e_deriv_e[m + l * stride_l_e + n * stride_n_e + i * stride_i_e + nw * stride_nw_e] *
een_rescaled_n[n + k * stride_k_n + j * stride_j_n + nw * stride_nw_n]; een_rescaled_n[n + k * stride_k_n + j * stride_j_n + nw * stride_nw_n];
}
}
} }
} }
} }
} }
} }
}
}
} }
return QMCKL_SUCCESS; return QMCKL_SUCCESS;
@ -6740,36 +6712,34 @@ qmckl_exit_code qmckl_compute_dtmp_c_omp_offload (
const int64_t size_n = walk_num*(cord_num+1)*nucl_num*elec_num; const int64_t size_n = walk_num*(cord_num+1)*nucl_num*elec_num;
const int64_t size_e = walk_num*(cord_num+1)*elec_num*4*elec_num; const int64_t size_e = walk_num*(cord_num+1)*elec_num*4*elec_num;
#pragma omp parallel copyout(dtmp_c [0:size_dtmp_c]) copyin(een_rescaled_e_deriv_e[0:size_e], een_rescaled_n[0:size_n]) // WARNING This implementation seems unomptimized
#pragma omp target map(from:dtmp_c[0:size_dtmp_c]) map(to:een_rescaled_e_deriv_e[0:size_e], een_rescaled_n[0:size_n])
{ {
#pragma omp target
for (int64_t i=0 ; i<size_dtmp_c ; ++i)
dtmp_c[i] = 0.;
#pragma loop independent gang worker vector collapse(6) #pragma omp teams distribute parallel for collapse(6)
for (int nw=0; nw < walk_num; nw++) { for (int nw=0; nw < walk_num; nw++) {
for (int i=0; i < cord_num; i++) { for (int i=0; i < cord_num; i++) {
// Single DGEMM // Single DGEMM
for(int j=0; j<cord_num+1; j++) { for(int j=0; j<cord_num+1; j++) {
for(int k=0; k<nucl_num; k++) { for(int k=0; k<nucl_num; k++) {
for(int l=0; l<4; l++) { for(int l=0; l<4; l++) {
for(int m=0; m<elec_num; m++) { for(int m=0; m<elec_num; m++) {
// Single reduction // Single reduction
dtmp_c[m + l * stride_l_d + k * stride_k_d + j * stride_j_d + i * stride_i_d + nw * stride_nw_d] = 0; dtmp_c[m + l * stride_l_d + k * stride_k_d + j * stride_j_d + i * stride_i_d + nw * stride_nw_d] = 0;
for(int n=0; n<elec_num; n++){ for(int n=0; n<elec_num; n++){
dtmp_c[m + l * stride_l_d + k * stride_k_d + j * stride_j_d + i * stride_i_d + nw * stride_nw_d] = dtmp_c[m + l * stride_l_d + k * stride_k_d + j * stride_j_d + i * stride_i_d + nw * stride_nw_d] =
dtmp_c[m + l * stride_l_d + k * stride_k_d + j * stride_j_d + i * stride_i_d + nw * stride_nw_d] + dtmp_c[m + l * stride_l_d + k * stride_k_d + j * stride_j_d + i * stride_i_d + nw * stride_nw_d] +
een_rescaled_e_deriv_e[m + l * stride_l_e + n * stride_n_e + i * stride_i_e + nw * stride_nw_e] * een_rescaled_e_deriv_e[m + l * stride_l_e + n * stride_n_e + i * stride_i_e + nw * stride_nw_e] *
een_rescaled_n[n + k * stride_k_n + j * stride_j_n + nw * stride_nw_n]; een_rescaled_n[n + k * stride_k_n + j * stride_j_n + nw * stride_nw_n];
}
}
} }
} }
} }
} }
} }
}
}
} }
return QMCKL_SUCCESS; return QMCKL_SUCCESS;
@ -6795,15 +6765,16 @@ qmckl_exit_code qmckl_compute_dtmp_c_omp_offload (
#+begin_src c :comments org :tangle (eval c) :noweb yes #+begin_src c :comments org :tangle (eval c) :noweb yes
#ifdef HAVE_CUBLAS_OFFLOAD #ifdef HAVE_CUBLAS_OFFLOAD
qmckl_exit_code qmckl_compute_dtmp_c_cuBlas (const qmckl_context context, qmckl_exit_code
const int64_t cord_num, qmckl_compute_dtmp_c_cublas_offload (
const int64_t elec_num, const qmckl_context context,
const int64_t nucl_num, const int64_t cord_num,
const int64_t walk_num, const int64_t elec_num,
const double* een_rescaled_e_deriv_e, const int64_t nucl_num,
const double* een_rescaled_n, const int64_t walk_num,
double* const dtmp_c ) const double* een_rescaled_e_deriv_e,
{ const double* een_rescaled_n,
double* const dtmp_c ) {
cublasHandle_t handle; cublasHandle_t handle;
if (cublasCreate(&handle) != CUBLAS_STATUS_SUCCESS) if (cublasCreate(&handle) != CUBLAS_STATUS_SUCCESS)
@ -6855,18 +6826,7 @@ qmckl_exit_code qmckl_compute_dtmp_c_cuBlas (const qmckl_context context,
#pragma omp target data use_device_ptr(een_rescaled_e_deriv_e, een_rescaled_n, dtmp_c) #pragma omp target data use_device_ptr(een_rescaled_e_deriv_e, een_rescaled_n, dtmp_c)
{ {
for (int64_t nw=0; nw < walk_num; ++nw) { for (int64_t nw=0; nw < walk_num; ++nw) {
/* //Manage CUBLAS ERRORS
for (int64_t i=0; i < cord_num; ++i) {
int cublasError = cublasDgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, M, N, K, &alpha, \
&(een_rescaled_e_deriv_e[af*(i+nw*(cord_num+1))]), \
LDA, \
&(een_rescaled_n[bf*nw]), \
LDB, \
&beta, \
&(dtmp_c[cf*(i+nw*cord_num)]), \
LDC);
,*/
//Manage CUBLAS ERRORS
int cublasError = cublasDgemmStridedBatched(handle, CUBLAS_OP_N, CUBLAS_OP_N, M, N, K, &alpha, \ int cublasError = cublasDgemmStridedBatched(handle, CUBLAS_OP_N, CUBLAS_OP_N, M, N, K, &alpha, \
&(een_rescaled_e_deriv_e[(nw*(cord_num+1))]), \ &(een_rescaled_e_deriv_e[(nw*(cord_num+1))]), \
@ -6902,7 +6862,7 @@ qmckl_exit_code qmckl_compute_dtmp_c_cuBlas (const qmckl_context context,
#+RESULTS: #+RESULTS:
#+begin_src c :comments org :tangle (eval h_private_func) :noweb yes :exports none #+begin_src c :comments org :tangle (eval h_private_func) :noweb yes :exports none
#ifdef HAVE_CUBLAS_OFFLOAD #ifdef HAVE_CUBLAS_OFFLOAD
qmckl_exit_code qmckl_compute_dtmp_c_cuBlas ( qmckl_exit_code qmckl_compute_dtmp_c_cublas_offload (
const qmckl_context context, const qmckl_context context,
const int64_t cord_num, const int64_t cord_num,
const int64_t elec_num, const int64_t elec_num,