1
0
mirror of https://github.com/TREX-CoE/irpjast.git synced 2025-01-10 13:08:29 +01:00

OpenMP tasks

This commit is contained in:
Anthony Scemama 2021-04-23 14:30:34 +02:00
parent b3f287b8fb
commit 47823c5bb7
3 changed files with 112 additions and 67 deletions

View File

@ -1,6 +1,7 @@
IRPF90 = irpf90/bin/irpf90 --codelet=factor_een:2 --align=4096 # -s nelec_8:504 -s nnuc:100 -s ncord:5 #-a -d IRPF90 = irpf90/bin/irpf90 --codelet=factor_een:2 --align=4096 # -s nelec_8:504 -s nnuc:100 -s ncord:5 #-a -d
#FC = ifort -xCORE-AVX512 -g -mkl=sequential -qopt-zmm-usage=high #FC = ifort -xCORE-AVX512 -g -mkl=sequential -qopt-zmm-usage=high
FC = ifort -xCORE-AVX2 -g FC = ifort -xCORE-AVX2 -g
CC = gcc -fopenmp
FCFLAGS= -O3 -I . FCFLAGS= -O3 -I .
NINJA = ninja NINJA = ninja
ARCHIVE = ar crs ARCHIVE = ar crs
@ -8,7 +9,7 @@ RANLIB = ranlib
SRC= IRPF90_temp/qmckl_blas_f.f90 IRPF90_temp/qmckl_dgemm.c SRC= IRPF90_temp/qmckl_blas_f.f90 IRPF90_temp/qmckl_dgemm.c
OBJ= IRPF90_temp/qmckl_blas_f.o IRPF90_temp/qmckl_dgemm.o OBJ= IRPF90_temp/qmckl_blas_f.o IRPF90_temp/qmckl_dgemm.o
LIB= -mkl=sequential LIB= -mkl=sequential -lgomp
-include irpf90.make -include irpf90.make
export export

View File

@ -27,10 +27,13 @@ static void qmckl_dgemm_rec(struct dgemm_args args) {
// printf("%5d %5d\n", args.m, args.n); // printf("%5d %5d\n", args.m, args.n);
if ( (args.m <= MIN_SIZE) || (args.n <= MIN_SIZE)) { if ( (args.m <= MIN_SIZE) || (args.n <= MIN_SIZE)) {
cblas_dgemm(CblasColMajor, args.transa, args.transb, #pragma omp task
args.m, args.n, args.k, args.alpha, {
args.A, args.lda, args.B, args.ldb, cblas_dgemm(CblasColMajor, args.transa, args.transb,
args.beta, args.C, args.ldc); args.m, args.n, args.k, args.alpha,
args.A, args.lda, args.B, args.ldb,
args.beta, args.C, args.ldc);
}
} else { } else {
int m1 = args.m / 2; int m1 = args.m / 2;
@ -38,33 +41,45 @@ static void qmckl_dgemm_rec(struct dgemm_args args) {
int n1 = args.n / 2; int n1 = args.n / 2;
int n2 = args.n - n1; int n2 = args.n - n1;
struct dgemm_args args_1 = args; #pragma omp task
args_1.m = m1; {
args_1.n = n1; struct dgemm_args args_1 = args;
qmckl_dgemm_rec(args_1); args_1.m = m1;
args_1.n = n1;
qmckl_dgemm_rec(args_1);
}
// TODO: assuming 'N', 'N' here #pragma omp task
struct dgemm_args args_2 = args; {
args_2.B = args.B + args.ldb*n1; // TODO: assuming 'N', 'N' here
args_2.C = args.C + args.ldc*n1; struct dgemm_args args_2 = args;
args_2.m = m1; args_2.B = args.B + args.ldb*n1;
args_2.n = n2; args_2.C = args.C + args.ldc*n1;
qmckl_dgemm_rec(args_2); args_2.m = m1;
args_2.n = n2;
qmckl_dgemm_rec(args_2);
}
struct dgemm_args args_3 = args; #pragma omp task
args_3.A = args.A + m1; {
args_3.C = args.C + m1; struct dgemm_args args_3 = args;
args_3.m = m2; args_3.A = args.A + m1;
args_3.n = n1; args_3.C = args.C + m1;
qmckl_dgemm_rec(args_3); args_3.m = m2;
args_3.n = n1;
qmckl_dgemm_rec(args_3);
}
struct dgemm_args args_4 = args; #pragma omp task
args_4.A = args.A + m1; {
args_4.B = args.B + args.ldb*n1; struct dgemm_args args_4 = args;
args_4.C = args.C + m1 + args.ldc*n1; args_4.A = args.A + m1;
args_4.m = m2; args_4.B = args.B + args.ldb*n1;
args_4.n = n2; args_4.C = args.C + m1 + args.ldc*n1;
qmckl_dgemm_rec(args_4); args_4.m = m2;
args_4.n = n2;
qmckl_dgemm_rec(args_4);
}
} }
} }
@ -104,5 +119,12 @@ void qmckl_dgemm(char transa, char transb,
args.transb = CblasNoTrans; args.transb = CblasNoTrans;
} }
qmckl_dgemm_rec(args); #pragma omp parallel
{
#pragma omp single
{
qmckl_dgemm_rec(args);
}
#pragma omp taskwait
}
} }

View File

@ -92,7 +92,14 @@ void qmckl_dgemm(char transa, char transb,
args.transb = CblasNoTrans; args.transb = CblasNoTrans;
} }
qmckl_dgemm_rec(args); #pragma omp parallel
{
#pragma omp single
{
qmckl_dgemm_rec(args);
}
#pragma omp taskwait
}
} }
#+END_SRC #+END_SRC
@ -108,10 +115,13 @@ static void qmckl_dgemm_rec(struct dgemm_args args) {
// printf("%5d %5d\n", args.m, args.n); // printf("%5d %5d\n", args.m, args.n);
if ( (args.m <= MIN_SIZE) || (args.n <= MIN_SIZE)) { if ( (args.m <= MIN_SIZE) || (args.n <= MIN_SIZE)) {
cblas_dgemm(CblasColMajor, args.transa, args.transb, #pragma omp task
args.m, args.n, args.k, args.alpha, {
args.A, args.lda, args.B, args.ldb, cblas_dgemm(CblasColMajor, args.transa, args.transb,
args.beta, args.C, args.ldc); args.m, args.n, args.k, args.alpha,
args.A, args.lda, args.B, args.ldb,
args.beta, args.C, args.ldc);
}
} else { } else {
int m1 = args.m / 2; int m1 = args.m / 2;
@ -119,33 +129,45 @@ static void qmckl_dgemm_rec(struct dgemm_args args) {
int n1 = args.n / 2; int n1 = args.n / 2;
int n2 = args.n - n1; int n2 = args.n - n1;
struct dgemm_args args_1 = args; #pragma omp task
args_1.m = m1; {
args_1.n = n1; struct dgemm_args args_1 = args;
qmckl_dgemm_rec(args_1); args_1.m = m1;
args_1.n = n1;
qmckl_dgemm_rec(args_1);
}
// TODO: assuming 'N', 'N' here #pragma omp task
struct dgemm_args args_2 = args; {
args_2.B = args.B + args.ldb*n1; // TODO: assuming 'N', 'N' here
args_2.C = args.C + args.ldc*n1; struct dgemm_args args_2 = args;
args_2.m = m1; args_2.B = args.B + args.ldb*n1;
args_2.n = n2; args_2.C = args.C + args.ldc*n1;
qmckl_dgemm_rec(args_2); args_2.m = m1;
args_2.n = n2;
qmckl_dgemm_rec(args_2);
}
struct dgemm_args args_3 = args; #pragma omp task
args_3.A = args.A + m1; {
args_3.C = args.C + m1; struct dgemm_args args_3 = args;
args_3.m = m2; args_3.A = args.A + m1;
args_3.n = n1; args_3.C = args.C + m1;
qmckl_dgemm_rec(args_3); args_3.m = m2;
args_3.n = n1;
qmckl_dgemm_rec(args_3);
}
struct dgemm_args args_4 = args; #pragma omp task
args_4.A = args.A + m1; {
args_4.B = args.B + args.ldb*n1; struct dgemm_args args_4 = args;
args_4.C = args.C + m1 + args.ldc*n1; args_4.A = args.A + m1;
args_4.m = m2; args_4.B = args.B + args.ldb*n1;
args_4.n = n2; args_4.C = args.C + m1 + args.ldc*n1;
qmckl_dgemm_rec(args_4); args_4.m = m2;
args_4.n = n2;
qmckl_dgemm_rec(args_4);
}
} }
} }