From abc7e8e11de14f9e9947f2f4b85c453c1e845460 Mon Sep 17 00:00:00 2001 From: q-posev Date: Wed, 4 May 2022 10:54:51 +0200 Subject: [PATCH] Add exit codes to the Python API --- python/src/process_header.py | 12 +++++++++++- python/src/pyqmckl.i | 2 +- python/test/test_api.py | 12 ++++++------ 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/python/src/process_header.py b/python/src/process_header.py index 9544916..07a2966 100644 --- a/python/src/process_header.py +++ b/python/src/process_header.py @@ -11,11 +11,16 @@ func_name = '' arrays = {} numbers = {} qmckl_public_api = [] - +qmckl_errors = [] with open("qmckl.h", 'r') as f_in: for line in f_in: + # get the errors but without the type cast because SWIG does not recognize it + if '#define' in line and 'qmckl_exit_code' in line: + qmckl_errors.append(line.strip().replace('(qmckl_exit_code)','')) + continue + if get_name: words = line.strip().split() if '(' in words[0]: @@ -142,6 +147,11 @@ processed = list(arrays.keys()) + list(numbers.keys()) with open("pyqmckl_include.i", 'w') as f_out: + # write the list of errors as constants without the type cast + for e in qmckl_errors: + line = e.replace('#define', '%constant qmckl_exit_code').replace('(','=').replace(')',';') + f_out.write(line + '\n') + swig_type = '' for v in numbers.values(): diff --git a/python/src/pyqmckl.i b/python/src/pyqmckl.i index 04c45a6..1627210 100644 --- a/python/src/pyqmckl.i +++ b/python/src/pyqmckl.i @@ -52,7 +52,7 @@ if sizeof(result) == sizeof(qmckl_exit_code), e.g. for functions that return non */ %exception { $action - if (result != 0 && sizeof(result) == sizeof(qmckl_exit_code)) { + if (result != QMCKL_SUCCESS && sizeof(result) == sizeof(qmckl_exit_code)) { SWIG_exception_fail(SWIG_RuntimeError, qmckl_string_of_error(result)); } } diff --git a/python/test/test_api.py b/python/test/test_api.py index 4efb9cc..69c7a5e 100644 --- a/python/test/test_api.py +++ b/python/test/test_api.py @@ -21,30 +21,30 @@ ctx = pq.qmckl_context_create() fname = join('data', 'Alz_small.h5') rc = pq.qmckl_trexio_read(ctx, fname) -assert rc==0 +assert rc==pq.QMCKL_SUCCESS print(pq.qmckl_string_of_error(rc)) rc = pq.qmckl_set_electron_walk_num(ctx, walk_num) -assert rc==0 +assert rc==pq.QMCKL_SUCCESS rc, mo_num = pq.qmckl_get_mo_basis_mo_num(ctx) -assert rc==0 +assert rc==pq.QMCKL_SUCCESS rc = pq.qmckl_set_electron_coord(ctx, 'T', coord) -assert rc==0 +assert rc==pq.QMCKL_SUCCESS size_max = 5*walk_num*elec_num*mo_num rc, mo_vgl = pq.qmckl_get_mo_basis_mo_vgl(ctx, size_max) -assert rc==0 +assert rc==pq.QMCKL_SUCCESS start = time.clock_gettime_ns(time.CLOCK_REALTIME) for _ in range(ITERMAX): rc, mo_vgl_in = pq.qmckl_get_mo_basis_mo_vgl_inplace(ctx, size_max) - assert rc==0 + assert rc==pq.QMCKL_SUCCESS end = time.clock_gettime_ns(time.CLOCK_REALTIME)