mirror of
https://github.com/triqs/dft_tools
synced 2025-01-12 05:58:18 +01:00
Fix #112 and put back g +=/-= matrix for imfreq
- The issue comes from the fact that the default generated += and co by the Python API is the one for immutable types, like int. - Indeed, in python, for an int : x=1 id(x) 140266967205832 x+=1 id(x) 140266967205808 - For a mutable type, like a gf, it is necessary to add explicitly the xxx_inplace_add functions. - Added : - the generation of the inplace_xxx functions - a method in class_ in the wrapper generator that deduce all += operator from the + operators. - this assumes that the +=, ... are defined in C++. - The generation of such operators are optional, with option with_inplace_operators in the arithmetic flag. - Also, added the overload g += M and g -= M for g : GfImfreq, M a complex matrix. Mainly for legacy Python codes.
This commit is contained in:
parent
dcbdd5bc54
commit
7cf7d09c77
@ -218,7 +218,7 @@ def make_gf( py_type, c_tag, is_complex_data = True, is_im = False, has_tail = T
|
|||||||
serializable= "tuple",
|
serializable= "tuple",
|
||||||
is_printable= True,
|
is_printable= True,
|
||||||
hdf5 = True,
|
hdf5 = True,
|
||||||
arithmetic = ("algebra",data_type)
|
arithmetic = ("algebra",data_type, "with_inplace_operators")
|
||||||
)
|
)
|
||||||
|
|
||||||
g.add_constructor(signature = "(gf_mesh<%s> mesh, mini_vector<size_t,2> shape, std::vector<std::vector<std::string>> indices = std::vector<std::vector<std::string>>{}, std::string name = "")"%c_tag, python_precall = "pytriqs.gf.local._gf_%s.init"%c_tag)
|
g.add_constructor(signature = "(gf_mesh<%s> mesh, mini_vector<size_t,2> shape, std::vector<std::vector<std::string>> indices = std::vector<std::vector<std::string>>{}, std::string name = "")"%c_tag, python_precall = "pytriqs.gf.local._gf_%s.init"%c_tag)
|
||||||
@ -372,6 +372,11 @@ g.add_method(name = "set_from_legendre",
|
|||||||
g.add_pure_python_method("pytriqs.gf.local._gf_imfreq.replace_by_tail")
|
g.add_pure_python_method("pytriqs.gf.local._gf_imfreq.replace_by_tail")
|
||||||
g.add_pure_python_method("pytriqs.gf.local._gf_imfreq.fit_tail")
|
g.add_pure_python_method("pytriqs.gf.local._gf_imfreq.fit_tail")
|
||||||
|
|
||||||
|
# For legacy Python code : authorize g + Matrix
|
||||||
|
#g.number_protocol['add'].add_overload(calling_pattern = "+", signature = "gf<imfreq>(gf<imfreq> x,matrix<std:complex<double>> y)")
|
||||||
|
g.number_protocol['inplace_add'].add_overload(calling_pattern = "+=", signature = "void(gf_view<imfreq> x,matrix<std::complex<double>> y)")
|
||||||
|
g.number_protocol['inplace_subtract'].add_overload(calling_pattern = "-=", signature = "void(gf_view<imfreq> x,matrix<std::complex<double>> y)")
|
||||||
|
|
||||||
module.add_class(g)
|
module.add_class(g)
|
||||||
|
|
||||||
########################
|
########################
|
||||||
|
@ -269,6 +269,11 @@ class class_ :
|
|||||||
- with_unit : +/- of an element with a scalar (injection of the scalar with the unit)
|
- with_unit : +/- of an element with a scalar (injection of the scalar with the unit)
|
||||||
- with_unary_minus : implement unary minus
|
- with_unary_minus : implement unary minus
|
||||||
- "add_only" : implements only +
|
- "add_only" : implements only +
|
||||||
|
- with_inplace_operators : option to deduce the +=, -=, ...
|
||||||
|
operators from +,-, .. It deduces the possibles terms to put at the rhs, looking at the
|
||||||
|
case of the +,- operators where the lhs is of the type of self.
|
||||||
|
NB : The operator is mapped to the corresponding C++ operators (for some objects, this may be faster)
|
||||||
|
so it has to be defined in C++ as well....
|
||||||
- .... more to be defined.
|
- .... more to be defined.
|
||||||
- serializable : Whether and how the object is to be serialized. Possible values are :
|
- serializable : Whether and how the object is to be serialized. Possible values are :
|
||||||
- "tuple" : reduce it to a tuple of smaller objects, using the
|
- "tuple" : reduce it to a tuple of smaller objects, using the
|
||||||
@ -309,6 +314,7 @@ class class_ :
|
|||||||
# read the with_... option and clean them for the list
|
# read the with_... option and clean them for the list
|
||||||
with_unary_minus = 'with_unary_minus' in arithmetic
|
with_unary_minus = 'with_unary_minus' in arithmetic
|
||||||
with_unit = 'with_unit' in arithmetic
|
with_unit = 'with_unit' in arithmetic
|
||||||
|
with_inplace_operators = 'with_inplace_operators' in arithmetic
|
||||||
arithmetic = [x for x in arithmetic if not x.startswith("with_")]
|
arithmetic = [x for x in arithmetic if not x.startswith("with_")]
|
||||||
add = arithmetic[0] in ("algebra", "abelian_group", "vector_space", "only_add")
|
add = arithmetic[0] in ("algebra", "abelian_group", "vector_space", "only_add")
|
||||||
abelian_group = arithmetic[0] in ("algebra", "abelian_group", "vector_space")
|
abelian_group = arithmetic[0] in ("algebra", "abelian_group", "vector_space")
|
||||||
@ -359,6 +365,23 @@ class class_ :
|
|||||||
neg.add_overload (calling_pattern = "-", signature = {'args' :[(self.c_type,'x')], 'rtype' : self.c_type})
|
neg.add_overload (calling_pattern = "-", signature = {'args' :[(self.c_type,'x')], 'rtype' : self.c_type})
|
||||||
self.number_protocol['negative'] = neg
|
self.number_protocol['negative'] = neg
|
||||||
|
|
||||||
|
if with_inplace_operators : self.deduce_inplace_arithmetic()
|
||||||
|
|
||||||
|
def deduce_inplace_arithmetic(self) :
|
||||||
|
"""Deduce all the +=, -=, *=, /= operators from the +, -, *, / operators"""
|
||||||
|
def one_op(op, name, iname) :
|
||||||
|
if name not in self.number_protocol : return
|
||||||
|
impl = pyfunction(name = iname, arity = 2)
|
||||||
|
for overload in self.number_protocol[name].overloads :
|
||||||
|
x_t,y_t = overload.args[0][0], overload.args[1][0]
|
||||||
|
if x_t == self.c_type : # only when first the object
|
||||||
|
impl.add_overload (calling_pattern = op+"=", signature = {'args' : [(x_t,'x'), (y_t,'y')], 'rtype' :overload.rtype})
|
||||||
|
self.number_protocol['inplace_'+name] = impl
|
||||||
|
one_op('+',"add","__iadd__")
|
||||||
|
one_op('-',"subtract","__isub__")
|
||||||
|
one_op('*',"multiply","__imul__")
|
||||||
|
one_op('/',"divide","__idiv__")
|
||||||
|
|
||||||
def add_constructor(self, signature, calling_pattern = None, python_precall = None, python_postcall = None, build_from_regular_type_if_view = True, doc = ''):
|
def add_constructor(self, signature, calling_pattern = None, python_precall = None, python_postcall = None, build_from_regular_type_if_view = True, doc = ''):
|
||||||
"""
|
"""
|
||||||
- signature : signature of the function, with types, parameter names and defaut value
|
- signature : signature of the function, with types, parameter names and defaut value
|
||||||
|
@ -939,8 +939,14 @@ static PyObject * ${c.py_type}_${op_name} (PyObject* v, PyObject *w){
|
|||||||
%for overload in op.overloads :
|
%for overload in op.overloads :
|
||||||
if (convertible_from_python<${overload.args[0][0]}>(v,false) && convertible_from_python<${overload.args[1][0]}>(w,false)) {
|
if (convertible_from_python<${overload.args[0][0]}>(v,false) && convertible_from_python<${overload.args[1][0]}>(w,false)) {
|
||||||
try {
|
try {
|
||||||
${regular_type_if_view_else_type(overload.rtype)} r = convert_from_python<${overload.args[0][0]}>(v) ${overload._get_calling_pattern()} convert_from_python<${overload.args[1][0]}>(w);
|
%if not op_name.startswith("inplace") :
|
||||||
return convert_to_python(std::move(r)); // in two steps to force type for expression templates in C++
|
${regular_type_if_view_else_type(overload.rtype)} r = convert_from_python<${overload.args[0][0]}>(v) ${overload._get_calling_pattern()} convert_from_python<${overload.args[1][0]}>(w);
|
||||||
|
return convert_to_python(std::move(r)); // in two steps to force type for expression templates in C++
|
||||||
|
%else:
|
||||||
|
convert_from_python<${overload.args[0][0]}>(v) ${overload._get_calling_pattern()} convert_from_python<${overload.args[1][0]}>(w);
|
||||||
|
Py_INCREF(v);
|
||||||
|
return v;
|
||||||
|
%endif
|
||||||
}
|
}
|
||||||
CATCH_AND_RETURN("in calling C++ overload \n ${overload._get_c_signature()} \nin implementation of operator ${overload._get_calling_pattern()} ", NULL)
|
CATCH_AND_RETURN("in calling C++ overload \n ${overload._get_c_signature()} \nin implementation of operator ${overload._get_calling_pattern()} ", NULL)
|
||||||
}
|
}
|
||||||
|
@ -17,3 +17,7 @@ add_triqs_test_hdf(dos " -d 1.e-6")
|
|||||||
|
|
||||||
# Pade approximation
|
# Pade approximation
|
||||||
add_triqs_test_hdf(pade " -d 1.e-6")
|
add_triqs_test_hdf(pade " -d 1.e-6")
|
||||||
|
|
||||||
|
# Bug fix #112
|
||||||
|
add_triqs_test_txt(gf_inplace_112)
|
||||||
|
|
||||||
|
12
test/pytriqs/base/gf_inplace_112.output
Normal file
12
test/pytriqs/base/gf_inplace_112.output
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
Before:
|
||||||
|
G['up'] = 3.14159265359j
|
||||||
|
G['dn'] = 3.14159265359j
|
||||||
|
After G['up'] += G['dn']:
|
||||||
|
G['up'] = 6.28318530718j
|
||||||
|
G['dn'] = 3.14159265359j
|
||||||
|
After g_up += g_dn:
|
||||||
|
G['up'] = 9.42477796077j
|
||||||
|
G['dn'] = 3.14159265359j
|
||||||
|
After G += G:
|
||||||
|
G['up'] = 18.8495559215j
|
||||||
|
G['dn'] = 6.28318530718j
|
55
test/pytriqs/base/gf_inplace_112.py
Normal file
55
test/pytriqs/base/gf_inplace_112.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
# Test from I. Krivenko.
|
||||||
|
from __future__ import print_function
|
||||||
|
from pytriqs.gf.local import *
|
||||||
|
from pytriqs.gf.local.descriptors import *
|
||||||
|
import sys
|
||||||
|
def print_err(*x) : print (*x, file= sys.stderr)
|
||||||
|
|
||||||
|
g_up = GfImFreq(indices = [0], beta = 1)
|
||||||
|
g_dn = GfImFreq(indices = [0], beta = 1)
|
||||||
|
|
||||||
|
g_up <<= iOmega_n
|
||||||
|
g_dn <<= iOmega_n
|
||||||
|
|
||||||
|
G = BlockGf(name_list=['up','dn'], block_list=[g_up,g_dn], make_copies=False)
|
||||||
|
|
||||||
|
print("Before:")
|
||||||
|
print("G['up'] =", G['up'].data[0,0,0])
|
||||||
|
print("G['dn'] =", G['dn'].data[0,0,0])
|
||||||
|
print_err('id(g_up) =', id(g_up))
|
||||||
|
print_err('id(g_dn) =', id(g_dn))
|
||||||
|
print_err ("(id=", id(G['up']),")")
|
||||||
|
print_err ("(id=", id(G['dn']),")")
|
||||||
|
|
||||||
|
G['up'] += G['dn']
|
||||||
|
|
||||||
|
print("After G['up'] += G['dn']:")
|
||||||
|
print("G['up'] =", G['up'].data[0,0,0])
|
||||||
|
print("G['dn'] =", G['dn'].data[0,0,0])
|
||||||
|
print_err('id(g_up) =', id(g_up))
|
||||||
|
print_err('id(g_dn) =', id(g_dn))
|
||||||
|
print_err ("(id=", id(G['up']),")")
|
||||||
|
print_err ("(id=", id(G['dn']),")")
|
||||||
|
|
||||||
|
|
||||||
|
g_up += g_dn
|
||||||
|
|
||||||
|
print("After g_up += g_dn:")
|
||||||
|
print("G['up'] =", G['up'].data[0,0,0])
|
||||||
|
print("G['dn'] =", G['dn'].data[0,0,0])
|
||||||
|
print_err('id(g_up) =', id(g_up))
|
||||||
|
print_err('id(g_dn) =', id(g_dn))
|
||||||
|
print_err ("(id=", id(G['up']),")")
|
||||||
|
print_err ("(id=", id(G['dn']),")")
|
||||||
|
|
||||||
|
|
||||||
|
G += G
|
||||||
|
print("After G += G:")
|
||||||
|
print("G['up'] =", G['up'].data[0,0,0])
|
||||||
|
print("G['dn'] =", G['dn'].data[0,0,0])
|
||||||
|
print_err('id(g_up) =', id(g_up))
|
||||||
|
print_err('id(g_dn) =', id(g_dn))
|
||||||
|
print_err ("(id=", id(G['up']),")")
|
||||||
|
print_err ("(id=", id(G['dn']),")")
|
||||||
|
|
||||||
|
|
@ -668,8 +668,8 @@ namespace gfs {
|
|||||||
// auxiliary function : invert the data : one function for all matrix valued gf (save code).
|
// auxiliary function : invert the data : one function for all matrix valued gf (save code).
|
||||||
template <typename A3> void _gf_invert_data_in_place(A3 && a) {
|
template <typename A3> void _gf_invert_data_in_place(A3 && a) {
|
||||||
for (int i = 0; i < first_dim(a); ++i) {// Rely on the ordering
|
for (int i = 0; i < first_dim(a); ++i) {// Rely on the ordering
|
||||||
auto v = a(i, arrays::range(), arrays::range());
|
auto v = make_matrix_view(a(i, arrays::range(), arrays::range()));
|
||||||
v = inverse(v);
|
v = triqs::arrays::inverse(v);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -142,6 +142,26 @@ namespace triqs { namespace gfs {
|
|||||||
>::type
|
>::type
|
||||||
operator - (A1 && a1) { return {std::forward<A1>(a1)};}
|
operator - (A1 && a1) { return {std::forward<A1>(a1)};}
|
||||||
|
|
||||||
|
// Now the inplace operator. Because of expression template, there are useless for speed
|
||||||
|
// we implement them trivially.
|
||||||
|
|
||||||
|
#define DEFINE_OPERATOR(OP1, OP2) \
|
||||||
|
template <typename Variable, typename Target, typename Opt, typename T> \
|
||||||
|
void operator OP1(gf_view<Variable, Target, Opt> g, T const &x) { \
|
||||||
|
g = g OP2 x; \
|
||||||
|
} \
|
||||||
|
template <typename Variable, typename Target, typename Opt, typename T> \
|
||||||
|
void operator OP1(gf<Variable, Target, Opt> &g, T const &x) { \
|
||||||
|
g = g OP2 x; \
|
||||||
|
}
|
||||||
|
|
||||||
|
DEFINE_OPERATOR(+=, +);
|
||||||
|
DEFINE_OPERATOR(-=, -);
|
||||||
|
DEFINE_OPERATOR(*=, *);
|
||||||
|
DEFINE_OPERATOR(/=, / );
|
||||||
|
|
||||||
|
#undef DEFINE_OPERATOR
|
||||||
|
|
||||||
}}//namespace triqs::gf
|
}}//namespace triqs::gf
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -166,5 +166,17 @@ namespace gfs {
|
|||||||
template <typename Opt> struct data_proxy<imfreq, scalar_valued, Opt> : data_proxy_array<std::complex<double>, 1> {};
|
template <typename Opt> struct data_proxy<imfreq, scalar_valued, Opt> : data_proxy_array<std::complex<double>, 1> {};
|
||||||
|
|
||||||
} // gfs_implementation
|
} // gfs_implementation
|
||||||
|
|
||||||
|
// specific operations (for legacy python code).
|
||||||
|
// +=, -= with a matrix
|
||||||
|
inline void operator+=(gf_view<imfreq> g, arrays::matrix<std::complex<double>> m) {
|
||||||
|
for (int u = 0; u < first_dim(g.data()); ++u) g.data()(u, arrays::ellipsis()) += m;
|
||||||
|
g.singularity()(0) += m;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void operator-=(gf_view<imfreq> g, arrays::matrix<std::complex<double>> m) {
|
||||||
|
for (int u = 0; u < first_dim(g.data()); ++u) g.data()(u, arrays::ellipsis()) -= m;
|
||||||
|
g.singularity()(0) -= m;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -393,5 +393,19 @@ namespace triqs { namespace gfs { namespace local {
|
|||||||
template<typename T1, typename T2> TYPE_ENABLE_IF(tail,mpl::and_<LocalTail<T1>, is_scalar_or_element<T2>>)
|
template<typename T1, typename T2> TYPE_ENABLE_IF(tail,mpl::and_<LocalTail<T1>, is_scalar_or_element<T2>>)
|
||||||
operator - (T1 const & t, T2 const & a) { return (-a) + t;}
|
operator - (T1 const & t, T2 const & a) { return (-a) + t;}
|
||||||
|
|
||||||
|
// inplace operators
|
||||||
|
#define DEFINE_OPERATOR(OP1, OP2) \
|
||||||
|
template <typename T> void operator OP1(tail_view g, T &&x) { g = g OP2 x; } \
|
||||||
|
template <typename T> void operator OP1(tail &g, T &&x) { g = g OP2 x; }
|
||||||
|
|
||||||
|
DEFINE_OPERATOR(+=, +);
|
||||||
|
DEFINE_OPERATOR(-=, -);
|
||||||
|
DEFINE_OPERATOR(*=, *);
|
||||||
|
DEFINE_OPERATOR(/=, / );
|
||||||
|
|
||||||
|
#undef DEFINE_OPERATOR
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}}}
|
}}}
|
||||||
#endif
|
#endif
|
||||||
|
@ -146,6 +146,11 @@ namespace params {
|
|||||||
return p1;
|
return p1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline parameters & operator+=(parameters & p1, parameters const& p2) {
|
||||||
|
p1.update(p2);
|
||||||
|
return p1;
|
||||||
|
}
|
||||||
|
|
||||||
// can only be implemented after complete declaration of parameters
|
// can only be implemented after complete declaration of parameters
|
||||||
template <typename... T> _field& _field::add_field(T&&... x) {
|
template <typename... T> _field& _field::add_field(T&&... x) {
|
||||||
auto* pp = dynamic_cast<_data_impl<parameters>*>(p.get());
|
auto* pp = dynamic_cast<_data_impl<parameters>*>(p.get());
|
||||||
|
Loading…
Reference in New Issue
Block a user