1
0
mirror of https://github.com/TREX-CoE/trexio.git synced 2025-01-09 20:48:53 +01:00

finish top-level Python API for sparse data

This commit is contained in:
q-posev 2021-12-27 14:00:46 +01:00
parent 722c546113
commit 31ffa574ab

View File

@ -2766,6 +2766,221 @@ interface
end interface end interface
#+end_src #+end_src
*** Python templates for front end
#+begin_src python :tangle write_dset_sparse_front.py
def write_$group_dset$(trexio_file: File, offset_file: int, buffer_size: int, indices: list, values: list) -> None:
"""Write the $group_dset$ indices and values in the TREXIO file.
Parameters:
trexio_file:
TREXIO File object.
offset_file: int
The number of integrals to be skipped in the file when writing.
buffer_size: int
The number of integrals to write in the file from the provided sparse arrays.
values: list OR numpy.ndarray
Array of $group_dset$ indices to be written. If array data type does not correspond to int32, the conversion is performed.
values: list OR numpy.ndarray
Array of $group_dset$ values to be written. If array data type does not correspond to float64, the conversion is performed.
Raises:
- Exception from AssertionError if TREXIO return code ~rc~ is different from TREXIO_SUCCESS and prints the error message using trexio_string_of_error.
- Exception from some other error (e.g. RuntimeError).
"""
try:
import numpy as np
except ImportError:
raise Exception("NumPy cannot be imported.")
if not isinstance(offset_file, int):
raise TypeError("offset_file argument has to be an integer.")
if not isinstance(buffer_size, int):
raise TypeError("buffer_size argument has to be an integer.")
if not isinstance(indices, (list, tuple, np.ndarray)):
raise TypeError("indices argument has to be an array (list, tuple or NumPy ndarray).")
if not isinstance(values, (list, tuple, np.ndarray)):
raise TypeError("values argument has to be an array (list, tuple or NumPy ndarray).")
convertIndices = False
convertValues = False
flattenIndices = False
if isinstance(indices, np.ndarray):
# convert to int32 if input indices are in a different precision
if not indices.dtype==np.int32:
convertIndices = True
if len(indices.shape) > 1:
flattenIndices = True
if convertIndices:
indices_32 = np.int32(indices).flatten()
else:
indices_32 = np.array(indices, dtype=np.int32).flatten()
else:
if convertIndices:
indices_32 = np.int32(indices)
else:
# if input array is a multidimensional list or tuple, we have to convert it
try:
doFlatten = True
# if list of indices is flat - the attempt to compute len(indices[0]) will raise a TypeError
ncol = len(indices[0])
indices_32 = np.array(indices, dtype=np.int32).flatten()
except TypeError:
doFlatten = False
pass
if isinstance(values, np.ndarray):
# convert to float64 if input values are in a different precision
if not values.dtype==np.float64:
convertValues = True
if convertValues:
values_64 = np.float64(values)
if (convertIndices or flattenIndices) and convertValues:
rc = pytr.trexio_write_safe_$group_dset$(trexio_file.pytrexio_s, offset_file, buffer_size, indices_32, values_64)
elif (convertIndices or flattenIndices) and not convertValues:
rc = pytr.trexio_write_safe_$group_dset$(trexio_file.pytrexio_s, offset_file, buffer_size, indices_32, values)
elif not (convertIndices or flattenIndices) and convertValues:
rc = pytr.trexio_write_safe_$group_dset$(trexio_file.pytrexio_s, offset_file, buffer_size, indices, values_64)
else:
rc = pytr.trexio_write_safe_$group_dset$(trexio_file.pytrexio_s, offset_file, buffer_size, indices, values)
if rc != TREXIO_SUCCESS:
raise Error(rc)
#+end_src
#+begin_src python :tangle read_dset_sparse_front.py
def read_$group_dset$(trexio_file: File, offset_file: int, buffer_size: int) -> tuple:
"""Read the $group_dset$ indices and values from the TREXIO file.
Parameters:
trexio_file:
TREXIO File object.
offset_file: int
The number of integrals to be skipped in the file when reading.
buffer_size: int
The number of integrals to read from the file.
Returns:
(indices, values, read_buf_size, eof_flag) tuple where
- indices and values are NumPy arrays [numpy.ndarray] with the default int32 and float64 precision, respectively;
- read_buf_size [int] is the number of integrals read from the trexio_file
(either strictly equal to buffer_size or less than buffer_size if EOF has been reached);
- eof_flag [bool] is True when EOF has been reached (i.e. when call to low-level pytrexio API returns TREXIO_END)
False otherwise.
Raises:
- Exception from AssertionError if TREXIO return code ~rc~ is different from TREXIO_SUCCESS
and prints the error message using trexio_string_of_error.
- Exception from some other error (e.g. RuntimeError).
"""
try:
import numpy as np
except ImportError:
raise Exception("NumPy cannot be imported.")
if not isinstance(offset_file, int):
raise TypeError("offset_file argument has to be an integer.")
if not isinstance(buffer_size, int):
raise TypeError("buffer_size argument has to be an integer.")
# read the number of integrals already in the file
integral_num = read_$group_dset$_size(trexio_file)
# additional modification needed to avoid allocating more memory than needed if EOF will be reached during read
overflow = offset_file + buffer_size - integral_num
eof_flag = False
if overflow > 0:
verified_size = buffer_size - overflow
eof_flag = True
else:
verified_size = buffer_size
# main call to the low-level (SWIG-wrapped) trexio_read function, which also requires the sizes of the output to be provided
# as the last 2 arguments (for numpy arrays of indices and values, respectively)
# read_buf_size contains the number of elements being read from the file, useful when EOF has been reached
rc, read_buf_size, indices, values = pytr.trexio_read_safe_$group_dset$(trexio_file.pytrexio_s,
offset_file,
verified_size,
verified_size * $group_dset_rank$,
verified_size)
if rc != TREXIO_SUCCESS:
raise Error(rc)
if read_buf_size == 0:
raise ValueError("No integrals have been read from the file.")
if indices is None or values is None:
raise ValueError("Returned NULL array from the low-level pytrexio API.")
# conversion to custom types can be performed on the user side, here we only reshape the returned flat array of indices according to group_dset_rank
shape = tuple([verified_size, $group_dset_rank$])
indices_reshaped = np.reshape(indices, shape, order='C')
return (indices_reshaped, values, read_buf_size, eof_flag)
def read_$group_dset$_size(trexio_file) -> int:
"""Read the number of $group_dset$ integrals stored in the TREXIO file.
Parameter is a ~TREXIO File~ object that has been created by a call to ~open~ function.
Returns:
~num_integral~: int
Integer value of corresponding to the size of the $group_dset$ sparse array from ~trexio_file~.
Raises:
- Exception from AssertionError if TREXIO return code ~rc~ is different from TREXIO_SUCCESS and prints the error message using trexio_string_of_error.
- Exception from some other error (e.g. RuntimeError).
"""
try:
rc, num_integral = pytr.trexio_read_$group_dset$_size(trexio_file.pytrexio_s)
if rc != TREXIO_SUCCESS:
raise Error(rc)
except:
raise
return num_integral
#+end_src
#+begin_src python :tangle has_dset_sparse_front.py
def has_$group_dset$(trexio_file) -> bool:
"""Check that $group_dset$ variable exists in the TREXIO file.
Parameter is a ~TREXIO File~ object that has been created by a call to ~open~ function.
Returns:
True if the variable exists, False otherwise
Raises:
- Exception from trexio.Error class if TREXIO return code ~rc~ is TREXIO_FAILURE and prints the error message using string_of_error.
- Exception from some other error (e.g. RuntimeError).
"""
try:
rc = pytr.trexio_has_$group_dset$(trexio_file.pytrexio_s)
if rc == TREXIO_FAILURE:
raise Error(rc)
except:
raise
if rc == TREXIO_SUCCESS:
return True
else:
return False
#+end_src
** Templates for front end has/read/write a dataset of strings ** Templates for front end has/read/write a dataset of strings
*** Introduction *** Introduction
This section concerns API calls related to datasets of strings. This section concerns API calls related to datasets of strings.