From 5e3231e7e39fd3f07bc68e7e9d8ad7875aa8dd47 Mon Sep 17 00:00:00 2001 From: Aurelien Delval Date: Thu, 24 Mar 2022 16:35:29 +0100 Subject: [PATCH] Add selection mechanism for offload mode in Jastrow This system adds an additional field to the QMCkl context to store the offload mode currently in use for each kernel (in this commit, this has been implemented for Jastrow as an example). This will be useful to test different offloading versions that can be easily toggled on/off at compilation and at runtime. --- org/ao_grid.f90 | 114 ++++++++++++++++++ org/qmckl_jastrow.org | 269 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 382 insertions(+), 1 deletion(-) create mode 100644 org/ao_grid.f90 diff --git a/org/ao_grid.f90 b/org/ao_grid.f90 new file mode 100644 index 0000000..685313f --- /dev/null +++ b/org/ao_grid.f90 @@ -0,0 +1,114 @@ +subroutine qmckl_check_error(rc, message) + use qmckl + implicit none + integer(qmckl_exit_code), intent(in) :: rc + character(len=*) , intent(in) :: message + character(len=128) :: str_buffer + if (rc /= QMCKL_SUCCESS) then + print *, message + call qmckl_string_of_error(rc, str_buffer) + print *, str_buffer + call exit(rc) + end if +end subroutine qmckl_check_error + +program ao_grid + use qmckl + implicit none + + integer(qmckl_context) :: qmckl_ctx ! QMCkl context + integer(qmckl_exit_code) :: rc ! Exit code of QMCkl functions + + character(len=128) :: trexio_filename + character(len=128) :: str_buffer + integer :: ao_id + integer :: point_num_x + + integer(c_int64_t) :: nucl_num + double precision, allocatable :: nucl_coord(:,:) + + integer(c_int64_t) :: point_num + integer(c_int64_t) :: ao_num + integer(c_int64_t) :: ipoint, i, j, k + double precision :: x, y, z, dr(3) + double precision :: rmin(3), rmax(3) + double precision, allocatable :: points(:,:) + double precision, allocatable :: ao_vgl(:,:,:) + +if (iargc() /= 3) then + print *, 'Syntax: ao_grid ' + call exit(-1) +end if +call getarg(1, trexio_filename) +call getarg(2, str_buffer) +read(str_buffer, *) ao_id +call getarg(3, str_buffer) +read(str_buffer, *) point_num_x + +if (point_num_x < 0 .or. point_num_x > 300) then + print *, 'Error: 0 < point_num < 300' + call exit(-1) +end if + +qmckl_ctx = qmckl_context_create() +rc = qmckl_trexio_read(qmckl_ctx, trexio_filename, 1_8*len(trim(trexio_filename))) +call qmckl_check_error(rc, 'Read TREXIO') + +rc = qmckl_get_ao_basis_ao_num(qmckl_ctx, ao_num) +call qmckl_check_error(rc, 'Getting ao_num') + +if (ao_id < 0 .or. ao_id > ao_num) then + print *, 'Error: 0 < ao_id < ', ao_num + call exit(-1) +end if + +rc = qmckl_get_nucleus_num(qmckl_ctx, nucl_num) +call qmckl_check_error(rc, 'Get nucleus num') + +allocate( nucl_coord(3, nucl_num) ) +rc = qmckl_get_nucleus_coord(qmckl_ctx, 'N', nucl_coord, 3_8*nucl_num) +call qmckl_check_error(rc, 'Get nucleus coord') + +rmin(1) = minval( nucl_coord(1,:) ) - 5.d0 +rmin(2) = minval( nucl_coord(2,:) ) - 5.d0 +rmin(3) = minval( nucl_coord(3,:) ) - 5.d0 + +rmax(1) = maxval( nucl_coord(1,:) ) + 5.d0 +rmax(2) = maxval( nucl_coord(2,:) ) + 5.d0 +rmax(3) = maxval( nucl_coord(3,:) ) + 5.d0 + +dr(1:3) = (rmax(1:3) - rmin(1:3)) / dble(point_num_x-1) + +point_num = point_num_x**3 +allocate( points(point_num, 3) ) +ipoint=0 +z = rmin(3) +do k=1,point_num_x + y = rmin(2) + do j=1,point_num_x + x = rmin(1) + do i=1,point_num_x + ipoint = ipoint+1 + points(ipoint,1) = x + points(ipoint,2) = y + points(ipoint,3) = z + x = x + dr(1) + end do + y = y + dr(2) + end do + z = z + dr(3) +end do + +rc = qmckl_set_point(qmckl_ctx, 'T', points, point_num) +call qmckl_check_error(rc, 'Setting points') + +allocate( ao_vgl(ao_num, 5, point_num) ) +rc = qmckl_get_ao_basis_ao_vgl(qmckl_ctx, ao_vgl, ao_num*5_8*point_num) +call qmckl_check_error(rc, 'Setting points') + +do ipoint=1, point_num + print '(3(F16.10,X),E20.10)', points(ipoint, 1:3), ao_vgl(ao_id,1,ipoint) +end do + +deallocate( nucl_coord, points, ao_vgl ) +end program ao_grid diff --git a/org/qmckl_jastrow.org b/org/qmckl_jastrow.org index 61062af..6a2c2a2 100644 --- a/org/qmckl_jastrow.org +++ b/org/qmckl_jastrow.org @@ -327,7 +327,14 @@ kappa_inv = 1.0/kappa ** Data structure - #+begin_src c :comments org :tangle (eval h_private_type) +#+begin_src c :comments org :tangle (eval h_type) +typedef enum qmckl_jastrow_offload_type{ + OFFLOAD_NONE, + OFFLOAD_OPENMP +} qmckl_jastrow_offload_type; +#+end_src + +#+begin_src c :comments org :tangle (eval h_private_type) typedef struct qmckl_jastrow_struct{ int32_t uninitialized; int64_t aord_num; @@ -372,6 +379,7 @@ typedef struct qmckl_jastrow_struct{ uint64_t een_rescaled_n_deriv_e_date; bool provided; char * type; + qmckl_jastrow_offload_type offload_type; } qmckl_jastrow_struct; #+end_src @@ -416,6 +424,7 @@ qmckl_exit_code qmckl_get_jastrow_type_nucl_vector (qmckl_context context, int qmckl_exit_code qmckl_get_jastrow_aord_vector (qmckl_context context, double * const aord_vector, const int64_t size_max); qmckl_exit_code qmckl_get_jastrow_bord_vector (qmckl_context context, double * const bord_vector, const int64_t size_max); qmckl_exit_code qmckl_get_jastrow_cord_vector (qmckl_context context, double * const cord_vector, const int64_t size_max); +qmckl_exit_code qmckl_get_jastrow_offload_type (qmckl_context context, qmckl_jastrow_offload_type * const offload_type); #+end_src Along with these core functions, calculation of the jastrow factor @@ -713,6 +722,32 @@ qmckl_get_jastrow_cord_vector (const qmckl_context context, return QMCKL_SUCCESS; } +qmckl_exit_code qmckl_get_jastrow_offload_type (const qmckl_context context, qmckl_jastrow_offload_type* const offload_type) { + + if (qmckl_context_check(context) == QMCKL_NULL_CONTEXT) { + return (char) 0; + } + + if (offload_type == NULL) { + return qmckl_failwith( context, + QMCKL_INVALID_ARG_2, + "qmckl_get_jastrow_offload_type", + "offload_type is a null pointer"); + } + + qmckl_context_struct* const ctx = (qmckl_context_struct* const) context; + assert (ctx != NULL); + + int32_t mask = 1 << 0; + + if ( (ctx->jastrow.uninitialized & mask) != 0) { + return QMCKL_NOT_PROVIDED; + } + + *offload_type = ctx->jastrow.offload_type; + return QMCKL_SUCCESS; +} + #+end_src ** Initialization functions @@ -727,6 +762,7 @@ qmckl_exit_code qmckl_set_jastrow_type_nucl_vector (qmckl_context context, con qmckl_exit_code qmckl_set_jastrow_aord_vector (qmckl_context context, const double * aord_vector, const int64_t size_max); qmckl_exit_code qmckl_set_jastrow_bord_vector (qmckl_context context, const double * bord_vector, const int64_t size_max); qmckl_exit_code qmckl_set_jastrow_cord_vector (qmckl_context context, const double * cord_vector, const int64_t size_max); +qmckl_exit_code qmckl_set_jastrow_offload_type (qmckl_context context, const qmckl_jastrow_offload_type offload_type); #+end_src #+NAME:pre2 @@ -1063,6 +1099,14 @@ qmckl_set_jastrow_cord_vector(qmckl_context context, <> } +qmckl_exit_code +qmckl_set_jastrow_offload_type(qmckl_context context, const qmckl_jastrow_offload_type offload_type) +{ +<> + ctx->jastrow.offload_type = offload_type; + return QMCKL_SUCCESS; +} + #+end_src When the required information is completely entered, other data structures are @@ -6093,6 +6137,30 @@ qmckl_exit_code qmckl_provide_factor_een_deriv_e(qmckl_context context) ctx->jastrow.factor_een_deriv_e = factor_een_deriv_e; } + /* Choose the correct compute function (depending on offload type) */ + bool default_compute = true; + +#ifdef HAVE_OPENMP_OFFLOAD + if(ctx->jastrow.offload_type == OFFLOAD_OPENMP) { + qmckl_exit_code rc = + qmckl_compute_factor_een_deriv_e_omp_offload(context, + ctx->electron.walk_num, + ctx->electron.num, + ctx->nucleus.num, + ctx->jastrow.cord_num, + ctx->jastrow.dim_cord_vect, + ctx->jastrow.cord_vect_full, + ctx->jastrow.lkpm_combined_index, + ctx->jastrow.tmp_c, + ctx->jastrow.dtmp_c, + ctx->jastrow.een_rescaled_n, + ctx->jastrow.een_rescaled_n_deriv_e, + ctx->jastrow.factor_een_deriv_e); + default_compute = false; + } +#endif + + if(default_compute) { qmckl_exit_code rc = qmckl_compute_factor_een_deriv_e(context, ctx->electron.walk_num, @@ -6107,6 +6175,8 @@ qmckl_exit_code qmckl_provide_factor_een_deriv_e(qmckl_context context) ctx->jastrow.een_rescaled_n, ctx->jastrow.een_rescaled_n_deriv_e, ctx->jastrow.factor_een_deriv_e); + } + if (rc != QMCKL_SUCCESS) { return rc; } @@ -6507,6 +6577,203 @@ end function qmckl_compute_factor_een_deriv_e_f end function qmckl_compute_factor_een_deriv_e #+end_src +*** Compute (OpenMP offload)... + :PROPERTIES: + :Name: qmckl_compute_factor_een_deriv_e + :CRetType: qmckl_exit_code + :FRetType: qmckl_exit_code + :END: + + #+NAME: qmckl_factor_een_deriv_e_omp_offload_args + | Variable | Type | In/Out | Description | + |--------------------------+---------------------------------------------------------------------+--------+------------------------------------------------| + | ~context~ | ~qmckl_context~ | in | Global state | + | ~walk_num~ | ~int64_t~ | in | Number of walkers | + | ~elec_num~ | ~int64_t~ | in | Number of electrons | + | ~nucl_num~ | ~int64_t~ | in | Number of nucleii | + | ~cord_num~ | ~int64_t~ | in | order of polynomials | + | ~dim_cord_vect~ | ~int64_t~ | in | dimension of full coefficient vector | + | ~cord_vect_full~ | ~double[dim_cord_vect][nucl_num]~ | in | full coefficient vector | + | ~lkpm_combined_index~ | ~int64_t[4][dim_cord_vect]~ | in | combined indices | + | ~tmp_c~ | ~double[walk_num][0:cord_num-1][0:cord_num][nucl_num][elec_num]~ | in | Temporary intermediate tensor | + | ~dtmp_c~ | ~double[walk_num][0:cord_num-1][0:cord_num][nucl_num][4][elec_num]~ | in | vector of non-zero coefficients | + | ~een_rescaled_n~ | ~double[walk_num][0:cord_num][nucl_num][elec_num]~ | in | Electron-nucleus rescaled factor | + | ~een_rescaled_n_deriv_e~ | ~double[walk_num][0:cord_num][nucl_num][4][elec_num]~ | in | Derivative of Electron-nucleus rescaled factor | + | ~factor_een_deriv_e~ | ~double[walk_num][4][elec_num]~ | out | Derivative of Electron-nucleus jastrow | + + + #+begin_src f90 :comments org :tangle (eval f) :noweb yes +#ifdef HAVE_OPENMP_OFFLOAD +integer function qmckl_compute_factor_een_deriv_e_omp_offload_f(context, walk_num, elec_num, nucl_num, cord_num, dim_cord_vect, & + cord_vect_full, lkpm_combined_index, & + tmp_c, dtmp_c, een_rescaled_n, een_rescaled_n_deriv_e, factor_een_deriv_e) & + result(info) + use qmckl + implicit none + integer(qmckl_context), intent(in) :: context + integer*8 , intent(in) :: walk_num, elec_num, cord_num, nucl_num, dim_cord_vect + integer*8 , intent(in) :: lkpm_combined_index(dim_cord_vect,4) + double precision , intent(in) :: cord_vect_full(nucl_num, dim_cord_vect) + double precision , intent(in) :: tmp_c(elec_num, nucl_num,0:cord_num, 0:cord_num-1, walk_num) + double precision , intent(in) :: dtmp_c(elec_num, 4, nucl_num,0:cord_num, 0:cord_num-1, walk_num) + double precision , intent(in) :: een_rescaled_n(elec_num, nucl_num, 0:cord_num, walk_num) + double precision , intent(in) :: een_rescaled_n_deriv_e(elec_num, 4, nucl_num, 0:cord_num, walk_num) + double precision , intent(out) :: factor_een_deriv_e(elec_num,4,walk_num) + + integer*8 :: i, a, j, l, k, p, m, n, nw, ii + double precision :: accu, accu2, cn + + info = QMCKL_SUCCESS + + if (context == QMCKL_NULL_CONTEXT) then + info = QMCKL_INVALID_CONTEXT + return + endif + + if (walk_num <= 0) then + info = QMCKL_INVALID_ARG_2 + return + endif + + if (elec_num <= 0) then + info = QMCKL_INVALID_ARG_3 + return + endif + + if (nucl_num <= 0) then + info = QMCKL_INVALID_ARG_4 + return + endif + + if (cord_num <= 0) then + info = QMCKL_INVALID_ARG_5 + return + endif + + factor_een_deriv_e = 0.0d0 + + do nw =1, walk_num + do n = 1, dim_cord_vect + l = lkpm_combined_index(n, 1) + k = lkpm_combined_index(n, 2) + p = lkpm_combined_index(n, 3) + m = lkpm_combined_index(n, 4) + + do a = 1, nucl_num + cn = cord_vect_full(a, n) + if(cn == 0.d0) cycle + + do ii = 1, 4 + do j = 1, elec_num + factor_een_deriv_e(j,ii,nw) = factor_een_deriv_e(j,ii,nw) + (& + tmp_c(j,a,m,k,nw) * een_rescaled_n_deriv_e(j,ii,a,m+l,nw) + & + (dtmp_c(j,ii,a,m,k,nw)) * een_rescaled_n(j,a,m+l,nw) + & + (dtmp_c(j,ii,a,m+l,k,nw)) * een_rescaled_n(j,a,m ,nw) + & + tmp_c(j,a,m+l,k,nw) * een_rescaled_n_deriv_e(j,ii,a,m,nw) & + ) * cn + end do + end do + + cn = cn + cn + do j = 1, elec_num + factor_een_deriv_e(j,4,nw) = factor_een_deriv_e(j,4,nw) + (& + (dtmp_c(j,1,a,m ,k,nw)) * een_rescaled_n_deriv_e(j,1,a,m+l,nw) + & + (dtmp_c(j,2,a,m ,k,nw)) * een_rescaled_n_deriv_e(j,2,a,m+l,nw) + & + (dtmp_c(j,3,a,m ,k,nw)) * een_rescaled_n_deriv_e(j,3,a,m+l,nw) + & + (dtmp_c(j,1,a,m+l,k,nw)) * een_rescaled_n_deriv_e(j,1,a,m ,nw) + & + (dtmp_c(j,2,a,m+l,k,nw)) * een_rescaled_n_deriv_e(j,2,a,m ,nw) + & + (dtmp_c(j,3,a,m+l,k,nw)) * een_rescaled_n_deriv_e(j,3,a,m ,nw) & + ) * cn + end do + end do + end do + end do + +end function qmckl_compute_factor_een_deriv_e_omp_offload_f +#endif + #+end_src + + #+CALL: generate_c_header(table=qmckl_factor_een_deriv_e_omp_offload_args,rettyp=get_value("CRetType"),fname=get_value("Name")) + + #+RESULTS: + #+begin_src c :tangle (eval h_func) :comments org +#ifdef HAVE_OPENMP_OFFLOAD + qmckl_exit_code qmckl_compute_factor_een_deriv_e_omp_offload ( + const qmckl_context context, + const int64_t walk_num, + const int64_t elec_num, + const int64_t nucl_num, + const int64_t cord_num, + const int64_t dim_cord_vect, + const double* cord_vect_full, + const int64_t* lkpm_combined_index, + const double* tmp_c, + const double* dtmp_c, + const double* een_rescaled_n, + const double* een_rescaled_n_deriv_e, + double* const factor_een_deriv_e ); +#endif + #+end_src + + + #+CALL: generate_c_interface(table=qmckl_factor_een_deriv_e_omp_offload_args,rettyp=get_value("CRetType"),fname=get_value("Name")) + + #+RESULTS: + #+begin_src f90 :tangle (eval f) :comments org :exports none +#ifdef HAVE_OPENMP_OFFLOAD + integer(c_int32_t) function qmckl_compute_factor_een_deriv_e_omp_offload & + (context, & + walk_num, & + elec_num, & + nucl_num, & + cord_num, & + dim_cord_vect, & + cord_vect_full, & + lkpm_combined_index, & + tmp_c, & + dtmp_c, & + een_rescaled_n, & + een_rescaled_n_deriv_e, & + factor_een_deriv_e) & + bind(C) result(info) + + use, intrinsic :: iso_c_binding + implicit none + + integer (c_int64_t) , intent(in) , value :: context + integer (c_int64_t) , intent(in) , value :: walk_num + integer (c_int64_t) , intent(in) , value :: elec_num + integer (c_int64_t) , intent(in) , value :: nucl_num + integer (c_int64_t) , intent(in) , value :: cord_num + integer (c_int64_t) , intent(in) , value :: dim_cord_vect + real (c_double ) , intent(in) :: cord_vect_full(nucl_num,dim_cord_vect) + integer (c_int64_t) , intent(in) :: lkpm_combined_index(dim_cord_vect,4) + real (c_double ) , intent(in) :: tmp_c(elec_num,nucl_num,0:cord_num,0:cord_num-1,walk_num) + real (c_double ) , intent(in) :: dtmp_c(elec_num,4,nucl_num,0:cord_num,0:cord_num-1,walk_num) + real (c_double ) , intent(in) :: een_rescaled_n(elec_num,nucl_num,0:cord_num,walk_num) + real (c_double ) , intent(in) :: een_rescaled_n_deriv_e(elec_num,4,nucl_num,0:cord_num,walk_num) + real (c_double ) , intent(out) :: factor_een_deriv_e(elec_num,4,walk_num) + + integer(c_int32_t), external :: qmckl_compute_factor_een_deriv_e_omp_offload_f + info = qmckl_compute_factor_een_deriv_e_omp_offload_f & + (context, & + walk_num, & + elec_num, & + nucl_num, & + cord_num, & + dim_cord_vect, & + cord_vect_full, & + lkpm_combined_index, & + tmp_c, & + dtmp_c, & + een_rescaled_n, & + een_rescaled_n_deriv_e, & + factor_een_deriv_e) + + end function qmckl_compute_factor_een_deriv_e_omp_offload +#endif + #+end_src + *** Test #+begin_src python :results output :exports none :noweb yes import numpy as np