From 7cf7d09c77ca53685818cbf74d98016923d83a38 Mon Sep 17 00:00:00 2001 From: Olivier Parcollet Date: Sat, 6 Sep 2014 15:53:01 +0200 Subject: [PATCH] Fix #112 and put back g +=/-= matrix for imfreq - The issue comes from the fact that the default generated += and co by the Python API is the one for immutable types, like int. - Indeed, in python, for an int : x=1 id(x) 140266967205832 x+=1 id(x) 140266967205808 - For a mutable type, like a gf, it is necessary to add explicitly the xxx_inplace_add functions. - Added : - the generation of the inplace_xxx functions - a method in class_ in the wrapper generator that deduce all += operator from the + operators. - this assumes that the +=, ... are defined in C++. - The generation of such operators are optional, with option with_inplace_operators in the arithmetic flag. - Also, added the overload g += M and g -= M for g : GfImfreq, M a complex matrix. Mainly for legacy Python codes. --- pytriqs/gf/local/gf_desc.py | 7 ++- pytriqs/wrap_generator/wrap_generator.py | 23 ++++++++++ pytriqs/wrap_generator/wrapper.mako.cpp | 10 ++++- test/pytriqs/base/CMakeLists.txt | 4 ++ test/pytriqs/base/gf_inplace_112.output | 12 ++++++ test/pytriqs/base/gf_inplace_112.py | 55 ++++++++++++++++++++++++ triqs/gfs/gf.hpp | 4 +- triqs/gfs/gf_expr.hpp | 20 +++++++++ triqs/gfs/imfreq.hpp | 12 ++++++ triqs/gfs/local/tail.hpp | 14 ++++++ triqs/parameters/parameters.hpp | 5 +++ 11 files changed, 161 insertions(+), 5 deletions(-) create mode 100644 test/pytriqs/base/gf_inplace_112.output create mode 100644 test/pytriqs/base/gf_inplace_112.py diff --git a/pytriqs/gf/local/gf_desc.py b/pytriqs/gf/local/gf_desc.py index f1c458e0..0169aca0 100644 --- a/pytriqs/gf/local/gf_desc.py +++ b/pytriqs/gf/local/gf_desc.py @@ -218,7 +218,7 @@ def make_gf( py_type, c_tag, is_complex_data = True, is_im = False, has_tail = T serializable= "tuple", is_printable= True, hdf5 = True, - arithmetic = ("algebra",data_type) + arithmetic = ("algebra",data_type, "with_inplace_operators") ) g.add_constructor(signature = "(gf_mesh<%s> mesh, mini_vector shape, std::vector> indices = std::vector>{}, std::string name = "")"%c_tag, python_precall = "pytriqs.gf.local._gf_%s.init"%c_tag) @@ -372,6 +372,11 @@ g.add_method(name = "set_from_legendre", g.add_pure_python_method("pytriqs.gf.local._gf_imfreq.replace_by_tail") g.add_pure_python_method("pytriqs.gf.local._gf_imfreq.fit_tail") +# For legacy Python code : authorize g + Matrix +#g.number_protocol['add'].add_overload(calling_pattern = "+", signature = "gf(gf x,matrix> y)") +g.number_protocol['inplace_add'].add_overload(calling_pattern = "+=", signature = "void(gf_view x,matrix> y)") +g.number_protocol['inplace_subtract'].add_overload(calling_pattern = "-=", signature = "void(gf_view x,matrix> y)") + module.add_class(g) ######################## diff --git a/pytriqs/wrap_generator/wrap_generator.py b/pytriqs/wrap_generator/wrap_generator.py index 3e040f9a..d6b73154 100644 --- a/pytriqs/wrap_generator/wrap_generator.py +++ b/pytriqs/wrap_generator/wrap_generator.py @@ -269,6 +269,11 @@ class class_ : - with_unit : +/- of an element with a scalar (injection of the scalar with the unit) - with_unary_minus : implement unary minus - "add_only" : implements only + + - with_inplace_operators : option to deduce the +=, -=, ... + operators from +,-, .. It deduces the possibles terms to put at the rhs, looking at the + case of the +,- operators where the lhs is of the type of self. + NB : The operator is mapped to the corresponding C++ operators (for some objects, this may be faster) + so it has to be defined in C++ as well.... - .... more to be defined. - serializable : Whether and how the object is to be serialized. Possible values are : - "tuple" : reduce it to a tuple of smaller objects, using the @@ -309,6 +314,7 @@ class class_ : # read the with_... option and clean them for the list with_unary_minus = 'with_unary_minus' in arithmetic with_unit = 'with_unit' in arithmetic + with_inplace_operators = 'with_inplace_operators' in arithmetic arithmetic = [x for x in arithmetic if not x.startswith("with_")] add = arithmetic[0] in ("algebra", "abelian_group", "vector_space", "only_add") abelian_group = arithmetic[0] in ("algebra", "abelian_group", "vector_space") @@ -359,6 +365,23 @@ class class_ : neg.add_overload (calling_pattern = "-", signature = {'args' :[(self.c_type,'x')], 'rtype' : self.c_type}) self.number_protocol['negative'] = neg + if with_inplace_operators : self.deduce_inplace_arithmetic() + + def deduce_inplace_arithmetic(self) : + """Deduce all the +=, -=, *=, /= operators from the +, -, *, / operators""" + def one_op(op, name, iname) : + if name not in self.number_protocol : return + impl = pyfunction(name = iname, arity = 2) + for overload in self.number_protocol[name].overloads : + x_t,y_t = overload.args[0][0], overload.args[1][0] + if x_t == self.c_type : # only when first the object + impl.add_overload (calling_pattern = op+"=", signature = {'args' : [(x_t,'x'), (y_t,'y')], 'rtype' :overload.rtype}) + self.number_protocol['inplace_'+name] = impl + one_op('+',"add","__iadd__") + one_op('-',"subtract","__isub__") + one_op('*',"multiply","__imul__") + one_op('/',"divide","__idiv__") + def add_constructor(self, signature, calling_pattern = None, python_precall = None, python_postcall = None, build_from_regular_type_if_view = True, doc = ''): """ - signature : signature of the function, with types, parameter names and defaut value diff --git a/pytriqs/wrap_generator/wrapper.mako.cpp b/pytriqs/wrap_generator/wrapper.mako.cpp index 5fac2508..54f9a9f3 100644 --- a/pytriqs/wrap_generator/wrapper.mako.cpp +++ b/pytriqs/wrap_generator/wrapper.mako.cpp @@ -939,8 +939,14 @@ static PyObject * ${c.py_type}_${op_name} (PyObject* v, PyObject *w){ %for overload in op.overloads : if (convertible_from_python<${overload.args[0][0]}>(v,false) && convertible_from_python<${overload.args[1][0]}>(w,false)) { try { - ${regular_type_if_view_else_type(overload.rtype)} r = convert_from_python<${overload.args[0][0]}>(v) ${overload._get_calling_pattern()} convert_from_python<${overload.args[1][0]}>(w); - return convert_to_python(std::move(r)); // in two steps to force type for expression templates in C++ + %if not op_name.startswith("inplace") : + ${regular_type_if_view_else_type(overload.rtype)} r = convert_from_python<${overload.args[0][0]}>(v) ${overload._get_calling_pattern()} convert_from_python<${overload.args[1][0]}>(w); + return convert_to_python(std::move(r)); // in two steps to force type for expression templates in C++ + %else: + convert_from_python<${overload.args[0][0]}>(v) ${overload._get_calling_pattern()} convert_from_python<${overload.args[1][0]}>(w); + Py_INCREF(v); + return v; + %endif } CATCH_AND_RETURN("in calling C++ overload \n ${overload._get_c_signature()} \nin implementation of operator ${overload._get_calling_pattern()} ", NULL) } diff --git a/test/pytriqs/base/CMakeLists.txt b/test/pytriqs/base/CMakeLists.txt index 870f7fd5..3d074d33 100644 --- a/test/pytriqs/base/CMakeLists.txt +++ b/test/pytriqs/base/CMakeLists.txt @@ -17,3 +17,7 @@ add_triqs_test_hdf(dos " -d 1.e-6") # Pade approximation add_triqs_test_hdf(pade " -d 1.e-6") + +# Bug fix #112 +add_triqs_test_txt(gf_inplace_112) + diff --git a/test/pytriqs/base/gf_inplace_112.output b/test/pytriqs/base/gf_inplace_112.output new file mode 100644 index 00000000..a7971b2f --- /dev/null +++ b/test/pytriqs/base/gf_inplace_112.output @@ -0,0 +1,12 @@ +Before: +G['up'] = 3.14159265359j +G['dn'] = 3.14159265359j +After G['up'] += G['dn']: +G['up'] = 6.28318530718j +G['dn'] = 3.14159265359j +After g_up += g_dn: +G['up'] = 9.42477796077j +G['dn'] = 3.14159265359j +After G += G: +G['up'] = 18.8495559215j +G['dn'] = 6.28318530718j diff --git a/test/pytriqs/base/gf_inplace_112.py b/test/pytriqs/base/gf_inplace_112.py new file mode 100644 index 00000000..66c92c69 --- /dev/null +++ b/test/pytriqs/base/gf_inplace_112.py @@ -0,0 +1,55 @@ +# Test from I. Krivenko. +from __future__ import print_function +from pytriqs.gf.local import * +from pytriqs.gf.local.descriptors import * +import sys +def print_err(*x) : print (*x, file= sys.stderr) + +g_up = GfImFreq(indices = [0], beta = 1) +g_dn = GfImFreq(indices = [0], beta = 1) + +g_up <<= iOmega_n +g_dn <<= iOmega_n + +G = BlockGf(name_list=['up','dn'], block_list=[g_up,g_dn], make_copies=False) + +print("Before:") +print("G['up'] =", G['up'].data[0,0,0]) +print("G['dn'] =", G['dn'].data[0,0,0]) +print_err('id(g_up) =', id(g_up)) +print_err('id(g_dn) =', id(g_dn)) +print_err ("(id=", id(G['up']),")") +print_err ("(id=", id(G['dn']),")") + +G['up'] += G['dn'] + +print("After G['up'] += G['dn']:") +print("G['up'] =", G['up'].data[0,0,0]) +print("G['dn'] =", G['dn'].data[0,0,0]) +print_err('id(g_up) =', id(g_up)) +print_err('id(g_dn) =', id(g_dn)) +print_err ("(id=", id(G['up']),")") +print_err ("(id=", id(G['dn']),")") + + +g_up += g_dn + +print("After g_up += g_dn:") +print("G['up'] =", G['up'].data[0,0,0]) +print("G['dn'] =", G['dn'].data[0,0,0]) +print_err('id(g_up) =', id(g_up)) +print_err('id(g_dn) =', id(g_dn)) +print_err ("(id=", id(G['up']),")") +print_err ("(id=", id(G['dn']),")") + + +G += G +print("After G += G:") +print("G['up'] =", G['up'].data[0,0,0]) +print("G['dn'] =", G['dn'].data[0,0,0]) +print_err('id(g_up) =', id(g_up)) +print_err('id(g_dn) =', id(g_dn)) +print_err ("(id=", id(G['up']),")") +print_err ("(id=", id(G['dn']),")") + + diff --git a/triqs/gfs/gf.hpp b/triqs/gfs/gf.hpp index 06f07625..e366e3ec 100644 --- a/triqs/gfs/gf.hpp +++ b/triqs/gfs/gf.hpp @@ -668,8 +668,8 @@ namespace gfs { // auxiliary function : invert the data : one function for all matrix valued gf (save code). template void _gf_invert_data_in_place(A3 && a) { for (int i = 0; i < first_dim(a); ++i) {// Rely on the ordering - auto v = a(i, arrays::range(), arrays::range()); - v = inverse(v); + auto v = make_matrix_view(a(i, arrays::range(), arrays::range())); + v = triqs::arrays::inverse(v); } } diff --git a/triqs/gfs/gf_expr.hpp b/triqs/gfs/gf_expr.hpp index b6b69e30..5f659cfc 100644 --- a/triqs/gfs/gf_expr.hpp +++ b/triqs/gfs/gf_expr.hpp @@ -142,6 +142,26 @@ namespace triqs { namespace gfs { >::type operator - (A1 && a1) { return {std::forward(a1)};} +// Now the inplace operator. Because of expression template, there are useless for speed +// we implement them trivially. + +#define DEFINE_OPERATOR(OP1, OP2) \ + template \ + void operator OP1(gf_view g, T const &x) { \ + g = g OP2 x; \ + } \ + template \ + void operator OP1(gf &g, T const &x) { \ + g = g OP2 x; \ + } + + DEFINE_OPERATOR(+=, +); + DEFINE_OPERATOR(-=, -); + DEFINE_OPERATOR(*=, *); + DEFINE_OPERATOR(/=, / ); + +#undef DEFINE_OPERATOR + }}//namespace triqs::gf #endif diff --git a/triqs/gfs/imfreq.hpp b/triqs/gfs/imfreq.hpp index 6923c8ef..ac175997 100644 --- a/triqs/gfs/imfreq.hpp +++ b/triqs/gfs/imfreq.hpp @@ -166,5 +166,17 @@ namespace gfs { template struct data_proxy : data_proxy_array, 1> {}; } // gfs_implementation + + // specific operations (for legacy python code). + // +=, -= with a matrix + inline void operator+=(gf_view g, arrays::matrix> m) { + for (int u = 0; u < first_dim(g.data()); ++u) g.data()(u, arrays::ellipsis()) += m; + g.singularity()(0) += m; + } + + inline void operator-=(gf_view g, arrays::matrix> m) { + for (int u = 0; u < first_dim(g.data()); ++u) g.data()(u, arrays::ellipsis()) -= m; + g.singularity()(0) -= m; + } } } diff --git a/triqs/gfs/local/tail.hpp b/triqs/gfs/local/tail.hpp index a8a893da..b665403f 100644 --- a/triqs/gfs/local/tail.hpp +++ b/triqs/gfs/local/tail.hpp @@ -393,5 +393,19 @@ namespace triqs { namespace gfs { namespace local { template TYPE_ENABLE_IF(tail,mpl::and_, is_scalar_or_element>) operator - (T1 const & t, T2 const & a) { return (-a) + t;} +// inplace operators +#define DEFINE_OPERATOR(OP1, OP2) \ + template void operator OP1(tail_view g, T &&x) { g = g OP2 x; } \ + template void operator OP1(tail &g, T &&x) { g = g OP2 x; } + + DEFINE_OPERATOR(+=, +); + DEFINE_OPERATOR(-=, -); + DEFINE_OPERATOR(*=, *); + DEFINE_OPERATOR(/=, / ); + +#undef DEFINE_OPERATOR + + + }}} #endif diff --git a/triqs/parameters/parameters.hpp b/triqs/parameters/parameters.hpp index d5f0fd5a..88ba49d2 100644 --- a/triqs/parameters/parameters.hpp +++ b/triqs/parameters/parameters.hpp @@ -146,6 +146,11 @@ namespace params { return p1; } + inline parameters & operator+=(parameters & p1, parameters const& p2) { + p1.update(p2); + return p1; + } + // can only be implemented after complete declaration of parameters template _field& _field::add_field(T&&... x) { auto* pp = dynamic_cast<_data_impl*>(p.get());