mirror of
https://github.com/TREX-CoE/trexio.git
synced 2025-01-03 10:06:01 +01:00
add type converters for numerical arrays in read/write Python functions
This commit is contained in:
parent
1dcd32ef7d
commit
4c28a4cac8
@ -9,7 +9,7 @@ import trexio as tr
|
||||
#=========================================================#
|
||||
|
||||
# 0: TREXIO_HDF5 ; 1: TREXIO_TEXT
|
||||
TEST_TREXIO_BACKEND = tr.TREXIO_TEXT
|
||||
TEST_TREXIO_BACKEND = tr.TREXIO_HDF5
|
||||
OUTPUT_FILENAME_TEXT = 'test_py_swig.dir'
|
||||
OUTPUT_FILENAME_HDF5 = 'test_py_swig.h5'
|
||||
|
||||
@ -49,7 +49,8 @@ tr.write_nucleus_num(test_file, nucleus_num)
|
||||
|
||||
# initialize charge arrays as a list and convert it to numpy array
|
||||
charges = [6., 6., 6., 6., 6., 6., 1., 1., 1., 1., 1., 1.]
|
||||
charges_np = np.array(charges, dtype=np.float64)
|
||||
#charges_np = np.array(charges, dtype=np.float32)
|
||||
charges_np = np.array(charges, dtype=np.int32)
|
||||
|
||||
# function call below works with both lists and numpy arrays, dimension needed for memory-safety is derived
|
||||
# from the size of the list/array by SWIG using typemaps from numpy.i
|
||||
@ -58,7 +59,7 @@ tr.write_nucleus_charge(test_file, charges_np)
|
||||
# initialize arrays of nuclear indices as a list and convert it to numpy array
|
||||
indices = [i for i in range(nucleus_num)]
|
||||
# type cast is important here because by default numpy transforms a list of integers into int64 array
|
||||
indices_np = np.array(indices, dtype=np.int32)
|
||||
indices_np = np.array(indices, dtype=np.int64)
|
||||
|
||||
# function call below works with both lists and numpy arrays, dimension needed for memory-safety is derived
|
||||
# from the size of the list/array by SWIG using typemacs from numpy.i
|
||||
@ -108,6 +109,7 @@ tr.write_nucleus_label(test_file,labels)
|
||||
# close TREXIO file
|
||||
tr.close(test_file)
|
||||
|
||||
|
||||
#==========================================================#
|
||||
#============ READ THE DATA FROM THE TEST FILE ============#
|
||||
#==========================================================#
|
||||
@ -131,10 +133,21 @@ except Exception:
|
||||
print("Unsafe call to safe API: checked")
|
||||
|
||||
# safe call to read array of int values (nuclear indices)
|
||||
rindices_np = tr.read_basis_nucleus_index(test_file2, dim=nucleus_num)
|
||||
assert rindices_np.dtype is np.dtype(np.int32)
|
||||
rindices_np_16 = tr.read_basis_nucleus_index(test_file2, dim=nucleus_num, dtype=np.int16)
|
||||
assert rindices_np_16.dtype is np.dtype(np.int16)
|
||||
for i in range(nucleus_num):
|
||||
assert rindices_np[i]==indices_np[i]
|
||||
assert rindices_np_16[i]==indices_np[i]
|
||||
|
||||
rindices_np_32 = tr.read_basis_nucleus_index(test_file2, dim=nucleus_num, dtype=np.int32)
|
||||
assert rindices_np_32.dtype is np.dtype(np.int32)
|
||||
for i in range(nucleus_num):
|
||||
assert rindices_np_32[i]==indices_np[i]
|
||||
|
||||
rindices_np_64 = tr.read_basis_nucleus_index(test_file2)
|
||||
assert rindices_np_64.dtype is np.dtype(np.int64)
|
||||
assert rindices_np_64.size==nucleus_num
|
||||
for i in range(nucleus_num):
|
||||
assert rindices_np_64[i]==indices_np[i]
|
||||
|
||||
# read nuclear coordinates without providing optional argument dim
|
||||
rcoords_np = tr.read_nucleus_coord(test_file2)
|
||||
|
@ -987,6 +987,7 @@ def close(trexio_file):
|
||||
| ~$group_dset_f_dtype_single$~ | Single precision datatype of the dataset [Fortran] | ~real(4)/integer(4)~ |
|
||||
| ~$group_dset_f_dtype_double$~ | Double precision datatype of the dataset [Fortran] | ~real(8)/integer(8)~ |
|
||||
| ~$group_dset_f_dims$~ | Dimensions in Fortran | ~(:,:)~ |
|
||||
| ~$group_dset_py_dtype$~ | Standard datatype of the dataset [Python] | ~float/int~ |
|
||||
|
||||
|
||||
Note: parent group name is always added to the child objects upon
|
||||
@ -1886,7 +1887,7 @@ def write_$group_dset$(trexio_file, dset_w) -> None:
|
||||
TREXIO file handle.
|
||||
|
||||
dset_w: list OR numpy.ndarray
|
||||
Array of $group_dset$ values to be written.
|
||||
Array of $group_dset$ values to be written. If array data type does not correspond to int64 or float64, the conversion is performed.
|
||||
|
||||
Raises:
|
||||
- Exception from AssertionError if TREXIO return code ~rc~ is different from TREXIO_SUCCESS and prints the error message using trexio_string_of_error.
|
||||
@ -1894,8 +1895,31 @@ def write_$group_dset$(trexio_file, dset_w) -> None:
|
||||
"""
|
||||
|
||||
|
||||
cutPrecision = False
|
||||
if not isinstance(dset_w, list):
|
||||
try:
|
||||
rc = trexio_write_safe_$group_dset$(trexio_file, dset_w)
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
raise Exception("NumPy cannot be imported.")
|
||||
|
||||
if isinstance(dset_w, np.ndarray) and (not dset_w.dtype==np.int64 or not dset_w.dtype==np.float64):
|
||||
cutPrecision = True
|
||||
|
||||
|
||||
if cutPrecision:
|
||||
try:
|
||||
# TODO: we have to make sure that this conversion does not introduce any noise in the data.
|
||||
dset_64 = np.$group_dset_py_dtype$64(dset_w)
|
||||
except:
|
||||
raise
|
||||
|
||||
|
||||
try:
|
||||
if cutPrecision:
|
||||
rc = trexio_write_safe_$group_dset$_64(trexio_file, dset_64)
|
||||
else:
|
||||
rc = trexio_write_safe_$group_dset$_64(trexio_file, dset_w)
|
||||
|
||||
assert rc==TREXIO_SUCCESS
|
||||
except AssertionError:
|
||||
raise Exception(trexio_string_of_error(rc))
|
||||
@ -1905,7 +1929,7 @@ def write_$group_dset$(trexio_file, dset_w) -> None:
|
||||
#+end_src
|
||||
|
||||
#+begin_src python :tangle read_dset_data_front.py
|
||||
def read_$group_dset$(trexio_file, dim = None):
|
||||
def read_$group_dset$(trexio_file, dim = None, dtype = None):
|
||||
"""Read the $group_dset$ array of numbers from the TREXIO file.
|
||||
|
||||
Parameters:
|
||||
@ -1917,8 +1941,11 @@ def read_$group_dset$(trexio_file, dim = None):
|
||||
Size of the block to be read from the file (i.e. how many items of $group_dset$ will be returned)
|
||||
If None, the function will read all necessary array dimensions from the file.
|
||||
|
||||
dtype (Optional): type
|
||||
NumPy data type of the output (e.g. np.int32|int16 or np.float32|float16). If specified, the output array will be converted from the default double precision.
|
||||
|
||||
Returns:
|
||||
~dset_r~: numpy.ndarray
|
||||
~dset_64~ if dtype is None or ~dset_converted~ otherwise: numpy.ndarray
|
||||
1D NumPy array with ~dim~ elements corresponding to $group_dset$ values read from the TREXIO file.
|
||||
|
||||
Raises:
|
||||
@ -1938,16 +1965,41 @@ def read_$group_dset$(trexio_file, dim = None):
|
||||
|
||||
|
||||
try:
|
||||
rc, dset_r = trexio_read_safe_$group_dset$(trexio_file, dim)
|
||||
rc, dset_64 = trexio_read_safe_$group_dset$_64(trexio_file, dim)
|
||||
assert rc==TREXIO_SUCCESS
|
||||
except AssertionError:
|
||||
raise Exception(trexio_string_of_error(rc))
|
||||
except:
|
||||
raise
|
||||
|
||||
# additional assert can be added here to check that read_safe functions returns numpy array of proper dimension
|
||||
|
||||
return dset_r
|
||||
isConverted = False
|
||||
dset_converted = None
|
||||
if dtype is not None:
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
raise Exception("NumPy cannot be imported.")
|
||||
|
||||
try:
|
||||
assert isinstance(dtype, type)
|
||||
except AssertionError:
|
||||
raise TypeError("dtype argument has to be an instance of the type class (e.g. np.float32).")
|
||||
|
||||
|
||||
if not dtype==np.int64 or not dtype==np.float64:
|
||||
try:
|
||||
dset_converted = np.array(dset_64, dtype=dtype)
|
||||
except:
|
||||
raise
|
||||
|
||||
isConverted = True
|
||||
|
||||
# additional assert can be added here to check that read_safe functions returns numpy array of proper dimension
|
||||
if isConverted:
|
||||
return dset_converted
|
||||
else:
|
||||
return dset_64
|
||||
#+end_src
|
||||
** Sparse data structures
|
||||
|
||||
|
@ -100,7 +100,7 @@ def recursive_populate_file(fname: str, paths: dict, detailed_source: dict) -> N
|
||||
fname_new = join('populated',f'pop_{fname}')
|
||||
templ_path = get_template_path(fname, paths)
|
||||
|
||||
triggers = ['group_dset_dtype', 'group_dset_h5_dtype', 'default_prec', 'is_index',
|
||||
triggers = ['group_dset_dtype', 'group_dset_py_dtype', 'group_dset_h5_dtype', 'default_prec', 'is_index',
|
||||
'group_dset_f_dtype_default', 'group_dset_f_dtype_double', 'group_dset_f_dtype_single',
|
||||
'group_dset_dtype_default', 'group_dset_dtype_double', 'group_dset_dtype_single',
|
||||
'group_dset_rank', 'group_dset_dim_list', 'group_dset_f_dims',
|
||||
@ -542,6 +542,7 @@ def split_dset_dict_detailed (datasets: dict) -> tuple:
|
||||
default_prec = '64'
|
||||
group_dset_std_dtype_out = '24.16e'
|
||||
group_dset_std_dtype_in = 'lf'
|
||||
group_dset_py_dtype = 'float'
|
||||
elif v[0] in ['int', 'index']:
|
||||
datatype = 'int64_t'
|
||||
group_dset_h5_dtype = 'native_int64'
|
||||
@ -554,6 +555,7 @@ def split_dset_dict_detailed (datasets: dict) -> tuple:
|
||||
default_prec = '32'
|
||||
group_dset_std_dtype_out = '" PRId64 "'
|
||||
group_dset_std_dtype_in = '" SCNd64 "'
|
||||
group_dset_py_dtype = 'int'
|
||||
elif v[0] == 'str':
|
||||
datatype = 'char*'
|
||||
group_dset_h5_dtype = ''
|
||||
@ -566,6 +568,7 @@ def split_dset_dict_detailed (datasets: dict) -> tuple:
|
||||
default_prec = ''
|
||||
group_dset_std_dtype_out = 's'
|
||||
group_dset_std_dtype_in = 's'
|
||||
group_dset_py_dtype = 'str'
|
||||
|
||||
# add the dset name for templates
|
||||
tmp_dict['group_dset'] = k
|
||||
@ -587,6 +590,7 @@ def split_dset_dict_detailed (datasets: dict) -> tuple:
|
||||
tmp_dict['default_prec'] = default_prec
|
||||
tmp_dict['group_dset_std_dtype_in'] = group_dset_std_dtype_in
|
||||
tmp_dict['group_dset_std_dtype_out'] = group_dset_std_dtype_out
|
||||
tmp_dict['group_dset_py_dtype'] = group_dset_py_dtype
|
||||
# add the rank
|
||||
tmp_dict['rank'] = len(v[1])
|
||||
tmp_dict['group_dset_rank'] = str(tmp_dict['rank'])
|
||||
|
Loading…
Reference in New Issue
Block a user