1
0
mirror of https://github.com/TREX-CoE/trexio.git synced 2024-10-02 06:21:05 +02:00

get rid of _safe suffix and add dimension to read_dset_str functions

This commit is contained in:
q-posev 2021-08-20 13:43:15 +03:00
parent 831973fc8e
commit 5d5025ae1d
2 changed files with 11 additions and 10 deletions

View File

@ -48,7 +48,7 @@ charges_np = np.array(charges, dtype=np.float64)
# function call below works with both lists and numpy arrays, dimension needed for memory-safety is derived
# from the size of the list/array by SWIG using typemaps from numpy.i
rc = tr.write_safe_nucleus_charge(test_file, charges_np)
rc = tr.write_nucleus_charge(test_file, charges_np)
# initialize arrays of nuclear indices as a list and convert it to numpy array
indices = [i for i in range(nucleus_num)]
@ -57,7 +57,7 @@ indices_np = np.array(indices, dtype=np.int32)
# function call below works with both lists and numpy arrays, dimension needed for memory-safety is derived
# from the size of the list/array by SWIG using typemacs from numpy.i
tr.write_safe_basis_nucleus_index(test_file, indices_np)
tr.write_basis_nucleus_index(test_file, indices_np)
point_group = 'B3U'
@ -91,23 +91,23 @@ rnum = tr.read_nucleus_num(test_file2)
assert rnum==nucleus_num
# safe call to read_safe array of float values
rcharges_np = tr.read_safe_nucleus_charge(test_file2, nucleus_num)
rcharges_np = tr.read_nucleus_charge(test_file2, dim=nucleus_num)
assert rcharges_np.dtype is np.dtype(np.float64)
np.testing.assert_array_almost_equal(rcharges_np, charges_np, decimal=8)
# unsafe call to read_safe should fail with error message corresponding to TREXIO_UNSAFE_ARRAY_DIM
try:
rcharges_fail = tr.read_safe_nucleus_charge(test_file2, nucleus_num*5)
rcharges_fail = tr.read_nucleus_charge(test_file2, dim=nucleus_num*5)
except Exception:
print("Unsafe call to safe API: successful")
# safe call to read_safe array of int values
rindices_np = tr.read_safe_basis_nucleus_index(test_file2, nucleus_num)
rindices_np = tr.read_basis_nucleus_index(test_file2, dim=nucleus_num)
assert rindices_np.dtype is np.dtype(np.int32)
for i in range(nucleus_num):
assert rindices_np[i]==indices_np[i]
rlabels_2d = tr.read_nucleus_label(test_file2)
rlabels_2d = tr.read_nucleus_label(test_file2, dim=nucleus_num)
print(rlabels_2d)
for i in range(nucleus_num):
assert rlabels_2d[i]==labels[i]

View File

@ -1749,7 +1749,7 @@ end interface
*** Python templates for front end
#+begin_src python :tangle write_dset_data_front.py
def write_safe_$group_dset$(trexio_file, dset_w) -> None:
def write_$group_dset$(trexio_file, dset_w) -> None:
try:
rc = trexio_write_safe_$group_dset$(trexio_file, dset_w)
@ -1762,7 +1762,7 @@ def write_safe_$group_dset$(trexio_file, dset_w) -> None:
#+end_src
#+begin_src python :tangle read_dset_data_front.py
def read_safe_$group_dset$(trexio_file, dim):
def read_$group_dset$(trexio_file, dim):
try:
rc, dset_r = trexio_read_safe_$group_dset$(trexio_file, dim)
@ -2249,7 +2249,7 @@ def write_$group_dset$(trexio_file, dset_w) -> None:
#+end_src
#+begin_src python :tangle read_dset_str_front.py
def read_$group_dset$(trexio_file):
def read_$group_dset$(trexio_file, dim):
try:
rc, dset_1d_r = trexio_read_$group_dset$_low(trexio_file, PYTREXIO_MAX_STR_LENGTH)
@ -2260,7 +2260,8 @@ def read_$group_dset$(trexio_file):
raise
try:
dset_2d_r = [item for item in dset_1d_r.split(TREXIO_DELIM) if item]
dset_full = dset_1d_r.split(TREXIO_DELIM)
dset_2d_r = [dset_full[i] for i in range(dim) if dset_full[i]]
assert dset_2d_r
except AssertionError:
raise TypeError(f"Output of {read_$group_dset$.__name__} function cannot be an empty list.")