1
0
mirror of https://github.com/TREX-CoE/trexio.git synced 2025-01-03 10:06:01 +01:00

add type converters for numerical arrays in read/write Python functions

This commit is contained in:
q-posev 2021-08-26 13:14:46 +03:00
parent 1dcd32ef7d
commit 4c28a4cac8
3 changed files with 84 additions and 15 deletions

View File

@ -9,7 +9,7 @@ import trexio as tr
#=========================================================#
# 0: TREXIO_HDF5 ; 1: TREXIO_TEXT
TEST_TREXIO_BACKEND = tr.TREXIO_TEXT
TEST_TREXIO_BACKEND = tr.TREXIO_HDF5
OUTPUT_FILENAME_TEXT = 'test_py_swig.dir'
OUTPUT_FILENAME_HDF5 = 'test_py_swig.h5'
@ -49,7 +49,8 @@ tr.write_nucleus_num(test_file, nucleus_num)
# initialize charge arrays as a list and convert it to numpy array
charges = [6., 6., 6., 6., 6., 6., 1., 1., 1., 1., 1., 1.]
charges_np = np.array(charges, dtype=np.float64)
#charges_np = np.array(charges, dtype=np.float32)
charges_np = np.array(charges, dtype=np.int32)
# function call below works with both lists and numpy arrays, dimension needed for memory-safety is derived
# from the size of the list/array by SWIG using typemaps from numpy.i
@ -58,7 +59,7 @@ tr.write_nucleus_charge(test_file, charges_np)
# initialize arrays of nuclear indices as a list and convert it to numpy array
indices = [i for i in range(nucleus_num)]
# type cast is important here because by default numpy transforms a list of integers into int64 array
indices_np = np.array(indices, dtype=np.int32)
indices_np = np.array(indices, dtype=np.int64)
# function call below works with both lists and numpy arrays, dimension needed for memory-safety is derived
# from the size of the list/array by SWIG using typemacs from numpy.i
@ -108,6 +109,7 @@ tr.write_nucleus_label(test_file,labels)
# close TREXIO file
tr.close(test_file)
#==========================================================#
#============ READ THE DATA FROM THE TEST FILE ============#
#==========================================================#
@ -131,10 +133,21 @@ except Exception:
print("Unsafe call to safe API: checked")
# safe call to read array of int values (nuclear indices)
rindices_np = tr.read_basis_nucleus_index(test_file2, dim=nucleus_num)
assert rindices_np.dtype is np.dtype(np.int32)
rindices_np_16 = tr.read_basis_nucleus_index(test_file2, dim=nucleus_num, dtype=np.int16)
assert rindices_np_16.dtype is np.dtype(np.int16)
for i in range(nucleus_num):
assert rindices_np[i]==indices_np[i]
assert rindices_np_16[i]==indices_np[i]
rindices_np_32 = tr.read_basis_nucleus_index(test_file2, dim=nucleus_num, dtype=np.int32)
assert rindices_np_32.dtype is np.dtype(np.int32)
for i in range(nucleus_num):
assert rindices_np_32[i]==indices_np[i]
rindices_np_64 = tr.read_basis_nucleus_index(test_file2)
assert rindices_np_64.dtype is np.dtype(np.int64)
assert rindices_np_64.size==nucleus_num
for i in range(nucleus_num):
assert rindices_np_64[i]==indices_np[i]
# read nuclear coordinates without providing optional argument dim
rcoords_np = tr.read_nucleus_coord(test_file2)

View File

@ -987,6 +987,7 @@ def close(trexio_file):
| ~$group_dset_f_dtype_single$~ | Single precision datatype of the dataset [Fortran] | ~real(4)/integer(4)~ |
| ~$group_dset_f_dtype_double$~ | Double precision datatype of the dataset [Fortran] | ~real(8)/integer(8)~ |
| ~$group_dset_f_dims$~ | Dimensions in Fortran | ~(:,:)~ |
| ~$group_dset_py_dtype$~ | Standard datatype of the dataset [Python] | ~float/int~ |
Note: parent group name is always added to the child objects upon
@ -1886,7 +1887,7 @@ def write_$group_dset$(trexio_file, dset_w) -> None:
TREXIO file handle.
dset_w: list OR numpy.ndarray
Array of $group_dset$ values to be written.
Array of $group_dset$ values to be written. If array data type does not correspond to int64 or float64, the conversion is performed.
Raises:
- Exception from AssertionError if TREXIO return code ~rc~ is different from TREXIO_SUCCESS and prints the error message using trexio_string_of_error.
@ -1894,8 +1895,31 @@ def write_$group_dset$(trexio_file, dset_w) -> None:
"""
cutPrecision = False
if not isinstance(dset_w, list):
try:
rc = trexio_write_safe_$group_dset$(trexio_file, dset_w)
import numpy as np
except ImportError:
raise Exception("NumPy cannot be imported.")
if isinstance(dset_w, np.ndarray) and (not dset_w.dtype==np.int64 or not dset_w.dtype==np.float64):
cutPrecision = True
if cutPrecision:
try:
# TODO: we have to make sure that this conversion does not introduce any noise in the data.
dset_64 = np.$group_dset_py_dtype$64(dset_w)
except:
raise
try:
if cutPrecision:
rc = trexio_write_safe_$group_dset$_64(trexio_file, dset_64)
else:
rc = trexio_write_safe_$group_dset$_64(trexio_file, dset_w)
assert rc==TREXIO_SUCCESS
except AssertionError:
raise Exception(trexio_string_of_error(rc))
@ -1905,7 +1929,7 @@ def write_$group_dset$(trexio_file, dset_w) -> None:
#+end_src
#+begin_src python :tangle read_dset_data_front.py
def read_$group_dset$(trexio_file, dim = None):
def read_$group_dset$(trexio_file, dim = None, dtype = None):
"""Read the $group_dset$ array of numbers from the TREXIO file.
Parameters:
@ -1917,8 +1941,11 @@ def read_$group_dset$(trexio_file, dim = None):
Size of the block to be read from the file (i.e. how many items of $group_dset$ will be returned)
If None, the function will read all necessary array dimensions from the file.
dtype (Optional): type
NumPy data type of the output (e.g. np.int32|int16 or np.float32|float16). If specified, the output array will be converted from the default double precision.
Returns:
~dset_r~: numpy.ndarray
~dset_64~ if dtype is None or ~dset_converted~ otherwise: numpy.ndarray
1D NumPy array with ~dim~ elements corresponding to $group_dset$ values read from the TREXIO file.
Raises:
@ -1938,16 +1965,41 @@ def read_$group_dset$(trexio_file, dim = None):
try:
rc, dset_r = trexio_read_safe_$group_dset$(trexio_file, dim)
rc, dset_64 = trexio_read_safe_$group_dset$_64(trexio_file, dim)
assert rc==TREXIO_SUCCESS
except AssertionError:
raise Exception(trexio_string_of_error(rc))
except:
raise
# additional assert can be added here to check that read_safe functions returns numpy array of proper dimension
return dset_r
isConverted = False
dset_converted = None
if dtype is not None:
try:
import numpy as np
except ImportError:
raise Exception("NumPy cannot be imported.")
try:
assert isinstance(dtype, type)
except AssertionError:
raise TypeError("dtype argument has to be an instance of the type class (e.g. np.float32).")
if not dtype==np.int64 or not dtype==np.float64:
try:
dset_converted = np.array(dset_64, dtype=dtype)
except:
raise
isConverted = True
# additional assert can be added here to check that read_safe functions returns numpy array of proper dimension
if isConverted:
return dset_converted
else:
return dset_64
#+end_src
** Sparse data structures

View File

@ -100,7 +100,7 @@ def recursive_populate_file(fname: str, paths: dict, detailed_source: dict) -> N
fname_new = join('populated',f'pop_{fname}')
templ_path = get_template_path(fname, paths)
triggers = ['group_dset_dtype', 'group_dset_h5_dtype', 'default_prec', 'is_index',
triggers = ['group_dset_dtype', 'group_dset_py_dtype', 'group_dset_h5_dtype', 'default_prec', 'is_index',
'group_dset_f_dtype_default', 'group_dset_f_dtype_double', 'group_dset_f_dtype_single',
'group_dset_dtype_default', 'group_dset_dtype_double', 'group_dset_dtype_single',
'group_dset_rank', 'group_dset_dim_list', 'group_dset_f_dims',
@ -542,6 +542,7 @@ def split_dset_dict_detailed (datasets: dict) -> tuple:
default_prec = '64'
group_dset_std_dtype_out = '24.16e'
group_dset_std_dtype_in = 'lf'
group_dset_py_dtype = 'float'
elif v[0] in ['int', 'index']:
datatype = 'int64_t'
group_dset_h5_dtype = 'native_int64'
@ -554,6 +555,7 @@ def split_dset_dict_detailed (datasets: dict) -> tuple:
default_prec = '32'
group_dset_std_dtype_out = '" PRId64 "'
group_dset_std_dtype_in = '" SCNd64 "'
group_dset_py_dtype = 'int'
elif v[0] == 'str':
datatype = 'char*'
group_dset_h5_dtype = ''
@ -566,6 +568,7 @@ def split_dset_dict_detailed (datasets: dict) -> tuple:
default_prec = ''
group_dset_std_dtype_out = 's'
group_dset_std_dtype_in = 's'
group_dset_py_dtype = 'str'
# add the dset name for templates
tmp_dict['group_dset'] = k
@ -587,6 +590,7 @@ def split_dset_dict_detailed (datasets: dict) -> tuple:
tmp_dict['default_prec'] = default_prec
tmp_dict['group_dset_std_dtype_in'] = group_dset_std_dtype_in
tmp_dict['group_dset_std_dtype_out'] = group_dset_std_dtype_out
tmp_dict['group_dset_py_dtype'] = group_dset_py_dtype
# add the rank
tmp_dict['rank'] = len(v[1])
tmp_dict['group_dset_rank'] = str(tmp_dict['rank'])