mirror of
https://github.com/TREX-CoE/qmckl.git
synced 2024-12-31 16:46:03 +01:00
Replace placeholder cuBLAS kernels with new C HPC implementation
This commit is contained in:
parent
f8e6d5f06b
commit
63c7f8ea72
@ -5783,17 +5783,6 @@ qmckl_exit_code qmckl_compute_tmp_c_cublas_offload (
|
||||
const double* een_rescaled_n,
|
||||
double* const tmp_c ) {
|
||||
|
||||
qmckl_exit_code info;
|
||||
int i, j, a, l, kk, p, lmax, nw;
|
||||
char TransA, TransB;
|
||||
double alpha, beta;
|
||||
int M, N, K, LDA, LDB, LDC;
|
||||
|
||||
TransA = 'N';
|
||||
TransB = 'N';
|
||||
alpha = 1.0;
|
||||
beta = 0.0;
|
||||
|
||||
if (context == QMCKL_NULL_CONTEXT) {
|
||||
return QMCKL_INVALID_CONTEXT;
|
||||
}
|
||||
@ -5810,29 +5799,40 @@ qmckl_exit_code qmckl_compute_tmp_c_cublas_offload (
|
||||
return QMCKL_INVALID_ARG_4;
|
||||
}
|
||||
|
||||
M = elec_num;
|
||||
N = nucl_num*(cord_num + 1);
|
||||
K = elec_num;
|
||||
if (walk_num <= 0) {
|
||||
return QMCKL_INVALID_ARG_5;
|
||||
}
|
||||
|
||||
LDA = sizeof(een_rescaled_e)/sizeof(double);
|
||||
LDB = sizeof(een_rescaled_n)/sizeof(double);
|
||||
LDC = sizeof(tmp_c)/sizeof(double);
|
||||
qmckl_exit_code info = QMCKL_SUCCESS;
|
||||
|
||||
// TODO Replace with cuBLAS calls
|
||||
for (int nw=0; nw < walk_num; ++nw) {
|
||||
for (int i=0; i<cord_num; ++i){
|
||||
info = qmckl_dgemm(context,TransA, TransB, M, N, K, alpha, \
|
||||
// &een_rescaled_e[0+0*elec_num+i*elec_num*elec_num+nw*elec_num*elec_num*(cord_num+1)],
|
||||
&een_rescaled_e[ i*elec_num*elec_num+nw*elec_num*elec_num*(cord_num+1)], \
|
||||
const char TransA = 'N';
|
||||
const char TransB = 'N';
|
||||
const double alpha = 1.0;
|
||||
const double beta = 0.0;
|
||||
|
||||
const int64_t M = elec_num;
|
||||
const int64_t N = nucl_num*(cord_num + 1);
|
||||
const int64_t K = elec_num;
|
||||
|
||||
const int64_t LDA = elec_num;
|
||||
const int64_t LDB = elec_num;
|
||||
const int64_t LDC = elec_num;
|
||||
|
||||
const int64_t af = elec_num*elec_num;
|
||||
const int64_t bf = elec_num*nucl_num*(cord_num+1);
|
||||
const int64_t cf = bf;
|
||||
|
||||
// TODO Replace with calls to cuBLAS
|
||||
for (int64_t nw=0; nw < walk_num; ++nw) {
|
||||
for (int64_t i=0; i<cord_num; ++i){
|
||||
info = qmckl_dgemm(context, TransA, TransB, M, N, K, alpha, \
|
||||
&(een_rescaled_e[af*(i+nw*(cord_num+1))]), \
|
||||
LDA, \
|
||||
// &een_rescaled_n[0+0*elec_num+0*elec_num*nucl_num+nw*elec_num*nucl_num*(cord_num+1)],
|
||||
&een_rescaled_n[ nw*elec_num*nucl_num*(cord_num+1)], \
|
||||
&(een_rescaled_n[bf*nw]), \
|
||||
LDB, \
|
||||
beta, \
|
||||
// &tmp_c[0+0*elec_num+0*elec_num*nucl_num+i*elec_num*nucl_num*(cord_num+1)+nw*elec_num*nucl_num*(cord_num+1)*cord_num],
|
||||
&tmp_c[ i*elec_num*nucl_num*(cord_num+1)+nw*elec_num*nucl_num*(cord_num+1)*cord_num], \
|
||||
&(tmp_c[cf*(i+nw*cord_num)]), \
|
||||
LDC);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@ -6244,16 +6244,6 @@ qmckl_exit_code qmckl_compute_dtmp_c_cublas_offload (
|
||||
const double* een_rescaled_n,
|
||||
double* const dtmp_c ) {
|
||||
|
||||
qmckl_exit_code info;
|
||||
char TransA, TransB;
|
||||
double alpha, beta;
|
||||
int M, N, K, LDA, LDB, LDC;
|
||||
|
||||
TransA = 'N';
|
||||
TransB = 'N';
|
||||
alpha = 1.0;
|
||||
beta = 0.0;
|
||||
|
||||
if (context == QMCKL_NULL_CONTEXT) {
|
||||
return QMCKL_INVALID_CONTEXT;
|
||||
}
|
||||
@ -6270,27 +6260,39 @@ qmckl_exit_code qmckl_compute_dtmp_c_cublas_offload (
|
||||
return QMCKL_INVALID_ARG_4;
|
||||
}
|
||||
|
||||
M = 4*elec_num;
|
||||
N = nucl_num*(cord_num + 1);
|
||||
K = elec_num;
|
||||
if (walk_num <= 0) {
|
||||
return QMCKL_INVALID_ARG_5;
|
||||
}
|
||||
|
||||
LDA = 4*sizeof(een_rescaled_e_deriv_e)/sizeof(double);
|
||||
LDB = sizeof(een_rescaled_n)/sizeof(double);
|
||||
LDC = 4*sizeof(dtmp_c)/sizeof(double);
|
||||
qmckl_exit_code info = QMCKL_SUCCESS;
|
||||
|
||||
// TODO Replace with cuBLAS calls
|
||||
for (int nw=0; nw < walk_num; ++nw) {
|
||||
for (int i=0; nw < cord_num; ++i) {
|
||||
info = qmckl_dgemm(context,TransA, TransB, M, N, K, alpha, \
|
||||
//&een_rescaled_e_deriv_e[0+0*elec_num+0*elec_num*4+i*elec_num*4*elec_num+nw*elec_num*4*elec_num*(cord_num+1)],
|
||||
&een_rescaled_e_deriv_e[i*elec_num*4*elec_num+nw*elec_num*4*elec_num*(cord_num+1)], \
|
||||
const char TransA = 'N';
|
||||
const char TransB = 'N';
|
||||
const double alpha = 1.0;
|
||||
const double beta = 0.0;
|
||||
|
||||
const int64_t M = 4*elec_num;
|
||||
const int64_t N = nucl_num*(cord_num + 1);
|
||||
const int64_t K = elec_num;
|
||||
|
||||
const int64_t LDA = 4*elec_num;
|
||||
const int64_t LDB = elec_num;
|
||||
const int64_t LDC = 4*elec_num;
|
||||
|
||||
const int64_t af = elec_num*elec_num*4;
|
||||
const int64_t bf = elec_num*nucl_num*(cord_num+1);
|
||||
const int64_t cf = elec_num*4*nucl_num*(cord_num+1);
|
||||
|
||||
// TODO Replace with calls to cuBLAS
|
||||
for (int64_t nw=0; nw < walk_num; ++nw) {
|
||||
for (int64_t i=0; i < cord_num; ++i) {
|
||||
info = qmckl_dgemm(context, TransA, TransB, M, N, K, alpha, \
|
||||
&(een_rescaled_e_deriv_e[af*(i+nw*(cord_num+1))]), \
|
||||
LDA, \
|
||||
//&een_rescaled_n[0+0*elec_num+0*elec_num*nucl_num+nw*elec_num*nucl_num*(cord_num+1)],
|
||||
&een_rescaled_n[nw*elec_num*nucl_num*(cord_num+1)], \
|
||||
&(een_rescaled_n[bf*nw]), \
|
||||
LDB, \
|
||||
beta, \
|
||||
//&dtmp_c[0+0*elec_num+0*elec_num*4+0*elec_num*4*nucl_num+i*elec_num*4*nucl_num*(cord_num+1)+nw*elec_num*4*nucl_num*(cord_num+1)*cord_num],
|
||||
&dtmp_c[i*elec_num*4*nucl_num*(cord_num+1)+nw*elec_num*4*nucl_num*(cord_num+1)*cord_num], \
|
||||
&(dtmp_c[cf*(i+nw*cord_num)]), \
|
||||
LDC);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user