diff --git a/el_nuc_el_blas.irp.f b/el_nuc_el_blas.irp.f index b757d66..4f2d91f 100644 --- a/el_nuc_el_blas.irp.f +++ b/el_nuc_el_blas.irp.f @@ -14,6 +14,21 @@ integer*8 :: tasks(100000), ntasks ntasks = 0_8 + +! type(c_ptr) :: ptr_a, ptr_b, ptr_c, ptr_d, ptr_e +! double precision, pointer :: A(:,:,:), B(:,:,:), C(:,:,:,:), D(:,:,:,:), E(:,:,:,:,:) + +! call alloc(ptr_a, int(size(rescale_een_e),8)) +! call c_f_pointer(ptr_a, A, shape(rescale_een_e)) +! A(:,:,:) = rescale_een_e(:,:,:) +! +! call alloc(ptr_b, int(size(rescale_een_n),8)) +! call c_f_pointer(ptr_b, B, shape(rescale_een_n)) +! B(:,:,:) = rescale_een_n(:,:,:) +! +! call alloc(ptr_c, int(size(tmp_c),8)) +! call c_f_pointer(ptr_c, C, shape(tmp_c)) + ! r_{ij}^k . R_{ja}^l -> tmp_c_{ia}^{kl} do k=0,ncord-1 do l=0,ncord @@ -21,9 +36,20 @@ rescale_een_e(1,1,k), size(rescale_een_e,1), & rescale_een_n(1,1,l), size(rescale_een_n,1), 0.d0, & tmp_c(1,1,l,k), size(tmp_c,1), tasks, ntasks) +! call qmckl_dgemm('N','N', nelec, nnuc, nelec, 1.d0, & +! A(:,:,k), size(rescale_een_e,1), & +! B(:,:,l), size(rescale_een_n,1), 0.d0, & +! C(:,:,l,k), size(tmp_c,1), tasks, ntasks) enddo enddo +!call alloc(ptr_d, int(size(rescale_een_e_deriv_e),8)) +!call c_f_pointer(ptr_d, D, shape(rescale_een_e_deriv_e)) +!D(:,:,:,:) = rescale_een_e_deriv_e(:,:,:,:) + +!call alloc(ptr_e, int(size(dtmp_c),8)) +!call c_f_pointer(ptr_e, E, shape(dtmp_c)) + ! dr_{ij}^k . R_{ja}^l -> dtmp_c_{ia}^{kl} do k=0,ncord-1 do l=0,ncord @@ -34,11 +60,26 @@ size(rescale_een_n,1), 0.d0, & dtmp_c(1,1,1,l,k), size(dtmp_c,1)*size(dtmp_c,2), & tasks, ntasks) + +! call qmckl_dgemm('N','N', nelec_8*4, nnuc, nelec, 1.d0, & +! D(:,:,:,k), & +! size(rescale_een_e_deriv_e,1)*size(rescale_een_e_deriv_e,2), & +! B(:,:,l), & +! size(rescale_een_n,1), 0.d0, & +! E(:,:,:,l,k), size(dtmp_c,1)*size(dtmp_c,2), & +! tasks, ntasks) enddo enddo print *, ntasks, ' tasks' call qmckl_tasks_run(tasks, ntasks) +! tmp_c(:,:,:,:) = C(:,:,:,:) +! dtmp_c(:,:,:,:,:) = E(:,:,:,:,:) +! call free(ptr_a) +! call free(ptr_b) +! call free(ptr_c) +! call free(ptr_d) +! call free(ptr_e) END_PROVIDER diff --git a/qmckl_blas_f.f90 b/qmckl_blas_f.f90 index cf93760..eb05fd6 100644 --- a/qmckl_blas_f.f90 +++ b/qmckl_blas_f.f90 @@ -25,6 +25,24 @@ module qmckl_blas integer (kind=c_int64_t) :: tasks(ntasks) end subroutine qmckl_tasks_run end interface + + interface + subroutine alloc(A, sze) bind(C) + use :: iso_c_binding + implicit none + type(c_ptr) :: A + integer(c_size_t), value :: sze + end subroutine + end interface + + interface + subroutine free(A) bind(C,name='starpu_free') + use :: iso_c_binding + implicit none + type(c_ptr), value :: A + end subroutine + end interface + end module qmckl_blas subroutine f_dgemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) & diff --git a/qmckl_dgemm.c b/qmckl_dgemm.c index 50ab1c2..0475269 100644 --- a/qmckl_dgemm.c +++ b/qmckl_dgemm.c @@ -7,11 +7,13 @@ #include #include -//#define CPU_ENABLED 1 +#define GPU_ENABLED 1 + +#ifdef GPU_ENABLED -#ifdef CPU_ENABLED #include #include + #endif void f_dgemm(const char transa, const char transb, const int m, const int n, const int k, @@ -59,15 +61,50 @@ void dgemm_codelet_cpu(void *buffers[], void* cl_arg) free(args); } +#ifdef GPU_ENABLED +void dgemm_codelet_gpu(void *buffers[], void* cl_arg) +{ + struct dgemm_args *args = cl_arg; + double* A = (double*) STARPU_MATRIX_GET_PTR(buffers[0]); + double* B = (double*) STARPU_MATRIX_GET_PTR(buffers[1]); + double* C = (double*) STARPU_MATRIX_GET_PTR(buffers[2]); + + int lda = STARPU_MATRIX_GET_LD(buffers[0]); + int ldb = STARPU_MATRIX_GET_LD(buffers[1]); + int ldc = STARPU_MATRIX_GET_LD(buffers[2]); + + int m = STARPU_MATRIX_GET_NX(buffers[2]); + int n = STARPU_MATRIX_GET_NY(buffers[2]); + int k = STARPU_MATRIX_GET_NY(buffers[0]); + + char transa = (args->transa == 'T' || args->transa == 'T')? CUBLAS_OP_T : CUBLAS_OP_N; + char transb = (args->transb == 'T' || args->transb == 'T')? CUBLAS_OP_T : CUBLAS_OP_N; + + cublasStatus_t status = cublasDgemm(starpu_cublas_get_local_handle(), + transa, transb, m, n, k, &(args->alpha), + A, lda, B, ldb, &(args->beta), C, ldc); + + if (status != CUBLAS_STATUS_SUCCESS) + STARPU_CUBLAS_REPORT_ERROR(status); + + free(args); +} +#endif + struct starpu_codelet dgemm_cl = { - .where = STARPU_CPU, +// .where = STARPU_CPU, +// .where = STARPU_CUDA, + .where = STARPU_CPU | STARPU_CUDA, .cpu_funcs = { dgemm_codelet_cpu }, .cpu_funcs_name = { "dgemm_codelet_cpu" }, +#ifdef GPU_ENABLED + .cuda_funcs = { dgemm_codelet_gpu }, + .cuda_flags = {STARPU_CUDA_ASYNC}, +#endif .nbuffers = 3, .max_parallelism = 1, .modes = {STARPU_R, STARPU_R, STARPU_RW}, - }; #include @@ -83,25 +120,7 @@ static struct dgemm_args* qmckl_dgemm_to_struct(char transa, char transb, struct dgemm_args* args = (struct dgemm_args*) malloc (sizeof(struct dgemm_args)); assert (args != NULL); - int dima = (transa == 'T' || transa == 't') ? m : k; - int dimb = (transb == 'T' || transb == 't') ? k : n; -/* - double* A2; - double* B2; - double* C2; - starpu_malloc_flags((void **)&A2, - lda*dima*sizeof(double), - STARPU_MALLOC_PINNED); - - starpu_malloc_flags((void **)&B2, - lda*dima*sizeof(double), - STARPU_MALLOC_PINNED); - - starpu_malloc_flags((void **)&C2, - lda*dima*sizeof(double), - STARPU_MALLOC_PINNED); -*/ args->alpha = alpha; args->beta = beta ; args->A = A; @@ -119,7 +138,7 @@ static struct dgemm_args* qmckl_dgemm_to_struct(char transa, char transb, } -#define MIN_SIZE 512 +#define MIN_SIZE 20480 static void qmckl_dgemm_rec(struct dgemm_args args, int64_t* tasks, int64_t* ntasks) { @@ -187,9 +206,13 @@ void qmckl_dgemm(char transa, char transb, double* C, int ldc, int64_t* tasks, int64_t* ntasks) { + struct dgemm_args* args = qmckl_dgemm_to_struct (transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + starpu_memory_pin(A, lda*k*sizeof(double)); + starpu_memory_pin(B, ldb*n*sizeof(double)); + starpu_memory_pin(C, ldc*n*sizeof(double)); qmckl_dgemm_rec(*args, tasks, ntasks); free(args); } @@ -197,6 +220,8 @@ void qmckl_dgemm(char transa, char transb, void qmckl_tasks_run(struct dgemm_args** gemms, int ngemms) { int rc = starpu_init(NULL); + starpu_cublas_init(); + assert (rc == 0); starpu_data_handle_t matrix_handle[ngemms][3]; @@ -249,4 +274,8 @@ void qmckl_tasks_run(struct dgemm_args** gemms, int ngemms) starpu_shutdown(); } +void alloc(void** ptr, int64_t size) { + printf("size: %ld\n", size); + starpu_malloc(ptr, (size_t) size * sizeof(double)); +}