mirror of
https://github.com/TREX-CoE/irpjast.git
synced 2024-12-22 12:23:57 +01:00
Transition to GPU
This commit is contained in:
parent
79c0d998ea
commit
f224fd1ca1
2
Makefile
2
Makefile
@ -10,7 +10,7 @@ RANLIB = ranlib
|
|||||||
|
|
||||||
SRC= qmckl_blas_f.f90 qmckl_dgemm.c
|
SRC= qmckl_blas_f.f90 qmckl_dgemm.c
|
||||||
OBJ= IRPF90_temp/qmckl_blas_f.o IRPF90_temp/qmckl_dgemm.o
|
OBJ= IRPF90_temp/qmckl_blas_f.o IRPF90_temp/qmckl_dgemm.o
|
||||||
LIB= -mkl=sequential $(shell pkg-config --libs $(STARPU) )
|
LIB= -mkl=sequential $(shell pkg-config --libs $(STARPU) magma)
|
||||||
|
|
||||||
-include irpf90.make
|
-include irpf90.make
|
||||||
export
|
export
|
||||||
|
@ -25,5 +25,17 @@ module qmckl_blas
|
|||||||
integer (kind=c_int64_t) :: tasks(ntasks)
|
integer (kind=c_int64_t) :: tasks(ntasks)
|
||||||
end subroutine qmckl_tasks_run
|
end subroutine qmckl_tasks_run
|
||||||
end interface
|
end interface
|
||||||
|
|
||||||
end module qmckl_blas
|
end module qmckl_blas
|
||||||
|
|
||||||
|
subroutine f_dgemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) &
|
||||||
|
bind(C, name='f_dgemm')
|
||||||
|
use iso_c_binding
|
||||||
|
implicit none
|
||||||
|
character, intent(in), value :: TRANSA, TRANSB
|
||||||
|
integer, intent(in), value :: M,N,K,LDA,LDB,LDC
|
||||||
|
double precision, intent(in), value :: ALPHA, BETA
|
||||||
|
double precision, intent(in) :: A(LDA,*), B(LDB,*)
|
||||||
|
double precision, intent(out) :: C(LDC,*)
|
||||||
|
call dgemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
|
||||||
|
end subroutine
|
||||||
|
|
||||||
|
@ -2,12 +2,22 @@
|
|||||||
|
|
||||||
#include <starpu.h>
|
#include <starpu.h>
|
||||||
|
|
||||||
#include <mkl_cblas.h>
|
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
|
//#define CPU_ENABLED 1
|
||||||
|
|
||||||
|
#ifdef CPU_ENABLED
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <starpu_cublas_v2.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void f_dgemm(const char transa, const char transb, const int m, const int n, const int k,
|
||||||
|
const double alpha, const double* A, const int lda, const double* B,
|
||||||
|
const int ldb, const double beta, double* C, const int ldc);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
struct dgemm_args {
|
struct dgemm_args {
|
||||||
@ -22,28 +32,38 @@ struct dgemm_args {
|
|||||||
int lda;
|
int lda;
|
||||||
int ldb;
|
int ldb;
|
||||||
int ldc;
|
int ldc;
|
||||||
CBLAS_LAYOUT transa;
|
char transa;
|
||||||
CBLAS_LAYOUT transb;
|
char transb;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
void qmckl_dgemm_cl(struct dgemm_args args, double* A, double* B, double* C);
|
void dgemm_codelet_cpu(void *buffers[], void* cl_arg)
|
||||||
|
|
||||||
void dgemm_codelet(void *buffers[], void* cl_arg)
|
|
||||||
{
|
{
|
||||||
struct dgemm_args *args = cl_arg;
|
struct dgemm_args *args = cl_arg;
|
||||||
double* A = (double*) STARPU_MATRIX_GET_PTR(buffers[0]);
|
double* A = (double*) STARPU_MATRIX_GET_PTR(buffers[0]);
|
||||||
double* B = (double*) STARPU_MATRIX_GET_PTR(buffers[1]);
|
double* B = (double*) STARPU_MATRIX_GET_PTR(buffers[1]);
|
||||||
double* C = (double*) STARPU_MATRIX_GET_PTR(buffers[2]);
|
double* C = (double*) STARPU_MATRIX_GET_PTR(buffers[2]);
|
||||||
qmckl_dgemm_cl(*args, A, B, C);
|
|
||||||
|
int lda = STARPU_MATRIX_GET_LD(buffers[0]);
|
||||||
|
int ldb = STARPU_MATRIX_GET_LD(buffers[1]);
|
||||||
|
int ldc = STARPU_MATRIX_GET_LD(buffers[2]);
|
||||||
|
|
||||||
|
int m = STARPU_MATRIX_GET_NX(buffers[2]);
|
||||||
|
int n = STARPU_MATRIX_GET_NY(buffers[2]);
|
||||||
|
int k = STARPU_MATRIX_GET_NY(buffers[0]);
|
||||||
|
|
||||||
|
f_dgemm(args->transa, args->transb,
|
||||||
|
m, n, k, args->alpha,
|
||||||
|
A, lda, B, ldb, args->beta, C, ldc);
|
||||||
|
|
||||||
free(args);
|
free(args);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct starpu_codelet dgemm_cl =
|
struct starpu_codelet dgemm_cl =
|
||||||
{
|
{
|
||||||
.where = STARPU_CPU,
|
.where = STARPU_CPU,
|
||||||
.cpu_funcs = { dgemm_codelet },
|
.cpu_funcs = { dgemm_codelet_cpu },
|
||||||
.cpu_funcs_name = { "dgemm_codelet" },
|
.cpu_funcs_name = { "dgemm_codelet_cpu" },
|
||||||
.nbuffers = 3,
|
.nbuffers = 3,
|
||||||
.max_parallelism = 1,
|
.max_parallelism = 1,
|
||||||
.modes = {STARPU_R, STARPU_R, STARPU_RW},
|
.modes = {STARPU_R, STARPU_R, STARPU_RW},
|
||||||
@ -52,13 +72,6 @@ struct starpu_codelet dgemm_cl =
|
|||||||
|
|
||||||
#include<stdio.h>
|
#include<stdio.h>
|
||||||
|
|
||||||
void qmckl_dgemm_cl(struct dgemm_args args, double* A, double* B, double* C) {
|
|
||||||
cblas_dgemm(CblasColMajor, args.transa, args.transb,
|
|
||||||
args.m, args.n, args.k, args.alpha,
|
|
||||||
A, args.lda, B, args.ldb,
|
|
||||||
args.beta, C, args.ldc);
|
|
||||||
}
|
|
||||||
|
|
||||||
static struct dgemm_args* qmckl_dgemm_to_struct(char transa, char transb,
|
static struct dgemm_args* qmckl_dgemm_to_struct(char transa, char transb,
|
||||||
int m, int n, int k,
|
int m, int n, int k,
|
||||||
double alpha,
|
double alpha,
|
||||||
@ -70,6 +83,25 @@ static struct dgemm_args* qmckl_dgemm_to_struct(char transa, char transb,
|
|||||||
struct dgemm_args* args = (struct dgemm_args*) malloc (sizeof(struct dgemm_args));
|
struct dgemm_args* args = (struct dgemm_args*) malloc (sizeof(struct dgemm_args));
|
||||||
assert (args != NULL);
|
assert (args != NULL);
|
||||||
|
|
||||||
|
int dima = (transa == 'T' || transa == 't') ? m : k;
|
||||||
|
int dimb = (transb == 'T' || transb == 't') ? k : n;
|
||||||
|
/*
|
||||||
|
double* A2;
|
||||||
|
double* B2;
|
||||||
|
double* C2;
|
||||||
|
|
||||||
|
starpu_malloc_flags((void **)&A2,
|
||||||
|
lda*dima*sizeof(double),
|
||||||
|
STARPU_MALLOC_PINNED);
|
||||||
|
|
||||||
|
starpu_malloc_flags((void **)&B2,
|
||||||
|
lda*dima*sizeof(double),
|
||||||
|
STARPU_MALLOC_PINNED);
|
||||||
|
|
||||||
|
starpu_malloc_flags((void **)&C2,
|
||||||
|
lda*dima*sizeof(double),
|
||||||
|
STARPU_MALLOC_PINNED);
|
||||||
|
*/
|
||||||
args->alpha = alpha;
|
args->alpha = alpha;
|
||||||
args->beta = beta ;
|
args->beta = beta ;
|
||||||
args->A = A;
|
args->A = A;
|
||||||
@ -81,18 +113,8 @@ static struct dgemm_args* qmckl_dgemm_to_struct(char transa, char transb,
|
|||||||
args->lda = lda;
|
args->lda = lda;
|
||||||
args->ldb = ldb;
|
args->ldb = ldb;
|
||||||
args->ldc = ldc;
|
args->ldc = ldc;
|
||||||
|
args->transa = transa;
|
||||||
if (transa == 'T' || transa == 't') {
|
args->transb = transb;
|
||||||
args->transa = CblasTrans;
|
|
||||||
} else {
|
|
||||||
args->transa = CblasNoTrans;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (transa == 'T' || transa == 't') {
|
|
||||||
args->transb = CblasTrans;
|
|
||||||
} else {
|
|
||||||
args->transb = CblasNoTrans;
|
|
||||||
}
|
|
||||||
return args;
|
return args;
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -177,7 +199,6 @@ void qmckl_tasks_run(struct dgemm_args** gemms, int ngemms)
|
|||||||
int rc = starpu_init(NULL);
|
int rc = starpu_init(NULL);
|
||||||
assert (rc == 0);
|
assert (rc == 0);
|
||||||
|
|
||||||
|
|
||||||
starpu_data_handle_t matrix_handle[ngemms][3];
|
starpu_data_handle_t matrix_handle[ngemms][3];
|
||||||
for (int i=0 ; i<ngemms ; ++i)
|
for (int i=0 ; i<ngemms ; ++i)
|
||||||
{
|
{
|
||||||
@ -224,6 +245,7 @@ void qmckl_tasks_run(struct dgemm_args** gemms, int ngemms)
|
|||||||
starpu_data_unregister(matrix_handle[i][1]);
|
starpu_data_unregister(matrix_handle[i][1]);
|
||||||
starpu_data_unregister(matrix_handle[i][2]);
|
starpu_data_unregister(matrix_handle[i][2]);
|
||||||
}
|
}
|
||||||
|
|
||||||
starpu_shutdown();
|
starpu_shutdown();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user