mirror of
https://github.com/triqs/dft_tools
synced 2024-12-25 05:43:40 +01:00
Fix #112 and put back g +=/-= matrix for imfreq
- The issue comes from the fact that the default generated += and co by the Python API is the one for immutable types, like int. - Indeed, in python, for an int : x=1 id(x) 140266967205832 x+=1 id(x) 140266967205808 - For a mutable type, like a gf, it is necessary to add explicitly the xxx_inplace_add functions. - Added : - the generation of the inplace_xxx functions - a method in class_ in the wrapper generator that deduce all += operator from the + operators. - this assumes that the +=, ... are defined in C++. - The generation of such operators are optional, with option with_inplace_operators in the arithmetic flag. - Also, added the overload g += M and g -= M for g : GfImfreq, M a complex matrix. Mainly for legacy Python codes.
This commit is contained in:
parent
dcbdd5bc54
commit
7cf7d09c77
@ -218,7 +218,7 @@ def make_gf( py_type, c_tag, is_complex_data = True, is_im = False, has_tail = T
|
||||
serializable= "tuple",
|
||||
is_printable= True,
|
||||
hdf5 = True,
|
||||
arithmetic = ("algebra",data_type)
|
||||
arithmetic = ("algebra",data_type, "with_inplace_operators")
|
||||
)
|
||||
|
||||
g.add_constructor(signature = "(gf_mesh<%s> mesh, mini_vector<size_t,2> shape, std::vector<std::vector<std::string>> indices = std::vector<std::vector<std::string>>{}, std::string name = "")"%c_tag, python_precall = "pytriqs.gf.local._gf_%s.init"%c_tag)
|
||||
@ -372,6 +372,11 @@ g.add_method(name = "set_from_legendre",
|
||||
g.add_pure_python_method("pytriqs.gf.local._gf_imfreq.replace_by_tail")
|
||||
g.add_pure_python_method("pytriqs.gf.local._gf_imfreq.fit_tail")
|
||||
|
||||
# For legacy Python code : authorize g + Matrix
|
||||
#g.number_protocol['add'].add_overload(calling_pattern = "+", signature = "gf<imfreq>(gf<imfreq> x,matrix<std:complex<double>> y)")
|
||||
g.number_protocol['inplace_add'].add_overload(calling_pattern = "+=", signature = "void(gf_view<imfreq> x,matrix<std::complex<double>> y)")
|
||||
g.number_protocol['inplace_subtract'].add_overload(calling_pattern = "-=", signature = "void(gf_view<imfreq> x,matrix<std::complex<double>> y)")
|
||||
|
||||
module.add_class(g)
|
||||
|
||||
########################
|
||||
|
@ -269,6 +269,11 @@ class class_ :
|
||||
- with_unit : +/- of an element with a scalar (injection of the scalar with the unit)
|
||||
- with_unary_minus : implement unary minus
|
||||
- "add_only" : implements only +
|
||||
- with_inplace_operators : option to deduce the +=, -=, ...
|
||||
operators from +,-, .. It deduces the possibles terms to put at the rhs, looking at the
|
||||
case of the +,- operators where the lhs is of the type of self.
|
||||
NB : The operator is mapped to the corresponding C++ operators (for some objects, this may be faster)
|
||||
so it has to be defined in C++ as well....
|
||||
- .... more to be defined.
|
||||
- serializable : Whether and how the object is to be serialized. Possible values are :
|
||||
- "tuple" : reduce it to a tuple of smaller objects, using the
|
||||
@ -309,6 +314,7 @@ class class_ :
|
||||
# read the with_... option and clean them for the list
|
||||
with_unary_minus = 'with_unary_minus' in arithmetic
|
||||
with_unit = 'with_unit' in arithmetic
|
||||
with_inplace_operators = 'with_inplace_operators' in arithmetic
|
||||
arithmetic = [x for x in arithmetic if not x.startswith("with_")]
|
||||
add = arithmetic[0] in ("algebra", "abelian_group", "vector_space", "only_add")
|
||||
abelian_group = arithmetic[0] in ("algebra", "abelian_group", "vector_space")
|
||||
@ -359,6 +365,23 @@ class class_ :
|
||||
neg.add_overload (calling_pattern = "-", signature = {'args' :[(self.c_type,'x')], 'rtype' : self.c_type})
|
||||
self.number_protocol['negative'] = neg
|
||||
|
||||
if with_inplace_operators : self.deduce_inplace_arithmetic()
|
||||
|
||||
def deduce_inplace_arithmetic(self) :
|
||||
"""Deduce all the +=, -=, *=, /= operators from the +, -, *, / operators"""
|
||||
def one_op(op, name, iname) :
|
||||
if name not in self.number_protocol : return
|
||||
impl = pyfunction(name = iname, arity = 2)
|
||||
for overload in self.number_protocol[name].overloads :
|
||||
x_t,y_t = overload.args[0][0], overload.args[1][0]
|
||||
if x_t == self.c_type : # only when first the object
|
||||
impl.add_overload (calling_pattern = op+"=", signature = {'args' : [(x_t,'x'), (y_t,'y')], 'rtype' :overload.rtype})
|
||||
self.number_protocol['inplace_'+name] = impl
|
||||
one_op('+',"add","__iadd__")
|
||||
one_op('-',"subtract","__isub__")
|
||||
one_op('*',"multiply","__imul__")
|
||||
one_op('/',"divide","__idiv__")
|
||||
|
||||
def add_constructor(self, signature, calling_pattern = None, python_precall = None, python_postcall = None, build_from_regular_type_if_view = True, doc = ''):
|
||||
"""
|
||||
- signature : signature of the function, with types, parameter names and defaut value
|
||||
|
@ -939,8 +939,14 @@ static PyObject * ${c.py_type}_${op_name} (PyObject* v, PyObject *w){
|
||||
%for overload in op.overloads :
|
||||
if (convertible_from_python<${overload.args[0][0]}>(v,false) && convertible_from_python<${overload.args[1][0]}>(w,false)) {
|
||||
try {
|
||||
${regular_type_if_view_else_type(overload.rtype)} r = convert_from_python<${overload.args[0][0]}>(v) ${overload._get_calling_pattern()} convert_from_python<${overload.args[1][0]}>(w);
|
||||
return convert_to_python(std::move(r)); // in two steps to force type for expression templates in C++
|
||||
%if not op_name.startswith("inplace") :
|
||||
${regular_type_if_view_else_type(overload.rtype)} r = convert_from_python<${overload.args[0][0]}>(v) ${overload._get_calling_pattern()} convert_from_python<${overload.args[1][0]}>(w);
|
||||
return convert_to_python(std::move(r)); // in two steps to force type for expression templates in C++
|
||||
%else:
|
||||
convert_from_python<${overload.args[0][0]}>(v) ${overload._get_calling_pattern()} convert_from_python<${overload.args[1][0]}>(w);
|
||||
Py_INCREF(v);
|
||||
return v;
|
||||
%endif
|
||||
}
|
||||
CATCH_AND_RETURN("in calling C++ overload \n ${overload._get_c_signature()} \nin implementation of operator ${overload._get_calling_pattern()} ", NULL)
|
||||
}
|
||||
|
@ -17,3 +17,7 @@ add_triqs_test_hdf(dos " -d 1.e-6")
|
||||
|
||||
# Pade approximation
|
||||
add_triqs_test_hdf(pade " -d 1.e-6")
|
||||
|
||||
# Bug fix #112
|
||||
add_triqs_test_txt(gf_inplace_112)
|
||||
|
||||
|
12
test/pytriqs/base/gf_inplace_112.output
Normal file
12
test/pytriqs/base/gf_inplace_112.output
Normal file
@ -0,0 +1,12 @@
|
||||
Before:
|
||||
G['up'] = 3.14159265359j
|
||||
G['dn'] = 3.14159265359j
|
||||
After G['up'] += G['dn']:
|
||||
G['up'] = 6.28318530718j
|
||||
G['dn'] = 3.14159265359j
|
||||
After g_up += g_dn:
|
||||
G['up'] = 9.42477796077j
|
||||
G['dn'] = 3.14159265359j
|
||||
After G += G:
|
||||
G['up'] = 18.8495559215j
|
||||
G['dn'] = 6.28318530718j
|
55
test/pytriqs/base/gf_inplace_112.py
Normal file
55
test/pytriqs/base/gf_inplace_112.py
Normal file
@ -0,0 +1,55 @@
|
||||
# Test from I. Krivenko.
|
||||
from __future__ import print_function
|
||||
from pytriqs.gf.local import *
|
||||
from pytriqs.gf.local.descriptors import *
|
||||
import sys
|
||||
def print_err(*x) : print (*x, file= sys.stderr)
|
||||
|
||||
g_up = GfImFreq(indices = [0], beta = 1)
|
||||
g_dn = GfImFreq(indices = [0], beta = 1)
|
||||
|
||||
g_up <<= iOmega_n
|
||||
g_dn <<= iOmega_n
|
||||
|
||||
G = BlockGf(name_list=['up','dn'], block_list=[g_up,g_dn], make_copies=False)
|
||||
|
||||
print("Before:")
|
||||
print("G['up'] =", G['up'].data[0,0,0])
|
||||
print("G['dn'] =", G['dn'].data[0,0,0])
|
||||
print_err('id(g_up) =', id(g_up))
|
||||
print_err('id(g_dn) =', id(g_dn))
|
||||
print_err ("(id=", id(G['up']),")")
|
||||
print_err ("(id=", id(G['dn']),")")
|
||||
|
||||
G['up'] += G['dn']
|
||||
|
||||
print("After G['up'] += G['dn']:")
|
||||
print("G['up'] =", G['up'].data[0,0,0])
|
||||
print("G['dn'] =", G['dn'].data[0,0,0])
|
||||
print_err('id(g_up) =', id(g_up))
|
||||
print_err('id(g_dn) =', id(g_dn))
|
||||
print_err ("(id=", id(G['up']),")")
|
||||
print_err ("(id=", id(G['dn']),")")
|
||||
|
||||
|
||||
g_up += g_dn
|
||||
|
||||
print("After g_up += g_dn:")
|
||||
print("G['up'] =", G['up'].data[0,0,0])
|
||||
print("G['dn'] =", G['dn'].data[0,0,0])
|
||||
print_err('id(g_up) =', id(g_up))
|
||||
print_err('id(g_dn) =', id(g_dn))
|
||||
print_err ("(id=", id(G['up']),")")
|
||||
print_err ("(id=", id(G['dn']),")")
|
||||
|
||||
|
||||
G += G
|
||||
print("After G += G:")
|
||||
print("G['up'] =", G['up'].data[0,0,0])
|
||||
print("G['dn'] =", G['dn'].data[0,0,0])
|
||||
print_err('id(g_up) =', id(g_up))
|
||||
print_err('id(g_dn) =', id(g_dn))
|
||||
print_err ("(id=", id(G['up']),")")
|
||||
print_err ("(id=", id(G['dn']),")")
|
||||
|
||||
|
@ -668,8 +668,8 @@ namespace gfs {
|
||||
// auxiliary function : invert the data : one function for all matrix valued gf (save code).
|
||||
template <typename A3> void _gf_invert_data_in_place(A3 && a) {
|
||||
for (int i = 0; i < first_dim(a); ++i) {// Rely on the ordering
|
||||
auto v = a(i, arrays::range(), arrays::range());
|
||||
v = inverse(v);
|
||||
auto v = make_matrix_view(a(i, arrays::range(), arrays::range()));
|
||||
v = triqs::arrays::inverse(v);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -142,6 +142,26 @@ namespace triqs { namespace gfs {
|
||||
>::type
|
||||
operator - (A1 && a1) { return {std::forward<A1>(a1)};}
|
||||
|
||||
// Now the inplace operator. Because of expression template, there are useless for speed
|
||||
// we implement them trivially.
|
||||
|
||||
#define DEFINE_OPERATOR(OP1, OP2) \
|
||||
template <typename Variable, typename Target, typename Opt, typename T> \
|
||||
void operator OP1(gf_view<Variable, Target, Opt> g, T const &x) { \
|
||||
g = g OP2 x; \
|
||||
} \
|
||||
template <typename Variable, typename Target, typename Opt, typename T> \
|
||||
void operator OP1(gf<Variable, Target, Opt> &g, T const &x) { \
|
||||
g = g OP2 x; \
|
||||
}
|
||||
|
||||
DEFINE_OPERATOR(+=, +);
|
||||
DEFINE_OPERATOR(-=, -);
|
||||
DEFINE_OPERATOR(*=, *);
|
||||
DEFINE_OPERATOR(/=, / );
|
||||
|
||||
#undef DEFINE_OPERATOR
|
||||
|
||||
}}//namespace triqs::gf
|
||||
#endif
|
||||
|
||||
|
@ -166,5 +166,17 @@ namespace gfs {
|
||||
template <typename Opt> struct data_proxy<imfreq, scalar_valued, Opt> : data_proxy_array<std::complex<double>, 1> {};
|
||||
|
||||
} // gfs_implementation
|
||||
|
||||
// specific operations (for legacy python code).
|
||||
// +=, -= with a matrix
|
||||
inline void operator+=(gf_view<imfreq> g, arrays::matrix<std::complex<double>> m) {
|
||||
for (int u = 0; u < first_dim(g.data()); ++u) g.data()(u, arrays::ellipsis()) += m;
|
||||
g.singularity()(0) += m;
|
||||
}
|
||||
|
||||
inline void operator-=(gf_view<imfreq> g, arrays::matrix<std::complex<double>> m) {
|
||||
for (int u = 0; u < first_dim(g.data()); ++u) g.data()(u, arrays::ellipsis()) -= m;
|
||||
g.singularity()(0) -= m;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -393,5 +393,19 @@ namespace triqs { namespace gfs { namespace local {
|
||||
template<typename T1, typename T2> TYPE_ENABLE_IF(tail,mpl::and_<LocalTail<T1>, is_scalar_or_element<T2>>)
|
||||
operator - (T1 const & t, T2 const & a) { return (-a) + t;}
|
||||
|
||||
// inplace operators
|
||||
#define DEFINE_OPERATOR(OP1, OP2) \
|
||||
template <typename T> void operator OP1(tail_view g, T &&x) { g = g OP2 x; } \
|
||||
template <typename T> void operator OP1(tail &g, T &&x) { g = g OP2 x; }
|
||||
|
||||
DEFINE_OPERATOR(+=, +);
|
||||
DEFINE_OPERATOR(-=, -);
|
||||
DEFINE_OPERATOR(*=, *);
|
||||
DEFINE_OPERATOR(/=, / );
|
||||
|
||||
#undef DEFINE_OPERATOR
|
||||
|
||||
|
||||
|
||||
}}}
|
||||
#endif
|
||||
|
@ -146,6 +146,11 @@ namespace params {
|
||||
return p1;
|
||||
}
|
||||
|
||||
inline parameters & operator+=(parameters & p1, parameters const& p2) {
|
||||
p1.update(p2);
|
||||
return p1;
|
||||
}
|
||||
|
||||
// can only be implemented after complete declaration of parameters
|
||||
template <typename... T> _field& _field::add_field(T&&... x) {
|
||||
auto* pp = dynamic_cast<_data_impl<parameters>*>(p.get());
|
||||
|
Loading…
Reference in New Issue
Block a user