From 4c28a4cac8d7c47c2fe0dc7fceb10c2bb220c1a0 Mon Sep 17 00:00:00 2001 From: q-posev Date: Thu, 26 Aug 2021 13:14:46 +0300 Subject: [PATCH] add type converters for numerical arrays in read/write Python functions --- python/test/test_api.py | 25 ++++++--- src/templates_front/templator_front.org | 68 ++++++++++++++++++++++--- tools/generator_tools.py | 6 ++- 3 files changed, 84 insertions(+), 15 deletions(-) diff --git a/python/test/test_api.py b/python/test/test_api.py index 802e48a..c962ff7 100644 --- a/python/test/test_api.py +++ b/python/test/test_api.py @@ -9,7 +9,7 @@ import trexio as tr #=========================================================# # 0: TREXIO_HDF5 ; 1: TREXIO_TEXT -TEST_TREXIO_BACKEND = tr.TREXIO_TEXT +TEST_TREXIO_BACKEND = tr.TREXIO_HDF5 OUTPUT_FILENAME_TEXT = 'test_py_swig.dir' OUTPUT_FILENAME_HDF5 = 'test_py_swig.h5' @@ -49,7 +49,8 @@ tr.write_nucleus_num(test_file, nucleus_num) # initialize charge arrays as a list and convert it to numpy array charges = [6., 6., 6., 6., 6., 6., 1., 1., 1., 1., 1., 1.] -charges_np = np.array(charges, dtype=np.float64) +#charges_np = np.array(charges, dtype=np.float32) +charges_np = np.array(charges, dtype=np.int32) # function call below works with both lists and numpy arrays, dimension needed for memory-safety is derived # from the size of the list/array by SWIG using typemaps from numpy.i @@ -58,7 +59,7 @@ tr.write_nucleus_charge(test_file, charges_np) # initialize arrays of nuclear indices as a list and convert it to numpy array indices = [i for i in range(nucleus_num)] # type cast is important here because by default numpy transforms a list of integers into int64 array -indices_np = np.array(indices, dtype=np.int32) +indices_np = np.array(indices, dtype=np.int64) # function call below works with both lists and numpy arrays, dimension needed for memory-safety is derived # from the size of the list/array by SWIG using typemacs from numpy.i @@ -108,6 +109,7 @@ tr.write_nucleus_label(test_file,labels) # close TREXIO file tr.close(test_file) + #==========================================================# #============ READ THE DATA FROM THE TEST FILE ============# #==========================================================# @@ -131,10 +133,21 @@ except Exception: print("Unsafe call to safe API: checked") # safe call to read array of int values (nuclear indices) -rindices_np = tr.read_basis_nucleus_index(test_file2, dim=nucleus_num) -assert rindices_np.dtype is np.dtype(np.int32) +rindices_np_16 = tr.read_basis_nucleus_index(test_file2, dim=nucleus_num, dtype=np.int16) +assert rindices_np_16.dtype is np.dtype(np.int16) for i in range(nucleus_num): - assert rindices_np[i]==indices_np[i] + assert rindices_np_16[i]==indices_np[i] + +rindices_np_32 = tr.read_basis_nucleus_index(test_file2, dim=nucleus_num, dtype=np.int32) +assert rindices_np_32.dtype is np.dtype(np.int32) +for i in range(nucleus_num): + assert rindices_np_32[i]==indices_np[i] + +rindices_np_64 = tr.read_basis_nucleus_index(test_file2) +assert rindices_np_64.dtype is np.dtype(np.int64) +assert rindices_np_64.size==nucleus_num +for i in range(nucleus_num): + assert rindices_np_64[i]==indices_np[i] # read nuclear coordinates without providing optional argument dim rcoords_np = tr.read_nucleus_coord(test_file2) diff --git a/src/templates_front/templator_front.org b/src/templates_front/templator_front.org index 5b1aa4b..fbbb6ba 100644 --- a/src/templates_front/templator_front.org +++ b/src/templates_front/templator_front.org @@ -987,6 +987,7 @@ def close(trexio_file): | ~$group_dset_f_dtype_single$~ | Single precision datatype of the dataset [Fortran] | ~real(4)/integer(4)~ | | ~$group_dset_f_dtype_double$~ | Double precision datatype of the dataset [Fortran] | ~real(8)/integer(8)~ | | ~$group_dset_f_dims$~ | Dimensions in Fortran | ~(:,:)~ | + | ~$group_dset_py_dtype$~ | Standard datatype of the dataset [Python] | ~float/int~ | Note: parent group name is always added to the child objects upon @@ -1886,7 +1887,7 @@ def write_$group_dset$(trexio_file, dset_w) -> None: TREXIO file handle. dset_w: list OR numpy.ndarray - Array of $group_dset$ values to be written. + Array of $group_dset$ values to be written. If array data type does not correspond to int64 or float64, the conversion is performed. Raises: - Exception from AssertionError if TREXIO return code ~rc~ is different from TREXIO_SUCCESS and prints the error message using trexio_string_of_error. @@ -1894,8 +1895,31 @@ def write_$group_dset$(trexio_file, dset_w) -> None: """ + cutPrecision = False + if not isinstance(dset_w, list): + try: + import numpy as np + except ImportError: + raise Exception("NumPy cannot be imported.") + + if isinstance(dset_w, np.ndarray) and (not dset_w.dtype==np.int64 or not dset_w.dtype==np.float64): + cutPrecision = True + + + if cutPrecision: + try: + # TODO: we have to make sure that this conversion does not introduce any noise in the data. + dset_64 = np.$group_dset_py_dtype$64(dset_w) + except: + raise + + try: - rc = trexio_write_safe_$group_dset$(trexio_file, dset_w) + if cutPrecision: + rc = trexio_write_safe_$group_dset$_64(trexio_file, dset_64) + else: + rc = trexio_write_safe_$group_dset$_64(trexio_file, dset_w) + assert rc==TREXIO_SUCCESS except AssertionError: raise Exception(trexio_string_of_error(rc)) @@ -1905,7 +1929,7 @@ def write_$group_dset$(trexio_file, dset_w) -> None: #+end_src #+begin_src python :tangle read_dset_data_front.py -def read_$group_dset$(trexio_file, dim = None): +def read_$group_dset$(trexio_file, dim = None, dtype = None): """Read the $group_dset$ array of numbers from the TREXIO file. Parameters: @@ -1917,8 +1941,11 @@ def read_$group_dset$(trexio_file, dim = None): Size of the block to be read from the file (i.e. how many items of $group_dset$ will be returned) If None, the function will read all necessary array dimensions from the file. + dtype (Optional): type + NumPy data type of the output (e.g. np.int32|int16 or np.float32|float16). If specified, the output array will be converted from the default double precision. + Returns: - ~dset_r~: numpy.ndarray + ~dset_64~ if dtype is None or ~dset_converted~ otherwise: numpy.ndarray 1D NumPy array with ~dim~ elements corresponding to $group_dset$ values read from the TREXIO file. Raises: @@ -1938,16 +1965,41 @@ def read_$group_dset$(trexio_file, dim = None): try: - rc, dset_r = trexio_read_safe_$group_dset$(trexio_file, dim) + rc, dset_64 = trexio_read_safe_$group_dset$_64(trexio_file, dim) assert rc==TREXIO_SUCCESS except AssertionError: raise Exception(trexio_string_of_error(rc)) except: - raise + raise + + + isConverted = False + dset_converted = None + if dtype is not None: + try: + import numpy as np + except ImportError: + raise Exception("NumPy cannot be imported.") + + try: + assert isinstance(dtype, type) + except AssertionError: + raise TypeError("dtype argument has to be an instance of the type class (e.g. np.float32).") + + + if not dtype==np.int64 or not dtype==np.float64: + try: + dset_converted = np.array(dset_64, dtype=dtype) + except: + raise + + isConverted = True # additional assert can be added here to check that read_safe functions returns numpy array of proper dimension - - return dset_r + if isConverted: + return dset_converted + else: + return dset_64 #+end_src ** Sparse data structures diff --git a/tools/generator_tools.py b/tools/generator_tools.py index 7f96175..18fb472 100644 --- a/tools/generator_tools.py +++ b/tools/generator_tools.py @@ -100,7 +100,7 @@ def recursive_populate_file(fname: str, paths: dict, detailed_source: dict) -> N fname_new = join('populated',f'pop_{fname}') templ_path = get_template_path(fname, paths) - triggers = ['group_dset_dtype', 'group_dset_h5_dtype', 'default_prec', 'is_index', + triggers = ['group_dset_dtype', 'group_dset_py_dtype', 'group_dset_h5_dtype', 'default_prec', 'is_index', 'group_dset_f_dtype_default', 'group_dset_f_dtype_double', 'group_dset_f_dtype_single', 'group_dset_dtype_default', 'group_dset_dtype_double', 'group_dset_dtype_single', 'group_dset_rank', 'group_dset_dim_list', 'group_dset_f_dims', @@ -542,6 +542,7 @@ def split_dset_dict_detailed (datasets: dict) -> tuple: default_prec = '64' group_dset_std_dtype_out = '24.16e' group_dset_std_dtype_in = 'lf' + group_dset_py_dtype = 'float' elif v[0] in ['int', 'index']: datatype = 'int64_t' group_dset_h5_dtype = 'native_int64' @@ -554,6 +555,7 @@ def split_dset_dict_detailed (datasets: dict) -> tuple: default_prec = '32' group_dset_std_dtype_out = '" PRId64 "' group_dset_std_dtype_in = '" SCNd64 "' + group_dset_py_dtype = 'int' elif v[0] == 'str': datatype = 'char*' group_dset_h5_dtype = '' @@ -566,6 +568,7 @@ def split_dset_dict_detailed (datasets: dict) -> tuple: default_prec = '' group_dset_std_dtype_out = 's' group_dset_std_dtype_in = 's' + group_dset_py_dtype = 'str' # add the dset name for templates tmp_dict['group_dset'] = k @@ -587,6 +590,7 @@ def split_dset_dict_detailed (datasets: dict) -> tuple: tmp_dict['default_prec'] = default_prec tmp_dict['group_dset_std_dtype_in'] = group_dset_std_dtype_in tmp_dict['group_dset_std_dtype_out'] = group_dset_std_dtype_out + tmp_dict['group_dset_py_dtype'] = group_dset_py_dtype # add the rank tmp_dict['rank'] = len(v[1]) tmp_dict['group_dset_rank'] = str(tmp_dict['rank'])