1
0
mirror of https://github.com/TREX-CoE/trexio.git synced 2024-12-23 04:43:57 +01:00

adapt write functions to receive multidimensional arrays or lists

This commit is contained in:
q-posev 2021-09-07 17:14:23 +02:00
parent 3e82fd9ae8
commit 96678fea2e
2 changed files with 51 additions and 38 deletions

View File

@ -85,18 +85,18 @@ trexio.write_basis_nucleus_index(test_file, indices_np)
# initialize a list of nuclear coordinates
coords = [
0.00000000 , 1.39250319 , 0.00000000 ,
-1.20594314 , 0.69625160 , 0.00000000 ,
-1.20594314 , -0.69625160 , 0.00000000 ,
0.00000000 , -1.39250319 , 0.00000000 ,
1.20594314 , -0.69625160 , 0.00000000 ,
1.20594314 , 0.69625160 , 0.00000000 ,
-2.14171677 , 1.23652075 , 0.00000000 ,
-2.14171677 , -1.23652075 , 0.00000000 ,
0.00000000 , -2.47304151 , 0.00000000 ,
2.14171677 , -1.23652075 , 0.00000000 ,
2.14171677 , 1.23652075 , 0.00000000 ,
0.00000000 , 2.47304151 , 0.00000000 ,
[ 0.00000000 , 1.39250319 , 0.00000000 ],
[-1.20594314 , 0.69625160 , 0.00000000 ],
[-1.20594314 , -0.69625160 , 0.00000000 ],
[ 0.00000000 , -1.39250319 , 0.00000000 ],
[ 1.20594314 , -0.69625160 , 0.00000000 ],
[ 1.20594314 , 0.69625160 , 0.00000000 ],
[-2.14171677 , 1.23652075 , 0.00000000 ],
[-2.14171677 , -1.23652075 , 0.00000000 ],
[ 0.00000000 , -2.47304151 , 0.00000000 ],
[ 2.14171677 , -1.23652075 , 0.00000000 ],
[ 2.14171677 , 1.23652075 , 0.00000000 ],
[ 0.00000000 , 2.47304151 , 0.00000000 ],
]
# write coordinates in the file

View File

@ -1973,35 +1973,48 @@ def write_$group_dset$(trexio_file, dset_w) -> None:
- Exception from some other error (e.g. RuntimeError).
"""
cutPrecision = False
if not isinstance(dset_w, (list, tuple)):
try:
import numpy as np
except ImportError:
raise Exception("NumPy cannot be imported.")
doConversion = False
doFlatten = False
if not isinstance(dset_w, (list, tuple)):
# if input array is not a list or tuple then it is probably a numpy array
if isinstance(dset_w, np.ndarray) and (not dset_w.dtype==np.int64 or not dset_w.dtype==np.float64):
cutPrecision = True
doConversion = True
if cutPrecision:
try:
if len(dset_w.shape) > 1:
doFlatten = True
if doConversion:
dset_64 = np.$group_dset_py_dtype$64(dset_w).flatten()
else:
dset_flat = np.array(dset_w, dtype=np.$group_dset_py_dtype$64).flatten()
else:
if doConversion:
dset_64 = np.$group_dset_py_dtype$64(dset_w)
except:
raise
else:
# if input array is a multidimensional list or tuple, we have to convert it
try:
if cutPrecision:
doFlatten = True
ncol = len(dset_w[0])
dset_flat = np.array(dset_w, dtype=np.$group_dset_py_dtype$64).flatten()
except TypeError:
doFlatten = False
pass
if doConversion:
rc = pytr.trexio_write_safe_$group_dset$_64(trexio_file.pytrexio_s, dset_64)
elif doFlatten:
rc = pytr.trexio_write_safe_$group_dset$_64(trexio_file.pytrexio_s, dset_flat)
else:
rc = pytr.trexio_write_safe_$group_dset$_64(trexio_file.pytrexio_s, dset_w)
if rc != TREXIO_SUCCESS:
raise Error(rc)
except:
raise
#+end_src
#+begin_src python :tangle read_dset_data_front.py