mirror of
https://github.com/triqs/dft_tools
synced 2024-10-31 19:23:45 +01:00
[arrays] Fix #118. Assignment to object convertible to scalar
- For a matrix, A = s, where s is scalar was coded properly, but not general enough when s is not a scalar, but can casted/converted to scalar, e.g. matsubara_freq, mesh_points. - As a result, doing with a gf : g(om_) << om_ + 0.0 // was fine because the expression evaluated to complex g(om_) << om_ + om_ // was not : evaluated to matsubara_freq g(om_) << om_ // was not : evaluated to mesh_point_t - Solution : In the case A = s, when s is not a scalar, but convertible to it, we use the convertion. Impl : allow is_scalar_for to be true when the type is only convertible to scalar. Also : split implementation of matrix = scalar into true scalar case/ generic case to avoid the strange RHS(0*rhs) construction. TODO: make a make_zero function ? - added a test
This commit is contained in:
parent
9366c29eab
commit
07fbd77669
44
test/triqs/gfs/bug1.cpp
Normal file
44
test/triqs/gfs/bug1.cpp
Normal file
@ -0,0 +1,44 @@
|
||||
#include <triqs/gfs.hpp>
|
||||
#include "./common.hpp"
|
||||
|
||||
using namespace triqs::gfs;
|
||||
|
||||
int main() {
|
||||
|
||||
using A = triqs::arrays::matrix_tensor_proxy<triqs::arrays::array<std::complex<double>, 3, void> &, true>;
|
||||
|
||||
static_assert(std::is_constructible<std::complex<double>, matsubara_freq>::value, "oops");
|
||||
static_assert(triqs::arrays::is_scalar_or_convertible<matsubara_freq>::value, "oops2");
|
||||
static_assert(triqs::arrays::is_scalar_for<std::complex<double>, A>::value, "oops2");
|
||||
static_assert(triqs::arrays::is_scalar_for<matsubara_freq, A>::value, "oops2");
|
||||
|
||||
triqs::clef::placeholder<0> om_;
|
||||
auto g = gf<imfreq>{{10, Fermion, 10}, {2, 2}};
|
||||
auto g2 = g;
|
||||
auto g3 = g;
|
||||
|
||||
g(om_) << om_ + 0.0; // Works
|
||||
|
||||
g2(om_) << om_; // Did not compile : rhs = mesh_point
|
||||
g3(om_) << om_ + om_; // Did not compile : rhs = matsubara_freq
|
||||
|
||||
assert_equal_array(g.data(), g2.data(), "bug !");
|
||||
assert_equal_array(g.data(), g3.data()/2, "bug !");
|
||||
|
||||
std::cerr << g.data()(triqs::arrays::ellipsis(), 0,0) << std::endl;
|
||||
std::cerr << g.data()(triqs::arrays::ellipsis(), 1,0) << std::endl;
|
||||
std::cerr << g.data()(triqs::arrays::ellipsis(), 0,1) << std::endl;
|
||||
std::cerr << g.data()(triqs::arrays::ellipsis(), 1,1) << std::endl;
|
||||
|
||||
std::cerr << g2.data()(triqs::arrays::ellipsis(), 0,0) << std::endl;
|
||||
std::cerr << g2.data()(triqs::arrays::ellipsis(), 1,0) << std::endl;
|
||||
std::cerr << g2.data()(triqs::arrays::ellipsis(), 0,1) << std::endl;
|
||||
std::cerr << g2.data()(triqs::arrays::ellipsis(), 1,1) << std::endl;
|
||||
|
||||
std::cerr << g3.data()(triqs::arrays::ellipsis(), 0,0) << std::endl;
|
||||
std::cerr << g3.data()(triqs::arrays::ellipsis(), 1,0) << std::endl;
|
||||
std::cerr << g3.data()(triqs::arrays::ellipsis(), 0,1) << std::endl;
|
||||
std::cerr << g3.data()(triqs::arrays::ellipsis(), 1,1) << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
11
test/triqs/gfs/common.hpp
Normal file
11
test/triqs/gfs/common.hpp
Normal file
@ -0,0 +1,11 @@
|
||||
template<typename T>
|
||||
void assert_equal(T const& x, T const& y, std::string mess) {
|
||||
if (std::abs(x - y) > 1.e-13) TRIQS_RUNTIME_ERROR << mess;
|
||||
}
|
||||
|
||||
template<typename T1, typename T2>
|
||||
void assert_equal_array(T1 const& x, T2 const& y, std::string mess) {
|
||||
if (max_element(abs(x - y)) > 1.e-13) TRIQS_RUNTIME_ERROR << mess << "\n" << x << "\n" << y << "\n" << max_element(abs(x - y));
|
||||
}
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
#include <triqs/gfs.hpp>
|
||||
#include <triqs/gfs/bz.hpp>
|
||||
#include <triqs/gfs/m_tail.hpp>
|
||||
#include "../common.hpp"
|
||||
|
||||
namespace h5 = triqs::h5;
|
||||
using namespace triqs::gfs;
|
||||
@ -9,16 +10,6 @@ using namespace triqs::clef;
|
||||
using namespace triqs::arrays;
|
||||
using namespace triqs::lattice;
|
||||
|
||||
template<typename T>
|
||||
void assert_equal(T const& x, T const& y, std::string mess) {
|
||||
if (std::abs(x - y) > 1.e-13) TRIQS_RUNTIME_ERROR << mess;
|
||||
}
|
||||
|
||||
template<typename T1, typename T2>
|
||||
void assert_equal_array(T1 const& x, T2 const& y, std::string mess) {
|
||||
if (max_element(abs(x - y)) > 1.e-13) TRIQS_RUNTIME_ERROR << mess << "\n" << x << "\n" << y << "\n" << max_element(abs(x - y));
|
||||
}
|
||||
|
||||
#define TEST(X) std::cout << BOOST_PP_STRINGIZE((X)) << " ---> " << (X) << std::endl << std::endl;
|
||||
|
||||
int main() {
|
||||
|
@ -135,31 +135,52 @@ namespace triqs { namespace arrays {
|
||||
TRIQS_REJECT_ASSIGN_TO_CONST;
|
||||
typedef typename LHS::value_type value_type;
|
||||
LHS & lhs; const RHS & rhs;
|
||||
impl(LHS & lhs_, const RHS & rhs_): lhs(lhs_), rhs(rhs_){}//, p(*(lhs_.data_start())) {}
|
||||
impl(LHS & lhs_, const RHS & rhs_): lhs(lhs_), rhs(rhs_){}
|
||||
template<typename ... Args> void operator()(Args const & ...args) const {_ops_<value_type, RHS, OP>::invoke(lhs(args...), rhs);}
|
||||
void invoke() { foreach(lhs,*this); }
|
||||
};
|
||||
|
||||
// ----------------- assignment for scalar RHS for Matrices --------------------------------------------------
|
||||
|
||||
template <typename T, int R> bool kronecker(mini_vector<T,R> const & key) { return ( (R==2) && (key[0]==key[1]));}
|
||||
//template <typename T, int R> bool kronecker(mini_vector<T,R> const & key) { return ( (R==2) && (key[0]==key[1]));}
|
||||
template <typename T> bool kronecker(T const & x0, T const & x1) { return ( (x0==x1));}
|
||||
|
||||
// CONCEPT : reunifiy the 2 class, put require on operator() for the 2 cases
|
||||
// Specialisation for Matrix Classes : scalar is a unity matrix, and operation is E, A, S, but NOT M, D
|
||||
template<typename LHS, typename RHS, char OP>
|
||||
struct impl<LHS,RHS,OP, ENABLE_IFC(is_scalar_for<RHS,LHS>::value && (MutableMatrix<LHS>::value && (OP=='A'||OP=='S'||OP=='E')))> {
|
||||
TRIQS_REJECT_ASSIGN_TO_CONST;
|
||||
typedef typename LHS::value_type value_type;
|
||||
LHS & lhs; const RHS & rhs;
|
||||
impl(LHS & lhs_, const RHS & rhs_): lhs(lhs_), rhs(rhs_){} //, p(*(lhs_.data_start())) {}
|
||||
// we MUST make off_diag like this, if value_type is a complicated type (i.e. gf, matrix) with a size
|
||||
// off diagonal element is 0*rhs, i.e. a 0, but with the SAME SIZE as the diagonal part.
|
||||
// otherwise further operation may fail later.
|
||||
// TO DO : look at performance issue ?? (we can remote the multiplication by 0 using an auxiliary function)
|
||||
template<typename ... Args>
|
||||
void operator()(Args const & ... args) const {_ops_<value_type, RHS, OP>::invoke(lhs(args...), (kronecker(args...) ? rhs : RHS{0*rhs}));}
|
||||
void invoke() { foreach(lhs,*this); }
|
||||
};
|
||||
// First case : when it is a true scalar or convertible to
|
||||
template <typename LHS, typename RHS, char OP>
|
||||
struct impl<LHS, RHS, OP, ENABLE_IFC(is_scalar_for<RHS, LHS>::value&& is_scalar_or_convertible<RHS>::value&&(
|
||||
MutableMatrix<LHS>::value&&(OP == 'A' || OP == 'S' || OP == 'E')))> {
|
||||
TRIQS_REJECT_ASSIGN_TO_CONST;
|
||||
using value_type = typename LHS::value_type;
|
||||
static_assert(is_scalar<value_type>::value, "Internal error");
|
||||
LHS& lhs;
|
||||
const RHS& rhs;
|
||||
impl(LHS& lhs_, const RHS& rhs_) : lhs(lhs_), rhs(rhs_) {}
|
||||
template <typename... Args> void operator()(Args const&... args) const {
|
||||
if (kronecker(args...))
|
||||
_ops_<value_type, RHS, OP>::invoke(lhs(args...), rhs);
|
||||
else
|
||||
_ops_<value_type, value_type, OP>::invoke(lhs(args...), 0);
|
||||
}
|
||||
void invoke() { foreach(lhs, *this); }
|
||||
};
|
||||
|
||||
// Specialisation for Matrix Classes : scalar is a unity matrix, and operation is E, A, S, but NOT M, D
|
||||
// Second generic case : we should introduce make_zero function ?
|
||||
template <typename LHS, typename RHS, char OP>
|
||||
struct impl<LHS, RHS, OP, ENABLE_IFC(is_scalar_for<RHS, LHS>::value&&(!is_scalar_or_convertible<RHS>::value) &&
|
||||
(MutableMatrix<LHS>::value&&(OP == 'A' || OP == 'S' || OP == 'E')))> {
|
||||
TRIQS_REJECT_ASSIGN_TO_CONST;
|
||||
typedef typename LHS::value_type value_type;
|
||||
LHS& lhs;
|
||||
const RHS& rhs;
|
||||
impl(LHS& lhs_, const RHS& rhs_) : lhs(lhs_), rhs(rhs_) {}
|
||||
template <typename... Args> void operator()(Args const&... args) const {
|
||||
_ops_<value_type, RHS, OP>::invoke(lhs(args...), (kronecker(args...) ? rhs : RHS{0 * rhs}));
|
||||
}
|
||||
void invoke() { foreach(lhs, *this); }
|
||||
};
|
||||
|
||||
#undef TRIQS_REJECT_MATRIX_COMPOUND_MUL_DIV_NON_SCALAR
|
||||
#undef TRIQS_REJECT_ASSIGN_TO_CONST
|
||||
|
@ -70,10 +70,13 @@ namespace arrays {
|
||||
template <class T> struct is_amv_value_or_view_class : _or<is_amv_value_class<T>, is_amv_view_class<T>> {};
|
||||
|
||||
template <class S> struct is_scalar : _or<std::is_arithmetic<S>, triqs::is_complex<S>> {};
|
||||
template <class S>
|
||||
struct is_scalar_or_convertible
|
||||
: std::integral_constant<bool, is_scalar<S>::value || std::is_constructible<std::complex<double>, S>::value> {};
|
||||
|
||||
template <class S, class A>
|
||||
struct is_scalar_for
|
||||
: std::conditional<is_scalar<typename A::value_type>::value, is_scalar<S>, std::is_same<S, typename A::value_type>>::type {
|
||||
: std::conditional<is_scalar<typename A::value_type>::value, is_scalar_or_convertible<S>, std::is_same<S, typename A::value_type>>::type {
|
||||
};
|
||||
}
|
||||
} // namespace triqs::arrays
|
||||
|
Loading…
Reference in New Issue
Block a user