1
0
mirror of https://github.com/TREX-CoE/qmckl.git synced 2024-07-17 16:33:59 +02:00

Add selection mechanism for offload mode in Jastrow

This system adds an additional field to the QMCkl context to store the
offload mode currently in use for each kernel (in this commit, this has
been implemented for Jastrow as an example). This will be useful to test
different offloading versions that can be easily toggled on/off at
compilation and at runtime.
This commit is contained in:
Aurelien Delval 2022-03-24 16:35:29 +01:00
parent 79d4cf130b
commit 5e3231e7e3
2 changed files with 382 additions and 1 deletions

114
org/ao_grid.f90 Normal file
View File

@ -0,0 +1,114 @@
subroutine qmckl_check_error(rc, message)
use qmckl
implicit none
integer(qmckl_exit_code), intent(in) :: rc
character(len=*) , intent(in) :: message
character(len=128) :: str_buffer
if (rc /= QMCKL_SUCCESS) then
print *, message
call qmckl_string_of_error(rc, str_buffer)
print *, str_buffer
call exit(rc)
end if
end subroutine qmckl_check_error
program ao_grid
use qmckl
implicit none
integer(qmckl_context) :: qmckl_ctx ! QMCkl context
integer(qmckl_exit_code) :: rc ! Exit code of QMCkl functions
character(len=128) :: trexio_filename
character(len=128) :: str_buffer
integer :: ao_id
integer :: point_num_x
integer(c_int64_t) :: nucl_num
double precision, allocatable :: nucl_coord(:,:)
integer(c_int64_t) :: point_num
integer(c_int64_t) :: ao_num
integer(c_int64_t) :: ipoint, i, j, k
double precision :: x, y, z, dr(3)
double precision :: rmin(3), rmax(3)
double precision, allocatable :: points(:,:)
double precision, allocatable :: ao_vgl(:,:,:)
if (iargc() /= 3) then
print *, 'Syntax: ao_grid <trexio_file> <AO_id> <point_num>'
call exit(-1)
end if
call getarg(1, trexio_filename)
call getarg(2, str_buffer)
read(str_buffer, *) ao_id
call getarg(3, str_buffer)
read(str_buffer, *) point_num_x
if (point_num_x < 0 .or. point_num_x > 300) then
print *, 'Error: 0 < point_num < 300'
call exit(-1)
end if
qmckl_ctx = qmckl_context_create()
rc = qmckl_trexio_read(qmckl_ctx, trexio_filename, 1_8*len(trim(trexio_filename)))
call qmckl_check_error(rc, 'Read TREXIO')
rc = qmckl_get_ao_basis_ao_num(qmckl_ctx, ao_num)
call qmckl_check_error(rc, 'Getting ao_num')
if (ao_id < 0 .or. ao_id > ao_num) then
print *, 'Error: 0 < ao_id < ', ao_num
call exit(-1)
end if
rc = qmckl_get_nucleus_num(qmckl_ctx, nucl_num)
call qmckl_check_error(rc, 'Get nucleus num')
allocate( nucl_coord(3, nucl_num) )
rc = qmckl_get_nucleus_coord(qmckl_ctx, 'N', nucl_coord, 3_8*nucl_num)
call qmckl_check_error(rc, 'Get nucleus coord')
rmin(1) = minval( nucl_coord(1,:) ) - 5.d0
rmin(2) = minval( nucl_coord(2,:) ) - 5.d0
rmin(3) = minval( nucl_coord(3,:) ) - 5.d0
rmax(1) = maxval( nucl_coord(1,:) ) + 5.d0
rmax(2) = maxval( nucl_coord(2,:) ) + 5.d0
rmax(3) = maxval( nucl_coord(3,:) ) + 5.d0
dr(1:3) = (rmax(1:3) - rmin(1:3)) / dble(point_num_x-1)
point_num = point_num_x**3
allocate( points(point_num, 3) )
ipoint=0
z = rmin(3)
do k=1,point_num_x
y = rmin(2)
do j=1,point_num_x
x = rmin(1)
do i=1,point_num_x
ipoint = ipoint+1
points(ipoint,1) = x
points(ipoint,2) = y
points(ipoint,3) = z
x = x + dr(1)
end do
y = y + dr(2)
end do
z = z + dr(3)
end do
rc = qmckl_set_point(qmckl_ctx, 'T', points, point_num)
call qmckl_check_error(rc, 'Setting points')
allocate( ao_vgl(ao_num, 5, point_num) )
rc = qmckl_get_ao_basis_ao_vgl(qmckl_ctx, ao_vgl, ao_num*5_8*point_num)
call qmckl_check_error(rc, 'Setting points')
do ipoint=1, point_num
print '(3(F16.10,X),E20.10)', points(ipoint, 1:3), ao_vgl(ao_id,1,ipoint)
end do
deallocate( nucl_coord, points, ao_vgl )
end program ao_grid

View File

@ -327,7 +327,14 @@ kappa_inv = 1.0/kappa
** Data structure ** Data structure
#+begin_src c :comments org :tangle (eval h_private_type) #+begin_src c :comments org :tangle (eval h_type)
typedef enum qmckl_jastrow_offload_type{
OFFLOAD_NONE,
OFFLOAD_OPENMP
} qmckl_jastrow_offload_type;
#+end_src
#+begin_src c :comments org :tangle (eval h_private_type)
typedef struct qmckl_jastrow_struct{ typedef struct qmckl_jastrow_struct{
int32_t uninitialized; int32_t uninitialized;
int64_t aord_num; int64_t aord_num;
@ -372,6 +379,7 @@ typedef struct qmckl_jastrow_struct{
uint64_t een_rescaled_n_deriv_e_date; uint64_t een_rescaled_n_deriv_e_date;
bool provided; bool provided;
char * type; char * type;
qmckl_jastrow_offload_type offload_type;
} qmckl_jastrow_struct; } qmckl_jastrow_struct;
#+end_src #+end_src
@ -416,6 +424,7 @@ qmckl_exit_code qmckl_get_jastrow_type_nucl_vector (qmckl_context context, int
qmckl_exit_code qmckl_get_jastrow_aord_vector (qmckl_context context, double * const aord_vector, const int64_t size_max); qmckl_exit_code qmckl_get_jastrow_aord_vector (qmckl_context context, double * const aord_vector, const int64_t size_max);
qmckl_exit_code qmckl_get_jastrow_bord_vector (qmckl_context context, double * const bord_vector, const int64_t size_max); qmckl_exit_code qmckl_get_jastrow_bord_vector (qmckl_context context, double * const bord_vector, const int64_t size_max);
qmckl_exit_code qmckl_get_jastrow_cord_vector (qmckl_context context, double * const cord_vector, const int64_t size_max); qmckl_exit_code qmckl_get_jastrow_cord_vector (qmckl_context context, double * const cord_vector, const int64_t size_max);
qmckl_exit_code qmckl_get_jastrow_offload_type (qmckl_context context, qmckl_jastrow_offload_type * const offload_type);
#+end_src #+end_src
Along with these core functions, calculation of the jastrow factor Along with these core functions, calculation of the jastrow factor
@ -713,6 +722,32 @@ qmckl_get_jastrow_cord_vector (const qmckl_context context,
return QMCKL_SUCCESS; return QMCKL_SUCCESS;
} }
qmckl_exit_code qmckl_get_jastrow_offload_type (const qmckl_context context, qmckl_jastrow_offload_type* const offload_type) {
if (qmckl_context_check(context) == QMCKL_NULL_CONTEXT) {
return (char) 0;
}
if (offload_type == NULL) {
return qmckl_failwith( context,
QMCKL_INVALID_ARG_2,
"qmckl_get_jastrow_offload_type",
"offload_type is a null pointer");
}
qmckl_context_struct* const ctx = (qmckl_context_struct* const) context;
assert (ctx != NULL);
int32_t mask = 1 << 0;
if ( (ctx->jastrow.uninitialized & mask) != 0) {
return QMCKL_NOT_PROVIDED;
}
*offload_type = ctx->jastrow.offload_type;
return QMCKL_SUCCESS;
}
#+end_src #+end_src
** Initialization functions ** Initialization functions
@ -727,6 +762,7 @@ qmckl_exit_code qmckl_set_jastrow_type_nucl_vector (qmckl_context context, con
qmckl_exit_code qmckl_set_jastrow_aord_vector (qmckl_context context, const double * aord_vector, const int64_t size_max); qmckl_exit_code qmckl_set_jastrow_aord_vector (qmckl_context context, const double * aord_vector, const int64_t size_max);
qmckl_exit_code qmckl_set_jastrow_bord_vector (qmckl_context context, const double * bord_vector, const int64_t size_max); qmckl_exit_code qmckl_set_jastrow_bord_vector (qmckl_context context, const double * bord_vector, const int64_t size_max);
qmckl_exit_code qmckl_set_jastrow_cord_vector (qmckl_context context, const double * cord_vector, const int64_t size_max); qmckl_exit_code qmckl_set_jastrow_cord_vector (qmckl_context context, const double * cord_vector, const int64_t size_max);
qmckl_exit_code qmckl_set_jastrow_offload_type (qmckl_context context, const qmckl_jastrow_offload_type offload_type);
#+end_src #+end_src
#+NAME:pre2 #+NAME:pre2
@ -1063,6 +1099,14 @@ qmckl_set_jastrow_cord_vector(qmckl_context context,
<<post2>> <<post2>>
} }
qmckl_exit_code
qmckl_set_jastrow_offload_type(qmckl_context context, const qmckl_jastrow_offload_type offload_type)
{
<<pre2>>
ctx->jastrow.offload_type = offload_type;
return QMCKL_SUCCESS;
}
#+end_src #+end_src
When the required information is completely entered, other data structures are When the required information is completely entered, other data structures are
@ -6093,6 +6137,30 @@ qmckl_exit_code qmckl_provide_factor_een_deriv_e(qmckl_context context)
ctx->jastrow.factor_een_deriv_e = factor_een_deriv_e; ctx->jastrow.factor_een_deriv_e = factor_een_deriv_e;
} }
/* Choose the correct compute function (depending on offload type) */
bool default_compute = true;
#ifdef HAVE_OPENMP_OFFLOAD
if(ctx->jastrow.offload_type == OFFLOAD_OPENMP) {
qmckl_exit_code rc =
qmckl_compute_factor_een_deriv_e_omp_offload(context,
ctx->electron.walk_num,
ctx->electron.num,
ctx->nucleus.num,
ctx->jastrow.cord_num,
ctx->jastrow.dim_cord_vect,
ctx->jastrow.cord_vect_full,
ctx->jastrow.lkpm_combined_index,
ctx->jastrow.tmp_c,
ctx->jastrow.dtmp_c,
ctx->jastrow.een_rescaled_n,
ctx->jastrow.een_rescaled_n_deriv_e,
ctx->jastrow.factor_een_deriv_e);
default_compute = false;
}
#endif
if(default_compute) {
qmckl_exit_code rc = qmckl_exit_code rc =
qmckl_compute_factor_een_deriv_e(context, qmckl_compute_factor_een_deriv_e(context,
ctx->electron.walk_num, ctx->electron.walk_num,
@ -6107,6 +6175,8 @@ qmckl_exit_code qmckl_provide_factor_een_deriv_e(qmckl_context context)
ctx->jastrow.een_rescaled_n, ctx->jastrow.een_rescaled_n,
ctx->jastrow.een_rescaled_n_deriv_e, ctx->jastrow.een_rescaled_n_deriv_e,
ctx->jastrow.factor_een_deriv_e); ctx->jastrow.factor_een_deriv_e);
}
if (rc != QMCKL_SUCCESS) { if (rc != QMCKL_SUCCESS) {
return rc; return rc;
} }
@ -6507,6 +6577,203 @@ end function qmckl_compute_factor_een_deriv_e_f
end function qmckl_compute_factor_een_deriv_e end function qmckl_compute_factor_een_deriv_e
#+end_src #+end_src
*** Compute (OpenMP offload)...
:PROPERTIES:
:Name: qmckl_compute_factor_een_deriv_e
:CRetType: qmckl_exit_code
:FRetType: qmckl_exit_code
:END:
#+NAME: qmckl_factor_een_deriv_e_omp_offload_args
| Variable | Type | In/Out | Description |
|--------------------------+---------------------------------------------------------------------+--------+------------------------------------------------|
| ~context~ | ~qmckl_context~ | in | Global state |
| ~walk_num~ | ~int64_t~ | in | Number of walkers |
| ~elec_num~ | ~int64_t~ | in | Number of electrons |
| ~nucl_num~ | ~int64_t~ | in | Number of nucleii |
| ~cord_num~ | ~int64_t~ | in | order of polynomials |
| ~dim_cord_vect~ | ~int64_t~ | in | dimension of full coefficient vector |
| ~cord_vect_full~ | ~double[dim_cord_vect][nucl_num]~ | in | full coefficient vector |
| ~lkpm_combined_index~ | ~int64_t[4][dim_cord_vect]~ | in | combined indices |
| ~tmp_c~ | ~double[walk_num][0:cord_num-1][0:cord_num][nucl_num][elec_num]~ | in | Temporary intermediate tensor |
| ~dtmp_c~ | ~double[walk_num][0:cord_num-1][0:cord_num][nucl_num][4][elec_num]~ | in | vector of non-zero coefficients |
| ~een_rescaled_n~ | ~double[walk_num][0:cord_num][nucl_num][elec_num]~ | in | Electron-nucleus rescaled factor |
| ~een_rescaled_n_deriv_e~ | ~double[walk_num][0:cord_num][nucl_num][4][elec_num]~ | in | Derivative of Electron-nucleus rescaled factor |
| ~factor_een_deriv_e~ | ~double[walk_num][4][elec_num]~ | out | Derivative of Electron-nucleus jastrow |
#+begin_src f90 :comments org :tangle (eval f) :noweb yes
#ifdef HAVE_OPENMP_OFFLOAD
integer function qmckl_compute_factor_een_deriv_e_omp_offload_f(context, walk_num, elec_num, nucl_num, cord_num, dim_cord_vect, &
cord_vect_full, lkpm_combined_index, &
tmp_c, dtmp_c, een_rescaled_n, een_rescaled_n_deriv_e, factor_een_deriv_e) &
result(info)
use qmckl
implicit none
integer(qmckl_context), intent(in) :: context
integer*8 , intent(in) :: walk_num, elec_num, cord_num, nucl_num, dim_cord_vect
integer*8 , intent(in) :: lkpm_combined_index(dim_cord_vect,4)
double precision , intent(in) :: cord_vect_full(nucl_num, dim_cord_vect)
double precision , intent(in) :: tmp_c(elec_num, nucl_num,0:cord_num, 0:cord_num-1, walk_num)
double precision , intent(in) :: dtmp_c(elec_num, 4, nucl_num,0:cord_num, 0:cord_num-1, walk_num)
double precision , intent(in) :: een_rescaled_n(elec_num, nucl_num, 0:cord_num, walk_num)
double precision , intent(in) :: een_rescaled_n_deriv_e(elec_num, 4, nucl_num, 0:cord_num, walk_num)
double precision , intent(out) :: factor_een_deriv_e(elec_num,4,walk_num)
integer*8 :: i, a, j, l, k, p, m, n, nw, ii
double precision :: accu, accu2, cn
info = QMCKL_SUCCESS
if (context == QMCKL_NULL_CONTEXT) then
info = QMCKL_INVALID_CONTEXT
return
endif
if (walk_num <= 0) then
info = QMCKL_INVALID_ARG_2
return
endif
if (elec_num <= 0) then
info = QMCKL_INVALID_ARG_3
return
endif
if (nucl_num <= 0) then
info = QMCKL_INVALID_ARG_4
return
endif
if (cord_num <= 0) then
info = QMCKL_INVALID_ARG_5
return
endif
factor_een_deriv_e = 0.0d0
do nw =1, walk_num
do n = 1, dim_cord_vect
l = lkpm_combined_index(n, 1)
k = lkpm_combined_index(n, 2)
p = lkpm_combined_index(n, 3)
m = lkpm_combined_index(n, 4)
do a = 1, nucl_num
cn = cord_vect_full(a, n)
if(cn == 0.d0) cycle
do ii = 1, 4
do j = 1, elec_num
factor_een_deriv_e(j,ii,nw) = factor_een_deriv_e(j,ii,nw) + (&
tmp_c(j,a,m,k,nw) * een_rescaled_n_deriv_e(j,ii,a,m+l,nw) + &
(dtmp_c(j,ii,a,m,k,nw)) * een_rescaled_n(j,a,m+l,nw) + &
(dtmp_c(j,ii,a,m+l,k,nw)) * een_rescaled_n(j,a,m ,nw) + &
tmp_c(j,a,m+l,k,nw) * een_rescaled_n_deriv_e(j,ii,a,m,nw) &
) * cn
end do
end do
cn = cn + cn
do j = 1, elec_num
factor_een_deriv_e(j,4,nw) = factor_een_deriv_e(j,4,nw) + (&
(dtmp_c(j,1,a,m ,k,nw)) * een_rescaled_n_deriv_e(j,1,a,m+l,nw) + &
(dtmp_c(j,2,a,m ,k,nw)) * een_rescaled_n_deriv_e(j,2,a,m+l,nw) + &
(dtmp_c(j,3,a,m ,k,nw)) * een_rescaled_n_deriv_e(j,3,a,m+l,nw) + &
(dtmp_c(j,1,a,m+l,k,nw)) * een_rescaled_n_deriv_e(j,1,a,m ,nw) + &
(dtmp_c(j,2,a,m+l,k,nw)) * een_rescaled_n_deriv_e(j,2,a,m ,nw) + &
(dtmp_c(j,3,a,m+l,k,nw)) * een_rescaled_n_deriv_e(j,3,a,m ,nw) &
) * cn
end do
end do
end do
end do
end function qmckl_compute_factor_een_deriv_e_omp_offload_f
#endif
#+end_src
#+CALL: generate_c_header(table=qmckl_factor_een_deriv_e_omp_offload_args,rettyp=get_value("CRetType"),fname=get_value("Name"))
#+RESULTS:
#+begin_src c :tangle (eval h_func) :comments org
#ifdef HAVE_OPENMP_OFFLOAD
qmckl_exit_code qmckl_compute_factor_een_deriv_e_omp_offload (
const qmckl_context context,
const int64_t walk_num,
const int64_t elec_num,
const int64_t nucl_num,
const int64_t cord_num,
const int64_t dim_cord_vect,
const double* cord_vect_full,
const int64_t* lkpm_combined_index,
const double* tmp_c,
const double* dtmp_c,
const double* een_rescaled_n,
const double* een_rescaled_n_deriv_e,
double* const factor_een_deriv_e );
#endif
#+end_src
#+CALL: generate_c_interface(table=qmckl_factor_een_deriv_e_omp_offload_args,rettyp=get_value("CRetType"),fname=get_value("Name"))
#+RESULTS:
#+begin_src f90 :tangle (eval f) :comments org :exports none
#ifdef HAVE_OPENMP_OFFLOAD
integer(c_int32_t) function qmckl_compute_factor_een_deriv_e_omp_offload &
(context, &
walk_num, &
elec_num, &
nucl_num, &
cord_num, &
dim_cord_vect, &
cord_vect_full, &
lkpm_combined_index, &
tmp_c, &
dtmp_c, &
een_rescaled_n, &
een_rescaled_n_deriv_e, &
factor_een_deriv_e) &
bind(C) result(info)
use, intrinsic :: iso_c_binding
implicit none
integer (c_int64_t) , intent(in) , value :: context
integer (c_int64_t) , intent(in) , value :: walk_num
integer (c_int64_t) , intent(in) , value :: elec_num
integer (c_int64_t) , intent(in) , value :: nucl_num
integer (c_int64_t) , intent(in) , value :: cord_num
integer (c_int64_t) , intent(in) , value :: dim_cord_vect
real (c_double ) , intent(in) :: cord_vect_full(nucl_num,dim_cord_vect)
integer (c_int64_t) , intent(in) :: lkpm_combined_index(dim_cord_vect,4)
real (c_double ) , intent(in) :: tmp_c(elec_num,nucl_num,0:cord_num,0:cord_num-1,walk_num)
real (c_double ) , intent(in) :: dtmp_c(elec_num,4,nucl_num,0:cord_num,0:cord_num-1,walk_num)
real (c_double ) , intent(in) :: een_rescaled_n(elec_num,nucl_num,0:cord_num,walk_num)
real (c_double ) , intent(in) :: een_rescaled_n_deriv_e(elec_num,4,nucl_num,0:cord_num,walk_num)
real (c_double ) , intent(out) :: factor_een_deriv_e(elec_num,4,walk_num)
integer(c_int32_t), external :: qmckl_compute_factor_een_deriv_e_omp_offload_f
info = qmckl_compute_factor_een_deriv_e_omp_offload_f &
(context, &
walk_num, &
elec_num, &
nucl_num, &
cord_num, &
dim_cord_vect, &
cord_vect_full, &
lkpm_combined_index, &
tmp_c, &
dtmp_c, &
een_rescaled_n, &
een_rescaled_n_deriv_e, &
factor_een_deriv_e)
end function qmckl_compute_factor_een_deriv_e_omp_offload
#endif
#+end_src
*** Test *** Test
#+begin_src python :results output :exports none :noweb yes #+begin_src python :results output :exports none :noweb yes
import numpy as np import numpy as np