From 5a2b4d96a7767fde03d2bfa275cc91b48877f5b1 Mon Sep 17 00:00:00 2001 From: q-posev Date: Fri, 27 Aug 2021 16:08:39 +0300 Subject: [PATCH] reshape output arrays by default when reading from the file --- python/test/test_api.py | 11 +++++- src/templates_front/templator_front.org | 46 +++++++++++++++++++------ 2 files changed, 46 insertions(+), 11 deletions(-) diff --git a/python/test/test_api.py b/python/test/test_api.py index ae421c5..a97b87e 100644 --- a/python/test/test_api.py +++ b/python/test/test_api.py @@ -9,7 +9,7 @@ import trexio as tr #=========================================================# # 0: TREXIO_HDF5 ; 1: TREXIO_TEXT -TEST_TREXIO_BACKEND = tr.TREXIO_TEXT +TEST_TREXIO_BACKEND = 0 OUTPUT_FILENAME_TEXT = 'test_py_swig.dir' OUTPUT_FILENAME_HDF5 = 'test_py_swig.h5' @@ -36,6 +36,8 @@ except: #============ WRITE THE DATA IN THE TEST FILE ============# #=========================================================# + + # create TREXIO file and open it for writing #test_file = tr.open(output_filename, 'w', TEST_TREXIO_BACKEND) test_file = tr.File(output_filename, mode='w', back_end=TEST_TREXIO_BACKEND) @@ -114,6 +116,8 @@ tr.write_nucleus_label(test_file,labels) # tr.close function. This is only an issue when the data is getting written and read in the same session (e.g. in Jupyter notebook) del test_file + + #==========================================================# #============ READ THE DATA FROM THE TEST FILE ============# #==========================================================# @@ -156,7 +160,12 @@ for i in range(nucleus_num): # read nuclear coordinates without providing optional argument dim rcoords_np = tr.read_nucleus_coord(test_file2) + assert rcoords_np.size==nucleus_num*3 +np.testing.assert_array_almost_equal(rcoords_np, np.array(coords).reshape(nucleus_num,3), decimal=8) + +# set doReshape to False to get a flat 1D array (e.g. when reading matrices like nuclear coordinates) +#rcoords_reshaped_2 = tr.read_nucleus_coord(test_file2, doReshape=False) # read array of nuclear labels rlabels_2d = tr.read_nucleus_label(test_file2, dim=nucleus_num) diff --git a/src/templates_front/templator_front.org b/src/templates_front/templator_front.org index d308d3c..028ed12 100644 --- a/src/templates_front/templator_front.org +++ b/src/templates_front/templator_front.org @@ -1993,7 +1993,7 @@ def write_$group_dset$(trexio_file, dset_w) -> None: #+end_src #+begin_src python :tangle read_dset_data_front.py -def read_$group_dset$(trexio_file, dim = None, dtype = None): +def read_$group_dset$(trexio_file, dim = None, doReshape = None, dtype = None): """Read the $group_dset$ array of numbers from the TREXIO file. Parameters: @@ -2007,6 +2007,10 @@ def read_$group_dset$(trexio_file, dim = None, dtype = None): dtype (Optional): type NumPy data type of the output (e.g. np.int32|int16 or np.float32|float16). If specified, the output array will be converted from the default double precision. + + doReshape (Optional): bool + Flag to determine whether the output NumPy array has be reshaped or not. Be default, reshaping is performed + based on the dimensions from the ~trex.json~ file. Otherwise, ~shape~ array (list or tuple) is used if provided by the user. Returns: ~dset_64~ if dtype is None or ~dset_converted~ otherwise: numpy.ndarray @@ -2015,17 +2019,30 @@ def read_$group_dset$(trexio_file, dim = None, dtype = None): Raises: - Exception from AssertionError if TREXIO return code ~rc~ is different from TREXIO_SUCCESS and prints the error message using trexio_string_of_error. - Exception from some other error (e.g. RuntimeError). - """ - +""" + + try: + import numpy as np + except ImportError: + raise Exception("NumPy cannot be imported.") + + if doReshape is None: + doReshape = True + # if dim is not specified, read dimensions from the TREXIO file - if dim is None: + dims_list = None + if dim is None or doReshape: $group_dset_dim$ = read_$group_dset_dim$(trexio_file) dims_list = [$group_dset_dim_list$] dim = 1 for i in range($group_dset_rank$): dim *= dims_list[i] - + + + shape = tuple(dims_list) + if shape is None and doReshape: + raise ValueError("Reshaping failure: shape is None.") try: rc, dset_64 = pytr.trexio_read_safe_$group_dset$_64(trexio_file.pytrexio_s, dim) @@ -2039,10 +2056,6 @@ def read_$group_dset$(trexio_file, dim = None, dtype = None): isConverted = False dset_converted = None if dtype is not None: - try: - import numpy as np - except ImportError: - raise Exception("NumPy cannot be imported.") try: assert isinstance(dtype, type) @@ -2059,8 +2072,21 @@ def read_$group_dset$(trexio_file, dim = None, dtype = None): isConverted = True # additional assert can be added here to check that read_safe functions returns numpy array of proper dimension - if isConverted: + + if doReshape: + try: + # in-place reshaping did not work so I have to make a copy + if isConverted: + dset_reshaped = np.reshape(dset_converted, shape, order='C') + else: + dset_reshaped = np.reshape(dset_64, shape, order='C') + except: + raise + + if isConverted: return dset_converted + elif doReshape: + return dset_reshaped else: return dset_64 #+end_src