mirror of
https://github.com/TREX-CoE/trexio.git
synced 2024-12-23 04:43:57 +01:00
adapt write functions to receive multidimensional arrays or lists
This commit is contained in:
parent
3e82fd9ae8
commit
96678fea2e
@ -85,18 +85,18 @@ trexio.write_basis_nucleus_index(test_file, indices_np)
|
||||
|
||||
# initialize a list of nuclear coordinates
|
||||
coords = [
|
||||
0.00000000 , 1.39250319 , 0.00000000 ,
|
||||
-1.20594314 , 0.69625160 , 0.00000000 ,
|
||||
-1.20594314 , -0.69625160 , 0.00000000 ,
|
||||
0.00000000 , -1.39250319 , 0.00000000 ,
|
||||
1.20594314 , -0.69625160 , 0.00000000 ,
|
||||
1.20594314 , 0.69625160 , 0.00000000 ,
|
||||
-2.14171677 , 1.23652075 , 0.00000000 ,
|
||||
-2.14171677 , -1.23652075 , 0.00000000 ,
|
||||
0.00000000 , -2.47304151 , 0.00000000 ,
|
||||
2.14171677 , -1.23652075 , 0.00000000 ,
|
||||
2.14171677 , 1.23652075 , 0.00000000 ,
|
||||
0.00000000 , 2.47304151 , 0.00000000 ,
|
||||
[ 0.00000000 , 1.39250319 , 0.00000000 ],
|
||||
[-1.20594314 , 0.69625160 , 0.00000000 ],
|
||||
[-1.20594314 , -0.69625160 , 0.00000000 ],
|
||||
[ 0.00000000 , -1.39250319 , 0.00000000 ],
|
||||
[ 1.20594314 , -0.69625160 , 0.00000000 ],
|
||||
[ 1.20594314 , 0.69625160 , 0.00000000 ],
|
||||
[-2.14171677 , 1.23652075 , 0.00000000 ],
|
||||
[-2.14171677 , -1.23652075 , 0.00000000 ],
|
||||
[ 0.00000000 , -2.47304151 , 0.00000000 ],
|
||||
[ 2.14171677 , -1.23652075 , 0.00000000 ],
|
||||
[ 2.14171677 , 1.23652075 , 0.00000000 ],
|
||||
[ 0.00000000 , 2.47304151 , 0.00000000 ],
|
||||
]
|
||||
|
||||
# write coordinates in the file
|
||||
|
@ -1973,35 +1973,48 @@ def write_$group_dset$(trexio_file, dset_w) -> None:
|
||||
- Exception from some other error (e.g. RuntimeError).
|
||||
"""
|
||||
|
||||
cutPrecision = False
|
||||
if not isinstance(dset_w, (list, tuple)):
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
raise Exception("NumPy cannot be imported.")
|
||||
|
||||
doConversion = False
|
||||
doFlatten = False
|
||||
if not isinstance(dset_w, (list, tuple)):
|
||||
# if input array is not a list or tuple then it is probably a numpy array
|
||||
if isinstance(dset_w, np.ndarray) and (not dset_w.dtype==np.int64 or not dset_w.dtype==np.float64):
|
||||
cutPrecision = True
|
||||
doConversion = True
|
||||
|
||||
|
||||
if cutPrecision:
|
||||
try:
|
||||
if len(dset_w.shape) > 1:
|
||||
doFlatten = True
|
||||
if doConversion:
|
||||
dset_64 = np.$group_dset_py_dtype$64(dset_w).flatten()
|
||||
else:
|
||||
dset_flat = np.array(dset_w, dtype=np.$group_dset_py_dtype$64).flatten()
|
||||
else:
|
||||
if doConversion:
|
||||
dset_64 = np.$group_dset_py_dtype$64(dset_w)
|
||||
except:
|
||||
raise
|
||||
|
||||
|
||||
else:
|
||||
# if input array is a multidimensional list or tuple, we have to convert it
|
||||
try:
|
||||
if cutPrecision:
|
||||
doFlatten = True
|
||||
ncol = len(dset_w[0])
|
||||
dset_flat = np.array(dset_w, dtype=np.$group_dset_py_dtype$64).flatten()
|
||||
except TypeError:
|
||||
doFlatten = False
|
||||
pass
|
||||
|
||||
|
||||
if doConversion:
|
||||
rc = pytr.trexio_write_safe_$group_dset$_64(trexio_file.pytrexio_s, dset_64)
|
||||
elif doFlatten:
|
||||
rc = pytr.trexio_write_safe_$group_dset$_64(trexio_file.pytrexio_s, dset_flat)
|
||||
else:
|
||||
rc = pytr.trexio_write_safe_$group_dset$_64(trexio_file.pytrexio_s, dset_w)
|
||||
|
||||
if rc != TREXIO_SUCCESS:
|
||||
raise Error(rc)
|
||||
except:
|
||||
raise
|
||||
|
||||
#+end_src
|
||||
|
||||
#+begin_src python :tangle read_dset_data_front.py
|
||||
|
Loading…
Reference in New Issue
Block a user