diff --git a/Makefile b/Makefile index 6a2be11..8911dc1 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ RANLIB = ranlib SRC= qmckl_blas_f.f90 qmckl_dgemm.c OBJ= IRPF90_temp/qmckl_blas_f.o IRPF90_temp/qmckl_dgemm.o -LIB= -mkl=sequential $(shell pkg-config --libs $(STARPU) ) +LIB= -mkl=sequential $(shell pkg-config --libs $(STARPU) magma) -include irpf90.make export diff --git a/qmckl_blas_f.f90 b/qmckl_blas_f.f90 index fb51309..cf93760 100644 --- a/qmckl_blas_f.f90 +++ b/qmckl_blas_f.f90 @@ -25,5 +25,17 @@ module qmckl_blas integer (kind=c_int64_t) :: tasks(ntasks) end subroutine qmckl_tasks_run end interface - end module qmckl_blas + +subroutine f_dgemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) & + bind(C, name='f_dgemm') + use iso_c_binding + implicit none + character, intent(in), value :: TRANSA, TRANSB + integer, intent(in), value :: M,N,K,LDA,LDB,LDC + double precision, intent(in), value :: ALPHA, BETA + double precision, intent(in) :: A(LDA,*), B(LDB,*) + double precision, intent(out) :: C(LDC,*) + call dgemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) +end subroutine + diff --git a/qmckl_dgemm.c b/qmckl_dgemm.c index cd22551..50ab1c2 100644 --- a/qmckl_dgemm.c +++ b/qmckl_dgemm.c @@ -2,12 +2,22 @@ #include -#include #include #include #include #include +//#define CPU_ENABLED 1 + +#ifdef CPU_ENABLED +#include +#include +#endif + +void f_dgemm(const char transa, const char transb, const int m, const int n, const int k, + const double alpha, const double* A, const int lda, const double* B, + const int ldb, const double beta, double* C, const int ldc); + struct dgemm_args { @@ -22,28 +32,38 @@ struct dgemm_args { int lda; int ldb; int ldc; - CBLAS_LAYOUT transa; - CBLAS_LAYOUT transb; + char transa; + char transb; }; -void qmckl_dgemm_cl(struct dgemm_args args, double* A, double* B, double* C); - -void dgemm_codelet(void *buffers[], void* cl_arg) +void dgemm_codelet_cpu(void *buffers[], void* cl_arg) { struct dgemm_args *args = cl_arg; double* A = (double*) STARPU_MATRIX_GET_PTR(buffers[0]); double* B = (double*) STARPU_MATRIX_GET_PTR(buffers[1]); double* C = (double*) STARPU_MATRIX_GET_PTR(buffers[2]); - qmckl_dgemm_cl(*args, A, B, C); + + int lda = STARPU_MATRIX_GET_LD(buffers[0]); + int ldb = STARPU_MATRIX_GET_LD(buffers[1]); + int ldc = STARPU_MATRIX_GET_LD(buffers[2]); + + int m = STARPU_MATRIX_GET_NX(buffers[2]); + int n = STARPU_MATRIX_GET_NY(buffers[2]); + int k = STARPU_MATRIX_GET_NY(buffers[0]); + + f_dgemm(args->transa, args->transb, + m, n, k, args->alpha, + A, lda, B, ldb, args->beta, C, ldc); + free(args); } struct starpu_codelet dgemm_cl = { .where = STARPU_CPU, - .cpu_funcs = { dgemm_codelet }, - .cpu_funcs_name = { "dgemm_codelet" }, + .cpu_funcs = { dgemm_codelet_cpu }, + .cpu_funcs_name = { "dgemm_codelet_cpu" }, .nbuffers = 3, .max_parallelism = 1, .modes = {STARPU_R, STARPU_R, STARPU_RW}, @@ -52,13 +72,6 @@ struct starpu_codelet dgemm_cl = #include -void qmckl_dgemm_cl(struct dgemm_args args, double* A, double* B, double* C) { - cblas_dgemm(CblasColMajor, args.transa, args.transb, - args.m, args.n, args.k, args.alpha, - A, args.lda, B, args.ldb, - args.beta, C, args.ldc); -} - static struct dgemm_args* qmckl_dgemm_to_struct(char transa, char transb, int m, int n, int k, double alpha, @@ -70,6 +83,25 @@ static struct dgemm_args* qmckl_dgemm_to_struct(char transa, char transb, struct dgemm_args* args = (struct dgemm_args*) malloc (sizeof(struct dgemm_args)); assert (args != NULL); + int dima = (transa == 'T' || transa == 't') ? m : k; + int dimb = (transb == 'T' || transb == 't') ? k : n; +/* + double* A2; + double* B2; + double* C2; + + starpu_malloc_flags((void **)&A2, + lda*dima*sizeof(double), + STARPU_MALLOC_PINNED); + + starpu_malloc_flags((void **)&B2, + lda*dima*sizeof(double), + STARPU_MALLOC_PINNED); + + starpu_malloc_flags((void **)&C2, + lda*dima*sizeof(double), + STARPU_MALLOC_PINNED); +*/ args->alpha = alpha; args->beta = beta ; args->A = A; @@ -81,18 +113,8 @@ static struct dgemm_args* qmckl_dgemm_to_struct(char transa, char transb, args->lda = lda; args->ldb = ldb; args->ldc = ldc; - - if (transa == 'T' || transa == 't') { - args->transa = CblasTrans; - } else { - args->transa = CblasNoTrans; - } - - if (transa == 'T' || transa == 't') { - args->transb = CblasTrans; - } else { - args->transb = CblasNoTrans; - } + args->transa = transa; + args->transb = transb; return args; } @@ -177,7 +199,6 @@ void qmckl_tasks_run(struct dgemm_args** gemms, int ngemms) int rc = starpu_init(NULL); assert (rc == 0); - starpu_data_handle_t matrix_handle[ngemms][3]; for (int i=0 ; i