1
0
mirror of https://github.com/TREX-CoE/irpjast.git synced 2025-01-05 02:49:02 +01:00

Introduced wait_tasks

This commit is contained in:
Anthony Scemama 2021-04-23 23:35:06 +02:00
parent 47823c5bb7
commit a43ef7893e
5 changed files with 144 additions and 74 deletions

View File

@ -7,18 +7,22 @@ NINJA = ninja
ARCHIVE = ar crs ARCHIVE = ar crs
RANLIB = ranlib RANLIB = ranlib
SRC= IRPF90_temp/qmckl_blas_f.f90 IRPF90_temp/qmckl_dgemm.c SRC= qmckl_blas_f.f90 qmckl_dgemm.c
OBJ= IRPF90_temp/qmckl_blas_f.o IRPF90_temp/qmckl_dgemm.o OBJ= IRPF90_temp/qmckl_blas_f.o IRPF90_temp/qmckl_dgemm.o
LIB= -mkl=sequential -lgomp LIB= -mkl=sequential -lgomp
-include irpf90.make -include irpf90.make
export export
#irpf90.make: IRPF90_temp/qmckl_blas_f.o irpf90.make: IRPF90_temp/qmckl_blas_f.o
irpf90.make: $(filter-out IRPF90_temp/%, $(wildcard */*.irp.f)) $(wildcard *.irp.f) $(wildcard *.inc.f) Makefile irpf90.make: $(filter-out IRPF90_temp/%, $(wildcard */*.irp.f)) $(wildcard *.irp.f) $(wildcard *.inc.f) Makefile
$(IRPF90) $(IRPF90)
IRPF90_temp/%.f90: %.f90
IRPF90_temp/%.c: %.c IRPF90_temp/%.c: %.c
IRPF90_temp/%.o: %.f90
$(FC) $(FCFLAGS) -g -c $< -o $@
IRPF90_temp/%.o: %.c IRPF90_temp/%.o: %.c
$(CC) -g -c $< -o $@ $(CC) -g -c $< -o $@

View File

@ -10,26 +10,32 @@
! dtmp_c: ! dtmp_c:
! dr_{ij}^k . R_{ja}^l -> dtmp_c_{ia}^{kl} ! dr_{ij}^k . R_{ja}^l -> dtmp_c_{ia}^{kl}
END_DOC END_DOC
integer :: k integer :: k, icount
integer*8 :: gemms(2*ncord)
icount = 0
! r_{ij}^k . R_{ja}^l -> tmp_c_{ia}^{kl} ! r_{ij}^k . R_{ja}^l -> tmp_c_{ia}^{kl}
do k=0,ncord-1 do k=0,ncord-1
icount += 1
call qmckl_dgemm('N','N', nelec, nnuc*(ncord+1), nelec, 1.d0, & call qmckl_dgemm('N','N', nelec, nnuc*(ncord+1), nelec, 1.d0, &
rescale_een_e(1,1,k), size(rescale_een_e,1), & rescale_een_e(1,1,k), size(rescale_een_e,1), &
rescale_een_n(1,1,0), size(rescale_een_n,1), 0.d0, & rescale_een_n(1,1,0), size(rescale_een_n,1), 0.d0, &
tmp_c(1,1,0,k), size(tmp_c,1)) tmp_c(1,1,0,k), size(tmp_c,1), gemms(icount))
enddo enddo
! dr_{ij}^k . R_{ja}^l -> dtmp_c_{ia}^{kl} ! dr_{ij}^k . R_{ja}^l -> dtmp_c_{ia}^{kl}
do k=0,ncord-1 do k=0,ncord-1
icount += 1
call qmckl_dgemm('N','N', 4*nelec_8, nnuc*(ncord+1), nelec, 1.d0, & call qmckl_dgemm('N','N', 4*nelec_8, nnuc*(ncord+1), nelec, 1.d0, &
rescale_een_e_deriv_e(1,1,1,k), & rescale_een_e_deriv_e(1,1,1,k), &
size(rescale_een_e_deriv_e,1)*size(rescale_een_e_deriv_e,2), & size(rescale_een_e_deriv_e,1)*size(rescale_een_e_deriv_e,2), &
rescale_een_n(1,1,0), & rescale_een_n(1,1,0), &
size(rescale_een_n,1), 0.d0, & size(rescale_een_n,1), 0.d0, &
dtmp_c(1,1,1,0,k), size(dtmp_c,1)*size(dtmp_c,2)) dtmp_c(1,1,1,0,k), size(dtmp_c,1)*size(dtmp_c,2), &
gemms(icount))
enddo enddo
call qmckl_tasks_run(gemms, icount)
END_PROVIDER END_PROVIDER

View File

@ -5,14 +5,24 @@ module qmckl_blas
interface interface
subroutine qmckl_dgemm(transa, transb, m, n, k, & subroutine qmckl_dgemm(transa, transb, m, n, k, &
alpha, A, lda, B, ldb, beta, C, ldc) bind(C) alpha, A, lda, B, ldb, beta, C, ldc, res) bind(C)
use :: iso_c_binding use :: iso_c_binding
implicit none implicit none
character(kind=c_char ), value :: transa, transb character(kind=c_char ), value :: transa, transb
integer (kind=c_int ), value :: m, n, k, lda, ldb, ldc integer (kind=c_int ), value :: m, n, k, lda, ldb, ldc
real (kind=c_double), value :: alpha, beta real (kind=c_double), value :: alpha, beta
real (kind=c_double) :: A(lda,*), B(ldb,*), C(ldc,*) real (kind=c_double) :: A(lda,*), B(ldb,*), C(ldc,*)
integer (kind=c_int64_t) :: res
end subroutine qmckl_dgemm end subroutine qmckl_dgemm
end interface end interface
interface
subroutine qmckl_tasks_run(gemms, ngemms) bind(C)
use :: iso_c_binding
implicit none
integer (kind=c_int32_t), value :: ngemms
integer (kind=c_int64_t) :: gemms(ngemms)
end subroutine qmckl_tasks_run
end interface
end module qmckl_blas end module qmckl_blas

View File

@ -1,6 +1,11 @@
/* Generated from qmckl_dgemm.org */ /* Generated from qmckl_dgemm.org */
#include <cblas.h> #include <cblas.h>
#include <stdint.h>
#include <assert.h>
#include <stdlib.h>
struct dgemm_args { struct dgemm_args {
double alpha; double alpha;
@ -29,6 +34,7 @@ static void qmckl_dgemm_rec(struct dgemm_args args) {
if ( (args.m <= MIN_SIZE) || (args.n <= MIN_SIZE)) { if ( (args.m <= MIN_SIZE) || (args.n <= MIN_SIZE)) {
#pragma omp task #pragma omp task
{ {
printf("BLAS %5d %5d %5d\n", args.m, args.n, args.k);
cblas_dgemm(CblasColMajor, args.transa, args.transb, cblas_dgemm(CblasColMajor, args.transa, args.transb,
args.m, args.n, args.k, args.alpha, args.m, args.n, args.k, args.alpha,
args.A, args.lda, args.B, args.ldb, args.A, args.lda, args.B, args.ldb,
@ -90,40 +96,50 @@ void qmckl_dgemm(char transa, char transb,
double* A, int lda, double* A, int lda,
double* B, int ldb, double* B, int ldb,
double beta, double beta,
double* C, int ldc) double* C, int ldc,
int64_t* result)
{ {
struct dgemm_args args; struct dgemm_args* args = (struct dgemm_args*) malloc (sizeof(struct dgemm_args));
assert (args != NULL);
*result = (int64_t) args;
args.alpha = alpha; args->alpha = alpha;
args.beta = beta ; args->beta = beta ;
args.A = A; args->A = A;
args.B = B; args->B = B;
args.C = C; args->C = C;
args.m = m; args->m = m;
args.n = n; args->n = n;
args.k = k; args->k = k;
args.lda = lda; args->lda = lda;
args.ldb = ldb; args->ldb = ldb;
args.ldc = ldc; args->ldc = ldc;
if (transa == 'T' || transa == 't') { if (transa == 'T' || transa == 't') {
args.transa = CblasTrans; args->transa = CblasTrans;
} else { } else {
args.transa = CblasNoTrans; args->transa = CblasNoTrans;
} }
CBLAS_LAYOUT tb;
if (transa == 'T' || transa == 't') { if (transa == 'T' || transa == 't') {
args.transb = CblasTrans; args->transb = CblasTrans;
} else { } else {
args.transb = CblasNoTrans; args->transb = CblasNoTrans;
} }
}
void qmckl_tasks_run(struct dgemm_args** gemms, int ngemms)
{
#pragma omp parallel #pragma omp parallel
{ {
#pragma omp single #pragma omp single
{ {
qmckl_dgemm_rec(args); for (int i=0 ; i<ngemms ; ++i)
{
qmckl_dgemm_rec(*(gemms[i]));
}
} }
#pragma omp taskwait #pragma omp taskwait
} }

View File

@ -5,6 +5,26 @@
Generated from qmckl_dgemm.org Generated from qmckl_dgemm.org
#+END_SRC #+END_SRC
#+BEGIN_SRC c :noweb yes :tangle qmckl_dgemm.c
/* <<header>> */
#include <cblas.h>
#include <stdint.h>
#include <assert.h>
#include <stdlib.h>
<<tasks_init>>
<<dgemm_args>>
<<dgemm_rec>>
<<dgemm>>
<<tasks_run>>
#+END_SRC
* Fortran interface * Fortran interface
#+BEGIN_SRC f90 :noweb yes :tangle qmckl_blas_f.f90 #+BEGIN_SRC f90 :noweb yes :tangle qmckl_blas_f.f90
@ -15,22 +35,34 @@ module qmckl_blas
interface interface
subroutine qmckl_dgemm(transa, transb, m, n, k, & subroutine qmckl_dgemm(transa, transb, m, n, k, &
alpha, A, lda, B, ldb, beta, C, ldc) bind(C) alpha, A, lda, B, ldb, beta, C, ldc, res) bind(C)
use :: iso_c_binding use :: iso_c_binding
implicit none implicit none
character(kind=c_char ), value :: transa, transb character(kind=c_char ), value :: transa, transb
integer (kind=c_int ), value :: m, n, k, lda, ldb, ldc integer (kind=c_int ), value :: m, n, k, lda, ldb, ldc
real (kind=c_double), value :: alpha, beta real (kind=c_double), value :: alpha, beta
real (kind=c_double) :: A(lda,*), B(ldb,*), C(ldc,*) real (kind=c_double) :: A(lda,*), B(ldb,*), C(ldc,*)
integer (kind=c_int64_t) :: res
end subroutine qmckl_dgemm end subroutine qmckl_dgemm
end interface end interface
interface
subroutine qmckl_tasks_run(gemms, ngemms) bind(C)
use :: iso_c_binding
implicit none
integer (kind=c_int32_t), value :: ngemms
integer (kind=c_int64_t) :: gemms(ngemms)
end subroutine qmckl_tasks_run
end interface
end module qmckl_blas end module qmckl_blas
#+END_SRC #+END_SRC
* TODO C code * TODO C code
To avoid passing too many arguments to recursive subroutines, we put
all the arguments in a struct.
The main function packs the arguments in the struct and returns
the struct as a result.
#+NAME: dgemm_args #+NAME: dgemm_args
#+BEGIN_SRC c #+BEGIN_SRC c
@ -52,9 +84,6 @@ struct dgemm_args {
#+END_SRC #+END_SRC
The driver routine packs the arguments in the struct and calls the
recursive routine.
#+NAME: dgemm #+NAME: dgemm
#+BEGIN_SRC c #+BEGIN_SRC c
void qmckl_dgemm(char transa, char transb, void qmckl_dgemm(char transa, char transb,
@ -63,45 +92,60 @@ void qmckl_dgemm(char transa, char transb,
double* A, int lda, double* A, int lda,
double* B, int ldb, double* B, int ldb,
double beta, double beta,
double* C, int ldc) double* C, int ldc,
int64_t* result)
{ {
struct dgemm_args args; struct dgemm_args* args = (struct dgemm_args*) malloc (sizeof(struct dgemm_args));
assert (args != NULL);
*result = (int64_t) args;
args.alpha = alpha; args->alpha = alpha;
args.beta = beta ; args->beta = beta ;
args.A = A; args->A = A;
args.B = B; args->B = B;
args.C = C; args->C = C;
args.m = m; args->m = m;
args.n = n; args->n = n;
args.k = k; args->k = k;
args.lda = lda; args->lda = lda;
args.ldb = ldb; args->ldb = ldb;
args.ldc = ldc; args->ldc = ldc;
if (transa == 'T' || transa == 't') { if (transa == 'T' || transa == 't') {
args.transa = CblasTrans; args->transa = CblasTrans;
} else { } else {
args.transa = CblasNoTrans; args->transa = CblasNoTrans;
} }
CBLAS_LAYOUT tb;
if (transa == 'T' || transa == 't') { if (transa == 'T' || transa == 't') {
args.transb = CblasTrans; args->transb = CblasTrans;
} else { } else {
args.transb = CblasNoTrans; args->transb = CblasNoTrans;
} }
}
#+END_SRC
To run the dgemms as tasks, pass all the dgemms to do to the
following function. It call task-based the recursive dgemm defined below.
#+NAME: tasks_run
#+BEGIN_SRC c
void qmckl_tasks_run(struct dgemm_args** gemms, int ngemms)
{
#pragma omp parallel #pragma omp parallel
{ {
#pragma omp single #pragma omp single
{ {
qmckl_dgemm_rec(args); for (int i=0 ; i<ngemms ; ++i)
{
qmckl_dgemm_rec(*(gemms[i]));
}
} }
#pragma omp taskwait #pragma omp taskwait
} }
} }
#+END_SRC #+END_SRC
@ -117,6 +161,7 @@ static void qmckl_dgemm_rec(struct dgemm_args args) {
if ( (args.m <= MIN_SIZE) || (args.n <= MIN_SIZE)) { if ( (args.m <= MIN_SIZE) || (args.n <= MIN_SIZE)) {
#pragma omp task #pragma omp task
{ {
printf("BLAS %5d %5d %5d\n", args.m, args.n, args.k);
cblas_dgemm(CblasColMajor, args.transa, args.transb, cblas_dgemm(CblasColMajor, args.transa, args.transb,
args.m, args.n, args.k, args.alpha, args.m, args.n, args.k, args.alpha,
args.A, args.lda, args.B, args.ldb, args.A, args.lda, args.B, args.ldb,
@ -173,15 +218,4 @@ static void qmckl_dgemm_rec(struct dgemm_args args) {
} }
#+END_SRC #+END_SRC
#+BEGIN_SRC c :noweb yes :tangle qmckl_dgemm.c
/* <<header>> */
#include <cblas.h>
<<dgemm_args>>
<<dgemm_rec>>
<<dgemm>>
#+END_SRC