diff --git a/el_nuc_el_blas.irp.f b/el_nuc_el_blas.irp.f index ee124ad..b757d66 100644 --- a/el_nuc_el_blas.irp.f +++ b/el_nuc_el_blas.irp.f @@ -10,36 +10,35 @@ ! dtmp_c: ! dr_{ij}^k . R_{ja}^l -> dtmp_c_{ia}^{kl} END_DOC - integer :: k, l, m, icount - integer*8 :: gemms(6*ncord*(ncord+1)) + integer :: k, l, m + integer*8 :: tasks(100000), ntasks - icount = 0 + ntasks = 0_8 ! r_{ij}^k . R_{ja}^l -> tmp_c_{ia}^{kl} do k=0,ncord-1 do l=0,ncord - icount += 1 call qmckl_dgemm('N','N', nelec, nnuc, nelec, 1.d0, & rescale_een_e(1,1,k), size(rescale_een_e,1), & rescale_een_n(1,1,l), size(rescale_een_n,1), 0.d0, & - tmp_c(1,1,l,k), size(tmp_c,1), gemms(icount)) + tmp_c(1,1,l,k), size(tmp_c,1), tasks, ntasks) enddo enddo ! dr_{ij}^k . R_{ja}^l -> dtmp_c_{ia}^{kl} do k=0,ncord-1 do l=0,ncord - icount += 1 call qmckl_dgemm('N','N', nelec_8*4, nnuc, nelec, 1.d0, & rescale_een_e_deriv_e(1,1,1,k), & size(rescale_een_e_deriv_e,1)*size(rescale_een_e_deriv_e,2), & rescale_een_n(1,1,l), & size(rescale_een_n,1), 0.d0, & dtmp_c(1,1,1,l,k), size(dtmp_c,1)*size(dtmp_c,2), & - gemms(icount)) + tasks, ntasks) enddo enddo - call qmckl_tasks_run(gemms, icount) + print *, ntasks, ' tasks' + call qmckl_tasks_run(tasks, ntasks) END_PROVIDER diff --git a/qmckl_blas_f.f90 b/qmckl_blas_f.f90 index 6cc3ca2..fb51309 100644 --- a/qmckl_blas_f.f90 +++ b/qmckl_blas_f.f90 @@ -5,23 +5,24 @@ module qmckl_blas interface subroutine qmckl_dgemm(transa, transb, m, n, k, & - alpha, A, lda, B, ldb, beta, C, ldc, res) bind(C) + alpha, A, lda, B, ldb, beta, C, ldc, tasks, ntasks) bind(C) use :: iso_c_binding implicit none character(kind=c_char ), value :: transa, transb integer (kind=c_int ), value :: m, n, k, lda, ldb, ldc real (kind=c_double), value :: alpha, beta real (kind=c_double) :: A(lda,*), B(ldb,*), C(ldc,*) - integer (kind=c_int64_t) :: res + integer (kind=c_int64_t) :: tasks(*) + integer (kind=c_int64_t) :: ntasks end subroutine qmckl_dgemm end interface interface - subroutine qmckl_tasks_run(gemms, ngemms) bind(C) + subroutine qmckl_tasks_run(tasks, ntasks) bind(C) use :: iso_c_binding implicit none - integer (kind=c_int32_t), value :: ngemms - integer (kind=c_int64_t) :: gemms(ngemms) + integer (kind=c_int64_t), value :: ntasks + integer (kind=c_int64_t) :: tasks(ntasks) end subroutine qmckl_tasks_run end interface diff --git a/qmckl_dgemm.c b/qmckl_dgemm.c index 55dae24..90d29c7 100644 --- a/qmckl_dgemm.c +++ b/qmckl_dgemm.c @@ -59,18 +59,16 @@ void qmckl_dgemm_cl(struct dgemm_args args, double* A, double* B, double* C) { args.beta, C, args.ldc); } -void qmckl_dgemm(char transa, char transb, +static struct dgemm_args* qmckl_dgemm_to_struct(char transa, char transb, int m, int n, int k, double alpha, double* A, int lda, double* B, int ldb, double beta, - double* C, int ldc, - int64_t* result) + double* C, int ldc) { struct dgemm_args* args = (struct dgemm_args*) malloc (sizeof(struct dgemm_args)); assert (args != NULL); - *result = (int64_t) args; args->alpha = alpha; args->beta = beta ; @@ -95,14 +93,29 @@ void qmckl_dgemm(char transa, char transb, } else { args->transb = CblasNoTrans; } + return args; } +void qmckl_dgemm(char transa, char transb, + int m, int n, int k, + double alpha, + double* A, int lda, + double* B, int ldb, + double beta, + double* C, int ldc, + int64_t* tasks, int64_t* ntasks) +{ + tasks[*ntasks] = (int64_t) qmckl_dgemm_to_struct (transa, transb, + m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + *ntasks += 1L; +} void qmckl_tasks_run(struct dgemm_args** gemms, int ngemms) { - starpu_init(NULL); - + int rc = starpu_init(NULL); + assert (rc == 0); + starpu_data_handle_t matrix_handle[ngemms][3]; for (int i=0 ; ihandles[0] = matrix_handle[i][0]; task->handles[1] = matrix_handle[i][1]; task->handles[2] = matrix_handle[i][2]; - starpu_task_submit(task); + rc = starpu_task_submit(task); + assert (rc == 0); } starpu_task_wait_for_all(); @@ -150,3 +164,65 @@ void qmckl_tasks_run(struct dgemm_args** gemms, int ngemms) } starpu_shutdown(); } + + +#define MIN_SIZE 512 +static void qmckl_dgemm_rec(struct dgemm_args args) { + +// printf("%5d %5d\n", args.m, args.n); + + if ( (args.m <= MIN_SIZE) || (args.n <= MIN_SIZE)) { + #pragma omp task + { + printf("BLAS %5d %5d %5d\n", args.m, args.n, args.k); + cblas_dgemm(CblasColMajor, args.transa, args.transb, + args.m, args.n, args.k, args.alpha, + args.A, args.lda, args.B, args.ldb, + args.beta, args.C, args.ldc); + } + } else { + + int m1 = args.m / 2; + int m2 = args.m - m1; + int n1 = args.n / 2; + int n2 = args.n - n1; + + { + struct dgemm_args args_1 = args; + args_1.m = m1; + args_1.n = n1; + qmckl_dgemm_rec(args_1); + } + + { + // TODO: assuming 'N', 'N' here + struct dgemm_args args_2 = args; + args_2.B = args.B + args.ldb*n1; + args_2.C = args.C + args.ldc*n1; + args_2.m = m1; + args_2.n = n2; + qmckl_dgemm_rec(args_2); + } + + { + struct dgemm_args args_3 = args; + args_3.A = args.A + m1; + args_3.C = args.C + m1; + args_3.m = m2; + args_3.n = n1; + qmckl_dgemm_rec(args_3); + } + + { + struct dgemm_args args_4 = args; + args_4.A = args.A + m1; + args_4.B = args.B + args.ldb*n1; + args_4.C = args.C + m1 + args.ldc*n1; + args_4.m = m2; + args_4.n = n2; + qmckl_dgemm_rec(args_4); + } + } + +} +