mirror of
https://github.com/TREX-CoE/irpjast.git
synced 2025-01-03 01:56:19 +01:00
GEMM GPU OK
This commit is contained in:
parent
f224fd1ca1
commit
7b6a9c3925
@ -14,6 +14,21 @@
|
|||||||
integer*8 :: tasks(100000), ntasks
|
integer*8 :: tasks(100000), ntasks
|
||||||
|
|
||||||
ntasks = 0_8
|
ntasks = 0_8
|
||||||
|
|
||||||
|
! type(c_ptr) :: ptr_a, ptr_b, ptr_c, ptr_d, ptr_e
|
||||||
|
! double precision, pointer :: A(:,:,:), B(:,:,:), C(:,:,:,:), D(:,:,:,:), E(:,:,:,:,:)
|
||||||
|
|
||||||
|
! call alloc(ptr_a, int(size(rescale_een_e),8))
|
||||||
|
! call c_f_pointer(ptr_a, A, shape(rescale_een_e))
|
||||||
|
! A(:,:,:) = rescale_een_e(:,:,:)
|
||||||
|
!
|
||||||
|
! call alloc(ptr_b, int(size(rescale_een_n),8))
|
||||||
|
! call c_f_pointer(ptr_b, B, shape(rescale_een_n))
|
||||||
|
! B(:,:,:) = rescale_een_n(:,:,:)
|
||||||
|
!
|
||||||
|
! call alloc(ptr_c, int(size(tmp_c),8))
|
||||||
|
! call c_f_pointer(ptr_c, C, shape(tmp_c))
|
||||||
|
|
||||||
! r_{ij}^k . R_{ja}^l -> tmp_c_{ia}^{kl}
|
! r_{ij}^k . R_{ja}^l -> tmp_c_{ia}^{kl}
|
||||||
do k=0,ncord-1
|
do k=0,ncord-1
|
||||||
do l=0,ncord
|
do l=0,ncord
|
||||||
@ -21,9 +36,20 @@
|
|||||||
rescale_een_e(1,1,k), size(rescale_een_e,1), &
|
rescale_een_e(1,1,k), size(rescale_een_e,1), &
|
||||||
rescale_een_n(1,1,l), size(rescale_een_n,1), 0.d0, &
|
rescale_een_n(1,1,l), size(rescale_een_n,1), 0.d0, &
|
||||||
tmp_c(1,1,l,k), size(tmp_c,1), tasks, ntasks)
|
tmp_c(1,1,l,k), size(tmp_c,1), tasks, ntasks)
|
||||||
|
! call qmckl_dgemm('N','N', nelec, nnuc, nelec, 1.d0, &
|
||||||
|
! A(:,:,k), size(rescale_een_e,1), &
|
||||||
|
! B(:,:,l), size(rescale_een_n,1), 0.d0, &
|
||||||
|
! C(:,:,l,k), size(tmp_c,1), tasks, ntasks)
|
||||||
enddo
|
enddo
|
||||||
enddo
|
enddo
|
||||||
|
|
||||||
|
!call alloc(ptr_d, int(size(rescale_een_e_deriv_e),8))
|
||||||
|
!call c_f_pointer(ptr_d, D, shape(rescale_een_e_deriv_e))
|
||||||
|
!D(:,:,:,:) = rescale_een_e_deriv_e(:,:,:,:)
|
||||||
|
|
||||||
|
!call alloc(ptr_e, int(size(dtmp_c),8))
|
||||||
|
!call c_f_pointer(ptr_e, E, shape(dtmp_c))
|
||||||
|
|
||||||
! dr_{ij}^k . R_{ja}^l -> dtmp_c_{ia}^{kl}
|
! dr_{ij}^k . R_{ja}^l -> dtmp_c_{ia}^{kl}
|
||||||
do k=0,ncord-1
|
do k=0,ncord-1
|
||||||
do l=0,ncord
|
do l=0,ncord
|
||||||
@ -34,11 +60,26 @@
|
|||||||
size(rescale_een_n,1), 0.d0, &
|
size(rescale_een_n,1), 0.d0, &
|
||||||
dtmp_c(1,1,1,l,k), size(dtmp_c,1)*size(dtmp_c,2), &
|
dtmp_c(1,1,1,l,k), size(dtmp_c,1)*size(dtmp_c,2), &
|
||||||
tasks, ntasks)
|
tasks, ntasks)
|
||||||
|
|
||||||
|
! call qmckl_dgemm('N','N', nelec_8*4, nnuc, nelec, 1.d0, &
|
||||||
|
! D(:,:,:,k), &
|
||||||
|
! size(rescale_een_e_deriv_e,1)*size(rescale_een_e_deriv_e,2), &
|
||||||
|
! B(:,:,l), &
|
||||||
|
! size(rescale_een_n,1), 0.d0, &
|
||||||
|
! E(:,:,:,l,k), size(dtmp_c,1)*size(dtmp_c,2), &
|
||||||
|
! tasks, ntasks)
|
||||||
enddo
|
enddo
|
||||||
enddo
|
enddo
|
||||||
|
|
||||||
print *, ntasks, ' tasks'
|
print *, ntasks, ' tasks'
|
||||||
call qmckl_tasks_run(tasks, ntasks)
|
call qmckl_tasks_run(tasks, ntasks)
|
||||||
|
! tmp_c(:,:,:,:) = C(:,:,:,:)
|
||||||
|
! dtmp_c(:,:,:,:,:) = E(:,:,:,:,:)
|
||||||
|
! call free(ptr_a)
|
||||||
|
! call free(ptr_b)
|
||||||
|
! call free(ptr_c)
|
||||||
|
! call free(ptr_d)
|
||||||
|
! call free(ptr_e)
|
||||||
|
|
||||||
END_PROVIDER
|
END_PROVIDER
|
||||||
|
|
||||||
|
@ -25,6 +25,24 @@ module qmckl_blas
|
|||||||
integer (kind=c_int64_t) :: tasks(ntasks)
|
integer (kind=c_int64_t) :: tasks(ntasks)
|
||||||
end subroutine qmckl_tasks_run
|
end subroutine qmckl_tasks_run
|
||||||
end interface
|
end interface
|
||||||
|
|
||||||
|
interface
|
||||||
|
subroutine alloc(A, sze) bind(C)
|
||||||
|
use :: iso_c_binding
|
||||||
|
implicit none
|
||||||
|
type(c_ptr) :: A
|
||||||
|
integer(c_size_t), value :: sze
|
||||||
|
end subroutine
|
||||||
|
end interface
|
||||||
|
|
||||||
|
interface
|
||||||
|
subroutine free(A) bind(C,name='starpu_free')
|
||||||
|
use :: iso_c_binding
|
||||||
|
implicit none
|
||||||
|
type(c_ptr), value :: A
|
||||||
|
end subroutine
|
||||||
|
end interface
|
||||||
|
|
||||||
end module qmckl_blas
|
end module qmckl_blas
|
||||||
|
|
||||||
subroutine f_dgemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) &
|
subroutine f_dgemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) &
|
||||||
|
@ -7,11 +7,13 @@
|
|||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
//#define CPU_ENABLED 1
|
#define GPU_ENABLED 1
|
||||||
|
|
||||||
|
#ifdef GPU_ENABLED
|
||||||
|
|
||||||
#ifdef CPU_ENABLED
|
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <starpu_cublas_v2.h>
|
#include <starpu_cublas_v2.h>
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void f_dgemm(const char transa, const char transb, const int m, const int n, const int k,
|
void f_dgemm(const char transa, const char transb, const int m, const int n, const int k,
|
||||||
@ -59,15 +61,50 @@ void dgemm_codelet_cpu(void *buffers[], void* cl_arg)
|
|||||||
free(args);
|
free(args);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef GPU_ENABLED
|
||||||
|
void dgemm_codelet_gpu(void *buffers[], void* cl_arg)
|
||||||
|
{
|
||||||
|
struct dgemm_args *args = cl_arg;
|
||||||
|
double* A = (double*) STARPU_MATRIX_GET_PTR(buffers[0]);
|
||||||
|
double* B = (double*) STARPU_MATRIX_GET_PTR(buffers[1]);
|
||||||
|
double* C = (double*) STARPU_MATRIX_GET_PTR(buffers[2]);
|
||||||
|
|
||||||
|
int lda = STARPU_MATRIX_GET_LD(buffers[0]);
|
||||||
|
int ldb = STARPU_MATRIX_GET_LD(buffers[1]);
|
||||||
|
int ldc = STARPU_MATRIX_GET_LD(buffers[2]);
|
||||||
|
|
||||||
|
int m = STARPU_MATRIX_GET_NX(buffers[2]);
|
||||||
|
int n = STARPU_MATRIX_GET_NY(buffers[2]);
|
||||||
|
int k = STARPU_MATRIX_GET_NY(buffers[0]);
|
||||||
|
|
||||||
|
char transa = (args->transa == 'T' || args->transa == 'T')? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||||
|
char transb = (args->transb == 'T' || args->transb == 'T')? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||||
|
|
||||||
|
cublasStatus_t status = cublasDgemm(starpu_cublas_get_local_handle(),
|
||||||
|
transa, transb, m, n, k, &(args->alpha),
|
||||||
|
A, lda, B, ldb, &(args->beta), C, ldc);
|
||||||
|
|
||||||
|
if (status != CUBLAS_STATUS_SUCCESS)
|
||||||
|
STARPU_CUBLAS_REPORT_ERROR(status);
|
||||||
|
|
||||||
|
free(args);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
struct starpu_codelet dgemm_cl =
|
struct starpu_codelet dgemm_cl =
|
||||||
{
|
{
|
||||||
.where = STARPU_CPU,
|
// .where = STARPU_CPU,
|
||||||
|
// .where = STARPU_CUDA,
|
||||||
|
.where = STARPU_CPU | STARPU_CUDA,
|
||||||
.cpu_funcs = { dgemm_codelet_cpu },
|
.cpu_funcs = { dgemm_codelet_cpu },
|
||||||
.cpu_funcs_name = { "dgemm_codelet_cpu" },
|
.cpu_funcs_name = { "dgemm_codelet_cpu" },
|
||||||
|
#ifdef GPU_ENABLED
|
||||||
|
.cuda_funcs = { dgemm_codelet_gpu },
|
||||||
|
.cuda_flags = {STARPU_CUDA_ASYNC},
|
||||||
|
#endif
|
||||||
.nbuffers = 3,
|
.nbuffers = 3,
|
||||||
.max_parallelism = 1,
|
.max_parallelism = 1,
|
||||||
.modes = {STARPU_R, STARPU_R, STARPU_RW},
|
.modes = {STARPU_R, STARPU_R, STARPU_RW},
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#include<stdio.h>
|
#include<stdio.h>
|
||||||
@ -83,25 +120,7 @@ static struct dgemm_args* qmckl_dgemm_to_struct(char transa, char transb,
|
|||||||
struct dgemm_args* args = (struct dgemm_args*) malloc (sizeof(struct dgemm_args));
|
struct dgemm_args* args = (struct dgemm_args*) malloc (sizeof(struct dgemm_args));
|
||||||
assert (args != NULL);
|
assert (args != NULL);
|
||||||
|
|
||||||
int dima = (transa == 'T' || transa == 't') ? m : k;
|
|
||||||
int dimb = (transb == 'T' || transb == 't') ? k : n;
|
|
||||||
/*
|
|
||||||
double* A2;
|
|
||||||
double* B2;
|
|
||||||
double* C2;
|
|
||||||
|
|
||||||
starpu_malloc_flags((void **)&A2,
|
|
||||||
lda*dima*sizeof(double),
|
|
||||||
STARPU_MALLOC_PINNED);
|
|
||||||
|
|
||||||
starpu_malloc_flags((void **)&B2,
|
|
||||||
lda*dima*sizeof(double),
|
|
||||||
STARPU_MALLOC_PINNED);
|
|
||||||
|
|
||||||
starpu_malloc_flags((void **)&C2,
|
|
||||||
lda*dima*sizeof(double),
|
|
||||||
STARPU_MALLOC_PINNED);
|
|
||||||
*/
|
|
||||||
args->alpha = alpha;
|
args->alpha = alpha;
|
||||||
args->beta = beta ;
|
args->beta = beta ;
|
||||||
args->A = A;
|
args->A = A;
|
||||||
@ -119,7 +138,7 @@ static struct dgemm_args* qmckl_dgemm_to_struct(char transa, char transb,
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define MIN_SIZE 512
|
#define MIN_SIZE 20480
|
||||||
static void qmckl_dgemm_rec(struct dgemm_args args, int64_t* tasks, int64_t* ntasks)
|
static void qmckl_dgemm_rec(struct dgemm_args args, int64_t* tasks, int64_t* ntasks)
|
||||||
{
|
{
|
||||||
|
|
||||||
@ -187,9 +206,13 @@ void qmckl_dgemm(char transa, char transb,
|
|||||||
double* C, int ldc,
|
double* C, int ldc,
|
||||||
int64_t* tasks, int64_t* ntasks)
|
int64_t* tasks, int64_t* ntasks)
|
||||||
{
|
{
|
||||||
|
|
||||||
struct dgemm_args* args = qmckl_dgemm_to_struct (transa, transb,
|
struct dgemm_args* args = qmckl_dgemm_to_struct (transa, transb,
|
||||||
m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
|
m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
|
||||||
|
|
||||||
|
starpu_memory_pin(A, lda*k*sizeof(double));
|
||||||
|
starpu_memory_pin(B, ldb*n*sizeof(double));
|
||||||
|
starpu_memory_pin(C, ldc*n*sizeof(double));
|
||||||
qmckl_dgemm_rec(*args, tasks, ntasks);
|
qmckl_dgemm_rec(*args, tasks, ntasks);
|
||||||
free(args);
|
free(args);
|
||||||
}
|
}
|
||||||
@ -197,6 +220,8 @@ void qmckl_dgemm(char transa, char transb,
|
|||||||
void qmckl_tasks_run(struct dgemm_args** gemms, int ngemms)
|
void qmckl_tasks_run(struct dgemm_args** gemms, int ngemms)
|
||||||
{
|
{
|
||||||
int rc = starpu_init(NULL);
|
int rc = starpu_init(NULL);
|
||||||
|
starpu_cublas_init();
|
||||||
|
|
||||||
assert (rc == 0);
|
assert (rc == 0);
|
||||||
|
|
||||||
starpu_data_handle_t matrix_handle[ngemms][3];
|
starpu_data_handle_t matrix_handle[ngemms][3];
|
||||||
@ -249,4 +274,8 @@ void qmckl_tasks_run(struct dgemm_args** gemms, int ngemms)
|
|||||||
starpu_shutdown();
|
starpu_shutdown();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void alloc(void** ptr, int64_t size) {
|
||||||
|
printf("size: %ld\n", size);
|
||||||
|
starpu_malloc(ptr, (size_t) size * sizeof(double));
|
||||||
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user