1
0
mirror of https://github.com/TREX-CoE/irpjast.git synced 2024-07-22 10:47:45 +02:00

GEMM GPU OK

This commit is contained in:
Anthony Scemama 2021-04-28 03:10:42 +02:00
parent f224fd1ca1
commit 7b6a9c3925
3 changed files with 111 additions and 23 deletions

View File

@ -14,6 +14,21 @@
integer*8 :: tasks(100000), ntasks
ntasks = 0_8
! type(c_ptr) :: ptr_a, ptr_b, ptr_c, ptr_d, ptr_e
! double precision, pointer :: A(:,:,:), B(:,:,:), C(:,:,:,:), D(:,:,:,:), E(:,:,:,:,:)
! call alloc(ptr_a, int(size(rescale_een_e),8))
! call c_f_pointer(ptr_a, A, shape(rescale_een_e))
! A(:,:,:) = rescale_een_e(:,:,:)
!
! call alloc(ptr_b, int(size(rescale_een_n),8))
! call c_f_pointer(ptr_b, B, shape(rescale_een_n))
! B(:,:,:) = rescale_een_n(:,:,:)
!
! call alloc(ptr_c, int(size(tmp_c),8))
! call c_f_pointer(ptr_c, C, shape(tmp_c))
! r_{ij}^k . R_{ja}^l -> tmp_c_{ia}^{kl}
do k=0,ncord-1
do l=0,ncord
@ -21,9 +36,20 @@
rescale_een_e(1,1,k), size(rescale_een_e,1), &
rescale_een_n(1,1,l), size(rescale_een_n,1), 0.d0, &
tmp_c(1,1,l,k), size(tmp_c,1), tasks, ntasks)
! call qmckl_dgemm('N','N', nelec, nnuc, nelec, 1.d0, &
! A(:,:,k), size(rescale_een_e,1), &
! B(:,:,l), size(rescale_een_n,1), 0.d0, &
! C(:,:,l,k), size(tmp_c,1), tasks, ntasks)
enddo
enddo
!call alloc(ptr_d, int(size(rescale_een_e_deriv_e),8))
!call c_f_pointer(ptr_d, D, shape(rescale_een_e_deriv_e))
!D(:,:,:,:) = rescale_een_e_deriv_e(:,:,:,:)
!call alloc(ptr_e, int(size(dtmp_c),8))
!call c_f_pointer(ptr_e, E, shape(dtmp_c))
! dr_{ij}^k . R_{ja}^l -> dtmp_c_{ia}^{kl}
do k=0,ncord-1
do l=0,ncord
@ -34,11 +60,26 @@
size(rescale_een_n,1), 0.d0, &
dtmp_c(1,1,1,l,k), size(dtmp_c,1)*size(dtmp_c,2), &
tasks, ntasks)
! call qmckl_dgemm('N','N', nelec_8*4, nnuc, nelec, 1.d0, &
! D(:,:,:,k), &
! size(rescale_een_e_deriv_e,1)*size(rescale_een_e_deriv_e,2), &
! B(:,:,l), &
! size(rescale_een_n,1), 0.d0, &
! E(:,:,:,l,k), size(dtmp_c,1)*size(dtmp_c,2), &
! tasks, ntasks)
enddo
enddo
print *, ntasks, ' tasks'
call qmckl_tasks_run(tasks, ntasks)
! tmp_c(:,:,:,:) = C(:,:,:,:)
! dtmp_c(:,:,:,:,:) = E(:,:,:,:,:)
! call free(ptr_a)
! call free(ptr_b)
! call free(ptr_c)
! call free(ptr_d)
! call free(ptr_e)
END_PROVIDER

View File

@ -25,6 +25,24 @@ module qmckl_blas
integer (kind=c_int64_t) :: tasks(ntasks)
end subroutine qmckl_tasks_run
end interface
interface
subroutine alloc(A, sze) bind(C)
use :: iso_c_binding
implicit none
type(c_ptr) :: A
integer(c_size_t), value :: sze
end subroutine
end interface
interface
subroutine free(A) bind(C,name='starpu_free')
use :: iso_c_binding
implicit none
type(c_ptr), value :: A
end subroutine
end interface
end module qmckl_blas
subroutine f_dgemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) &

View File

@ -7,11 +7,13 @@
#include <stdlib.h>
#include <stdio.h>
//#define CPU_ENABLED 1
#define GPU_ENABLED 1
#ifdef GPU_ENABLED
#ifdef CPU_ENABLED
#include <cuda.h>
#include <starpu_cublas_v2.h>
#endif
void f_dgemm(const char transa, const char transb, const int m, const int n, const int k,
@ -59,15 +61,50 @@ void dgemm_codelet_cpu(void *buffers[], void* cl_arg)
free(args);
}
#ifdef GPU_ENABLED
void dgemm_codelet_gpu(void *buffers[], void* cl_arg)
{
struct dgemm_args *args = cl_arg;
double* A = (double*) STARPU_MATRIX_GET_PTR(buffers[0]);
double* B = (double*) STARPU_MATRIX_GET_PTR(buffers[1]);
double* C = (double*) STARPU_MATRIX_GET_PTR(buffers[2]);
int lda = STARPU_MATRIX_GET_LD(buffers[0]);
int ldb = STARPU_MATRIX_GET_LD(buffers[1]);
int ldc = STARPU_MATRIX_GET_LD(buffers[2]);
int m = STARPU_MATRIX_GET_NX(buffers[2]);
int n = STARPU_MATRIX_GET_NY(buffers[2]);
int k = STARPU_MATRIX_GET_NY(buffers[0]);
char transa = (args->transa == 'T' || args->transa == 'T')? CUBLAS_OP_T : CUBLAS_OP_N;
char transb = (args->transb == 'T' || args->transb == 'T')? CUBLAS_OP_T : CUBLAS_OP_N;
cublasStatus_t status = cublasDgemm(starpu_cublas_get_local_handle(),
transa, transb, m, n, k, &(args->alpha),
A, lda, B, ldb, &(args->beta), C, ldc);
if (status != CUBLAS_STATUS_SUCCESS)
STARPU_CUBLAS_REPORT_ERROR(status);
free(args);
}
#endif
struct starpu_codelet dgemm_cl =
{
.where = STARPU_CPU,
// .where = STARPU_CPU,
// .where = STARPU_CUDA,
.where = STARPU_CPU | STARPU_CUDA,
.cpu_funcs = { dgemm_codelet_cpu },
.cpu_funcs_name = { "dgemm_codelet_cpu" },
#ifdef GPU_ENABLED
.cuda_funcs = { dgemm_codelet_gpu },
.cuda_flags = {STARPU_CUDA_ASYNC},
#endif
.nbuffers = 3,
.max_parallelism = 1,
.modes = {STARPU_R, STARPU_R, STARPU_RW},
};
#include<stdio.h>
@ -83,25 +120,7 @@ static struct dgemm_args* qmckl_dgemm_to_struct(char transa, char transb,
struct dgemm_args* args = (struct dgemm_args*) malloc (sizeof(struct dgemm_args));
assert (args != NULL);
int dima = (transa == 'T' || transa == 't') ? m : k;
int dimb = (transb == 'T' || transb == 't') ? k : n;
/*
double* A2;
double* B2;
double* C2;
starpu_malloc_flags((void **)&A2,
lda*dima*sizeof(double),
STARPU_MALLOC_PINNED);
starpu_malloc_flags((void **)&B2,
lda*dima*sizeof(double),
STARPU_MALLOC_PINNED);
starpu_malloc_flags((void **)&C2,
lda*dima*sizeof(double),
STARPU_MALLOC_PINNED);
*/
args->alpha = alpha;
args->beta = beta ;
args->A = A;
@ -119,7 +138,7 @@ static struct dgemm_args* qmckl_dgemm_to_struct(char transa, char transb,
}
#define MIN_SIZE 512
#define MIN_SIZE 20480
static void qmckl_dgemm_rec(struct dgemm_args args, int64_t* tasks, int64_t* ntasks)
{
@ -187,9 +206,13 @@ void qmckl_dgemm(char transa, char transb,
double* C, int ldc,
int64_t* tasks, int64_t* ntasks)
{
struct dgemm_args* args = qmckl_dgemm_to_struct (transa, transb,
m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
starpu_memory_pin(A, lda*k*sizeof(double));
starpu_memory_pin(B, ldb*n*sizeof(double));
starpu_memory_pin(C, ldc*n*sizeof(double));
qmckl_dgemm_rec(*args, tasks, ntasks);
free(args);
}
@ -197,6 +220,8 @@ void qmckl_dgemm(char transa, char transb,
void qmckl_tasks_run(struct dgemm_args** gemms, int ngemms)
{
int rc = starpu_init(NULL);
starpu_cublas_init();
assert (rc == 0);
starpu_data_handle_t matrix_handle[ngemms][3];
@ -249,4 +274,8 @@ void qmckl_tasks_run(struct dgemm_args** gemms, int ngemms)
starpu_shutdown();
}
void alloc(void** ptr, int64_t size) {
printf("size: %ld\n", size);
starpu_malloc(ptr, (size_t) size * sizeof(double));
}