From 96678fea2eb5a3d4d31aa2af27844279b4fa8758 Mon Sep 17 00:00:00 2001 From: q-posev Date: Tue, 7 Sep 2021 17:14:23 +0200 Subject: [PATCH] adapt write functions to receive multidimensional arrays or lists --- python/test/test_api.py | 24 ++++----- src/templates_front/templator_front.org | 65 +++++++++++++++---------- 2 files changed, 51 insertions(+), 38 deletions(-) diff --git a/python/test/test_api.py b/python/test/test_api.py index 5f23a91..104a7f6 100644 --- a/python/test/test_api.py +++ b/python/test/test_api.py @@ -85,18 +85,18 @@ trexio.write_basis_nucleus_index(test_file, indices_np) # initialize a list of nuclear coordinates coords = [ - 0.00000000 , 1.39250319 , 0.00000000 , - -1.20594314 , 0.69625160 , 0.00000000 , - -1.20594314 , -0.69625160 , 0.00000000 , - 0.00000000 , -1.39250319 , 0.00000000 , - 1.20594314 , -0.69625160 , 0.00000000 , - 1.20594314 , 0.69625160 , 0.00000000 , - -2.14171677 , 1.23652075 , 0.00000000 , - -2.14171677 , -1.23652075 , 0.00000000 , - 0.00000000 , -2.47304151 , 0.00000000 , - 2.14171677 , -1.23652075 , 0.00000000 , - 2.14171677 , 1.23652075 , 0.00000000 , - 0.00000000 , 2.47304151 , 0.00000000 , + [ 0.00000000 , 1.39250319 , 0.00000000 ], + [-1.20594314 , 0.69625160 , 0.00000000 ], + [-1.20594314 , -0.69625160 , 0.00000000 ], + [ 0.00000000 , -1.39250319 , 0.00000000 ], + [ 1.20594314 , -0.69625160 , 0.00000000 ], + [ 1.20594314 , 0.69625160 , 0.00000000 ], + [-2.14171677 , 1.23652075 , 0.00000000 ], + [-2.14171677 , -1.23652075 , 0.00000000 ], + [ 0.00000000 , -2.47304151 , 0.00000000 ], + [ 2.14171677 , -1.23652075 , 0.00000000 ], + [ 2.14171677 , 1.23652075 , 0.00000000 ], + [ 0.00000000 , 2.47304151 , 0.00000000 ], ] # write coordinates in the file diff --git a/src/templates_front/templator_front.org b/src/templates_front/templator_front.org index f00e9e8..7af827d 100644 --- a/src/templates_front/templator_front.org +++ b/src/templates_front/templator_front.org @@ -1973,35 +1973,48 @@ def write_$group_dset$(trexio_file, dset_w) -> None: - Exception from some other error (e.g. RuntimeError). """ - cutPrecision = False - if not isinstance(dset_w, (list, tuple)): - try: - import numpy as np - except ImportError: - raise Exception("NumPy cannot be imported.") - - if isinstance(dset_w, np.ndarray) and (not dset_w.dtype==np.int64 or not dset_w.dtype==np.float64): - cutPrecision = True - - - if cutPrecision: - try: - dset_64 = np.$group_dset_py_dtype$64(dset_w) - except: - raise - - try: - if cutPrecision: - rc = pytr.trexio_write_safe_$group_dset$_64(trexio_file.pytrexio_s, dset_64) - else: - rc = pytr.trexio_write_safe_$group_dset$_64(trexio_file.pytrexio_s, dset_w) + import numpy as np + except ImportError: + raise Exception("NumPy cannot be imported.") - if rc != TREXIO_SUCCESS: - raise Error(rc) - except: - raise + doConversion = False + doFlatten = False + if not isinstance(dset_w, (list, tuple)): + # if input array is not a list or tuple then it is probably a numpy array + if isinstance(dset_w, np.ndarray) and (not dset_w.dtype==np.int64 or not dset_w.dtype==np.float64): + doConversion = True + if len(dset_w.shape) > 1: + doFlatten = True + if doConversion: + dset_64 = np.$group_dset_py_dtype$64(dset_w).flatten() + else: + dset_flat = np.array(dset_w, dtype=np.$group_dset_py_dtype$64).flatten() + else: + if doConversion: + dset_64 = np.$group_dset_py_dtype$64(dset_w) + + else: + # if input array is a multidimensional list or tuple, we have to convert it + try: + doFlatten = True + ncol = len(dset_w[0]) + dset_flat = np.array(dset_w, dtype=np.$group_dset_py_dtype$64).flatten() + except TypeError: + doFlatten = False + pass + + + if doConversion: + rc = pytr.trexio_write_safe_$group_dset$_64(trexio_file.pytrexio_s, dset_64) + elif doFlatten: + rc = pytr.trexio_write_safe_$group_dset$_64(trexio_file.pytrexio_s, dset_flat) + else: + rc = pytr.trexio_write_safe_$group_dset$_64(trexio_file.pytrexio_s, dset_w) + + if rc != TREXIO_SUCCESS: + raise Error(rc) #+end_src #+begin_src python :tangle read_dset_data_front.py