1
0
mirror of https://github.com/TREX-CoE/trexio.git synced 2024-07-22 18:57:39 +02:00

reshape output arrays by default when reading from the file

This commit is contained in:
q-posev 2021-08-27 16:08:39 +03:00
parent 7b5ebf6272
commit 5a2b4d96a7
2 changed files with 46 additions and 11 deletions

View File

@ -9,7 +9,7 @@ import trexio as tr
#=========================================================# #=========================================================#
# 0: TREXIO_HDF5 ; 1: TREXIO_TEXT # 0: TREXIO_HDF5 ; 1: TREXIO_TEXT
TEST_TREXIO_BACKEND = tr.TREXIO_TEXT TEST_TREXIO_BACKEND = 0
OUTPUT_FILENAME_TEXT = 'test_py_swig.dir' OUTPUT_FILENAME_TEXT = 'test_py_swig.dir'
OUTPUT_FILENAME_HDF5 = 'test_py_swig.h5' OUTPUT_FILENAME_HDF5 = 'test_py_swig.h5'
@ -36,6 +36,8 @@ except:
#============ WRITE THE DATA IN THE TEST FILE ============# #============ WRITE THE DATA IN THE TEST FILE ============#
#=========================================================# #=========================================================#
# create TREXIO file and open it for writing # create TREXIO file and open it for writing
#test_file = tr.open(output_filename, 'w', TEST_TREXIO_BACKEND) #test_file = tr.open(output_filename, 'w', TEST_TREXIO_BACKEND)
test_file = tr.File(output_filename, mode='w', back_end=TEST_TREXIO_BACKEND) test_file = tr.File(output_filename, mode='w', back_end=TEST_TREXIO_BACKEND)
@ -114,6 +116,8 @@ tr.write_nucleus_label(test_file,labels)
# tr.close function. This is only an issue when the data is getting written and read in the same session (e.g. in Jupyter notebook) # tr.close function. This is only an issue when the data is getting written and read in the same session (e.g. in Jupyter notebook)
del test_file del test_file
#==========================================================# #==========================================================#
#============ READ THE DATA FROM THE TEST FILE ============# #============ READ THE DATA FROM THE TEST FILE ============#
#==========================================================# #==========================================================#
@ -156,7 +160,12 @@ for i in range(nucleus_num):
# read nuclear coordinates without providing optional argument dim # read nuclear coordinates without providing optional argument dim
rcoords_np = tr.read_nucleus_coord(test_file2) rcoords_np = tr.read_nucleus_coord(test_file2)
assert rcoords_np.size==nucleus_num*3 assert rcoords_np.size==nucleus_num*3
np.testing.assert_array_almost_equal(rcoords_np, np.array(coords).reshape(nucleus_num,3), decimal=8)
# set doReshape to False to get a flat 1D array (e.g. when reading matrices like nuclear coordinates)
#rcoords_reshaped_2 = tr.read_nucleus_coord(test_file2, doReshape=False)
# read array of nuclear labels # read array of nuclear labels
rlabels_2d = tr.read_nucleus_label(test_file2, dim=nucleus_num) rlabels_2d = tr.read_nucleus_label(test_file2, dim=nucleus_num)

View File

@ -1993,7 +1993,7 @@ def write_$group_dset$(trexio_file, dset_w) -> None:
#+end_src #+end_src
#+begin_src python :tangle read_dset_data_front.py #+begin_src python :tangle read_dset_data_front.py
def read_$group_dset$(trexio_file, dim = None, dtype = None): def read_$group_dset$(trexio_file, dim = None, doReshape = None, dtype = None):
"""Read the $group_dset$ array of numbers from the TREXIO file. """Read the $group_dset$ array of numbers from the TREXIO file.
Parameters: Parameters:
@ -2007,6 +2007,10 @@ def read_$group_dset$(trexio_file, dim = None, dtype = None):
dtype (Optional): type dtype (Optional): type
NumPy data type of the output (e.g. np.int32|int16 or np.float32|float16). If specified, the output array will be converted from the default double precision. NumPy data type of the output (e.g. np.int32|int16 or np.float32|float16). If specified, the output array will be converted from the default double precision.
doReshape (Optional): bool
Flag to determine whether the output NumPy array has be reshaped or not. Be default, reshaping is performed
based on the dimensions from the ~trex.json~ file. Otherwise, ~shape~ array (list or tuple) is used if provided by the user.
Returns: Returns:
~dset_64~ if dtype is None or ~dset_converted~ otherwise: numpy.ndarray ~dset_64~ if dtype is None or ~dset_converted~ otherwise: numpy.ndarray
@ -2015,17 +2019,30 @@ def read_$group_dset$(trexio_file, dim = None, dtype = None):
Raises: Raises:
- Exception from AssertionError if TREXIO return code ~rc~ is different from TREXIO_SUCCESS and prints the error message using trexio_string_of_error. - Exception from AssertionError if TREXIO return code ~rc~ is different from TREXIO_SUCCESS and prints the error message using trexio_string_of_error.
- Exception from some other error (e.g. RuntimeError). - Exception from some other error (e.g. RuntimeError).
""" """
try:
import numpy as np
except ImportError:
raise Exception("NumPy cannot be imported.")
if doReshape is None:
doReshape = True
# if dim is not specified, read dimensions from the TREXIO file # if dim is not specified, read dimensions from the TREXIO file
if dim is None: dims_list = None
if dim is None or doReshape:
$group_dset_dim$ = read_$group_dset_dim$(trexio_file) $group_dset_dim$ = read_$group_dset_dim$(trexio_file)
dims_list = [$group_dset_dim_list$] dims_list = [$group_dset_dim_list$]
dim = 1 dim = 1
for i in range($group_dset_rank$): for i in range($group_dset_rank$):
dim *= dims_list[i] dim *= dims_list[i]
shape = tuple(dims_list)
if shape is None and doReshape:
raise ValueError("Reshaping failure: shape is None.")
try: try:
rc, dset_64 = pytr.trexio_read_safe_$group_dset$_64(trexio_file.pytrexio_s, dim) rc, dset_64 = pytr.trexio_read_safe_$group_dset$_64(trexio_file.pytrexio_s, dim)
@ -2039,10 +2056,6 @@ def read_$group_dset$(trexio_file, dim = None, dtype = None):
isConverted = False isConverted = False
dset_converted = None dset_converted = None
if dtype is not None: if dtype is not None:
try:
import numpy as np
except ImportError:
raise Exception("NumPy cannot be imported.")
try: try:
assert isinstance(dtype, type) assert isinstance(dtype, type)
@ -2059,8 +2072,21 @@ def read_$group_dset$(trexio_file, dim = None, dtype = None):
isConverted = True isConverted = True
# additional assert can be added here to check that read_safe functions returns numpy array of proper dimension # additional assert can be added here to check that read_safe functions returns numpy array of proper dimension
if isConverted:
if doReshape:
try:
# in-place reshaping did not work so I have to make a copy
if isConverted:
dset_reshaped = np.reshape(dset_converted, shape, order='C')
else:
dset_reshaped = np.reshape(dset_64, shape, order='C')
except:
raise
if isConverted:
return dset_converted return dset_converted
elif doReshape:
return dset_reshaped
else: else:
return dset_64 return dset_64
#+end_src #+end_src