mirror of
https://github.com/TREX-CoE/trexio.git
synced 2024-12-22 20:35:44 +01:00
reshape output arrays by default when reading from the file
This commit is contained in:
parent
7b5ebf6272
commit
5a2b4d96a7
@ -9,7 +9,7 @@ import trexio as tr
|
||||
#=========================================================#
|
||||
|
||||
# 0: TREXIO_HDF5 ; 1: TREXIO_TEXT
|
||||
TEST_TREXIO_BACKEND = tr.TREXIO_TEXT
|
||||
TEST_TREXIO_BACKEND = 0
|
||||
OUTPUT_FILENAME_TEXT = 'test_py_swig.dir'
|
||||
OUTPUT_FILENAME_HDF5 = 'test_py_swig.h5'
|
||||
|
||||
@ -36,6 +36,8 @@ except:
|
||||
#============ WRITE THE DATA IN THE TEST FILE ============#
|
||||
#=========================================================#
|
||||
|
||||
|
||||
|
||||
# create TREXIO file and open it for writing
|
||||
#test_file = tr.open(output_filename, 'w', TEST_TREXIO_BACKEND)
|
||||
test_file = tr.File(output_filename, mode='w', back_end=TEST_TREXIO_BACKEND)
|
||||
@ -114,6 +116,8 @@ tr.write_nucleus_label(test_file,labels)
|
||||
# tr.close function. This is only an issue when the data is getting written and read in the same session (e.g. in Jupyter notebook)
|
||||
del test_file
|
||||
|
||||
|
||||
|
||||
#==========================================================#
|
||||
#============ READ THE DATA FROM THE TEST FILE ============#
|
||||
#==========================================================#
|
||||
@ -156,7 +160,12 @@ for i in range(nucleus_num):
|
||||
|
||||
# read nuclear coordinates without providing optional argument dim
|
||||
rcoords_np = tr.read_nucleus_coord(test_file2)
|
||||
|
||||
assert rcoords_np.size==nucleus_num*3
|
||||
np.testing.assert_array_almost_equal(rcoords_np, np.array(coords).reshape(nucleus_num,3), decimal=8)
|
||||
|
||||
# set doReshape to False to get a flat 1D array (e.g. when reading matrices like nuclear coordinates)
|
||||
#rcoords_reshaped_2 = tr.read_nucleus_coord(test_file2, doReshape=False)
|
||||
|
||||
# read array of nuclear labels
|
||||
rlabels_2d = tr.read_nucleus_label(test_file2, dim=nucleus_num)
|
||||
|
@ -1993,7 +1993,7 @@ def write_$group_dset$(trexio_file, dset_w) -> None:
|
||||
#+end_src
|
||||
|
||||
#+begin_src python :tangle read_dset_data_front.py
|
||||
def read_$group_dset$(trexio_file, dim = None, dtype = None):
|
||||
def read_$group_dset$(trexio_file, dim = None, doReshape = None, dtype = None):
|
||||
"""Read the $group_dset$ array of numbers from the TREXIO file.
|
||||
|
||||
Parameters:
|
||||
@ -2007,6 +2007,10 @@ def read_$group_dset$(trexio_file, dim = None, dtype = None):
|
||||
|
||||
dtype (Optional): type
|
||||
NumPy data type of the output (e.g. np.int32|int16 or np.float32|float16). If specified, the output array will be converted from the default double precision.
|
||||
|
||||
doReshape (Optional): bool
|
||||
Flag to determine whether the output NumPy array has be reshaped or not. Be default, reshaping is performed
|
||||
based on the dimensions from the ~trex.json~ file. Otherwise, ~shape~ array (list or tuple) is used if provided by the user.
|
||||
|
||||
Returns:
|
||||
~dset_64~ if dtype is None or ~dset_converted~ otherwise: numpy.ndarray
|
||||
@ -2015,17 +2019,30 @@ def read_$group_dset$(trexio_file, dim = None, dtype = None):
|
||||
Raises:
|
||||
- Exception from AssertionError if TREXIO return code ~rc~ is different from TREXIO_SUCCESS and prints the error message using trexio_string_of_error.
|
||||
- Exception from some other error (e.g. RuntimeError).
|
||||
"""
|
||||
|
||||
"""
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
raise Exception("NumPy cannot be imported.")
|
||||
|
||||
if doReshape is None:
|
||||
doReshape = True
|
||||
|
||||
# if dim is not specified, read dimensions from the TREXIO file
|
||||
if dim is None:
|
||||
dims_list = None
|
||||
if dim is None or doReshape:
|
||||
$group_dset_dim$ = read_$group_dset_dim$(trexio_file)
|
||||
|
||||
dims_list = [$group_dset_dim_list$]
|
||||
dim = 1
|
||||
for i in range($group_dset_rank$):
|
||||
dim *= dims_list[i]
|
||||
|
||||
|
||||
|
||||
shape = tuple(dims_list)
|
||||
if shape is None and doReshape:
|
||||
raise ValueError("Reshaping failure: shape is None.")
|
||||
|
||||
try:
|
||||
rc, dset_64 = pytr.trexio_read_safe_$group_dset$_64(trexio_file.pytrexio_s, dim)
|
||||
@ -2039,10 +2056,6 @@ def read_$group_dset$(trexio_file, dim = None, dtype = None):
|
||||
isConverted = False
|
||||
dset_converted = None
|
||||
if dtype is not None:
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
raise Exception("NumPy cannot be imported.")
|
||||
|
||||
try:
|
||||
assert isinstance(dtype, type)
|
||||
@ -2059,8 +2072,21 @@ def read_$group_dset$(trexio_file, dim = None, dtype = None):
|
||||
isConverted = True
|
||||
|
||||
# additional assert can be added here to check that read_safe functions returns numpy array of proper dimension
|
||||
if isConverted:
|
||||
|
||||
if doReshape:
|
||||
try:
|
||||
# in-place reshaping did not work so I have to make a copy
|
||||
if isConverted:
|
||||
dset_reshaped = np.reshape(dset_converted, shape, order='C')
|
||||
else:
|
||||
dset_reshaped = np.reshape(dset_64, shape, order='C')
|
||||
except:
|
||||
raise
|
||||
|
||||
if isConverted:
|
||||
return dset_converted
|
||||
elif doReshape:
|
||||
return dset_reshaped
|
||||
else:
|
||||
return dset_64
|
||||
#+end_src
|
||||
|
Loading…
Reference in New Issue
Block a user