mirror of
https://github.com/TREX-CoE/trexio.git
synced 2025-01-10 04:58:31 +01:00
reshape output arrays by default when reading from the file
This commit is contained in:
parent
7b5ebf6272
commit
5a2b4d96a7
@ -9,7 +9,7 @@ import trexio as tr
|
|||||||
#=========================================================#
|
#=========================================================#
|
||||||
|
|
||||||
# 0: TREXIO_HDF5 ; 1: TREXIO_TEXT
|
# 0: TREXIO_HDF5 ; 1: TREXIO_TEXT
|
||||||
TEST_TREXIO_BACKEND = tr.TREXIO_TEXT
|
TEST_TREXIO_BACKEND = 0
|
||||||
OUTPUT_FILENAME_TEXT = 'test_py_swig.dir'
|
OUTPUT_FILENAME_TEXT = 'test_py_swig.dir'
|
||||||
OUTPUT_FILENAME_HDF5 = 'test_py_swig.h5'
|
OUTPUT_FILENAME_HDF5 = 'test_py_swig.h5'
|
||||||
|
|
||||||
@ -36,6 +36,8 @@ except:
|
|||||||
#============ WRITE THE DATA IN THE TEST FILE ============#
|
#============ WRITE THE DATA IN THE TEST FILE ============#
|
||||||
#=========================================================#
|
#=========================================================#
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# create TREXIO file and open it for writing
|
# create TREXIO file and open it for writing
|
||||||
#test_file = tr.open(output_filename, 'w', TEST_TREXIO_BACKEND)
|
#test_file = tr.open(output_filename, 'w', TEST_TREXIO_BACKEND)
|
||||||
test_file = tr.File(output_filename, mode='w', back_end=TEST_TREXIO_BACKEND)
|
test_file = tr.File(output_filename, mode='w', back_end=TEST_TREXIO_BACKEND)
|
||||||
@ -114,6 +116,8 @@ tr.write_nucleus_label(test_file,labels)
|
|||||||
# tr.close function. This is only an issue when the data is getting written and read in the same session (e.g. in Jupyter notebook)
|
# tr.close function. This is only an issue when the data is getting written and read in the same session (e.g. in Jupyter notebook)
|
||||||
del test_file
|
del test_file
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#==========================================================#
|
#==========================================================#
|
||||||
#============ READ THE DATA FROM THE TEST FILE ============#
|
#============ READ THE DATA FROM THE TEST FILE ============#
|
||||||
#==========================================================#
|
#==========================================================#
|
||||||
@ -156,7 +160,12 @@ for i in range(nucleus_num):
|
|||||||
|
|
||||||
# read nuclear coordinates without providing optional argument dim
|
# read nuclear coordinates without providing optional argument dim
|
||||||
rcoords_np = tr.read_nucleus_coord(test_file2)
|
rcoords_np = tr.read_nucleus_coord(test_file2)
|
||||||
|
|
||||||
assert rcoords_np.size==nucleus_num*3
|
assert rcoords_np.size==nucleus_num*3
|
||||||
|
np.testing.assert_array_almost_equal(rcoords_np, np.array(coords).reshape(nucleus_num,3), decimal=8)
|
||||||
|
|
||||||
|
# set doReshape to False to get a flat 1D array (e.g. when reading matrices like nuclear coordinates)
|
||||||
|
#rcoords_reshaped_2 = tr.read_nucleus_coord(test_file2, doReshape=False)
|
||||||
|
|
||||||
# read array of nuclear labels
|
# read array of nuclear labels
|
||||||
rlabels_2d = tr.read_nucleus_label(test_file2, dim=nucleus_num)
|
rlabels_2d = tr.read_nucleus_label(test_file2, dim=nucleus_num)
|
||||||
|
@ -1993,7 +1993,7 @@ def write_$group_dset$(trexio_file, dset_w) -> None:
|
|||||||
#+end_src
|
#+end_src
|
||||||
|
|
||||||
#+begin_src python :tangle read_dset_data_front.py
|
#+begin_src python :tangle read_dset_data_front.py
|
||||||
def read_$group_dset$(trexio_file, dim = None, dtype = None):
|
def read_$group_dset$(trexio_file, dim = None, doReshape = None, dtype = None):
|
||||||
"""Read the $group_dset$ array of numbers from the TREXIO file.
|
"""Read the $group_dset$ array of numbers from the TREXIO file.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
@ -2008,6 +2008,10 @@ def read_$group_dset$(trexio_file, dim = None, dtype = None):
|
|||||||
dtype (Optional): type
|
dtype (Optional): type
|
||||||
NumPy data type of the output (e.g. np.int32|int16 or np.float32|float16). If specified, the output array will be converted from the default double precision.
|
NumPy data type of the output (e.g. np.int32|int16 or np.float32|float16). If specified, the output array will be converted from the default double precision.
|
||||||
|
|
||||||
|
doReshape (Optional): bool
|
||||||
|
Flag to determine whether the output NumPy array has be reshaped or not. Be default, reshaping is performed
|
||||||
|
based on the dimensions from the ~trex.json~ file. Otherwise, ~shape~ array (list or tuple) is used if provided by the user.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
~dset_64~ if dtype is None or ~dset_converted~ otherwise: numpy.ndarray
|
~dset_64~ if dtype is None or ~dset_converted~ otherwise: numpy.ndarray
|
||||||
1D NumPy array with ~dim~ elements corresponding to $group_dset$ values read from the TREXIO file.
|
1D NumPy array with ~dim~ elements corresponding to $group_dset$ values read from the TREXIO file.
|
||||||
@ -2017,8 +2021,17 @@ def read_$group_dset$(trexio_file, dim = None, dtype = None):
|
|||||||
- Exception from some other error (e.g. RuntimeError).
|
- Exception from some other error (e.g. RuntimeError).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
import numpy as np
|
||||||
|
except ImportError:
|
||||||
|
raise Exception("NumPy cannot be imported.")
|
||||||
|
|
||||||
|
if doReshape is None:
|
||||||
|
doReshape = True
|
||||||
|
|
||||||
# if dim is not specified, read dimensions from the TREXIO file
|
# if dim is not specified, read dimensions from the TREXIO file
|
||||||
if dim is None:
|
dims_list = None
|
||||||
|
if dim is None or doReshape:
|
||||||
$group_dset_dim$ = read_$group_dset_dim$(trexio_file)
|
$group_dset_dim$ = read_$group_dset_dim$(trexio_file)
|
||||||
|
|
||||||
dims_list = [$group_dset_dim_list$]
|
dims_list = [$group_dset_dim_list$]
|
||||||
@ -2027,6 +2040,10 @@ def read_$group_dset$(trexio_file, dim = None, dtype = None):
|
|||||||
dim *= dims_list[i]
|
dim *= dims_list[i]
|
||||||
|
|
||||||
|
|
||||||
|
shape = tuple(dims_list)
|
||||||
|
if shape is None and doReshape:
|
||||||
|
raise ValueError("Reshaping failure: shape is None.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rc, dset_64 = pytr.trexio_read_safe_$group_dset$_64(trexio_file.pytrexio_s, dim)
|
rc, dset_64 = pytr.trexio_read_safe_$group_dset$_64(trexio_file.pytrexio_s, dim)
|
||||||
assert rc==TREXIO_SUCCESS
|
assert rc==TREXIO_SUCCESS
|
||||||
@ -2039,10 +2056,6 @@ def read_$group_dset$(trexio_file, dim = None, dtype = None):
|
|||||||
isConverted = False
|
isConverted = False
|
||||||
dset_converted = None
|
dset_converted = None
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
try:
|
|
||||||
import numpy as np
|
|
||||||
except ImportError:
|
|
||||||
raise Exception("NumPy cannot be imported.")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
assert isinstance(dtype, type)
|
assert isinstance(dtype, type)
|
||||||
@ -2059,8 +2072,21 @@ def read_$group_dset$(trexio_file, dim = None, dtype = None):
|
|||||||
isConverted = True
|
isConverted = True
|
||||||
|
|
||||||
# additional assert can be added here to check that read_safe functions returns numpy array of proper dimension
|
# additional assert can be added here to check that read_safe functions returns numpy array of proper dimension
|
||||||
|
|
||||||
|
if doReshape:
|
||||||
|
try:
|
||||||
|
# in-place reshaping did not work so I have to make a copy
|
||||||
|
if isConverted:
|
||||||
|
dset_reshaped = np.reshape(dset_converted, shape, order='C')
|
||||||
|
else:
|
||||||
|
dset_reshaped = np.reshape(dset_64, shape, order='C')
|
||||||
|
except:
|
||||||
|
raise
|
||||||
|
|
||||||
if isConverted:
|
if isConverted:
|
||||||
return dset_converted
|
return dset_converted
|
||||||
|
elif doReshape:
|
||||||
|
return dset_reshaped
|
||||||
else:
|
else:
|
||||||
return dset_64
|
return dset_64
|
||||||
#+end_src
|
#+end_src
|
||||||
|
Loading…
Reference in New Issue
Block a user