mirror of
https://github.com/TREX-CoE/trexio.git
synced 2024-11-03 20:54:07 +01:00
finish top-level Python API for sparse data
This commit is contained in:
parent
722c546113
commit
31ffa574ab
@ -2766,6 +2766,221 @@ interface
|
||||
end interface
|
||||
#+end_src
|
||||
|
||||
*** Python templates for front end
|
||||
|
||||
#+begin_src python :tangle write_dset_sparse_front.py
|
||||
def write_$group_dset$(trexio_file: File, offset_file: int, buffer_size: int, indices: list, values: list) -> None:
|
||||
"""Write the $group_dset$ indices and values in the TREXIO file.
|
||||
|
||||
Parameters:
|
||||
|
||||
trexio_file:
|
||||
TREXIO File object.
|
||||
|
||||
offset_file: int
|
||||
The number of integrals to be skipped in the file when writing.
|
||||
|
||||
buffer_size: int
|
||||
The number of integrals to write in the file from the provided sparse arrays.
|
||||
|
||||
values: list OR numpy.ndarray
|
||||
Array of $group_dset$ indices to be written. If array data type does not correspond to int32, the conversion is performed.
|
||||
|
||||
values: list OR numpy.ndarray
|
||||
Array of $group_dset$ values to be written. If array data type does not correspond to float64, the conversion is performed.
|
||||
|
||||
Raises:
|
||||
- Exception from AssertionError if TREXIO return code ~rc~ is different from TREXIO_SUCCESS and prints the error message using trexio_string_of_error.
|
||||
- Exception from some other error (e.g. RuntimeError).
|
||||
"""
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
raise Exception("NumPy cannot be imported.")
|
||||
|
||||
if not isinstance(offset_file, int):
|
||||
raise TypeError("offset_file argument has to be an integer.")
|
||||
if not isinstance(buffer_size, int):
|
||||
raise TypeError("buffer_size argument has to be an integer.")
|
||||
if not isinstance(indices, (list, tuple, np.ndarray)):
|
||||
raise TypeError("indices argument has to be an array (list, tuple or NumPy ndarray).")
|
||||
if not isinstance(values, (list, tuple, np.ndarray)):
|
||||
raise TypeError("values argument has to be an array (list, tuple or NumPy ndarray).")
|
||||
|
||||
convertIndices = False
|
||||
convertValues = False
|
||||
flattenIndices = False
|
||||
if isinstance(indices, np.ndarray):
|
||||
# convert to int32 if input indices are in a different precision
|
||||
if not indices.dtype==np.int32:
|
||||
convertIndices = True
|
||||
|
||||
if len(indices.shape) > 1:
|
||||
flattenIndices = True
|
||||
if convertIndices:
|
||||
indices_32 = np.int32(indices).flatten()
|
||||
else:
|
||||
indices_32 = np.array(indices, dtype=np.int32).flatten()
|
||||
else:
|
||||
if convertIndices:
|
||||
indices_32 = np.int32(indices)
|
||||
else:
|
||||
# if input array is a multidimensional list or tuple, we have to convert it
|
||||
try:
|
||||
doFlatten = True
|
||||
# if list of indices is flat - the attempt to compute len(indices[0]) will raise a TypeError
|
||||
ncol = len(indices[0])
|
||||
indices_32 = np.array(indices, dtype=np.int32).flatten()
|
||||
except TypeError:
|
||||
doFlatten = False
|
||||
pass
|
||||
|
||||
if isinstance(values, np.ndarray):
|
||||
# convert to float64 if input values are in a different precision
|
||||
if not values.dtype==np.float64:
|
||||
convertValues = True
|
||||
if convertValues:
|
||||
values_64 = np.float64(values)
|
||||
|
||||
if (convertIndices or flattenIndices) and convertValues:
|
||||
rc = pytr.trexio_write_safe_$group_dset$(trexio_file.pytrexio_s, offset_file, buffer_size, indices_32, values_64)
|
||||
elif (convertIndices or flattenIndices) and not convertValues:
|
||||
rc = pytr.trexio_write_safe_$group_dset$(trexio_file.pytrexio_s, offset_file, buffer_size, indices_32, values)
|
||||
elif not (convertIndices or flattenIndices) and convertValues:
|
||||
rc = pytr.trexio_write_safe_$group_dset$(trexio_file.pytrexio_s, offset_file, buffer_size, indices, values_64)
|
||||
else:
|
||||
rc = pytr.trexio_write_safe_$group_dset$(trexio_file.pytrexio_s, offset_file, buffer_size, indices, values)
|
||||
|
||||
if rc != TREXIO_SUCCESS:
|
||||
raise Error(rc)
|
||||
#+end_src
|
||||
|
||||
#+begin_src python :tangle read_dset_sparse_front.py
|
||||
def read_$group_dset$(trexio_file: File, offset_file: int, buffer_size: int) -> tuple:
|
||||
"""Read the $group_dset$ indices and values from the TREXIO file.
|
||||
|
||||
Parameters:
|
||||
|
||||
trexio_file:
|
||||
TREXIO File object.
|
||||
|
||||
offset_file: int
|
||||
The number of integrals to be skipped in the file when reading.
|
||||
|
||||
buffer_size: int
|
||||
The number of integrals to read from the file.
|
||||
|
||||
Returns:
|
||||
(indices, values, read_buf_size, eof_flag) tuple where
|
||||
- indices and values are NumPy arrays [numpy.ndarray] with the default int32 and float64 precision, respectively;
|
||||
- read_buf_size [int] is the number of integrals read from the trexio_file
|
||||
(either strictly equal to buffer_size or less than buffer_size if EOF has been reached);
|
||||
- eof_flag [bool] is True when EOF has been reached (i.e. when call to low-level pytrexio API returns TREXIO_END)
|
||||
False otherwise.
|
||||
|
||||
Raises:
|
||||
- Exception from AssertionError if TREXIO return code ~rc~ is different from TREXIO_SUCCESS
|
||||
and prints the error message using trexio_string_of_error.
|
||||
- Exception from some other error (e.g. RuntimeError).
|
||||
"""
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
raise Exception("NumPy cannot be imported.")
|
||||
|
||||
if not isinstance(offset_file, int):
|
||||
raise TypeError("offset_file argument has to be an integer.")
|
||||
if not isinstance(buffer_size, int):
|
||||
raise TypeError("buffer_size argument has to be an integer.")
|
||||
|
||||
|
||||
# read the number of integrals already in the file
|
||||
integral_num = read_$group_dset$_size(trexio_file)
|
||||
|
||||
# additional modification needed to avoid allocating more memory than needed if EOF will be reached during read
|
||||
overflow = offset_file + buffer_size - integral_num
|
||||
eof_flag = False
|
||||
if overflow > 0:
|
||||
verified_size = buffer_size - overflow
|
||||
eof_flag = True
|
||||
else:
|
||||
verified_size = buffer_size
|
||||
|
||||
# main call to the low-level (SWIG-wrapped) trexio_read function, which also requires the sizes of the output to be provided
|
||||
# as the last 2 arguments (for numpy arrays of indices and values, respectively)
|
||||
# read_buf_size contains the number of elements being read from the file, useful when EOF has been reached
|
||||
rc, read_buf_size, indices, values = pytr.trexio_read_safe_$group_dset$(trexio_file.pytrexio_s,
|
||||
offset_file,
|
||||
verified_size,
|
||||
verified_size * $group_dset_rank$,
|
||||
verified_size)
|
||||
if rc != TREXIO_SUCCESS:
|
||||
raise Error(rc)
|
||||
if read_buf_size == 0:
|
||||
raise ValueError("No integrals have been read from the file.")
|
||||
if indices is None or values is None:
|
||||
raise ValueError("Returned NULL array from the low-level pytrexio API.")
|
||||
|
||||
# conversion to custom types can be performed on the user side, here we only reshape the returned flat array of indices according to group_dset_rank
|
||||
shape = tuple([verified_size, $group_dset_rank$])
|
||||
indices_reshaped = np.reshape(indices, shape, order='C')
|
||||
|
||||
return (indices_reshaped, values, read_buf_size, eof_flag)
|
||||
|
||||
|
||||
def read_$group_dset$_size(trexio_file) -> int:
|
||||
"""Read the number of $group_dset$ integrals stored in the TREXIO file.
|
||||
|
||||
Parameter is a ~TREXIO File~ object that has been created by a call to ~open~ function.
|
||||
|
||||
Returns:
|
||||
~num_integral~: int
|
||||
Integer value of corresponding to the size of the $group_dset$ sparse array from ~trexio_file~.
|
||||
|
||||
Raises:
|
||||
- Exception from AssertionError if TREXIO return code ~rc~ is different from TREXIO_SUCCESS and prints the error message using trexio_string_of_error.
|
||||
- Exception from some other error (e.g. RuntimeError).
|
||||
"""
|
||||
|
||||
try:
|
||||
rc, num_integral = pytr.trexio_read_$group_dset$_size(trexio_file.pytrexio_s)
|
||||
if rc != TREXIO_SUCCESS:
|
||||
raise Error(rc)
|
||||
except:
|
||||
raise
|
||||
|
||||
return num_integral
|
||||
#+end_src
|
||||
|
||||
#+begin_src python :tangle has_dset_sparse_front.py
|
||||
def has_$group_dset$(trexio_file) -> bool:
|
||||
"""Check that $group_dset$ variable exists in the TREXIO file.
|
||||
|
||||
Parameter is a ~TREXIO File~ object that has been created by a call to ~open~ function.
|
||||
|
||||
Returns:
|
||||
True if the variable exists, False otherwise
|
||||
|
||||
Raises:
|
||||
- Exception from trexio.Error class if TREXIO return code ~rc~ is TREXIO_FAILURE and prints the error message using string_of_error.
|
||||
- Exception from some other error (e.g. RuntimeError).
|
||||
"""
|
||||
|
||||
try:
|
||||
rc = pytr.trexio_has_$group_dset$(trexio_file.pytrexio_s)
|
||||
if rc == TREXIO_FAILURE:
|
||||
raise Error(rc)
|
||||
except:
|
||||
raise
|
||||
|
||||
if rc == TREXIO_SUCCESS:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
#+end_src
|
||||
|
||||
** Templates for front end has/read/write a dataset of strings
|
||||
*** Introduction
|
||||
This section concerns API calls related to datasets of strings.
|
||||
|
Loading…
Reference in New Issue
Block a user