3
0
mirror of https://github.com/triqs/dft_tools synced 2024-10-31 19:23:45 +01:00

[arrays] Fix #118. Assignment to object convertible to scalar

- For a matrix, A = s, where s is scalar was coded properly,
  but not general enough when s is not a scalar, but can casted/converted to
  scalar, e.g. matsubara_freq, mesh_points.

- As a result, doing with a gf :
  g(om_) << om_ + 0.0  // was fine because the expression evaluated to complex
  g(om_) << om_ + om_  // was not : evaluated to matsubara_freq
  g(om_) << om_        // was not : evaluated to mesh_point_t

- Solution : In the case A = s, when s is not a scalar, but convertible
  to it, we use the convertion.
  Impl : allow is_scalar_for to be true when the type is only convertible to scalar.
  Also : split implementation of matrix = scalar into true scalar case/ generic case
         to avoid the strange RHS(0*rhs) construction.
        TODO: make a make_zero function ?

- added a test
This commit is contained in:
Olivier Parcollet 2014-10-26 14:27:45 +01:00
parent 9366c29eab
commit 07fbd77669
5 changed files with 97 additions and 27 deletions

44
test/triqs/gfs/bug1.cpp Normal file
View File

@ -0,0 +1,44 @@
#include <triqs/gfs.hpp>
#include "./common.hpp"
using namespace triqs::gfs;
int main() {
using A = triqs::arrays::matrix_tensor_proxy<triqs::arrays::array<std::complex<double>, 3, void> &, true>;
static_assert(std::is_constructible<std::complex<double>, matsubara_freq>::value, "oops");
static_assert(triqs::arrays::is_scalar_or_convertible<matsubara_freq>::value, "oops2");
static_assert(triqs::arrays::is_scalar_for<std::complex<double>, A>::value, "oops2");
static_assert(triqs::arrays::is_scalar_for<matsubara_freq, A>::value, "oops2");
triqs::clef::placeholder<0> om_;
auto g = gf<imfreq>{{10, Fermion, 10}, {2, 2}};
auto g2 = g;
auto g3 = g;
g(om_) << om_ + 0.0; // Works
g2(om_) << om_; // Did not compile : rhs = mesh_point
g3(om_) << om_ + om_; // Did not compile : rhs = matsubara_freq
assert_equal_array(g.data(), g2.data(), "bug !");
assert_equal_array(g.data(), g3.data()/2, "bug !");
std::cerr << g.data()(triqs::arrays::ellipsis(), 0,0) << std::endl;
std::cerr << g.data()(triqs::arrays::ellipsis(), 1,0) << std::endl;
std::cerr << g.data()(triqs::arrays::ellipsis(), 0,1) << std::endl;
std::cerr << g.data()(triqs::arrays::ellipsis(), 1,1) << std::endl;
std::cerr << g2.data()(triqs::arrays::ellipsis(), 0,0) << std::endl;
std::cerr << g2.data()(triqs::arrays::ellipsis(), 1,0) << std::endl;
std::cerr << g2.data()(triqs::arrays::ellipsis(), 0,1) << std::endl;
std::cerr << g2.data()(triqs::arrays::ellipsis(), 1,1) << std::endl;
std::cerr << g3.data()(triqs::arrays::ellipsis(), 0,0) << std::endl;
std::cerr << g3.data()(triqs::arrays::ellipsis(), 1,0) << std::endl;
std::cerr << g3.data()(triqs::arrays::ellipsis(), 0,1) << std::endl;
std::cerr << g3.data()(triqs::arrays::ellipsis(), 1,1) << std::endl;
return 0;
}

11
test/triqs/gfs/common.hpp Normal file
View File

@ -0,0 +1,11 @@
template<typename T>
void assert_equal(T const& x, T const& y, std::string mess) {
if (std::abs(x - y) > 1.e-13) TRIQS_RUNTIME_ERROR << mess;
}
template<typename T1, typename T2>
void assert_equal_array(T1 const& x, T2 const& y, std::string mess) {
if (max_element(abs(x - y)) > 1.e-13) TRIQS_RUNTIME_ERROR << mess << "\n" << x << "\n" << y << "\n" << max_element(abs(x - y));
}

View File

@ -2,6 +2,7 @@
#include <triqs/gfs.hpp>
#include <triqs/gfs/bz.hpp>
#include <triqs/gfs/m_tail.hpp>
#include "../common.hpp"
namespace h5 = triqs::h5;
using namespace triqs::gfs;
@ -9,16 +10,6 @@ using namespace triqs::clef;
using namespace triqs::arrays;
using namespace triqs::lattice;
template<typename T>
void assert_equal(T const& x, T const& y, std::string mess) {
if (std::abs(x - y) > 1.e-13) TRIQS_RUNTIME_ERROR << mess;
}
template<typename T1, typename T2>
void assert_equal_array(T1 const& x, T2 const& y, std::string mess) {
if (max_element(abs(x - y)) > 1.e-13) TRIQS_RUNTIME_ERROR << mess << "\n" << x << "\n" << y << "\n" << max_element(abs(x - y));
}
#define TEST(X) std::cout << BOOST_PP_STRINGIZE((X)) << " ---> " << (X) << std::endl << std::endl;
int main() {

View File

@ -135,31 +135,52 @@ namespace triqs { namespace arrays {
TRIQS_REJECT_ASSIGN_TO_CONST;
typedef typename LHS::value_type value_type;
LHS & lhs; const RHS & rhs;
impl(LHS & lhs_, const RHS & rhs_): lhs(lhs_), rhs(rhs_){}//, p(*(lhs_.data_start())) {}
impl(LHS & lhs_, const RHS & rhs_): lhs(lhs_), rhs(rhs_){}
template<typename ... Args> void operator()(Args const & ...args) const {_ops_<value_type, RHS, OP>::invoke(lhs(args...), rhs);}
void invoke() { foreach(lhs,*this); }
};
// ----------------- assignment for scalar RHS for Matrices --------------------------------------------------
template <typename T, int R> bool kronecker(mini_vector<T,R> const & key) { return ( (R==2) && (key[0]==key[1]));}
//template <typename T, int R> bool kronecker(mini_vector<T,R> const & key) { return ( (R==2) && (key[0]==key[1]));}
template <typename T> bool kronecker(T const & x0, T const & x1) { return ( (x0==x1));}
// CONCEPT : reunifiy the 2 class, put require on operator() for the 2 cases
// Specialisation for Matrix Classes : scalar is a unity matrix, and operation is E, A, S, but NOT M, D
template<typename LHS, typename RHS, char OP>
struct impl<LHS,RHS,OP, ENABLE_IFC(is_scalar_for<RHS,LHS>::value && (MutableMatrix<LHS>::value && (OP=='A'||OP=='S'||OP=='E')))> {
TRIQS_REJECT_ASSIGN_TO_CONST;
typedef typename LHS::value_type value_type;
LHS & lhs; const RHS & rhs;
impl(LHS & lhs_, const RHS & rhs_): lhs(lhs_), rhs(rhs_){} //, p(*(lhs_.data_start())) {}
// we MUST make off_diag like this, if value_type is a complicated type (i.e. gf, matrix) with a size
// off diagonal element is 0*rhs, i.e. a 0, but with the SAME SIZE as the diagonal part.
// otherwise further operation may fail later.
// TO DO : look at performance issue ?? (we can remote the multiplication by 0 using an auxiliary function)
template<typename ... Args>
void operator()(Args const & ... args) const {_ops_<value_type, RHS, OP>::invoke(lhs(args...), (kronecker(args...) ? rhs : RHS{0*rhs}));}
void invoke() { foreach(lhs,*this); }
};
// First case : when it is a true scalar or convertible to
template <typename LHS, typename RHS, char OP>
struct impl<LHS, RHS, OP, ENABLE_IFC(is_scalar_for<RHS, LHS>::value&& is_scalar_or_convertible<RHS>::value&&(
MutableMatrix<LHS>::value&&(OP == 'A' || OP == 'S' || OP == 'E')))> {
TRIQS_REJECT_ASSIGN_TO_CONST;
using value_type = typename LHS::value_type;
static_assert(is_scalar<value_type>::value, "Internal error");
LHS& lhs;
const RHS& rhs;
impl(LHS& lhs_, const RHS& rhs_) : lhs(lhs_), rhs(rhs_) {}
template <typename... Args> void operator()(Args const&... args) const {
if (kronecker(args...))
_ops_<value_type, RHS, OP>::invoke(lhs(args...), rhs);
else
_ops_<value_type, value_type, OP>::invoke(lhs(args...), 0);
}
void invoke() { foreach(lhs, *this); }
};
// Specialisation for Matrix Classes : scalar is a unity matrix, and operation is E, A, S, but NOT M, D
// Second generic case : we should introduce make_zero function ?
template <typename LHS, typename RHS, char OP>
struct impl<LHS, RHS, OP, ENABLE_IFC(is_scalar_for<RHS, LHS>::value&&(!is_scalar_or_convertible<RHS>::value) &&
(MutableMatrix<LHS>::value&&(OP == 'A' || OP == 'S' || OP == 'E')))> {
TRIQS_REJECT_ASSIGN_TO_CONST;
typedef typename LHS::value_type value_type;
LHS& lhs;
const RHS& rhs;
impl(LHS& lhs_, const RHS& rhs_) : lhs(lhs_), rhs(rhs_) {}
template <typename... Args> void operator()(Args const&... args) const {
_ops_<value_type, RHS, OP>::invoke(lhs(args...), (kronecker(args...) ? rhs : RHS{0 * rhs}));
}
void invoke() { foreach(lhs, *this); }
};
#undef TRIQS_REJECT_MATRIX_COMPOUND_MUL_DIV_NON_SCALAR
#undef TRIQS_REJECT_ASSIGN_TO_CONST

View File

@ -70,10 +70,13 @@ namespace arrays {
template <class T> struct is_amv_value_or_view_class : _or<is_amv_value_class<T>, is_amv_view_class<T>> {};
template <class S> struct is_scalar : _or<std::is_arithmetic<S>, triqs::is_complex<S>> {};
template <class S>
struct is_scalar_or_convertible
: std::integral_constant<bool, is_scalar<S>::value || std::is_constructible<std::complex<double>, S>::value> {};
template <class S, class A>
struct is_scalar_for
: std::conditional<is_scalar<typename A::value_type>::value, is_scalar<S>, std::is_same<S, typename A::value_type>>::type {
: std::conditional<is_scalar<typename A::value_type>::value, is_scalar_or_convertible<S>, std::is_same<S, typename A::value_type>>::type {
};
}
} // namespace triqs::arrays