1
0
mirror of https://github.com/TREX-CoE/trexio.git synced 2025-01-09 04:43:24 +01:00

reshape output arrays by default when reading from the file

This commit is contained in:
q-posev 2021-08-27 16:08:39 +03:00
parent 7b5ebf6272
commit 5a2b4d96a7
2 changed files with 46 additions and 11 deletions

View File

@ -9,7 +9,7 @@ import trexio as tr
#=========================================================#
# 0: TREXIO_HDF5 ; 1: TREXIO_TEXT
TEST_TREXIO_BACKEND = tr.TREXIO_TEXT
TEST_TREXIO_BACKEND = 0
OUTPUT_FILENAME_TEXT = 'test_py_swig.dir'
OUTPUT_FILENAME_HDF5 = 'test_py_swig.h5'
@ -36,6 +36,8 @@ except:
#============ WRITE THE DATA IN THE TEST FILE ============#
#=========================================================#
# create TREXIO file and open it for writing
#test_file = tr.open(output_filename, 'w', TEST_TREXIO_BACKEND)
test_file = tr.File(output_filename, mode='w', back_end=TEST_TREXIO_BACKEND)
@ -114,6 +116,8 @@ tr.write_nucleus_label(test_file,labels)
# tr.close function. This is only an issue when the data is getting written and read in the same session (e.g. in Jupyter notebook)
del test_file
#==========================================================#
#============ READ THE DATA FROM THE TEST FILE ============#
#==========================================================#
@ -156,7 +160,12 @@ for i in range(nucleus_num):
# read nuclear coordinates without providing optional argument dim
rcoords_np = tr.read_nucleus_coord(test_file2)
assert rcoords_np.size==nucleus_num*3
np.testing.assert_array_almost_equal(rcoords_np, np.array(coords).reshape(nucleus_num,3), decimal=8)
# set doReshape to False to get a flat 1D array (e.g. when reading matrices like nuclear coordinates)
#rcoords_reshaped_2 = tr.read_nucleus_coord(test_file2, doReshape=False)
# read array of nuclear labels
rlabels_2d = tr.read_nucleus_label(test_file2, dim=nucleus_num)

View File

@ -1993,7 +1993,7 @@ def write_$group_dset$(trexio_file, dset_w) -> None:
#+end_src
#+begin_src python :tangle read_dset_data_front.py
def read_$group_dset$(trexio_file, dim = None, dtype = None):
def read_$group_dset$(trexio_file, dim = None, doReshape = None, dtype = None):
"""Read the $group_dset$ array of numbers from the TREXIO file.
Parameters:
@ -2008,6 +2008,10 @@ def read_$group_dset$(trexio_file, dim = None, dtype = None):
dtype (Optional): type
NumPy data type of the output (e.g. np.int32|int16 or np.float32|float16). If specified, the output array will be converted from the default double precision.
doReshape (Optional): bool
Flag to determine whether the output NumPy array has be reshaped or not. Be default, reshaping is performed
based on the dimensions from the ~trex.json~ file. Otherwise, ~shape~ array (list or tuple) is used if provided by the user.
Returns:
~dset_64~ if dtype is None or ~dset_converted~ otherwise: numpy.ndarray
1D NumPy array with ~dim~ elements corresponding to $group_dset$ values read from the TREXIO file.
@ -2015,10 +2019,19 @@ def read_$group_dset$(trexio_file, dim = None, dtype = None):
Raises:
- Exception from AssertionError if TREXIO return code ~rc~ is different from TREXIO_SUCCESS and prints the error message using trexio_string_of_error.
- Exception from some other error (e.g. RuntimeError).
"""
"""
try:
import numpy as np
except ImportError:
raise Exception("NumPy cannot be imported.")
if doReshape is None:
doReshape = True
# if dim is not specified, read dimensions from the TREXIO file
if dim is None:
dims_list = None
if dim is None or doReshape:
$group_dset_dim$ = read_$group_dset_dim$(trexio_file)
dims_list = [$group_dset_dim_list$]
@ -2027,6 +2040,10 @@ def read_$group_dset$(trexio_file, dim = None, dtype = None):
dim *= dims_list[i]
shape = tuple(dims_list)
if shape is None and doReshape:
raise ValueError("Reshaping failure: shape is None.")
try:
rc, dset_64 = pytr.trexio_read_safe_$group_dset$_64(trexio_file.pytrexio_s, dim)
assert rc==TREXIO_SUCCESS
@ -2039,10 +2056,6 @@ def read_$group_dset$(trexio_file, dim = None, dtype = None):
isConverted = False
dset_converted = None
if dtype is not None:
try:
import numpy as np
except ImportError:
raise Exception("NumPy cannot be imported.")
try:
assert isinstance(dtype, type)
@ -2059,8 +2072,21 @@ def read_$group_dset$(trexio_file, dim = None, dtype = None):
isConverted = True
# additional assert can be added here to check that read_safe functions returns numpy array of proper dimension
if doReshape:
try:
# in-place reshaping did not work so I have to make a copy
if isConverted:
dset_reshaped = np.reshape(dset_converted, shape, order='C')
else:
dset_reshaped = np.reshape(dset_64, shape, order='C')
except:
raise
if isConverted:
return dset_converted
elif doReshape:
return dset_reshaped
else:
return dset_64
#+end_src