1
0
mirror of https://github.com/TREX-CoE/trexio.git synced 2024-08-25 06:31:43 +02:00

return double datasets as numpy arrays using numpy.i and safe API

This commit is contained in:
q-posev 2021-07-28 15:59:29 +02:00
parent ea0ea0ac38
commit e0162d4570
2 changed files with 43 additions and 1 deletions

View File

@ -46,6 +46,22 @@
#define TREXIO_HDF5 0
#define TREXIO_TEXT 0
*/
/* This is an attempt to make SWIG treat double * dset_out, const uint64_t dim_out pattern
as a special case in order to return the NumPy array to Python from C pointer to array
provided by trexio_read_safe_[dset_num] function.
NOTE: numpy.i is currently not part of SWIG but included in the numpy distribution (under numpy/tools/swig/numpy.i)
This means that the interface file have to be provided to SWIG upon compilation either by
copying it to the local working directory or by providing -l/path/to/numpy.i flag upon SWIG compilation
*/
%include "numpy.i"
%init %{
import_array();
%}
%apply (double* ARGOUT_ARRAY1, int DIM1) {(double * const dset_out, const uint64_t dim_out)};
/* This tells SWIG to treat char ** dset_in pattern as a special case
Enables access to trexio_[...]_write_dset_str set of functions directly, i.e.
by converting input list of strings from Python into char ** of C
@ -79,7 +95,7 @@
/* [WIP] This is an attempt to make SWIG treat char ** dset_out as a special case
In order to return list of string to Python from C-native char ** dset_out,
which is modified (but not allocated) within the trexio_[...}read_dset_str function
which is modified (but not allocated) within the trexio_[...]_read_dset_str function
*/
%typemap(in, numinputs=0) char ** dset_out (char * temp) {
/*temp = (char *) malloc(1028*sizeof(char));*/

View File

@ -1,5 +1,6 @@
import os
import shutil
#import numpy as np
from pytrexio import *
@ -95,9 +96,26 @@ assert rc==TREXIO_SUCCESS
for i in range(nucleus_num):
assert charges2[i]==charges[i]
#charge_numpy = np.zeros(nucleus_num, dtype=np.float64)
#print(charge_numpy)
rc, charge_numpy = trexio_read_safe_nucleus_charge(test_file2, 12)
print(charge_numpy)
print(charge_numpy[11])
assert rc==TREXIO_SUCCESS
# unsafe call to read_safe should not only have return code = TREXIO_UNSAFE_ARRAY_DIM
# but also should not return numpy array filled with garbage
rc, charge_numpy = trexio_read_safe_nucleus_charge(test_file2, 12*5)
#print(charge_numpy)
assert rc==TREXIO_UNSAFE_ARRAY_DIM
# [WIP]: ideally, the list of strings should be returned as below
#rc, label_2d = trexio_read_nucleus_label(test_file2, 10)
# [WIP]: currently only low-level routines (return one long string instead of an array of strings) work
rc, labels_1d = trexio_read_nucleus_label_low(test_file2, 10)
assert rc==TREXIO_SUCCESS
@ -109,5 +127,13 @@ for i in range(nucleus_num):
rc = trexio_close(test_file2)
assert rc==TREXIO_SUCCESS
try:
if TEST_TREXIO_BACKEND == TREXIO_HDF5:
os.remove(output_filename)
elif TEST_TREXIO_BACKEND == TREXIO_TEXT:
shutil.rmtree(output_filename)
except:
print (f'No output file {output_filename} has been produced')
#==========================================================#