3
0
mirror of https://github.com/triqs/dft_tools synced 2024-12-25 13:53:40 +01:00

Fix #112 and put back g +=/-= matrix for imfreq

- The issue comes from the fact that the default generated
  += and co by the Python API is the one for immutable types, like int.
- Indeed, in python, for an int :
  x=1
  id(x)
  140266967205832
  x+=1
  id(x)
  140266967205808
- For a mutable type, like a gf, it is necessary to
  add explicitly the xxx_inplace_add functions.
- Added :
   - the generation of the inplace_xxx functions
   - a method in class_ in the wrapper generator that
     deduce all += operator from the + operators.
   - this assumes that the +=, ... are defined in C++.
   - The generation of such operators are optional, with option
     with_inplace_operators in the arithmetic flag.
- Also, added the overload g += M and g -= M for
  g : GfImfreq, M a complex matrix.
  Mainly for legacy Python codes.
This commit is contained in:
Olivier Parcollet 2014-09-06 15:53:01 +02:00
parent dcbdd5bc54
commit 7cf7d09c77
11 changed files with 161 additions and 5 deletions

View File

@ -218,7 +218,7 @@ def make_gf( py_type, c_tag, is_complex_data = True, is_im = False, has_tail = T
serializable= "tuple", serializable= "tuple",
is_printable= True, is_printable= True,
hdf5 = True, hdf5 = True,
arithmetic = ("algebra",data_type) arithmetic = ("algebra",data_type, "with_inplace_operators")
) )
g.add_constructor(signature = "(gf_mesh<%s> mesh, mini_vector<size_t,2> shape, std::vector<std::vector<std::string>> indices = std::vector<std::vector<std::string>>{}, std::string name = "")"%c_tag, python_precall = "pytriqs.gf.local._gf_%s.init"%c_tag) g.add_constructor(signature = "(gf_mesh<%s> mesh, mini_vector<size_t,2> shape, std::vector<std::vector<std::string>> indices = std::vector<std::vector<std::string>>{}, std::string name = "")"%c_tag, python_precall = "pytriqs.gf.local._gf_%s.init"%c_tag)
@ -372,6 +372,11 @@ g.add_method(name = "set_from_legendre",
g.add_pure_python_method("pytriqs.gf.local._gf_imfreq.replace_by_tail") g.add_pure_python_method("pytriqs.gf.local._gf_imfreq.replace_by_tail")
g.add_pure_python_method("pytriqs.gf.local._gf_imfreq.fit_tail") g.add_pure_python_method("pytriqs.gf.local._gf_imfreq.fit_tail")
# For legacy Python code : authorize g + Matrix
#g.number_protocol['add'].add_overload(calling_pattern = "+", signature = "gf<imfreq>(gf<imfreq> x,matrix<std:complex<double>> y)")
g.number_protocol['inplace_add'].add_overload(calling_pattern = "+=", signature = "void(gf_view<imfreq> x,matrix<std::complex<double>> y)")
g.number_protocol['inplace_subtract'].add_overload(calling_pattern = "-=", signature = "void(gf_view<imfreq> x,matrix<std::complex<double>> y)")
module.add_class(g) module.add_class(g)
######################## ########################

View File

@ -269,6 +269,11 @@ class class_ :
- with_unit : +/- of an element with a scalar (injection of the scalar with the unit) - with_unit : +/- of an element with a scalar (injection of the scalar with the unit)
- with_unary_minus : implement unary minus - with_unary_minus : implement unary minus
- "add_only" : implements only + - "add_only" : implements only +
- with_inplace_operators : option to deduce the +=, -=, ...
operators from +,-, .. It deduces the possibles terms to put at the rhs, looking at the
case of the +,- operators where the lhs is of the type of self.
NB : The operator is mapped to the corresponding C++ operators (for some objects, this may be faster)
so it has to be defined in C++ as well....
- .... more to be defined. - .... more to be defined.
- serializable : Whether and how the object is to be serialized. Possible values are : - serializable : Whether and how the object is to be serialized. Possible values are :
- "tuple" : reduce it to a tuple of smaller objects, using the - "tuple" : reduce it to a tuple of smaller objects, using the
@ -309,6 +314,7 @@ class class_ :
# read the with_... option and clean them for the list # read the with_... option and clean them for the list
with_unary_minus = 'with_unary_minus' in arithmetic with_unary_minus = 'with_unary_minus' in arithmetic
with_unit = 'with_unit' in arithmetic with_unit = 'with_unit' in arithmetic
with_inplace_operators = 'with_inplace_operators' in arithmetic
arithmetic = [x for x in arithmetic if not x.startswith("with_")] arithmetic = [x for x in arithmetic if not x.startswith("with_")]
add = arithmetic[0] in ("algebra", "abelian_group", "vector_space", "only_add") add = arithmetic[0] in ("algebra", "abelian_group", "vector_space", "only_add")
abelian_group = arithmetic[0] in ("algebra", "abelian_group", "vector_space") abelian_group = arithmetic[0] in ("algebra", "abelian_group", "vector_space")
@ -359,6 +365,23 @@ class class_ :
neg.add_overload (calling_pattern = "-", signature = {'args' :[(self.c_type,'x')], 'rtype' : self.c_type}) neg.add_overload (calling_pattern = "-", signature = {'args' :[(self.c_type,'x')], 'rtype' : self.c_type})
self.number_protocol['negative'] = neg self.number_protocol['negative'] = neg
if with_inplace_operators : self.deduce_inplace_arithmetic()
def deduce_inplace_arithmetic(self) :
"""Deduce all the +=, -=, *=, /= operators from the +, -, *, / operators"""
def one_op(op, name, iname) :
if name not in self.number_protocol : return
impl = pyfunction(name = iname, arity = 2)
for overload in self.number_protocol[name].overloads :
x_t,y_t = overload.args[0][0], overload.args[1][0]
if x_t == self.c_type : # only when first the object
impl.add_overload (calling_pattern = op+"=", signature = {'args' : [(x_t,'x'), (y_t,'y')], 'rtype' :overload.rtype})
self.number_protocol['inplace_'+name] = impl
one_op('+',"add","__iadd__")
one_op('-',"subtract","__isub__")
one_op('*',"multiply","__imul__")
one_op('/',"divide","__idiv__")
def add_constructor(self, signature, calling_pattern = None, python_precall = None, python_postcall = None, build_from_regular_type_if_view = True, doc = ''): def add_constructor(self, signature, calling_pattern = None, python_precall = None, python_postcall = None, build_from_regular_type_if_view = True, doc = ''):
""" """
- signature : signature of the function, with types, parameter names and defaut value - signature : signature of the function, with types, parameter names and defaut value

View File

@ -939,8 +939,14 @@ static PyObject * ${c.py_type}_${op_name} (PyObject* v, PyObject *w){
%for overload in op.overloads : %for overload in op.overloads :
if (convertible_from_python<${overload.args[0][0]}>(v,false) && convertible_from_python<${overload.args[1][0]}>(w,false)) { if (convertible_from_python<${overload.args[0][0]}>(v,false) && convertible_from_python<${overload.args[1][0]}>(w,false)) {
try { try {
%if not op_name.startswith("inplace") :
${regular_type_if_view_else_type(overload.rtype)} r = convert_from_python<${overload.args[0][0]}>(v) ${overload._get_calling_pattern()} convert_from_python<${overload.args[1][0]}>(w); ${regular_type_if_view_else_type(overload.rtype)} r = convert_from_python<${overload.args[0][0]}>(v) ${overload._get_calling_pattern()} convert_from_python<${overload.args[1][0]}>(w);
return convert_to_python(std::move(r)); // in two steps to force type for expression templates in C++ return convert_to_python(std::move(r)); // in two steps to force type for expression templates in C++
%else:
convert_from_python<${overload.args[0][0]}>(v) ${overload._get_calling_pattern()} convert_from_python<${overload.args[1][0]}>(w);
Py_INCREF(v);
return v;
%endif
} }
CATCH_AND_RETURN("in calling C++ overload \n ${overload._get_c_signature()} \nin implementation of operator ${overload._get_calling_pattern()} ", NULL) CATCH_AND_RETURN("in calling C++ overload \n ${overload._get_c_signature()} \nin implementation of operator ${overload._get_calling_pattern()} ", NULL)
} }

View File

@ -17,3 +17,7 @@ add_triqs_test_hdf(dos " -d 1.e-6")
# Pade approximation # Pade approximation
add_triqs_test_hdf(pade " -d 1.e-6") add_triqs_test_hdf(pade " -d 1.e-6")
# Bug fix #112
add_triqs_test_txt(gf_inplace_112)

View File

@ -0,0 +1,12 @@
Before:
G['up'] = 3.14159265359j
G['dn'] = 3.14159265359j
After G['up'] += G['dn']:
G['up'] = 6.28318530718j
G['dn'] = 3.14159265359j
After g_up += g_dn:
G['up'] = 9.42477796077j
G['dn'] = 3.14159265359j
After G += G:
G['up'] = 18.8495559215j
G['dn'] = 6.28318530718j

View File

@ -0,0 +1,55 @@
# Test from I. Krivenko.
from __future__ import print_function
from pytriqs.gf.local import *
from pytriqs.gf.local.descriptors import *
import sys
def print_err(*x) : print (*x, file= sys.stderr)
g_up = GfImFreq(indices = [0], beta = 1)
g_dn = GfImFreq(indices = [0], beta = 1)
g_up <<= iOmega_n
g_dn <<= iOmega_n
G = BlockGf(name_list=['up','dn'], block_list=[g_up,g_dn], make_copies=False)
print("Before:")
print("G['up'] =", G['up'].data[0,0,0])
print("G['dn'] =", G['dn'].data[0,0,0])
print_err('id(g_up) =', id(g_up))
print_err('id(g_dn) =', id(g_dn))
print_err ("(id=", id(G['up']),")")
print_err ("(id=", id(G['dn']),")")
G['up'] += G['dn']
print("After G['up'] += G['dn']:")
print("G['up'] =", G['up'].data[0,0,0])
print("G['dn'] =", G['dn'].data[0,0,0])
print_err('id(g_up) =', id(g_up))
print_err('id(g_dn) =', id(g_dn))
print_err ("(id=", id(G['up']),")")
print_err ("(id=", id(G['dn']),")")
g_up += g_dn
print("After g_up += g_dn:")
print("G['up'] =", G['up'].data[0,0,0])
print("G['dn'] =", G['dn'].data[0,0,0])
print_err('id(g_up) =', id(g_up))
print_err('id(g_dn) =', id(g_dn))
print_err ("(id=", id(G['up']),")")
print_err ("(id=", id(G['dn']),")")
G += G
print("After G += G:")
print("G['up'] =", G['up'].data[0,0,0])
print("G['dn'] =", G['dn'].data[0,0,0])
print_err('id(g_up) =', id(g_up))
print_err('id(g_dn) =', id(g_dn))
print_err ("(id=", id(G['up']),")")
print_err ("(id=", id(G['dn']),")")

View File

@ -668,8 +668,8 @@ namespace gfs {
// auxiliary function : invert the data : one function for all matrix valued gf (save code). // auxiliary function : invert the data : one function for all matrix valued gf (save code).
template <typename A3> void _gf_invert_data_in_place(A3 && a) { template <typename A3> void _gf_invert_data_in_place(A3 && a) {
for (int i = 0; i < first_dim(a); ++i) {// Rely on the ordering for (int i = 0; i < first_dim(a); ++i) {// Rely on the ordering
auto v = a(i, arrays::range(), arrays::range()); auto v = make_matrix_view(a(i, arrays::range(), arrays::range()));
v = inverse(v); v = triqs::arrays::inverse(v);
} }
} }

View File

@ -142,6 +142,26 @@ namespace triqs { namespace gfs {
>::type >::type
operator - (A1 && a1) { return {std::forward<A1>(a1)};} operator - (A1 && a1) { return {std::forward<A1>(a1)};}
// Now the inplace operator. Because of expression template, there are useless for speed
// we implement them trivially.
#define DEFINE_OPERATOR(OP1, OP2) \
template <typename Variable, typename Target, typename Opt, typename T> \
void operator OP1(gf_view<Variable, Target, Opt> g, T const &x) { \
g = g OP2 x; \
} \
template <typename Variable, typename Target, typename Opt, typename T> \
void operator OP1(gf<Variable, Target, Opt> &g, T const &x) { \
g = g OP2 x; \
}
DEFINE_OPERATOR(+=, +);
DEFINE_OPERATOR(-=, -);
DEFINE_OPERATOR(*=, *);
DEFINE_OPERATOR(/=, / );
#undef DEFINE_OPERATOR
}}//namespace triqs::gf }}//namespace triqs::gf
#endif #endif

View File

@ -166,5 +166,17 @@ namespace gfs {
template <typename Opt> struct data_proxy<imfreq, scalar_valued, Opt> : data_proxy_array<std::complex<double>, 1> {}; template <typename Opt> struct data_proxy<imfreq, scalar_valued, Opt> : data_proxy_array<std::complex<double>, 1> {};
} // gfs_implementation } // gfs_implementation
// specific operations (for legacy python code).
// +=, -= with a matrix
inline void operator+=(gf_view<imfreq> g, arrays::matrix<std::complex<double>> m) {
for (int u = 0; u < first_dim(g.data()); ++u) g.data()(u, arrays::ellipsis()) += m;
g.singularity()(0) += m;
}
inline void operator-=(gf_view<imfreq> g, arrays::matrix<std::complex<double>> m) {
for (int u = 0; u < first_dim(g.data()); ++u) g.data()(u, arrays::ellipsis()) -= m;
g.singularity()(0) -= m;
}
} }
} }

View File

@ -393,5 +393,19 @@ namespace triqs { namespace gfs { namespace local {
template<typename T1, typename T2> TYPE_ENABLE_IF(tail,mpl::and_<LocalTail<T1>, is_scalar_or_element<T2>>) template<typename T1, typename T2> TYPE_ENABLE_IF(tail,mpl::and_<LocalTail<T1>, is_scalar_or_element<T2>>)
operator - (T1 const & t, T2 const & a) { return (-a) + t;} operator - (T1 const & t, T2 const & a) { return (-a) + t;}
// inplace operators
#define DEFINE_OPERATOR(OP1, OP2) \
template <typename T> void operator OP1(tail_view g, T &&x) { g = g OP2 x; } \
template <typename T> void operator OP1(tail &g, T &&x) { g = g OP2 x; }
DEFINE_OPERATOR(+=, +);
DEFINE_OPERATOR(-=, -);
DEFINE_OPERATOR(*=, *);
DEFINE_OPERATOR(/=, / );
#undef DEFINE_OPERATOR
}}} }}}
#endif #endif

View File

@ -146,6 +146,11 @@ namespace params {
return p1; return p1;
} }
inline parameters & operator+=(parameters & p1, parameters const& p2) {
p1.update(p2);
return p1;
}
// can only be implemented after complete declaration of parameters // can only be implemented after complete declaration of parameters
template <typename... T> _field& _field::add_field(T&&... x) { template <typename... T> _field& _field::add_field(T&&... x) {
auto* pp = dynamic_cast<_data_impl<parameters>*>(p.get()); auto* pp = dynamic_cast<_data_impl<parameters>*>(p.get());