From 04367c1824be745af2d4fdb579322a2d70504726 Mon Sep 17 00:00:00 2001
From: q-posev <posenitskiy@irsamc.ups-tlse.fr>
Date: Wed, 4 May 2022 11:47:30 +0200
Subject: [PATCH] More Python-ic error handling

---
 python/src/pyqmckl.i    | 17 +++++++++++++++--
 python/test/test_api.py | 29 ++++++++++++++---------------
 2 files changed, 29 insertions(+), 17 deletions(-)

diff --git a/python/src/pyqmckl.i b/python/src/pyqmckl.i
index 1627210..df36681 100644
--- a/python/src/pyqmckl.i
+++ b/python/src/pyqmckl.i
@@ -45,21 +45,34 @@ import_array();
 
 /* exception.i is a generic (language-independent) module */
 %include "exception.i"
-/* Error handling 
+
+/* Error handling */
+%typemap(out) qmckl_exit_code %{
+    if ($1 != QMCKL_SUCCESS) {
+        SWIG_exception(SWIG_RuntimeError, qmckl_string_of_error($1));
+    }
+    $result = Py_None;
+    Py_INCREF(Py_None); /* Py_None is a singleton so increment its reference if used. */
+%}
+
+/* More swig-y solution (e.g. compatible beyond Python) BUT it does not consume the qmckl_exit_code output as the solution above 
 TODO: the sizeof() check below if a dummy workaround
 It is good to skip exception raise for functions like context_create and others, but might fail
 if sizeof(result) == sizeof(qmckl_exit_code), e.g. for functions that return non-zero integers or floats
 */
+/*
 %exception {
   $action
   if (result != QMCKL_SUCCESS && sizeof(result) == sizeof(qmckl_exit_code)) {
     SWIG_exception_fail(SWIG_RuntimeError, qmckl_string_of_error(result));
   }
 }
-
+*/
 /* The exception handling above does not work for void functions like lock/unlock so exclude them for now */
+/*
 %ignore qmckl_lock;
 %ignore qmckl_unlock;
+*/
 
 /* Parse the header files to generate wrappers */
 %include "qmckl.h"
diff --git a/python/test/test_api.py b/python/test/test_api.py
index 69c7a5e..99c57d8 100644
--- a/python/test/test_api.py
+++ b/python/test/test_api.py
@@ -18,33 +18,32 @@ ITERMAX = 10
 
 ctx = pq.qmckl_context_create()
 
+try:
+    pq.qmckl_trexio_read(ctx, 'fake.h5')
+except RuntimeError:
+    print('Error handling check: passed')
+
 fname = join('data', 'Alz_small.h5')
 
-rc = pq.qmckl_trexio_read(ctx, fname)
-assert rc==pq.QMCKL_SUCCESS
-print(pq.qmckl_string_of_error(rc))
+pq.qmckl_trexio_read(ctx, fname)
+print('trexio_read: passed')
 
-rc = pq.qmckl_set_electron_walk_num(ctx, walk_num)
-assert rc==pq.QMCKL_SUCCESS
+pq.qmckl_set_electron_walk_num(ctx, walk_num)
 
-rc, mo_num = pq.qmckl_get_mo_basis_mo_num(ctx)
-assert rc==pq.QMCKL_SUCCESS
+mo_num = pq.qmckl_get_mo_basis_mo_num(ctx)
+assert mo_num == 404
 
-rc = pq.qmckl_set_electron_coord(ctx, 'T', coord)
-assert rc==pq.QMCKL_SUCCESS
+pq.qmckl_set_electron_coord(ctx, 'T', coord)
 
 size_max = 5*walk_num*elec_num*mo_num
 
-
-
-rc, mo_vgl = pq.qmckl_get_mo_basis_mo_vgl(ctx, size_max)
-assert rc==pq.QMCKL_SUCCESS
+mo_vgl = pq.qmckl_get_mo_basis_mo_vgl(ctx, size_max)
+assert mo_vgl.size == size_max
 
 start = time.clock_gettime_ns(time.CLOCK_REALTIME)
 
 for _ in range(ITERMAX):
-    rc, mo_vgl_in = pq.qmckl_get_mo_basis_mo_vgl_inplace(ctx, size_max)
-    assert rc==pq.QMCKL_SUCCESS
+    mo_vgl_in = pq.qmckl_get_mo_basis_mo_vgl_inplace(ctx, size_max)
 
 end = time.clock_gettime_ns(time.CLOCK_REALTIME)