1
0
mirror of https://github.com/TREX-CoE/trexio.git synced 2025-01-08 20:33:36 +01:00

adapt write functions to receive multidimensional arrays or lists

This commit is contained in:
q-posev 2021-09-07 17:14:23 +02:00
parent 3e82fd9ae8
commit 96678fea2e
2 changed files with 51 additions and 38 deletions

View File

@ -85,18 +85,18 @@ trexio.write_basis_nucleus_index(test_file, indices_np)
# initialize a list of nuclear coordinates # initialize a list of nuclear coordinates
coords = [ coords = [
0.00000000 , 1.39250319 , 0.00000000 , [ 0.00000000 , 1.39250319 , 0.00000000 ],
-1.20594314 , 0.69625160 , 0.00000000 , [-1.20594314 , 0.69625160 , 0.00000000 ],
-1.20594314 , -0.69625160 , 0.00000000 , [-1.20594314 , -0.69625160 , 0.00000000 ],
0.00000000 , -1.39250319 , 0.00000000 , [ 0.00000000 , -1.39250319 , 0.00000000 ],
1.20594314 , -0.69625160 , 0.00000000 , [ 1.20594314 , -0.69625160 , 0.00000000 ],
1.20594314 , 0.69625160 , 0.00000000 , [ 1.20594314 , 0.69625160 , 0.00000000 ],
-2.14171677 , 1.23652075 , 0.00000000 , [-2.14171677 , 1.23652075 , 0.00000000 ],
-2.14171677 , -1.23652075 , 0.00000000 , [-2.14171677 , -1.23652075 , 0.00000000 ],
0.00000000 , -2.47304151 , 0.00000000 , [ 0.00000000 , -2.47304151 , 0.00000000 ],
2.14171677 , -1.23652075 , 0.00000000 , [ 2.14171677 , -1.23652075 , 0.00000000 ],
2.14171677 , 1.23652075 , 0.00000000 , [ 2.14171677 , 1.23652075 , 0.00000000 ],
0.00000000 , 2.47304151 , 0.00000000 , [ 0.00000000 , 2.47304151 , 0.00000000 ],
] ]
# write coordinates in the file # write coordinates in the file

View File

@ -1973,35 +1973,48 @@ def write_$group_dset$(trexio_file, dset_w) -> None:
- Exception from some other error (e.g. RuntimeError). - Exception from some other error (e.g. RuntimeError).
""" """
cutPrecision = False
if not isinstance(dset_w, (list, tuple)):
try:
import numpy as np
except ImportError:
raise Exception("NumPy cannot be imported.")
if isinstance(dset_w, np.ndarray) and (not dset_w.dtype==np.int64 or not dset_w.dtype==np.float64):
cutPrecision = True
if cutPrecision:
try:
dset_64 = np.$group_dset_py_dtype$64(dset_w)
except:
raise
try: try:
if cutPrecision: import numpy as np
rc = pytr.trexio_write_safe_$group_dset$_64(trexio_file.pytrexio_s, dset_64) except ImportError:
else: raise Exception("NumPy cannot be imported.")
rc = pytr.trexio_write_safe_$group_dset$_64(trexio_file.pytrexio_s, dset_w)
if rc != TREXIO_SUCCESS: doConversion = False
raise Error(rc) doFlatten = False
except: if not isinstance(dset_w, (list, tuple)):
raise # if input array is not a list or tuple then it is probably a numpy array
if isinstance(dset_w, np.ndarray) and (not dset_w.dtype==np.int64 or not dset_w.dtype==np.float64):
doConversion = True
if len(dset_w.shape) > 1:
doFlatten = True
if doConversion:
dset_64 = np.$group_dset_py_dtype$64(dset_w).flatten()
else:
dset_flat = np.array(dset_w, dtype=np.$group_dset_py_dtype$64).flatten()
else:
if doConversion:
dset_64 = np.$group_dset_py_dtype$64(dset_w)
else:
# if input array is a multidimensional list or tuple, we have to convert it
try:
doFlatten = True
ncol = len(dset_w[0])
dset_flat = np.array(dset_w, dtype=np.$group_dset_py_dtype$64).flatten()
except TypeError:
doFlatten = False
pass
if doConversion:
rc = pytr.trexio_write_safe_$group_dset$_64(trexio_file.pytrexio_s, dset_64)
elif doFlatten:
rc = pytr.trexio_write_safe_$group_dset$_64(trexio_file.pytrexio_s, dset_flat)
else:
rc = pytr.trexio_write_safe_$group_dset$_64(trexio_file.pytrexio_s, dset_w)
if rc != TREXIO_SUCCESS:
raise Error(rc)
#+end_src #+end_src
#+begin_src python :tangle read_dset_data_front.py #+begin_src python :tangle read_dset_data_front.py