diff --git a/qmckl_dgemm.c b/qmckl_dgemm.c index 90d29c7..c8e24fb 100644 --- a/qmckl_dgemm.c +++ b/qmckl_dgemm.c @@ -97,6 +97,65 @@ static struct dgemm_args* qmckl_dgemm_to_struct(char transa, char transb, } +#define MIN_SIZE 512 +static void qmckl_dgemm_rec(struct dgemm_args args, int64_t* tasks, int64_t* ntasks) +{ + + + if ( args.m * args.n <= MIN_SIZE*MIN_SIZE) { + +// printf("%5d %5d\n", args.m, args.n); + struct dgemm_args* args_new = (struct dgemm_args*) malloc (sizeof(struct dgemm_args)); + memcpy(args_new, &args, sizeof(args)); + tasks[*ntasks] = (int64_t) args_new; + *ntasks += 1L; + + } else { + + int m1 = args.m / 2; + int m2 = args.m - m1; + int n1 = args.n / 2; + int n2 = args.n - n1; + + { + struct dgemm_args args_1 = args; + args_1.m = m1; + args_1.n = n1; + qmckl_dgemm_rec(args_1, tasks, ntasks); + } + + { + // TODO: assuming 'N', 'N' here + struct dgemm_args args_2 = args; + args_2.B = args.B + args.ldb*n1; + args_2.C = args.C + args.ldc*n1; + args_2.m = m1; + args_2.n = n2; + qmckl_dgemm_rec(args_2, tasks, ntasks); + } + + { + struct dgemm_args args_3 = args; + args_3.A = args.A + m1; + args_3.C = args.C + m1; + args_3.m = m2; + args_3.n = n1; + qmckl_dgemm_rec(args_3, tasks, ntasks); + } + + { + struct dgemm_args args_4 = args; + args_4.A = args.A + m1; + args_4.B = args.B + args.ldb*n1; + args_4.C = args.C + m1 + args.ldc*n1; + args_4.m = m2; + args_4.n = n2; + qmckl_dgemm_rec(args_4, tasks, ntasks); + } + } + +} + void qmckl_dgemm(char transa, char transb, int m, int n, int k, double alpha, @@ -106,9 +165,11 @@ void qmckl_dgemm(char transa, char transb, double* C, int ldc, int64_t* tasks, int64_t* ntasks) { - tasks[*ntasks] = (int64_t) qmckl_dgemm_to_struct (transa, transb, + struct dgemm_args* args = qmckl_dgemm_to_struct (transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); - *ntasks += 1L; + + qmckl_dgemm_rec(*args, tasks, ntasks); + free(args); } void qmckl_tasks_run(struct dgemm_args** gemms, int ngemms) @@ -166,63 +227,3 @@ void qmckl_tasks_run(struct dgemm_args** gemms, int ngemms) } -#define MIN_SIZE 512 -static void qmckl_dgemm_rec(struct dgemm_args args) { - -// printf("%5d %5d\n", args.m, args.n); - - if ( (args.m <= MIN_SIZE) || (args.n <= MIN_SIZE)) { - #pragma omp task - { - printf("BLAS %5d %5d %5d\n", args.m, args.n, args.k); - cblas_dgemm(CblasColMajor, args.transa, args.transb, - args.m, args.n, args.k, args.alpha, - args.A, args.lda, args.B, args.ldb, - args.beta, args.C, args.ldc); - } - } else { - - int m1 = args.m / 2; - int m2 = args.m - m1; - int n1 = args.n / 2; - int n2 = args.n - n1; - - { - struct dgemm_args args_1 = args; - args_1.m = m1; - args_1.n = n1; - qmckl_dgemm_rec(args_1); - } - - { - // TODO: assuming 'N', 'N' here - struct dgemm_args args_2 = args; - args_2.B = args.B + args.ldb*n1; - args_2.C = args.C + args.ldc*n1; - args_2.m = m1; - args_2.n = n2; - qmckl_dgemm_rec(args_2); - } - - { - struct dgemm_args args_3 = args; - args_3.A = args.A + m1; - args_3.C = args.C + m1; - args_3.m = m2; - args_3.n = n1; - qmckl_dgemm_rec(args_3); - } - - { - struct dgemm_args args_4 = args; - args_4.A = args.A + m1; - args_4.B = args.B + args.ldb*n1; - args_4.C = args.C + m1 + args.ldc*n1; - args_4.m = m2; - args_4.n = n2; - qmckl_dgemm_rec(args_4); - } - } - -} -