1
0
mirror of https://github.com/TREX-CoE/trexio.git synced 2025-01-03 10:06:01 +01:00

add type converters for numerical arrays in read/write Python functions

This commit is contained in:
q-posev 2021-08-26 13:14:46 +03:00
parent 1dcd32ef7d
commit 4c28a4cac8
3 changed files with 84 additions and 15 deletions

View File

@ -9,7 +9,7 @@ import trexio as tr
#=========================================================# #=========================================================#
# 0: TREXIO_HDF5 ; 1: TREXIO_TEXT # 0: TREXIO_HDF5 ; 1: TREXIO_TEXT
TEST_TREXIO_BACKEND = tr.TREXIO_TEXT TEST_TREXIO_BACKEND = tr.TREXIO_HDF5
OUTPUT_FILENAME_TEXT = 'test_py_swig.dir' OUTPUT_FILENAME_TEXT = 'test_py_swig.dir'
OUTPUT_FILENAME_HDF5 = 'test_py_swig.h5' OUTPUT_FILENAME_HDF5 = 'test_py_swig.h5'
@ -49,7 +49,8 @@ tr.write_nucleus_num(test_file, nucleus_num)
# initialize charge arrays as a list and convert it to numpy array # initialize charge arrays as a list and convert it to numpy array
charges = [6., 6., 6., 6., 6., 6., 1., 1., 1., 1., 1., 1.] charges = [6., 6., 6., 6., 6., 6., 1., 1., 1., 1., 1., 1.]
charges_np = np.array(charges, dtype=np.float64) #charges_np = np.array(charges, dtype=np.float32)
charges_np = np.array(charges, dtype=np.int32)
# function call below works with both lists and numpy arrays, dimension needed for memory-safety is derived # function call below works with both lists and numpy arrays, dimension needed for memory-safety is derived
# from the size of the list/array by SWIG using typemaps from numpy.i # from the size of the list/array by SWIG using typemaps from numpy.i
@ -58,7 +59,7 @@ tr.write_nucleus_charge(test_file, charges_np)
# initialize arrays of nuclear indices as a list and convert it to numpy array # initialize arrays of nuclear indices as a list and convert it to numpy array
indices = [i for i in range(nucleus_num)] indices = [i for i in range(nucleus_num)]
# type cast is important here because by default numpy transforms a list of integers into int64 array # type cast is important here because by default numpy transforms a list of integers into int64 array
indices_np = np.array(indices, dtype=np.int32) indices_np = np.array(indices, dtype=np.int64)
# function call below works with both lists and numpy arrays, dimension needed for memory-safety is derived # function call below works with both lists and numpy arrays, dimension needed for memory-safety is derived
# from the size of the list/array by SWIG using typemacs from numpy.i # from the size of the list/array by SWIG using typemacs from numpy.i
@ -108,6 +109,7 @@ tr.write_nucleus_label(test_file,labels)
# close TREXIO file # close TREXIO file
tr.close(test_file) tr.close(test_file)
#==========================================================# #==========================================================#
#============ READ THE DATA FROM THE TEST FILE ============# #============ READ THE DATA FROM THE TEST FILE ============#
#==========================================================# #==========================================================#
@ -131,10 +133,21 @@ except Exception:
print("Unsafe call to safe API: checked") print("Unsafe call to safe API: checked")
# safe call to read array of int values (nuclear indices) # safe call to read array of int values (nuclear indices)
rindices_np = tr.read_basis_nucleus_index(test_file2, dim=nucleus_num) rindices_np_16 = tr.read_basis_nucleus_index(test_file2, dim=nucleus_num, dtype=np.int16)
assert rindices_np.dtype is np.dtype(np.int32) assert rindices_np_16.dtype is np.dtype(np.int16)
for i in range(nucleus_num): for i in range(nucleus_num):
assert rindices_np[i]==indices_np[i] assert rindices_np_16[i]==indices_np[i]
rindices_np_32 = tr.read_basis_nucleus_index(test_file2, dim=nucleus_num, dtype=np.int32)
assert rindices_np_32.dtype is np.dtype(np.int32)
for i in range(nucleus_num):
assert rindices_np_32[i]==indices_np[i]
rindices_np_64 = tr.read_basis_nucleus_index(test_file2)
assert rindices_np_64.dtype is np.dtype(np.int64)
assert rindices_np_64.size==nucleus_num
for i in range(nucleus_num):
assert rindices_np_64[i]==indices_np[i]
# read nuclear coordinates without providing optional argument dim # read nuclear coordinates without providing optional argument dim
rcoords_np = tr.read_nucleus_coord(test_file2) rcoords_np = tr.read_nucleus_coord(test_file2)

View File

@ -987,6 +987,7 @@ def close(trexio_file):
| ~$group_dset_f_dtype_single$~ | Single precision datatype of the dataset [Fortran] | ~real(4)/integer(4)~ | | ~$group_dset_f_dtype_single$~ | Single precision datatype of the dataset [Fortran] | ~real(4)/integer(4)~ |
| ~$group_dset_f_dtype_double$~ | Double precision datatype of the dataset [Fortran] | ~real(8)/integer(8)~ | | ~$group_dset_f_dtype_double$~ | Double precision datatype of the dataset [Fortran] | ~real(8)/integer(8)~ |
| ~$group_dset_f_dims$~ | Dimensions in Fortran | ~(:,:)~ | | ~$group_dset_f_dims$~ | Dimensions in Fortran | ~(:,:)~ |
| ~$group_dset_py_dtype$~ | Standard datatype of the dataset [Python] | ~float/int~ |
Note: parent group name is always added to the child objects upon Note: parent group name is always added to the child objects upon
@ -1886,7 +1887,7 @@ def write_$group_dset$(trexio_file, dset_w) -> None:
TREXIO file handle. TREXIO file handle.
dset_w: list OR numpy.ndarray dset_w: list OR numpy.ndarray
Array of $group_dset$ values to be written. Array of $group_dset$ values to be written. If array data type does not correspond to int64 or float64, the conversion is performed.
Raises: Raises:
- Exception from AssertionError if TREXIO return code ~rc~ is different from TREXIO_SUCCESS and prints the error message using trexio_string_of_error. - Exception from AssertionError if TREXIO return code ~rc~ is different from TREXIO_SUCCESS and prints the error message using trexio_string_of_error.
@ -1894,8 +1895,31 @@ def write_$group_dset$(trexio_file, dset_w) -> None:
""" """
cutPrecision = False
if not isinstance(dset_w, list):
try:
import numpy as np
except ImportError:
raise Exception("NumPy cannot be imported.")
if isinstance(dset_w, np.ndarray) and (not dset_w.dtype==np.int64 or not dset_w.dtype==np.float64):
cutPrecision = True
if cutPrecision:
try:
# TODO: we have to make sure that this conversion does not introduce any noise in the data.
dset_64 = np.$group_dset_py_dtype$64(dset_w)
except:
raise
try: try:
rc = trexio_write_safe_$group_dset$(trexio_file, dset_w) if cutPrecision:
rc = trexio_write_safe_$group_dset$_64(trexio_file, dset_64)
else:
rc = trexio_write_safe_$group_dset$_64(trexio_file, dset_w)
assert rc==TREXIO_SUCCESS assert rc==TREXIO_SUCCESS
except AssertionError: except AssertionError:
raise Exception(trexio_string_of_error(rc)) raise Exception(trexio_string_of_error(rc))
@ -1905,7 +1929,7 @@ def write_$group_dset$(trexio_file, dset_w) -> None:
#+end_src #+end_src
#+begin_src python :tangle read_dset_data_front.py #+begin_src python :tangle read_dset_data_front.py
def read_$group_dset$(trexio_file, dim = None): def read_$group_dset$(trexio_file, dim = None, dtype = None):
"""Read the $group_dset$ array of numbers from the TREXIO file. """Read the $group_dset$ array of numbers from the TREXIO file.
Parameters: Parameters:
@ -1917,8 +1941,11 @@ def read_$group_dset$(trexio_file, dim = None):
Size of the block to be read from the file (i.e. how many items of $group_dset$ will be returned) Size of the block to be read from the file (i.e. how many items of $group_dset$ will be returned)
If None, the function will read all necessary array dimensions from the file. If None, the function will read all necessary array dimensions from the file.
dtype (Optional): type
NumPy data type of the output (e.g. np.int32|int16 or np.float32|float16). If specified, the output array will be converted from the default double precision.
Returns: Returns:
~dset_r~: numpy.ndarray ~dset_64~ if dtype is None or ~dset_converted~ otherwise: numpy.ndarray
1D NumPy array with ~dim~ elements corresponding to $group_dset$ values read from the TREXIO file. 1D NumPy array with ~dim~ elements corresponding to $group_dset$ values read from the TREXIO file.
Raises: Raises:
@ -1938,16 +1965,41 @@ def read_$group_dset$(trexio_file, dim = None):
try: try:
rc, dset_r = trexio_read_safe_$group_dset$(trexio_file, dim) rc, dset_64 = trexio_read_safe_$group_dset$_64(trexio_file, dim)
assert rc==TREXIO_SUCCESS assert rc==TREXIO_SUCCESS
except AssertionError: except AssertionError:
raise Exception(trexio_string_of_error(rc)) raise Exception(trexio_string_of_error(rc))
except: except:
raise raise
isConverted = False
dset_converted = None
if dtype is not None:
try:
import numpy as np
except ImportError:
raise Exception("NumPy cannot be imported.")
try:
assert isinstance(dtype, type)
except AssertionError:
raise TypeError("dtype argument has to be an instance of the type class (e.g. np.float32).")
if not dtype==np.int64 or not dtype==np.float64:
try:
dset_converted = np.array(dset_64, dtype=dtype)
except:
raise
isConverted = True
# additional assert can be added here to check that read_safe functions returns numpy array of proper dimension # additional assert can be added here to check that read_safe functions returns numpy array of proper dimension
if isConverted:
return dset_r return dset_converted
else:
return dset_64
#+end_src #+end_src
** Sparse data structures ** Sparse data structures

View File

@ -100,7 +100,7 @@ def recursive_populate_file(fname: str, paths: dict, detailed_source: dict) -> N
fname_new = join('populated',f'pop_{fname}') fname_new = join('populated',f'pop_{fname}')
templ_path = get_template_path(fname, paths) templ_path = get_template_path(fname, paths)
triggers = ['group_dset_dtype', 'group_dset_h5_dtype', 'default_prec', 'is_index', triggers = ['group_dset_dtype', 'group_dset_py_dtype', 'group_dset_h5_dtype', 'default_prec', 'is_index',
'group_dset_f_dtype_default', 'group_dset_f_dtype_double', 'group_dset_f_dtype_single', 'group_dset_f_dtype_default', 'group_dset_f_dtype_double', 'group_dset_f_dtype_single',
'group_dset_dtype_default', 'group_dset_dtype_double', 'group_dset_dtype_single', 'group_dset_dtype_default', 'group_dset_dtype_double', 'group_dset_dtype_single',
'group_dset_rank', 'group_dset_dim_list', 'group_dset_f_dims', 'group_dset_rank', 'group_dset_dim_list', 'group_dset_f_dims',
@ -542,6 +542,7 @@ def split_dset_dict_detailed (datasets: dict) -> tuple:
default_prec = '64' default_prec = '64'
group_dset_std_dtype_out = '24.16e' group_dset_std_dtype_out = '24.16e'
group_dset_std_dtype_in = 'lf' group_dset_std_dtype_in = 'lf'
group_dset_py_dtype = 'float'
elif v[0] in ['int', 'index']: elif v[0] in ['int', 'index']:
datatype = 'int64_t' datatype = 'int64_t'
group_dset_h5_dtype = 'native_int64' group_dset_h5_dtype = 'native_int64'
@ -554,6 +555,7 @@ def split_dset_dict_detailed (datasets: dict) -> tuple:
default_prec = '32' default_prec = '32'
group_dset_std_dtype_out = '" PRId64 "' group_dset_std_dtype_out = '" PRId64 "'
group_dset_std_dtype_in = '" SCNd64 "' group_dset_std_dtype_in = '" SCNd64 "'
group_dset_py_dtype = 'int'
elif v[0] == 'str': elif v[0] == 'str':
datatype = 'char*' datatype = 'char*'
group_dset_h5_dtype = '' group_dset_h5_dtype = ''
@ -566,6 +568,7 @@ def split_dset_dict_detailed (datasets: dict) -> tuple:
default_prec = '' default_prec = ''
group_dset_std_dtype_out = 's' group_dset_std_dtype_out = 's'
group_dset_std_dtype_in = 's' group_dset_std_dtype_in = 's'
group_dset_py_dtype = 'str'
# add the dset name for templates # add the dset name for templates
tmp_dict['group_dset'] = k tmp_dict['group_dset'] = k
@ -587,6 +590,7 @@ def split_dset_dict_detailed (datasets: dict) -> tuple:
tmp_dict['default_prec'] = default_prec tmp_dict['default_prec'] = default_prec
tmp_dict['group_dset_std_dtype_in'] = group_dset_std_dtype_in tmp_dict['group_dset_std_dtype_in'] = group_dset_std_dtype_in
tmp_dict['group_dset_std_dtype_out'] = group_dset_std_dtype_out tmp_dict['group_dset_std_dtype_out'] = group_dset_std_dtype_out
tmp_dict['group_dset_py_dtype'] = group_dset_py_dtype
# add the rank # add the rank
tmp_dict['rank'] = len(v[1]) tmp_dict['rank'] = len(v[1])
tmp_dict['group_dset_rank'] = str(tmp_dict['rank']) tmp_dict['group_dset_rank'] = str(tmp_dict['rank'])