mirror of
https://github.com/TREX-CoE/qmckl.git
synced 2025-01-03 10:06:09 +01:00
Start implementing cublas
This commit is contained in:
parent
6fb261d635
commit
39bcc569e0
@ -108,6 +108,12 @@ int main() {
|
|||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
|
|
||||||
|
|
||||||
|
#include <cuda_runtime_api.h>
|
||||||
|
#include "cublas_v2.h"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include "qmckl.h"
|
#include "qmckl.h"
|
||||||
@ -4857,7 +4863,7 @@ qmckl_exit_code qmckl_provide_tmp_c(qmckl_context context)
|
|||||||
}
|
}
|
||||||
ctx->jastrow.tmp_c = tmp_c;
|
ctx->jastrow.tmp_c = tmp_c;
|
||||||
}
|
}
|
||||||
|
/*
|
||||||
qmckl_exit_code rc =
|
qmckl_exit_code rc =
|
||||||
qmckl_compute_tmp_c(context,
|
qmckl_compute_tmp_c(context,
|
||||||
ctx->jastrow.cord_num,
|
ctx->jastrow.cord_num,
|
||||||
@ -4870,6 +4876,20 @@ qmckl_exit_code qmckl_provide_tmp_c(qmckl_context context)
|
|||||||
if (rc != QMCKL_SUCCESS) {
|
if (rc != QMCKL_SUCCESS) {
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
|
,*/
|
||||||
|
qmckl_exit_code rc =
|
||||||
|
qmckl_compute_tmp_c_cuBlas(context,
|
||||||
|
ctx->jastrow.cord_num,
|
||||||
|
ctx->electron.num,
|
||||||
|
ctx->nucleus.num,
|
||||||
|
ctx->electron.walk_num,
|
||||||
|
ctx->jastrow.een_rescaled_e,
|
||||||
|
ctx->jastrow.een_rescaled_n,
|
||||||
|
ctx->jastrow.tmp_c);
|
||||||
|
if (rc != QMCKL_SUCCESS) {
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
ctx->jastrow.tmp_c_date = ctx->date;
|
ctx->jastrow.tmp_c_date = ctx->date;
|
||||||
}
|
}
|
||||||
@ -4899,7 +4919,7 @@ qmckl_exit_code qmckl_provide_dtmp_c(qmckl_context context)
|
|||||||
|
|
||||||
qmckl_memory_info_struct mem_info = qmckl_memory_info_struct_zero;
|
qmckl_memory_info_struct mem_info = qmckl_memory_info_struct_zero;
|
||||||
mem_info.size = (ctx->jastrow.cord_num) * (ctx->jastrow.cord_num + 1)
|
mem_info.size = (ctx->jastrow.cord_num) * (ctx->jastrow.cord_num + 1)
|
||||||
* 4 * ctx->electron.num * ctx->nucleus.num * ctx->electron.walk_num * sizeof(double);
|
,* 4 * ctx->electron.num * ctx->nucleus.num * ctx->electron.walk_num * sizeof(double);
|
||||||
double* dtmp_c = (double*) qmckl_malloc(context, mem_info);
|
double* dtmp_c = (double*) qmckl_malloc(context, mem_info);
|
||||||
|
|
||||||
if (dtmp_c == NULL) {
|
if (dtmp_c == NULL) {
|
||||||
@ -4910,7 +4930,6 @@ qmckl_exit_code qmckl_provide_dtmp_c(qmckl_context context)
|
|||||||
}
|
}
|
||||||
ctx->jastrow.dtmp_c = dtmp_c;
|
ctx->jastrow.dtmp_c = dtmp_c;
|
||||||
}
|
}
|
||||||
|
|
||||||
qmckl_exit_code rc =
|
qmckl_exit_code rc =
|
||||||
qmckl_compute_dtmp_c(context,
|
qmckl_compute_dtmp_c(context,
|
||||||
ctx->jastrow.cord_num,
|
ctx->jastrow.cord_num,
|
||||||
@ -4924,6 +4943,7 @@ qmckl_exit_code qmckl_provide_dtmp_c(qmckl_context context)
|
|||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
ctx->jastrow.dtmp_c_date = ctx->date;
|
ctx->jastrow.dtmp_c_date = ctx->date;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -5453,6 +5473,105 @@ qmckl_exit_code qmckl_compute_tmp_c_hpc (
|
|||||||
}
|
}
|
||||||
#+end_src
|
#+end_src
|
||||||
|
|
||||||
|
#+begin_src c :comments org :tangle (eval c) :noweb yes
|
||||||
|
qmckl_exit_code qmckl_compute_tmp_c_cuBlas (
|
||||||
|
const qmckl_context context,
|
||||||
|
const int64_t cord_num,
|
||||||
|
const int64_t elec_num,
|
||||||
|
const int64_t nucl_num,
|
||||||
|
const int64_t walk_num,
|
||||||
|
const double* een_rescaled_e,
|
||||||
|
const double* een_rescaled_n,
|
||||||
|
double* const tmp_c ) {
|
||||||
|
|
||||||
|
qmckl_exit_code info;
|
||||||
|
|
||||||
|
//Initialisation of cublas
|
||||||
|
|
||||||
|
cublasHandle_t handle;
|
||||||
|
if (cublasCreate(&handle) != CUBLAS_STATUS_SUCCESS)
|
||||||
|
{
|
||||||
|
fprintf(stdout, "CUBLAS initialization failed!\n");
|
||||||
|
exit(EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if (context == QMCKL_NULL_CONTEXT) {
|
||||||
|
return QMCKL_INVALID_CONTEXT;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cord_num <= 0) {
|
||||||
|
return QMCKL_INVALID_ARG_2;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (elec_num <= 0) {
|
||||||
|
return QMCKL_INVALID_ARG_3;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (nucl_num <= 0) {
|
||||||
|
return QMCKL_INVALID_ARG_4;
|
||||||
|
}
|
||||||
|
|
||||||
|
const double alpha = 1.0;
|
||||||
|
const double beta = 0.0;
|
||||||
|
|
||||||
|
const int64_t M = elec_num;
|
||||||
|
const int64_t N = nucl_num*(cord_num + 1);
|
||||||
|
const int64_t K = elec_num;
|
||||||
|
|
||||||
|
const int64_t LDA = elec_num;
|
||||||
|
const int64_t LDB = elec_num;
|
||||||
|
const int64_t LDC = elec_num;
|
||||||
|
|
||||||
|
const int64_t af = elec_num*elec_num;
|
||||||
|
const int64_t bf = elec_num*nucl_num*(cord_num+1);
|
||||||
|
const int64_t cf = bf;
|
||||||
|
|
||||||
|
const double* tmp_c_gpu = malloc(sizeof(tmp_c));
|
||||||
|
|
||||||
|
#pragma omp target enter data map(alloc:een_rescaled_e[0:elec_num*elec_num*(cord_num+1)*walk_num],een_rescaled_n[0:M*N*K],tmp_c_gpu[0:sizeof(tmp_c_gpu)/sizeof(double)])
|
||||||
|
#pragma omp target data use_device_ptr(een_rescaled_e,een_rescaled_n,tmp_c)
|
||||||
|
{
|
||||||
|
for (int nw=0; nw < walk_num; ++nw) {
|
||||||
|
for (int i=0; i<cord_num; ++i){
|
||||||
|
|
||||||
|
//CuBlas implementation
|
||||||
|
int cublasError = cublasDgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, M, N, K, &alpha,
|
||||||
|
&(een_rescaled_e[af*(i+nw*(cord_num+1))]), \
|
||||||
|
LDA, \
|
||||||
|
&(een_rescaled_n[bf*nw]), \
|
||||||
|
LDB, \
|
||||||
|
&beta, \
|
||||||
|
&(tmp_c[cf*(i+nw*cord_num)]), \
|
||||||
|
LDC);
|
||||||
|
|
||||||
|
//Manage cublas ERROR
|
||||||
|
if(cublasError != CUBLAS_STATUS_SUCCESS){
|
||||||
|
printf("CUBLAS ERROR %d", cublasError);
|
||||||
|
info = QMCKL_FAILURE;
|
||||||
|
return info;
|
||||||
|
}else{
|
||||||
|
info = QMCKL_SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cudaDeviceSynchronize();
|
||||||
|
cublasDestroy(handle);
|
||||||
|
|
||||||
|
|
||||||
|
#pragma omp target exit data map(from:tmp_c_gpu[0:elec_num*nucl_num*(cord_num+1)*cord_num*walk_num])
|
||||||
|
|
||||||
|
|
||||||
|
return info;
|
||||||
|
}
|
||||||
|
#+end_src
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#+CALL: generate_c_header(table=qmckl_factor_tmp_c_args,rettyp=get_value("CRetType"),fname="qmckl_compute_tmp_c")
|
#+CALL: generate_c_header(table=qmckl_factor_tmp_c_args,rettyp=get_value("CRetType"),fname="qmckl_compute_tmp_c")
|
||||||
|
|
||||||
#+RESULTS:
|
#+RESULTS:
|
||||||
@ -5468,6 +5587,18 @@ qmckl_exit_code qmckl_compute_tmp_c (
|
|||||||
double* const tmp_c );
|
double* const tmp_c );
|
||||||
#+end_src
|
#+end_src
|
||||||
|
|
||||||
|
#+begin_src c :tangle (eval h_func) :comments org
|
||||||
|
qmckl_exit_code qmckl_compute_tmp_c_cuBlas (
|
||||||
|
const qmckl_context context,
|
||||||
|
const int64_t cord_num,
|
||||||
|
const int64_t elec_num,
|
||||||
|
const int64_t nucl_num,
|
||||||
|
const int64_t walk_num,
|
||||||
|
const double* een_rescaled_e,
|
||||||
|
const double* een_rescaled_n,
|
||||||
|
double* const tmp_c );
|
||||||
|
|
||||||
|
#+end_src
|
||||||
# #+CALL: generate_c_header(table=qmckl_factor_tmp_c_args,rettyp=get_value("CRetType"),fname="qmckl_compute_tmp_c_doc")
|
# #+CALL: generate_c_header(table=qmckl_factor_tmp_c_args,rettyp=get_value("CRetType"),fname="qmckl_compute_tmp_c_doc")
|
||||||
|
|
||||||
#+RESULTS:
|
#+RESULTS:
|
||||||
|
Loading…
Reference in New Issue
Block a user