mirror of
https://github.com/TREX-CoE/irpjast.git
synced 2025-01-03 01:56:19 +01:00
GEMM GPU OK
This commit is contained in:
parent
f224fd1ca1
commit
7b6a9c3925
@ -14,6 +14,21 @@
|
||||
integer*8 :: tasks(100000), ntasks
|
||||
|
||||
ntasks = 0_8
|
||||
|
||||
! type(c_ptr) :: ptr_a, ptr_b, ptr_c, ptr_d, ptr_e
|
||||
! double precision, pointer :: A(:,:,:), B(:,:,:), C(:,:,:,:), D(:,:,:,:), E(:,:,:,:,:)
|
||||
|
||||
! call alloc(ptr_a, int(size(rescale_een_e),8))
|
||||
! call c_f_pointer(ptr_a, A, shape(rescale_een_e))
|
||||
! A(:,:,:) = rescale_een_e(:,:,:)
|
||||
!
|
||||
! call alloc(ptr_b, int(size(rescale_een_n),8))
|
||||
! call c_f_pointer(ptr_b, B, shape(rescale_een_n))
|
||||
! B(:,:,:) = rescale_een_n(:,:,:)
|
||||
!
|
||||
! call alloc(ptr_c, int(size(tmp_c),8))
|
||||
! call c_f_pointer(ptr_c, C, shape(tmp_c))
|
||||
|
||||
! r_{ij}^k . R_{ja}^l -> tmp_c_{ia}^{kl}
|
||||
do k=0,ncord-1
|
||||
do l=0,ncord
|
||||
@ -21,9 +36,20 @@
|
||||
rescale_een_e(1,1,k), size(rescale_een_e,1), &
|
||||
rescale_een_n(1,1,l), size(rescale_een_n,1), 0.d0, &
|
||||
tmp_c(1,1,l,k), size(tmp_c,1), tasks, ntasks)
|
||||
! call qmckl_dgemm('N','N', nelec, nnuc, nelec, 1.d0, &
|
||||
! A(:,:,k), size(rescale_een_e,1), &
|
||||
! B(:,:,l), size(rescale_een_n,1), 0.d0, &
|
||||
! C(:,:,l,k), size(tmp_c,1), tasks, ntasks)
|
||||
enddo
|
||||
enddo
|
||||
|
||||
!call alloc(ptr_d, int(size(rescale_een_e_deriv_e),8))
|
||||
!call c_f_pointer(ptr_d, D, shape(rescale_een_e_deriv_e))
|
||||
!D(:,:,:,:) = rescale_een_e_deriv_e(:,:,:,:)
|
||||
|
||||
!call alloc(ptr_e, int(size(dtmp_c),8))
|
||||
!call c_f_pointer(ptr_e, E, shape(dtmp_c))
|
||||
|
||||
! dr_{ij}^k . R_{ja}^l -> dtmp_c_{ia}^{kl}
|
||||
do k=0,ncord-1
|
||||
do l=0,ncord
|
||||
@ -34,11 +60,26 @@
|
||||
size(rescale_een_n,1), 0.d0, &
|
||||
dtmp_c(1,1,1,l,k), size(dtmp_c,1)*size(dtmp_c,2), &
|
||||
tasks, ntasks)
|
||||
|
||||
! call qmckl_dgemm('N','N', nelec_8*4, nnuc, nelec, 1.d0, &
|
||||
! D(:,:,:,k), &
|
||||
! size(rescale_een_e_deriv_e,1)*size(rescale_een_e_deriv_e,2), &
|
||||
! B(:,:,l), &
|
||||
! size(rescale_een_n,1), 0.d0, &
|
||||
! E(:,:,:,l,k), size(dtmp_c,1)*size(dtmp_c,2), &
|
||||
! tasks, ntasks)
|
||||
enddo
|
||||
enddo
|
||||
|
||||
print *, ntasks, ' tasks'
|
||||
call qmckl_tasks_run(tasks, ntasks)
|
||||
! tmp_c(:,:,:,:) = C(:,:,:,:)
|
||||
! dtmp_c(:,:,:,:,:) = E(:,:,:,:,:)
|
||||
! call free(ptr_a)
|
||||
! call free(ptr_b)
|
||||
! call free(ptr_c)
|
||||
! call free(ptr_d)
|
||||
! call free(ptr_e)
|
||||
|
||||
END_PROVIDER
|
||||
|
||||
|
@ -25,6 +25,24 @@ module qmckl_blas
|
||||
integer (kind=c_int64_t) :: tasks(ntasks)
|
||||
end subroutine qmckl_tasks_run
|
||||
end interface
|
||||
|
||||
interface
|
||||
subroutine alloc(A, sze) bind(C)
|
||||
use :: iso_c_binding
|
||||
implicit none
|
||||
type(c_ptr) :: A
|
||||
integer(c_size_t), value :: sze
|
||||
end subroutine
|
||||
end interface
|
||||
|
||||
interface
|
||||
subroutine free(A) bind(C,name='starpu_free')
|
||||
use :: iso_c_binding
|
||||
implicit none
|
||||
type(c_ptr), value :: A
|
||||
end subroutine
|
||||
end interface
|
||||
|
||||
end module qmckl_blas
|
||||
|
||||
subroutine f_dgemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) &
|
||||
|
@ -7,11 +7,13 @@
|
||||
#include <stdlib.h>
|
||||
#include <stdio.h>
|
||||
|
||||
//#define CPU_ENABLED 1
|
||||
#define GPU_ENABLED 1
|
||||
|
||||
#ifdef GPU_ENABLED
|
||||
|
||||
#ifdef CPU_ENABLED
|
||||
#include <cuda.h>
|
||||
#include <starpu_cublas_v2.h>
|
||||
|
||||
#endif
|
||||
|
||||
void f_dgemm(const char transa, const char transb, const int m, const int n, const int k,
|
||||
@ -59,15 +61,50 @@ void dgemm_codelet_cpu(void *buffers[], void* cl_arg)
|
||||
free(args);
|
||||
}
|
||||
|
||||
#ifdef GPU_ENABLED
|
||||
void dgemm_codelet_gpu(void *buffers[], void* cl_arg)
|
||||
{
|
||||
struct dgemm_args *args = cl_arg;
|
||||
double* A = (double*) STARPU_MATRIX_GET_PTR(buffers[0]);
|
||||
double* B = (double*) STARPU_MATRIX_GET_PTR(buffers[1]);
|
||||
double* C = (double*) STARPU_MATRIX_GET_PTR(buffers[2]);
|
||||
|
||||
int lda = STARPU_MATRIX_GET_LD(buffers[0]);
|
||||
int ldb = STARPU_MATRIX_GET_LD(buffers[1]);
|
||||
int ldc = STARPU_MATRIX_GET_LD(buffers[2]);
|
||||
|
||||
int m = STARPU_MATRIX_GET_NX(buffers[2]);
|
||||
int n = STARPU_MATRIX_GET_NY(buffers[2]);
|
||||
int k = STARPU_MATRIX_GET_NY(buffers[0]);
|
||||
|
||||
char transa = (args->transa == 'T' || args->transa == 'T')? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
char transb = (args->transb == 'T' || args->transb == 'T')? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
|
||||
cublasStatus_t status = cublasDgemm(starpu_cublas_get_local_handle(),
|
||||
transa, transb, m, n, k, &(args->alpha),
|
||||
A, lda, B, ldb, &(args->beta), C, ldc);
|
||||
|
||||
if (status != CUBLAS_STATUS_SUCCESS)
|
||||
STARPU_CUBLAS_REPORT_ERROR(status);
|
||||
|
||||
free(args);
|
||||
}
|
||||
#endif
|
||||
|
||||
struct starpu_codelet dgemm_cl =
|
||||
{
|
||||
.where = STARPU_CPU,
|
||||
// .where = STARPU_CPU,
|
||||
// .where = STARPU_CUDA,
|
||||
.where = STARPU_CPU | STARPU_CUDA,
|
||||
.cpu_funcs = { dgemm_codelet_cpu },
|
||||
.cpu_funcs_name = { "dgemm_codelet_cpu" },
|
||||
#ifdef GPU_ENABLED
|
||||
.cuda_funcs = { dgemm_codelet_gpu },
|
||||
.cuda_flags = {STARPU_CUDA_ASYNC},
|
||||
#endif
|
||||
.nbuffers = 3,
|
||||
.max_parallelism = 1,
|
||||
.modes = {STARPU_R, STARPU_R, STARPU_RW},
|
||||
|
||||
};
|
||||
|
||||
#include<stdio.h>
|
||||
@ -83,25 +120,7 @@ static struct dgemm_args* qmckl_dgemm_to_struct(char transa, char transb,
|
||||
struct dgemm_args* args = (struct dgemm_args*) malloc (sizeof(struct dgemm_args));
|
||||
assert (args != NULL);
|
||||
|
||||
int dima = (transa == 'T' || transa == 't') ? m : k;
|
||||
int dimb = (transb == 'T' || transb == 't') ? k : n;
|
||||
/*
|
||||
double* A2;
|
||||
double* B2;
|
||||
double* C2;
|
||||
|
||||
starpu_malloc_flags((void **)&A2,
|
||||
lda*dima*sizeof(double),
|
||||
STARPU_MALLOC_PINNED);
|
||||
|
||||
starpu_malloc_flags((void **)&B2,
|
||||
lda*dima*sizeof(double),
|
||||
STARPU_MALLOC_PINNED);
|
||||
|
||||
starpu_malloc_flags((void **)&C2,
|
||||
lda*dima*sizeof(double),
|
||||
STARPU_MALLOC_PINNED);
|
||||
*/
|
||||
args->alpha = alpha;
|
||||
args->beta = beta ;
|
||||
args->A = A;
|
||||
@ -119,7 +138,7 @@ static struct dgemm_args* qmckl_dgemm_to_struct(char transa, char transb,
|
||||
|
||||
}
|
||||
|
||||
#define MIN_SIZE 512
|
||||
#define MIN_SIZE 20480
|
||||
static void qmckl_dgemm_rec(struct dgemm_args args, int64_t* tasks, int64_t* ntasks)
|
||||
{
|
||||
|
||||
@ -187,9 +206,13 @@ void qmckl_dgemm(char transa, char transb,
|
||||
double* C, int ldc,
|
||||
int64_t* tasks, int64_t* ntasks)
|
||||
{
|
||||
|
||||
struct dgemm_args* args = qmckl_dgemm_to_struct (transa, transb,
|
||||
m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
|
||||
|
||||
starpu_memory_pin(A, lda*k*sizeof(double));
|
||||
starpu_memory_pin(B, ldb*n*sizeof(double));
|
||||
starpu_memory_pin(C, ldc*n*sizeof(double));
|
||||
qmckl_dgemm_rec(*args, tasks, ntasks);
|
||||
free(args);
|
||||
}
|
||||
@ -197,6 +220,8 @@ void qmckl_dgemm(char transa, char transb,
|
||||
void qmckl_tasks_run(struct dgemm_args** gemms, int ngemms)
|
||||
{
|
||||
int rc = starpu_init(NULL);
|
||||
starpu_cublas_init();
|
||||
|
||||
assert (rc == 0);
|
||||
|
||||
starpu_data_handle_t matrix_handle[ngemms][3];
|
||||
@ -249,4 +274,8 @@ void qmckl_tasks_run(struct dgemm_args** gemms, int ngemms)
|
||||
starpu_shutdown();
|
||||
}
|
||||
|
||||
void alloc(void** ptr, int64_t size) {
|
||||
printf("size: %ld\n", size);
|
||||
starpu_malloc(ptr, (size_t) size * sizeof(double));
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user