diff --git a/Makefile b/Makefile index d65fa91..103e8aa 100644 --- a/Makefile +++ b/Makefile @@ -7,18 +7,22 @@ NINJA = ninja ARCHIVE = ar crs RANLIB = ranlib -SRC= IRPF90_temp/qmckl_blas_f.f90 IRPF90_temp/qmckl_dgemm.c +SRC= qmckl_blas_f.f90 qmckl_dgemm.c OBJ= IRPF90_temp/qmckl_blas_f.o IRPF90_temp/qmckl_dgemm.o LIB= -mkl=sequential -lgomp -include irpf90.make export -#irpf90.make: IRPF90_temp/qmckl_blas_f.o +irpf90.make: IRPF90_temp/qmckl_blas_f.o irpf90.make: $(filter-out IRPF90_temp/%, $(wildcard */*.irp.f)) $(wildcard *.irp.f) $(wildcard *.inc.f) Makefile $(IRPF90) +IRPF90_temp/%.f90: %.f90 IRPF90_temp/%.c: %.c +IRPF90_temp/%.o: %.f90 + $(FC) $(FCFLAGS) -g -c $< -o $@ + IRPF90_temp/%.o: %.c $(CC) -g -c $< -o $@ diff --git a/el_nuc_el_blas.irp.f b/el_nuc_el_blas.irp.f index c369ba2..88da0e4 100644 --- a/el_nuc_el_blas.irp.f +++ b/el_nuc_el_blas.irp.f @@ -10,26 +10,32 @@ ! dtmp_c: ! dr_{ij}^k . R_{ja}^l -> dtmp_c_{ia}^{kl} END_DOC - integer :: k + integer :: k, icount + integer*8 :: gemms(2*ncord) + icount = 0 ! r_{ij}^k . R_{ja}^l -> tmp_c_{ia}^{kl} do k=0,ncord-1 + icount += 1 call qmckl_dgemm('N','N', nelec, nnuc*(ncord+1), nelec, 1.d0, & rescale_een_e(1,1,k), size(rescale_een_e,1), & rescale_een_n(1,1,0), size(rescale_een_n,1), 0.d0, & - tmp_c(1,1,0,k), size(tmp_c,1)) + tmp_c(1,1,0,k), size(tmp_c,1), gemms(icount)) enddo ! dr_{ij}^k . R_{ja}^l -> dtmp_c_{ia}^{kl} do k=0,ncord-1 + icount += 1 call qmckl_dgemm('N','N', 4*nelec_8, nnuc*(ncord+1), nelec, 1.d0, & rescale_een_e_deriv_e(1,1,1,k), & size(rescale_een_e_deriv_e,1)*size(rescale_een_e_deriv_e,2), & rescale_een_n(1,1,0), & size(rescale_een_n,1), 0.d0, & - dtmp_c(1,1,1,0,k), size(dtmp_c,1)*size(dtmp_c,2)) + dtmp_c(1,1,1,0,k), size(dtmp_c,1)*size(dtmp_c,2), & + gemms(icount)) enddo + call qmckl_tasks_run(gemms, icount) END_PROVIDER diff --git a/qmckl_blas_f.f90 b/qmckl_blas_f.f90 index 621bc89..6cc3ca2 100644 --- a/qmckl_blas_f.f90 +++ b/qmckl_blas_f.f90 @@ -5,14 +5,24 @@ module qmckl_blas interface subroutine qmckl_dgemm(transa, transb, m, n, k, & - alpha, A, lda, B, ldb, beta, C, ldc) bind(C) + alpha, A, lda, B, ldb, beta, C, ldc, res) bind(C) use :: iso_c_binding implicit none character(kind=c_char ), value :: transa, transb integer (kind=c_int ), value :: m, n, k, lda, ldb, ldc real (kind=c_double), value :: alpha, beta real (kind=c_double) :: A(lda,*), B(ldb,*), C(ldc,*) + integer (kind=c_int64_t) :: res end subroutine qmckl_dgemm end interface + interface + subroutine qmckl_tasks_run(gemms, ngemms) bind(C) + use :: iso_c_binding + implicit none + integer (kind=c_int32_t), value :: ngemms + integer (kind=c_int64_t) :: gemms(ngemms) + end subroutine qmckl_tasks_run + end interface + end module qmckl_blas diff --git a/qmckl_dgemm.c b/qmckl_dgemm.c index 6c7bea3..d372d22 100644 --- a/qmckl_dgemm.c +++ b/qmckl_dgemm.c @@ -1,6 +1,11 @@ /* Generated from qmckl_dgemm.org */ #include +#include +#include +#include + + struct dgemm_args { double alpha; @@ -19,7 +24,7 @@ struct dgemm_args { }; -#define MIN_SIZE 512 +#define MIN_SIZE 512 #include static void qmckl_dgemm_rec(struct dgemm_args args) { @@ -29,6 +34,7 @@ static void qmckl_dgemm_rec(struct dgemm_args args) { if ( (args.m <= MIN_SIZE) || (args.n <= MIN_SIZE)) { #pragma omp task { + printf("BLAS %5d %5d %5d\n", args.m, args.n, args.k); cblas_dgemm(CblasColMajor, args.transa, args.transb, args.m, args.n, args.k, args.alpha, args.A, args.lda, args.B, args.ldb, @@ -90,40 +96,50 @@ void qmckl_dgemm(char transa, char transb, double* A, int lda, double* B, int ldb, double beta, - double* C, int ldc) + double* C, int ldc, + int64_t* result) { - struct dgemm_args args; + struct dgemm_args* args = (struct dgemm_args*) malloc (sizeof(struct dgemm_args)); + assert (args != NULL); + *result = (int64_t) args; - args.alpha = alpha; - args.beta = beta ; - args.A = A; - args.B = B; - args.C = C; - args.m = m; - args.n = n; - args.k = k; - args.lda = lda; - args.ldb = ldb; - args.ldc = ldc; + args->alpha = alpha; + args->beta = beta ; + args->A = A; + args->B = B; + args->C = C; + args->m = m; + args->n = n; + args->k = k; + args->lda = lda; + args->ldb = ldb; + args->ldc = ldc; if (transa == 'T' || transa == 't') { - args.transa = CblasTrans; + args->transa = CblasTrans; } else { - args.transa = CblasNoTrans; + args->transa = CblasNoTrans; } - CBLAS_LAYOUT tb; if (transa == 'T' || transa == 't') { - args.transb = CblasTrans; + args->transb = CblasTrans; } else { - args.transb = CblasNoTrans; + args->transb = CblasNoTrans; } +} + + +void qmckl_tasks_run(struct dgemm_args** gemms, int ngemms) +{ #pragma omp parallel { #pragma omp single { - qmckl_dgemm_rec(args); + for (int i=0 ; i> */ + +#include +#include +#include +#include + +<> + +<> + +<> + +<> + +<> + + #+END_SRC + * Fortran interface #+BEGIN_SRC f90 :noweb yes :tangle qmckl_blas_f.f90 @@ -15,25 +35,37 @@ module qmckl_blas interface subroutine qmckl_dgemm(transa, transb, m, n, k, & - alpha, A, lda, B, ldb, beta, C, ldc) bind(C) + alpha, A, lda, B, ldb, beta, C, ldc, res) bind(C) use :: iso_c_binding implicit none character(kind=c_char ), value :: transa, transb integer (kind=c_int ), value :: m, n, k, lda, ldb, ldc real (kind=c_double), value :: alpha, beta real (kind=c_double) :: A(lda,*), B(ldb,*), C(ldc,*) + integer (kind=c_int64_t) :: res end subroutine qmckl_dgemm end interface + interface + subroutine qmckl_tasks_run(gemms, ngemms) bind(C) + use :: iso_c_binding + implicit none + integer (kind=c_int32_t), value :: ngemms + integer (kind=c_int64_t) :: gemms(ngemms) + end subroutine qmckl_tasks_run + end interface + end module qmckl_blas #+END_SRC * TODO C code - To avoid passing too many arguments to recursive subroutines, we put - all the arguments in a struct. - #+NAME: dgemm_args - #+BEGIN_SRC c + + The main function packs the arguments in the struct and returns + the struct as a result. + + #+NAME: dgemm_args + #+BEGIN_SRC c struct dgemm_args { double alpha; double beta; @@ -50,64 +82,76 @@ struct dgemm_args { CBLAS_LAYOUT transb; }; - #+END_SRC + #+END_SRC - The driver routine packs the arguments in the struct and calls the - recursive routine. - - #+NAME: dgemm - #+BEGIN_SRC c + #+NAME: dgemm + #+BEGIN_SRC c void qmckl_dgemm(char transa, char transb, int m, int n, int k, double alpha, double* A, int lda, double* B, int ldb, double beta, - double* C, int ldc) + double* C, int ldc, + int64_t* result) { - struct dgemm_args args; + struct dgemm_args* args = (struct dgemm_args*) malloc (sizeof(struct dgemm_args)); + assert (args != NULL); + *result = (int64_t) args; - args.alpha = alpha; - args.beta = beta ; - args.A = A; - args.B = B; - args.C = C; - args.m = m; - args.n = n; - args.k = k; - args.lda = lda; - args.ldb = ldb; - args.ldc = ldc; + args->alpha = alpha; + args->beta = beta ; + args->A = A; + args->B = B; + args->C = C; + args->m = m; + args->n = n; + args->k = k; + args->lda = lda; + args->ldb = ldb; + args->ldc = ldc; if (transa == 'T' || transa == 't') { - args.transa = CblasTrans; + args->transa = CblasTrans; } else { - args.transa = CblasNoTrans; + args->transa = CblasNoTrans; } - CBLAS_LAYOUT tb; if (transa == 'T' || transa == 't') { - args.transb = CblasTrans; + args->transb = CblasTrans; } else { - args.transb = CblasNoTrans; + args->transb = CblasNoTrans; } +} + + #+END_SRC + + To run the dgemms as tasks, pass all the dgemms to do to the + following function. It call task-based the recursive dgemm defined below. + + #+NAME: tasks_run + #+BEGIN_SRC c +void qmckl_tasks_run(struct dgemm_args** gemms, int ngemms) +{ #pragma omp parallel { #pragma omp single { - qmckl_dgemm_rec(args); + for (int i=0 ; i static void qmckl_dgemm_rec(struct dgemm_args args) { @@ -117,6 +161,7 @@ static void qmckl_dgemm_rec(struct dgemm_args args) { if ( (args.m <= MIN_SIZE) || (args.n <= MIN_SIZE)) { #pragma omp task { + printf("BLAS %5d %5d %5d\n", args.m, args.n, args.k); cblas_dgemm(CblasColMajor, args.transa, args.transb, args.m, args.n, args.k, args.alpha, args.A, args.lda, args.B, args.ldb, @@ -171,17 +216,6 @@ static void qmckl_dgemm_rec(struct dgemm_args args) { } } - #+END_SRC + #+END_SRC - #+BEGIN_SRC c :noweb yes :tangle qmckl_dgemm.c -/* <
> */ - -#include - -<> - -<> - -<> - #+END_SRC