1
0
mirror of https://github.com/TREX-CoE/trexio.git synced 2024-11-03 12:43:55 +01:00

Add Python API for external array I/O

This commit is contained in:
q-posev 2023-05-08 14:52:50 +02:00
parent 11bf772a34
commit 8ac64c2a75
No known key found for this signature in database
4 changed files with 336 additions and 2 deletions

View File

@ -47,3 +47,6 @@ det_test = [1, 2, 3, 2, 1, 3]
orb_up_test = [0, 65, 128, 129]
orb_dn_test = [1, 64, 128, 129]
external_2Dfloat_name = "test external float matrix"
external_1Dint32_name = "test external int32 vector"

View File

@ -149,6 +149,19 @@ class TestIO:
assert trexio.has_nucleus_coord(self.test_file)
def test_external_array(self):
"""Write external arrays."""
self.open()
assert not trexio.has_external_array(self.test_file, external_2Dfloat_name)
trexio.write_external_array(self.test_file, nucleus_coord, external_2Dfloat_name)
assert trexio.has_external_array(self.test_file, external_2Dfloat_name)
assert not trexio.has_external_array(self.test_file, external_1Dint32_name)
trexio.write_external_array(self.test_file, np.array(nucleus_charge,dtype=np.int32), external_1Dint32_name)
assert trexio.has_external_array(self.test_file, external_1Dint32_name)
def test_indices(self):
"""Write array of indices."""
self.open()
@ -252,6 +265,21 @@ class TestIO:
np.testing.assert_array_almost_equal(coords_np, np.array(nucleus_coord).reshape(nucleus_num,3), decimal=8)
def test_read_external_array(self):
"""Read external arrays."""
self.open(mode='r')
# read nuclear coordinates without providing optional argument dim
coords_external_np = trexio.read_external_array(self.test_file, name=external_2Dfloat_name, dtype="float64", size=nucleus_num*3)
assert coords_external_np.dtype is np.dtype(np.float64)
assert coords_external_np.size == nucleus_num * 3
np.testing.assert_array_almost_equal(coords_external_np.reshape(nucleus_num,3), np.array(nucleus_coord).reshape(nucleus_num,3), decimal=8)
charge_external_np = trexio.read_external_array(self.test_file, name=external_1Dint32_name, dtype="int32", size=nucleus_num)
assert charge_external_np.dtype is np.dtype(np.int32)
assert charge_external_np.size == nucleus_num
np.testing.assert_array_almost_equal(charge_external_np, np.array(nucleus_charge, dtype=np.int32))
def test_read_errors(self):
"""Test some reading errors."""
self.open(mode='r')

View File

@ -108,6 +108,9 @@ import_array();
/* For some reasons SWIG does not apply the proper bitfield_t typemap, so one has to manually specify int64_t* ARGOUT_ARRAY1 below */
%apply (int64_t* ARGOUT_ARRAY1, int32_t DIM1) {(bitfield_t* const bit_list, const int32_t N_int)};
/* For passing dimensions of external arrays fron Python front to C back */
%apply (uint64_t* IN_ARRAY1, int32_t DIM1) {(const uint64_t* dims_in, const int32_t dims_dim_in)};
/* This tells SWIG to treat char ** dset_in pattern as a special case
Enables access to trexio_[...]_write_dset_str set of functions directly, i.e.
by converting input list of strings from Python into char ** of C

View File

@ -1993,6 +1993,23 @@ trexio_write_external_$suffix$_array(trexio_t* const file, const $c_type$* array
return TREXIO_FAILURE;
}
trexio_exit_code
trexio_write_safe_external_$suffix$_array(trexio_t* const file, const $c_type$* dset_in, const int64_t dim_in, const uint32_t rank, const uint64_t* dims_in, const int32_t dims_dim_in, const char* name)
{
if (file == NULL) return TREXIO_INVALID_ARG_1;
if (dset_in == NULL) return TREXIO_INVALID_ARG_2;
if (dim_in <= 0) return TREXIO_INVALID_ARG_3;
if (rank == 0) return TREXIO_INVALID_ARG_4;
if (dims_in == NULL) return TREXIO_INVALID_ARG_5;
if (dims_dim_in == 0) return TREXIO_INVALID_ARG_6;
if (name == NULL) return TREXIO_INVALID_ARG_7;
for (uint32_t i=0; i<rank; i++){
if (dims_in[i] == 0) return TREXIO_INVALID_ARG_5;
}
return trexio_write_external_$suffix$_array(file, dset_in, rank, dims_in, name);
}
#+end_src
#+NAME:template_read_func_c
@ -2022,13 +2039,26 @@ trexio_read_external_$suffix$_array(trexio_t* const file, $c_type$* const array,
return TREXIO_FAILURE;
}
trexio_exit_code
trexio_read_safe_external_$suffix$_array(trexio_t* const file, $c_type$* const dset_out, const int64_t dim_out, const char* name)
{
if (file == NULL) return TREXIO_INVALID_ARG_1;
if (dset_out == NULL) return TREXIO_INVALID_ARG_2;
if (dim_out <= 0) return TREXIO_INVALID_ARG_3;
if (name == NULL) return TREXIO_INVALID_ARG_4;
return trexio_read_external_$suffix$_array(file, dset_out, name);
}
#+end_src
#+begin_src python :var table=table-external-datatypes :results drawer :noweb yes
""" This script generates the C functions for generic I/O (external group) """
template_write_func_h = "trexio_exit_code trexio_write_external_$suffix$_array(trexio_t* const file, const $c_type$* array, const uint32_t rank, const uint64_t* dimensions, const char* name);"
template_read_func_h = "trexio_exit_code trexio_read_external_$suffix$_array(trexio_t* const file, $c_type$* const array, const char* name);"
template_write_func_h = "trexio_exit_code trexio_write_external_$suffix$_array(trexio_t* const file, const $c_type$* array, const uint32_t rank, const uint64_t* dimensions, const char* name);\n"
template_write_func_h += "trexio_exit_code trexio_write_safe_external_$suffix$_array(trexio_t* const file, const $c_type$* dset_in, const int64_t dim_in, const uint32_t rank, const uint64_t* dims_in, const int32_t dims_dim_in, const char* name);"
template_read_func_h = "trexio_exit_code trexio_read_external_$suffix$_array(trexio_t* const file, $c_type$* const array, const char* name);\n"
template_read_func_h += "trexio_exit_code trexio_read_safe_external_$suffix$_array(trexio_t* const file, $c_type$* const dset_out, const int64_t dim_out, const char* name);"
template_write_func_c = """
<<template_write_func_c>>
"""
@ -2060,13 +2090,21 @@ return '\n'.join(result_h + ['\n'] + result_c)
:results:
#+begin_src c :tangle prefix_front.h :exports none
trexio_exit_code trexio_write_external_int32_array(trexio_t* const file, const int32_t* array, const uint32_t rank, const uint64_t* dimensions, const char* name);
trexio_exit_code trexio_write_safe_external_int32_array(trexio_t* const file, const int32_t* dset_in, const int64_t dim_in, const uint32_t rank, const uint64_t* dims_in, const int32_t dims_dim_in, const char* name);
trexio_exit_code trexio_read_external_int32_array(trexio_t* const file, int32_t* const array, const char* name);
trexio_exit_code trexio_read_safe_external_int32_array(trexio_t* const file, int32_t* const dset_out, const int64_t dim_out, const char* name);
trexio_exit_code trexio_write_external_float32_array(trexio_t* const file, const float* array, const uint32_t rank, const uint64_t* dimensions, const char* name);
trexio_exit_code trexio_write_safe_external_float32_array(trexio_t* const file, const float* dset_in, const int64_t dim_in, const uint32_t rank, const uint64_t* dims_in, const int32_t dims_dim_in, const char* name);
trexio_exit_code trexio_read_external_float32_array(trexio_t* const file, float* const array, const char* name);
trexio_exit_code trexio_read_safe_external_float32_array(trexio_t* const file, float* const dset_out, const int64_t dim_out, const char* name);
trexio_exit_code trexio_write_external_int64_array(trexio_t* const file, const int64_t* array, const uint32_t rank, const uint64_t* dimensions, const char* name);
trexio_exit_code trexio_write_safe_external_int64_array(trexio_t* const file, const int64_t* dset_in, const int64_t dim_in, const uint32_t rank, const uint64_t* dims_in, const int32_t dims_dim_in, const char* name);
trexio_exit_code trexio_read_external_int64_array(trexio_t* const file, int64_t* const array, const char* name);
trexio_exit_code trexio_read_safe_external_int64_array(trexio_t* const file, int64_t* const dset_out, const int64_t dim_out, const char* name);
trexio_exit_code trexio_write_external_float64_array(trexio_t* const file, const double* array, const uint32_t rank, const uint64_t* dimensions, const char* name);
trexio_exit_code trexio_write_safe_external_float64_array(trexio_t* const file, const double* dset_in, const int64_t dim_in, const uint32_t rank, const uint64_t* dims_in, const int32_t dims_dim_in, const char* name);
trexio_exit_code trexio_read_external_float64_array(trexio_t* const file, double* const array, const char* name);
trexio_exit_code trexio_read_safe_external_float64_array(trexio_t* const file, double* const dset_out, const int64_t dim_out, const char* name);
trexio_exit_code trexio_has_external_array(trexio_t* const file, const char* name);
trexio_exit_code trexio_has_external(trexio_t* const file);
#+end_src
@ -2105,6 +2143,23 @@ trexio_write_external_int32_array(trexio_t* const file, const int32_t* array, co
return TREXIO_FAILURE;
}
trexio_exit_code
trexio_write_safe_external_int32_array(trexio_t* const file, const int32_t* dset_in, const int64_t dim_in, const uint32_t rank, const uint64_t* dims_in, const int32_t dims_dim_in, const char* name)
{
if (file == NULL) return TREXIO_INVALID_ARG_1;
if (dset_in == NULL) return TREXIO_INVALID_ARG_2;
if (dim_in <= 0) return TREXIO_INVALID_ARG_3;
if (rank == 0) return TREXIO_INVALID_ARG_4;
if (dims_in == NULL) return TREXIO_INVALID_ARG_5;
if (dims_dim_in == 0) return TREXIO_INVALID_ARG_6;
if (name == NULL) return TREXIO_INVALID_ARG_7;
for (uint32_t i=0; i<rank; i++){
if (dims_in[i] == 0) return TREXIO_INVALID_ARG_5;
}
return trexio_write_external_int32_array(file, dset_in, rank, dims_in, name);
}
trexio_exit_code
trexio_read_external_int32_array(trexio_t* const file, int32_t* const array, const char* name)
@ -2132,6 +2187,17 @@ trexio_read_external_int32_array(trexio_t* const file, int32_t* const array, con
return TREXIO_FAILURE;
}
trexio_exit_code
trexio_read_safe_external_int32_array(trexio_t* const file, int32_t* const dset_out, const int64_t dim_out, const char* name)
{
if (file == NULL) return TREXIO_INVALID_ARG_1;
if (dset_out == NULL) return TREXIO_INVALID_ARG_2;
if (dim_out <= 0) return TREXIO_INVALID_ARG_3;
if (name == NULL) return TREXIO_INVALID_ARG_4;
return trexio_read_external_int32_array(file, dset_out, name);
}
trexio_exit_code
trexio_write_external_float32_array(trexio_t* const file, const float* array, const uint32_t rank, const uint64_t* dimensions, const char* name)
@ -2164,6 +2230,23 @@ trexio_write_external_float32_array(trexio_t* const file, const float* array, co
return TREXIO_FAILURE;
}
trexio_exit_code
trexio_write_safe_external_float32_array(trexio_t* const file, const float* dset_in, const int64_t dim_in, const uint32_t rank, const uint64_t* dims_in, const int32_t dims_dim_in, const char* name)
{
if (file == NULL) return TREXIO_INVALID_ARG_1;
if (dset_in == NULL) return TREXIO_INVALID_ARG_2;
if (dim_in <= 0) return TREXIO_INVALID_ARG_3;
if (rank == 0) return TREXIO_INVALID_ARG_4;
if (dims_in == NULL) return TREXIO_INVALID_ARG_5;
if (dims_dim_in == 0) return TREXIO_INVALID_ARG_6;
if (name == NULL) return TREXIO_INVALID_ARG_7;
for (uint32_t i=0; i<rank; i++){
if (dims_in[i] == 0) return TREXIO_INVALID_ARG_5;
}
return trexio_write_external_float32_array(file, dset_in, rank, dims_in, name);
}
trexio_exit_code
trexio_read_external_float32_array(trexio_t* const file, float* const array, const char* name)
@ -2191,6 +2274,17 @@ trexio_read_external_float32_array(trexio_t* const file, float* const array, con
return TREXIO_FAILURE;
}
trexio_exit_code
trexio_read_safe_external_float32_array(trexio_t* const file, float* const dset_out, const int64_t dim_out, const char* name)
{
if (file == NULL) return TREXIO_INVALID_ARG_1;
if (dset_out == NULL) return TREXIO_INVALID_ARG_2;
if (dim_out <= 0) return TREXIO_INVALID_ARG_3;
if (name == NULL) return TREXIO_INVALID_ARG_4;
return trexio_read_external_float32_array(file, dset_out, name);
}
trexio_exit_code
trexio_write_external_int64_array(trexio_t* const file, const int64_t* array, const uint32_t rank, const uint64_t* dimensions, const char* name)
@ -2223,6 +2317,23 @@ trexio_write_external_int64_array(trexio_t* const file, const int64_t* array, co
return TREXIO_FAILURE;
}
trexio_exit_code
trexio_write_safe_external_int64_array(trexio_t* const file, const int64_t* dset_in, const int64_t dim_in, const uint32_t rank, const uint64_t* dims_in, const int32_t dims_dim_in, const char* name)
{
if (file == NULL) return TREXIO_INVALID_ARG_1;
if (dset_in == NULL) return TREXIO_INVALID_ARG_2;
if (dim_in <= 0) return TREXIO_INVALID_ARG_3;
if (rank == 0) return TREXIO_INVALID_ARG_4;
if (dims_in == NULL) return TREXIO_INVALID_ARG_5;
if (dims_dim_in == 0) return TREXIO_INVALID_ARG_6;
if (name == NULL) return TREXIO_INVALID_ARG_7;
for (uint32_t i=0; i<rank; i++){
if (dims_in[i] == 0) return TREXIO_INVALID_ARG_5;
}
return trexio_write_external_int64_array(file, dset_in, rank, dims_in, name);
}
trexio_exit_code
trexio_read_external_int64_array(trexio_t* const file, int64_t* const array, const char* name)
@ -2250,6 +2361,17 @@ trexio_read_external_int64_array(trexio_t* const file, int64_t* const array, con
return TREXIO_FAILURE;
}
trexio_exit_code
trexio_read_safe_external_int64_array(trexio_t* const file, int64_t* const dset_out, const int64_t dim_out, const char* name)
{
if (file == NULL) return TREXIO_INVALID_ARG_1;
if (dset_out == NULL) return TREXIO_INVALID_ARG_2;
if (dim_out <= 0) return TREXIO_INVALID_ARG_3;
if (name == NULL) return TREXIO_INVALID_ARG_4;
return trexio_read_external_int64_array(file, dset_out, name);
}
trexio_exit_code
trexio_write_external_float64_array(trexio_t* const file, const double* array, const uint32_t rank, const uint64_t* dimensions, const char* name)
@ -2282,6 +2404,23 @@ trexio_write_external_float64_array(trexio_t* const file, const double* array, c
return TREXIO_FAILURE;
}
trexio_exit_code
trexio_write_safe_external_float64_array(trexio_t* const file, const double* dset_in, const int64_t dim_in, const uint32_t rank, const uint64_t* dims_in, const int32_t dims_dim_in, const char* name)
{
if (file == NULL) return TREXIO_INVALID_ARG_1;
if (dset_in == NULL) return TREXIO_INVALID_ARG_2;
if (dim_in <= 0) return TREXIO_INVALID_ARG_3;
if (rank == 0) return TREXIO_INVALID_ARG_4;
if (dims_in == NULL) return TREXIO_INVALID_ARG_5;
if (dims_dim_in == 0) return TREXIO_INVALID_ARG_6;
if (name == NULL) return TREXIO_INVALID_ARG_7;
for (uint32_t i=0; i<rank; i++){
if (dims_in[i] == 0) return TREXIO_INVALID_ARG_5;
}
return trexio_write_external_float64_array(file, dset_in, rank, dims_in, name);
}
trexio_exit_code
trexio_read_external_float64_array(trexio_t* const file, double* const array, const char* name)
@ -2309,6 +2448,17 @@ trexio_read_external_float64_array(trexio_t* const file, double* const array, co
return TREXIO_FAILURE;
}
trexio_exit_code
trexio_read_safe_external_float64_array(trexio_t* const file, double* const dset_out, const int64_t dim_out, const char* name)
{
if (file == NULL) return TREXIO_INVALID_ARG_1;
if (dset_out == NULL) return TREXIO_INVALID_ARG_2;
if (dim_out <= 0) return TREXIO_INVALID_ARG_3;
if (name == NULL) return TREXIO_INVALID_ARG_4;
return trexio_read_external_float64_array(file, dset_out, name);
}
#+end_src
:end:
@ -2357,6 +2507,156 @@ trexio_has_external_array(trexio_t* const file, const char* name)
}
#+end_src
*** Python
#+begin_src python :tangle basic_python.py
def write_external_array(trexio_file, dset_w, name) -> None:
"""Write an arbitrary array of numbers in the TREXIO file.
Parameters:
trexio_file:
TREXIO File object.
dset_w: list, tuple OR numpy.ndarray
Array of values to be written.
name: string
Name of the array as it will be stored in the external group of TREXIO file
Raises:
- trexio.Error if TREXIO return code ~rc~ is different from TREXIO_SUCCESS and prints the error message.
- Exception from some other error (e.g. RuntimeError).
"""
# get dimensions and rank from input array
if not isinstance(dset_w, (list, tuple)):
# if input array is not a list or tuple then it is probably a numpy array
rank = len(dset_w.shape)
dimensions = np.array(dset_w.shape, dtype=np.uint64)
else:
get_shape = lambda l: [len(l)] + get_shape(l[0]) if (type(l) == list or type(l) == tuple) else []
get_type = lambda l: [type(l)] + get_type(l[0]) if (type(l) == list or type(l) == tuple) else [type(l)]
dset_shape = get_shape(dset_w)
dset_dtype = get_type(dset_w)[-1]
rank = len(dset_shape)
dimensions = np.array(dset_shape, dtype=np.uint64)
# decide whether to flatten or not
doFlatten = False
if rank > 1:
doFlatten = True
# handle list/typle
if isinstance(dset_w, (list, tuple)):
if dset_dtype is int:
if doFlatten:
rc = pytr.trexio_write_safe_external_int64_array(trexio_file.pytrexio_s, np.array(dset_w, dtype=np.int64).flatten(), rank, dimensions, name)
else:
rc = pytr.trexio_write_safe_external_int64_array(trexio_file.pytrexio_s, dset_w, rank, dimensions, name)
elif dset_dtype is float:
if doFlatten:
rc = pytr.trexio_write_safe_external_float64_array(trexio_file.pytrexio_s, np.array(dset_w, dtype=np.float64).flatten(), rank, dimensions, name)
else:
rc = pytr.trexio_write_safe_external_float64_array(trexio_file.pytrexio_s, dset_w, rank, dimensions, name)
else:
raise TypeError("Unsupported type of a list/tuple for generic I/O of arrays.")
# handle numpy array
elif isinstance(dset_w, np.ndarray):
if dset_w.dtype==np.int64:
if doFlatten:
rc = pytr.trexio_write_safe_external_int64_array(trexio_file.pytrexio_s, dset_w.flatten(), rank, dimensions, name)
else:
rc = pytr.trexio_write_safe_external_int64_array(trexio_file.pytrexio_s, dset_w, rank, dimensions, name)
elif dset_w.dtype==np.int32:
if doFlatten:
rc = pytr.trexio_write_safe_external_int32_array(trexio_file.pytrexio_s, dset_w.flatten(), rank, dimensions, name)
else:
rc = pytr.trexio_write_safe_external_int32_array(trexio_file.pytrexio_s, dset_w, rank, dimensions, name)
elif dset_w.dtype==np.float64:
if doFlatten:
rc = pytr.trexio_write_safe_external_float64_array(trexio_file.pytrexio_s, dset_w.flatten(), rank, dimensions, name)
else:
rc = pytr.trexio_write_safe_external_float64_array(trexio_file.pytrexio_s, dset_w, rank, dimensions, name)
elif dset_w.dtype==np.float32:
if doFlatten:
rc = pytr.trexio_write_safe_external_float32_array(trexio_file.pytrexio_s, dset_w.flatten(), rank, dimensions, name)
else:
rc = pytr.trexio_write_safe_external_float32_array(trexio_file.pytrexio_s, dset_w, rank, dimensions, name)
else:
raise TypeError("Unsupported type of a NumPy array for generic I/O of arrays.")
else:
raise TypeError("Unsupported array type for generic I/O.")
if rc != TREXIO_SUCCESS:
raise Error(rc)
#+end_src
#+begin_src python :tangle basic_python.py
def read_external_array(trexio_file, name, size, dtype):
"""Read an external array of numbers from the TREXIO file.
Parameters:
trexio_file:
TREXIO File object.
name:
string name of an array
size:
integer value corresponding to the total number of elements to read
dtype:
string indicating the datatype of the array (int/int32/int64/float/float32/float64/double)
Returns:
~dset_r~: 1D NumPy array with ~dim~ elements corresponding to of "name" array read from the TREXIO file.
Raises:
- trexio.Error if TREXIO return code ~rc~ is different from TREXIO_SUCCESS and prints the error message.
- Exception from some other error (e.g. RuntimeError).
"""
if dtype in ['int', 'int64']:
rc, dset_r = pytr.trexio_read_safe_external_int64_array(trexio_file.pytrexio_s, size, name)
elif dtype in ['int32']:
rc, dset_r = pytr.trexio_read_safe_external_int32_array(trexio_file.pytrexio_s, size, name)
elif dtype in ['float', 'float64', 'double']:
rc, dset_r = pytr.trexio_read_safe_external_float64_array(trexio_file.pytrexio_s, size, name)
elif dtype in ['float32']:
rc, dset_r = pytr.trexio_read_safe_external_float32_array(trexio_file.pytrexio_s, size, name)
else:
raise ValueError("Unsupported dtype passed to read_external_array.")
if rc != TREXIO_SUCCESS:
raise Error(rc)
return dset_r
#+end_src
#+begin_src python :tangle basic_python.py
def has_external_array(trexio_file, name) -> bool:
"""Check that external array exists in the TREXIO file.
trexio_file:
Parameter is a ~TREXIO File~ object that has been created by a call to ~open~ function.
name:
String name of the array from the TREXIO file
Returns:
True if the variable exists, False otherwise
Raises:
- trexio.Error if TREXIO return code ~rc~ is TREXIO_FAILURE and prints the error message using string_of_error.
- Exception from some other error (e.g. RuntimeError).
"""
rc = pytr.trexio_has_external_array(trexio_file.pytrexio_s, name)
if rc == TREXIO_FAILURE:
raise Error(rc)
return rc == TREXIO_SUCCESS
#+end_src
* Templates for front end
** Description