3
0
mirror of https://github.com/triqs/dft_tools synced 2024-12-25 13:53:40 +01:00

Fix bug in building regular type from python

- the flag is really enforce_copy and should force a copy
- for a view : it is false, no change
- for a regular type : it is true, and now this will enforce the
  copy in the call of numpy. numpy does it for us.
- The problematic case was when we construct a regular type
  from a complicated view in python, which is an array
  but is not C contiguous.
  (hence the PyArray_Check was false, and the C_Contiguous flag was not set)
  Now it is fine, since we ask numpy to systematically copy the data for us
  and build a C contigous array.
  --> the constructor from python does not support custom memory layout
  because numpy only support C and Fortran

Conflicts:

	triqs/arrays/impl/indexmap_storage_pair.hpp
This commit is contained in:
Olivier Parcollet 2014-09-18 14:55:43 +02:00
parent 1a85b9eb81
commit c7a1a25846
2 changed files with 10 additions and 10 deletions

View File

@ -97,10 +97,9 @@ namespace triqs { namespace arrays {
protected: protected:
#ifdef TRIQS_WITH_PYTHON_SUPPORT #ifdef TRIQS_WITH_PYTHON_SUPPORT
indexmap_storage_pair (PyObject * X, bool allow_copy, const char * name ) { indexmap_storage_pair (PyObject * X, bool enforce_copy, const char * name ) {
//std::cout << " Enter IPS ref count = "<< X->ob_refcnt << std::endl;
try { try {
numpy_interface::numpy_extractor<indexmap_type,value_type> E(X, allow_copy); numpy_interface::numpy_extractor<indexmap_type,value_type> E(X, enforce_copy);
this->indexmap_ = E.indexmap(); this->storage_ = E.storage(); this->indexmap_ = E.indexmap(); this->storage_ = E.storage();
} }
catch(numpy_interface::copy_exception s){// intercept only this one... catch(numpy_interface::copy_exception s){// intercept only this one...

View File

@ -27,9 +27,8 @@
namespace triqs { namespace arrays { namespace numpy_interface { namespace triqs { namespace arrays { namespace numpy_interface {
PyObject *numpy_extractor_impl(PyObject *X, bool enforce_copy, std::string type_name, int elementsType, int rank,
PyObject * numpy_extractor_impl ( PyObject * X, bool allow_copy, std::string type_name, size_t *lengths, std::ptrdiff_t *strides, size_t size_of_ValueType) {
int elementsType, int rank, size_t * lengths, std::ptrdiff_t * strides, size_t size_of_ValueType) {
PyObject * numpy_obj; PyObject * numpy_obj;
@ -38,7 +37,7 @@ namespace triqs { namespace arrays { namespace numpy_interface {
static const char * error_msg = " A deep copy of the object would be necessary while views are supposed to guarantee to present a *view* of the python data.\n"; static const char * error_msg = " A deep copy of the object would be necessary while views are supposed to guarantee to present a *view* of the python data.\n";
if (!allow_copy) { if (!enforce_copy) {
if (!PyArray_Check(X)) throw copy_exception () << error_msg<<" Indeed the object was not even an array !\n"; if (!PyArray_Check(X)) throw copy_exception () << error_msg<<" Indeed the object was not even an array !\n";
if ( elementsType != PyArray_TYPE((PyArrayObject*)X)) if ( elementsType != PyArray_TYPE((PyArrayObject*)X))
throw copy_exception () << error_msg<<" The deep copy is caused by a type mismatch of the elements. Expected "<< type_name<< " and found XXX \n"; throw copy_exception () << error_msg<<" The deep copy is caused by a type mismatch of the elements. Expected "<< type_name<< " and found XXX \n";
@ -64,12 +63,14 @@ namespace triqs { namespace arrays { namespace numpy_interface {
//bool ForceCast = false;// Unless FORCECAST is present in flags, this call will generate an error if the data type cannot be safely obtained from the object. //bool ForceCast = false;// Unless FORCECAST is present in flags, this call will generate an error if the data type cannot be safely obtained from the object.
int flags = 0; //(ForceCast ? NPY_FORCECAST : 0) ;// do NOT force a copy | (make_copy ? NPY_ENSURECOPY : 0); int flags = 0; //(ForceCast ? NPY_FORCECAST : 0) ;// do NOT force a copy | (make_copy ? NPY_ENSURECOPY : 0);
if (!(PyArray_Check(X) )) //if (!(PyArray_Check(X) ))
//flags |= ( IndexMapType::traversal_order == indexmaps::mem_layout::c_order(rank) ? NPY_C_CONTIGUOUS : NPY_F_CONTIGUOUS); //impose mem order //flags |= ( IndexMapType::traversal_order == indexmaps::mem_layout::c_order(rank) ? NPY_C_CONTIGUOUS : NPY_F_CONTIGUOUS); //impose mem order
#ifdef TRIQS_NUMPY_VERSION_LT_17 #ifdef TRIQS_NUMPY_VERSION_LT_17
flags |= (NPY_C_CONTIGUOUS); //impose mem order flags |= (NPY_C_CONTIGUOUS); //impose mem order
flags |= (NPY_ENSURECOPY);
#else #else
flags |= (NPY_ARRAY_C_CONTIGUOUS); //impose mem order flags |= (NPY_ARRAY_C_CONTIGUOUS); // impose mem order
flags |= (NPY_ARRAY_ENSURECOPY);
#endif #endif
numpy_obj= PyArray_FromAny(X,PyArray_DescrFromType(elementsType), rank,rank, flags , NULL ); numpy_obj= PyArray_FromAny(X,PyArray_DescrFromType(elementsType), rank,rank, flags , NULL );