mirror of
https://github.com/TREX-CoE/irpjast.git
synced 2024-12-25 05:43:56 +01:00
Recursive GEMM in StarPU
This commit is contained in:
parent
412dba6b92
commit
569893eb25
125
qmckl_dgemm.c
125
qmckl_dgemm.c
@ -97,6 +97,65 @@ static struct dgemm_args* qmckl_dgemm_to_struct(char transa, char transb,
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define MIN_SIZE 512
|
||||||
|
static void qmckl_dgemm_rec(struct dgemm_args args, int64_t* tasks, int64_t* ntasks)
|
||||||
|
{
|
||||||
|
|
||||||
|
|
||||||
|
if ( args.m * args.n <= MIN_SIZE*MIN_SIZE) {
|
||||||
|
|
||||||
|
// printf("%5d %5d\n", args.m, args.n);
|
||||||
|
struct dgemm_args* args_new = (struct dgemm_args*) malloc (sizeof(struct dgemm_args));
|
||||||
|
memcpy(args_new, &args, sizeof(args));
|
||||||
|
tasks[*ntasks] = (int64_t) args_new;
|
||||||
|
*ntasks += 1L;
|
||||||
|
|
||||||
|
} else {
|
||||||
|
|
||||||
|
int m1 = args.m / 2;
|
||||||
|
int m2 = args.m - m1;
|
||||||
|
int n1 = args.n / 2;
|
||||||
|
int n2 = args.n - n1;
|
||||||
|
|
||||||
|
{
|
||||||
|
struct dgemm_args args_1 = args;
|
||||||
|
args_1.m = m1;
|
||||||
|
args_1.n = n1;
|
||||||
|
qmckl_dgemm_rec(args_1, tasks, ntasks);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// TODO: assuming 'N', 'N' here
|
||||||
|
struct dgemm_args args_2 = args;
|
||||||
|
args_2.B = args.B + args.ldb*n1;
|
||||||
|
args_2.C = args.C + args.ldc*n1;
|
||||||
|
args_2.m = m1;
|
||||||
|
args_2.n = n2;
|
||||||
|
qmckl_dgemm_rec(args_2, tasks, ntasks);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
struct dgemm_args args_3 = args;
|
||||||
|
args_3.A = args.A + m1;
|
||||||
|
args_3.C = args.C + m1;
|
||||||
|
args_3.m = m2;
|
||||||
|
args_3.n = n1;
|
||||||
|
qmckl_dgemm_rec(args_3, tasks, ntasks);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
struct dgemm_args args_4 = args;
|
||||||
|
args_4.A = args.A + m1;
|
||||||
|
args_4.B = args.B + args.ldb*n1;
|
||||||
|
args_4.C = args.C + m1 + args.ldc*n1;
|
||||||
|
args_4.m = m2;
|
||||||
|
args_4.n = n2;
|
||||||
|
qmckl_dgemm_rec(args_4, tasks, ntasks);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
void qmckl_dgemm(char transa, char transb,
|
void qmckl_dgemm(char transa, char transb,
|
||||||
int m, int n, int k,
|
int m, int n, int k,
|
||||||
double alpha,
|
double alpha,
|
||||||
@ -106,9 +165,11 @@ void qmckl_dgemm(char transa, char transb,
|
|||||||
double* C, int ldc,
|
double* C, int ldc,
|
||||||
int64_t* tasks, int64_t* ntasks)
|
int64_t* tasks, int64_t* ntasks)
|
||||||
{
|
{
|
||||||
tasks[*ntasks] = (int64_t) qmckl_dgemm_to_struct (transa, transb,
|
struct dgemm_args* args = qmckl_dgemm_to_struct (transa, transb,
|
||||||
m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
|
m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
|
||||||
*ntasks += 1L;
|
|
||||||
|
qmckl_dgemm_rec(*args, tasks, ntasks);
|
||||||
|
free(args);
|
||||||
}
|
}
|
||||||
|
|
||||||
void qmckl_tasks_run(struct dgemm_args** gemms, int ngemms)
|
void qmckl_tasks_run(struct dgemm_args** gemms, int ngemms)
|
||||||
@ -166,63 +227,3 @@ void qmckl_tasks_run(struct dgemm_args** gemms, int ngemms)
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#define MIN_SIZE 512
|
|
||||||
static void qmckl_dgemm_rec(struct dgemm_args args) {
|
|
||||||
|
|
||||||
// printf("%5d %5d\n", args.m, args.n);
|
|
||||||
|
|
||||||
if ( (args.m <= MIN_SIZE) || (args.n <= MIN_SIZE)) {
|
|
||||||
#pragma omp task
|
|
||||||
{
|
|
||||||
printf("BLAS %5d %5d %5d\n", args.m, args.n, args.k);
|
|
||||||
cblas_dgemm(CblasColMajor, args.transa, args.transb,
|
|
||||||
args.m, args.n, args.k, args.alpha,
|
|
||||||
args.A, args.lda, args.B, args.ldb,
|
|
||||||
args.beta, args.C, args.ldc);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
|
|
||||||
int m1 = args.m / 2;
|
|
||||||
int m2 = args.m - m1;
|
|
||||||
int n1 = args.n / 2;
|
|
||||||
int n2 = args.n - n1;
|
|
||||||
|
|
||||||
{
|
|
||||||
struct dgemm_args args_1 = args;
|
|
||||||
args_1.m = m1;
|
|
||||||
args_1.n = n1;
|
|
||||||
qmckl_dgemm_rec(args_1);
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
// TODO: assuming 'N', 'N' here
|
|
||||||
struct dgemm_args args_2 = args;
|
|
||||||
args_2.B = args.B + args.ldb*n1;
|
|
||||||
args_2.C = args.C + args.ldc*n1;
|
|
||||||
args_2.m = m1;
|
|
||||||
args_2.n = n2;
|
|
||||||
qmckl_dgemm_rec(args_2);
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
struct dgemm_args args_3 = args;
|
|
||||||
args_3.A = args.A + m1;
|
|
||||||
args_3.C = args.C + m1;
|
|
||||||
args_3.m = m2;
|
|
||||||
args_3.n = n1;
|
|
||||||
qmckl_dgemm_rec(args_3);
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
struct dgemm_args args_4 = args;
|
|
||||||
args_4.A = args.A + m1;
|
|
||||||
args_4.B = args.B + args.ldb*n1;
|
|
||||||
args_4.C = args.C + m1 + args.ldc*n1;
|
|
||||||
args_4.m = m2;
|
|
||||||
args_4.n = n2;
|
|
||||||
qmckl_dgemm_rec(args_4);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user