3
0
mirror of https://github.com/triqs/dft_tools synced 2025-01-12 22:18:23 +01:00

arrays: Remove dim0, dim1, .shape in various matrix object.

Not in the concept, not needed, just an annoyance.
replaced by free functions :
first_dim(A), second_dim(A), get_shape(A) and so on...
This commit is contained in:
Olivier Parcollet 2013-07-29 22:44:43 +02:00
parent e241f85a7d
commit 072b45ac1c
21 changed files with 66 additions and 88 deletions

View File

@ -39,14 +39,6 @@ namespace triqs { namespace arrays {
//
friend std::ostream & operator<<(std::ostream & out, immutable_diagonal_matrix_view const & d) { return out<<"diagonal_matrix "<<d.data;}
// ----------------------
// should be remove from concept. redundant....
// need to clean dim0, dim1 and shape and make them free function everywhere (deduced from domain)
size_t dim0() const { return data.shape()[0];}
size_t dim1() const { return data.shape()[0];}
mini_vector<size_t,2> shape() const { auto s = data.shape()[0]; return mini_vector<size_t,2>(s,s);}
};
}}

View File

@ -66,7 +66,7 @@ namespace triqs { namespace arrays { namespace blas {
gemm (typename MT1::value_type alpha, MT1 const & A, MT2 const & B, typename MT1::value_type beta, MTOut & C) {
//std::cerr << "gemm: blas call "<< std::endl ;
// first resize if necessary and possible
resize_or_check_if_view(C,make_shape(A.dim0(),B.dim1()));
resize_or_check_if_view(C,make_shape(first_dim(A),second_dim(B)));
// now we use qcache instead of the matrix to make a copy if necessary ...
// not optimal : if stride == 1, N ---> use LDA parameters
@ -77,16 +77,16 @@ namespace triqs { namespace arrays { namespace blas {
// then tC = tB tA !
const_qcache<MT1> Cb(A); // note the inversion A <-> B
const_qcache<MT2> Ca(B); // note the inversion A <-> B
if (!(Ca().dim0() == Cb().dim1())) TRIQS_RUNTIME_ERROR << "Dimension mismatch in gemm : A : "<< Ca().shape() <<" while B : "<<Cb().shape();
if (!(first_dim(Ca()) == second_dim(Cb()))) TRIQS_RUNTIME_ERROR << "Dimension mismatch in gemm : A : "<< get_shape(Ca()) <<" while B : "<<get_shape(Cb());
char trans_a= get_trans(Ca(), true);
char trans_b= get_trans(Cb(), true);
int m = (trans_a == 'N' ? get_n_rows(Ca()) : get_n_cols(Ca()));
int n = (trans_b == 'N' ? get_n_cols(Cb()) : get_n_rows(Cb()));
int k = (trans_a == 'N' ? get_n_cols(Ca()) : get_n_rows(Ca()));
//std::cerr<< " about to call GEMM"<< std::endl ;
//std::cerr<< "A = "<< Ca().shape()<< Ca()<< std::endl;
//std::cerr<< "B = "<< Cb().shape()<< Cb()<< std::endl;
//std::cerr<< "C c" << Cc().shape() << Cc().indexmap().strides() << std::endl;
//std::cerr<< "A = "<< get_shape(Ca())<< Ca()<< std::endl;
//std::cerr<< "B = "<< get_shape(Cb())<< Cb()<< std::endl;
//std::cerr<< "C c" << get_shape(Cc()) << Cc().indexmap().strides() << std::endl;
//std::cerr<<Ca().memory_layout_is_c() <<Ca().memory_layout_is_fortran()<<std::endl;
//std::cerr<< get_n_rows(Ca())<<get_n_cols(Cb())<<get_n_cols(Ca()) << std::endl ;
f77::gemm(trans_a,trans_b,m,n,k,
@ -96,7 +96,7 @@ namespace triqs { namespace arrays { namespace blas {
else {
const_qcache<MT1> Ca(A);
const_qcache<MT2> Cb(B);
if (!(Ca().dim1() == Cb().dim0())) TRIQS_RUNTIME_ERROR << "Dimension mismatch in gemm : A : "<< Ca().shape() <<" while B : "<<Cb().shape();
if (!(second_dim(Ca()) == first_dim(Cb()))) TRIQS_RUNTIME_ERROR << "Dimension mismatch in gemm : A : "<< get_shape(Ca()) <<" while B : "<<get_shape(Cb());
char trans_a= get_trans(Ca(), false);
char trans_b= get_trans(Cb(), false);
int m = (trans_a == 'N' ? get_n_rows(Ca()) : get_n_cols(Ca()));
@ -113,12 +113,12 @@ namespace triqs { namespace arrays { namespace blas {
void gemm_generic (typename MT1::value_type alpha, MT1 const & A, MT2 const & B, typename MT1::value_type beta, MTOut & C) {
//std::cerr << "gemm: generic call "<< std::endl ;
// first resize if necessary and possible
resize_or_check_if_view(C,make_shape(A.dim0(),B.dim1()));
if (A.dim1() != B.dim0()) TRIQS_RUNTIME_ERROR << "gemm generic : dimension mismatch "<< A.shape() << B.shape();
resize_or_check_if_view(C,make_shape(first_dim(A),second_dim(B)));
if (second_dim(A) != first_dim(B)) TRIQS_RUNTIME_ERROR << "gemm generic : dimension mismatch "<< get_shape(A) << get_shape(B);
C() = 0;
for (int i=0; i<A.dim0(); ++i)
for (int k=0; k<A.dim1(); ++k)
for (int j=0; j<B.dim1(); ++j)
for (int i=0; i<first_dim(A); ++i)
for (int k=0; k<second_dim(A); ++k)
for (int j=0; j<second_dim(B); ++j)
C(i,j) += A(i,k)*B(k,j);
}

View File

@ -65,10 +65,10 @@ namespace triqs { namespace arrays { namespace blas {
typename std::enable_if< use_blas_gemv<MT,VT,VTOut>::value >::type
gemv (typename MT::value_type alpha, MT const & A, VT const & X, typename MT::value_type beta, VTOut & Y) {
//std::cerr << "gemm: blas call "<< std::endl ;
resize_or_check_if_view(Y,make_shape(A.dim0()));// first resize if necessary and possible
resize_or_check_if_view(Y,make_shape(first_dim(A)));// first resize if necessary and possible
const_qcache<MT> Ca(A);
const_qcache<VT> Cx(X); // mettre la condition a la main
if (!(Ca().dim1() == Cx().size())) TRIQS_RUNTIME_ERROR << "Dimension mismatch in gemv : A : "<< Ca().shape() <<" while X : "<<Cx().shape();
if (!(second_dim(Ca()) == Cx().size())) TRIQS_RUNTIME_ERROR << "Dimension mismatch in gemv : A : "<< get_shape(Ca()) <<" while X : "<<get_shape(Cx());
char trans_a= get_trans(Ca(), false);
int m1 = get_n_rows(Ca()), m2 = get_n_cols(Ca());
int lda = get_ld(Ca());
@ -81,11 +81,11 @@ namespace triqs { namespace arrays { namespace blas {
void gemv_generic (typename MT::value_type alpha, MT const & A, VT const & X, typename MT::value_type beta, VTOut & C) {
//std::cerr << "gemm: generic call "<< std::endl ;
// first resize if necessary and possible
resize_or_check_if_view(C,make_shape(A.dim0()));
if (A.dim1() != X.size()) TRIQS_RUNTIME_ERROR << "gemm generic : dimension mismatch "<< A.dim1() << " vs " << X.size();
resize_or_check_if_view(C,make_shape(first_dim(A)));
if (second_dim(A) != X.size()) TRIQS_RUNTIME_ERROR << "gemm generic : dimension mismatch "<< second_dim(A) << " vs " << X.size();
C() = 0;
for (int i=0; i<A.dim0(); ++i)
for (int k=0; k<A.dim1(); ++k)
for (int i=0; i<first_dim(A); ++i)
for (int k=0; k<second_dim(A); ++k)
C(i) += A(i,k)*X(k);
}

View File

@ -50,7 +50,7 @@ namespace triqs { namespace arrays { namespace blas {
typename std::enable_if< is_blas_lapack_type<typename VTX::value_type>::value && have_same_value_type< VTX, VTY, MT>::value >::type
ger (typename VTX::value_type alpha, VTX const & X, VTY const & Y, MT & A) {
static_assert( is_amv_value_or_view_class<MT>::value, "ger : A must be a matrix or a matrix_view");
if (( A.dim0() != Y.size()) || (A.dim1() != X.size())) TRIQS_RUNTIME_ERROR << "Dimension mismatch in ger : A : "<< A().shape() <<" while X : "<<X().shape()<<" and Y : "<<Y().shape();
if (( first_dim(A) != Y.size()) || (second_dim(A) != X.size())) TRIQS_RUNTIME_ERROR << "Dimension mismatch in ger : A : "<< get_shape(A()) <<" while X : "<<get_shape(X())<<" and Y : "<<get_shape(Y());
const_qcache<VTX> Cx(X); // mettre la condition a la main
const_qcache<VTY> Cy(Y); // mettre la condition a la main
reflexive_qcache<MT> Ca(A);

View File

@ -52,7 +52,7 @@ namespace triqs { namespace arrays { namespace lapack {
getrf (MT & A, arrays::vector<int> & ipiv, bool assert_fortran_order = false ) {
if (assert_fortran_order && A.memory_layout_is_c()) TRIQS_RUNTIME_ERROR<< "matrix passed to getrf is not in Fortran order";
reflexive_qcache<MT> Ca(A);
auto dm = std::min(Ca().dim0(), Ca().dim1());
auto dm = std::min(first_dim(Ca()), second_dim(Ca()));
if (ipiv.size() < dm) ipiv.resize(dm);
int info;
f77::getrf ( get_n_rows(Ca()), get_n_cols(Ca()), Ca().data_start(), get_ld(Ca()), ipiv.data_start(), info);

View File

@ -53,7 +53,7 @@ namespace triqs { namespace arrays { namespace lapack {
getri (MT & A, arrays::vector<int> & ipiv) {
//getri (MT & A, arrays::vector<int> & ipiv, arrays::vector<typename MT::value_type> & work ) {
reflexive_qcache<MT> Ca(A);
auto dm = std::min(Ca().dim0(), Ca().dim1());
auto dm = std::min(first_dim(Ca()), second_dim(Ca()));
if (ipiv.size() < dm) TRIQS_RUNTIME_ERROR << "getri : error in ipiv size : found "<<ipiv.size()<< " while it should be at least" << dm;
int info;
typename MT::value_type work1[2];

View File

@ -49,11 +49,11 @@ namespace triqs { namespace arrays {
// returns the # of rows of the matrix *seen* as fortran matrix
template <typename MatrixType> int get_n_rows (MatrixType const & A) {
return (A.memory_layout_is_fortran() ? A.dim0() : A.dim1());
return (A.memory_layout_is_fortran() ? first_dim(A) : second_dim(A));
}
// returns the # of cols of the matrix *seen* as fortran matrix
template <typename MatrixType> int get_n_cols (MatrixType const & A) {
return (A.memory_layout_is_fortran() ? A.dim1() : A.dim0());
return (A.memory_layout_is_fortran() ? second_dim(A) : first_dim(A));
}
template <typename MatrixType> int get_ld (MatrixType const & A) {

View File

@ -39,9 +39,6 @@ namespace triqs { namespace arrays {
template<typename LL, typename RR> matrix_expr(LL && l_, RR && r_) : l(std::forward<LL>(l_)), r(std::forward<RR>(r_)) {}
domain_type domain() const { return combine_domain()(l,r); }
mini_vector<size_t,2> shape() const { return this->domain().lengths();}
size_t dim0() const { return this->domain().lengths()[0];}
size_t dim1() const { return this->domain().lengths()[1];}
//template<typename KeyType> value_type operator[](KeyType && key) const { return operation<Tag>()(l[std::forward<KeyType>(key)] , r[std::forward<KeyType>(key)]);}
template<typename ... Args> value_type operator()(Args && ... args) const { return operation<Tag>()(l(std::forward<Args>(args)...) , r(std::forward<Args>(args)...));}
@ -57,9 +54,6 @@ namespace triqs { namespace arrays {
template<typename LL> matrix_unary_m_expr(LL && l_) : l(std::forward<LL>(l_)) {}
domain_type domain() const { return l.domain(); }
mini_vector<size_t,2> shape() const { return this->domain().lengths();}
size_t dim0() const { return this->domain().lengths()[0];}
size_t dim1() const { return this->domain().lengths()[1];}
//template<typename KeyType> value_type operator[](KeyType&& key) const {return -l[key];}
template<typename ... Args> value_type operator()(Args && ... args) const { return -l(std::forward<Args>(args)...);}

View File

@ -36,7 +36,6 @@ namespace triqs { namespace arrays {
domain_type domain() const { return combine_domain()(l,r); }
mini_vector<size_t,1> shape() const { return this->domain().lengths();}
//size_t dim0() const { return this->domain().lengths()[0];}
size_t size() const { return this->domain().lengths()[0];}
//template<typename KeyType> value_type operator[](KeyType && key) const { return operation<Tag>()(l[std::forward<KeyType>(key)] , r[std::forward<KeyType>(key)]);}
@ -54,7 +53,6 @@ namespace triqs { namespace arrays {
domain_type domain() const { return l.domain(); }
mini_vector<size_t,1> shape() const { return this->domain().lengths();}
//size_t dim0() const { return this->domain().lengths()[0];}
size_t size() const { return this->domain().lengths()[0];}
//template<typename KeyType> value_type operator[](KeyType&& key) const {return -l[key];}

View File

@ -67,8 +67,6 @@ namespace triqs { namespace arrays {
A const & a; F f;
m_result(F const & f_, A const & a_):a(a_),f(f_) {}
domain_type domain() const { return a.domain(); }
size_t dim0() const { return a.dim0();}
size_t dim1() const { return a.dim1();}
template<typename ... Args> value_type operator() (Args const & ... args) const { return f(a(args...)); }
//value_type operator[] ( typename domain_type::index_value_type const & key) const { return f(a[key]); }
friend std::ostream & operator<<(std::ostream & out, m_result const & x){ return out<<"lazy matrix resulting of a mapping";}

View File

@ -38,6 +38,16 @@
namespace triqs { namespace arrays {
template<typename A> auto get_shape (A const & x) DECL_AND_RETURN(x.domain().lengths());
template<typename A> size_t first_dim (A const & x) { return x.domain().lengths()[0];}
template<typename A> size_t second_dim (A const & x) { return x.domain().lengths()[1];}
template<typename A> size_t third_dim (A const & x) { return x.domain().lengths()[2];}
template<typename A> size_t fourth_dim (A const & x) { return x.domain().lengths()[3];}
template<typename A> size_t fifth_dim (A const & x) { return x.domain().lengths()[4];}
template<typename A> size_t sixth_dim (A const & x) { return x.domain().lengths()[5];}
template<typename A> size_t seventh_dim (A const & x) { return x.domain().lengths()[6];}
template <bool Const, typename IndexMapIterator, typename StorageType > class iterator_adapter;
template <class V, int R, ull_t OptionFlags, ull_t TraversalOrder, class ViewTag, bool Borrowed > struct ISPViewType;

View File

@ -55,15 +55,13 @@ namespace triqs { namespace arrays {
a_x_ty_lazy( ScalarType a_, VectorType1 const & x_, VectorType2 const & y_):a(a_),x(x_),y(y_){}
domain_type domain() const { return domain_type(mini_vector<size_t,2>(x.size(), y.size()));}
size_t dim0() const { return x.size();}
size_t dim1() const { return y.size();}
template<typename K0, typename K1> value_type operator() (K0 const & k0, K1 const & k1) const { return a * x(k0) * y(k1); }
// Optimized implementation of =
template<typename LHS>
friend void triqs_arrays_assign_delegation (LHS & lhs, a_x_ty_lazy const & rhs) {
resize_or_check_if_view(lhs,make_shape(rhs.dim0(),rhs.dim1()));
resize_or_check_if_view(lhs,make_shape( first_dim(rhs),second_dim(rhs) ));
lhs()=0;
blas::ger(rhs.a,rhs.x, rhs.y, lhs);
}

View File

@ -68,9 +68,9 @@ namespace triqs { namespace arrays {
short step;
public:
det_and_inverse_worker (ViewType const & a): V(a), dim(a.dim0()), ipiv(dim), step(0) {
if (a.dim0()!=a.dim1())
TRIQS_RUNTIME_ERROR<<"Inverse/Det error : non-square matrix. Dimensions are : ("<<a.dim0()<<","<<a.dim1()<<")"<<"\n ";
det_and_inverse_worker (ViewType const & a): V(a), dim(first_dim(a)), ipiv(dim), step(0) {
if (first_dim(a)!=second_dim(a))
TRIQS_RUNTIME_ERROR<<"Inverse/Det error : non-square matrix. Dimensions are : ("<<first_dim(a)<<","<<second_dim(a)<<")"<<"\n ";
if (!(has_contiguous_data(a))) TRIQS_RUNTIME_ERROR<<"det_and_inverse_worker only takes a contiguous view";
}
VT det() { V_type W = fortran_view(V); _step1(W); _compute_det(W); return _det;}
@ -117,11 +117,10 @@ namespace triqs { namespace arrays {
typedef typename const_view_type_if_exists_else_type<A>::type A_type;
const A_type a;
inverse_lazy_impl(A const & a_):a (a_) {
if (a.dim0() != a.dim1()) TRIQS_RUNTIME_ERROR<< "Inverse : matrix is not square but of size "<< a.dim0()<<" x "<< a.dim1();
if (first_dim(a) != second_dim(a)) TRIQS_RUNTIME_ERROR<< "Inverse : matrix is not square but of size "<< first_dim(a)<<" x "<< second_dim(a);
}
//typename A::shape_type shape() const { return a.shape();}
domain_type domain() const { return a.domain(); }
size_t dim0() const { return a.dim0();}
size_t dim1() const { return a.dim1();}
template<typename K0, typename K1> value_type operator() (K0 const & k0, K1 const & k1) const { activate(); return _id->M(k0,k1); }
friend std::ostream & operator<<(std::ostream & out,inverse_lazy_impl const&x){return out<<"inverse("<<x.a<<")";}
protected:

View File

@ -59,7 +59,7 @@ namespace triqs { namespace arrays { namespace linalg {
if (mat.is_empty()) TRIQS_RUNTIME_ERROR<<"eigenelements_worker : the matrix is empty : matrix = "<<mat<<" ";
if (!mat.is_square()) TRIQS_RUNTIME_ERROR<<"eigenelements_worker : the matrix "<<mat<<" is not square ";
if (!mat.indexmap().is_contiguous()) TRIQS_RUNTIME_ERROR<<"eigenelements_worker : the matrix "<<mat<<" is not contiguous in memory";
dim = mat.dim0();
dim = first_dim(mat);
ev.resize(dim);
lwork = 64*dim;
work.resize(lwork);

View File

@ -55,7 +55,7 @@ namespace triqs { namespace arrays {
struct internal_data {
vector_type R;
internal_data(mat_vec_mul_lazy const & P): R(P.M.dim0()) { blas::gemv(1,P.M,P.V,0,R); }
internal_data(mat_vec_mul_lazy const & P): R( first_dim(P.M) ) { blas::gemv(1,P.M,P.V,0,R); }
};
friend struct internal_data;
mutable std::shared_ptr<internal_data> _id;
@ -63,12 +63,12 @@ namespace triqs { namespace arrays {
public:
mat_vec_mul_lazy( MT const & M_, VT const & V_):M(M_),V(V_){
if (M.dim1() != V.size()) TRIQS_RUNTIME_ERROR<< "Matrix product : dimension mismatch in Matrix*Vector "<< M<<" "<< V;
if (second_dim(M) != V.size()) TRIQS_RUNTIME_ERROR<< "Matrix product : dimension mismatch in Matrix*Vector "<< M<<" "<< V;
}
domain_type domain() const { return mini_vector<size_t,1>(size());}
//domain_type domain() const { return indexmaps::cuboid::domain_t<1>(mini_vector<size_t,1>(size()));}
size_t size() const { return M.dim0();}
size_t size() const { return first_dim(M);}
template<typename KeyType> value_type operator() (KeyType const & key) const { activate(); return _id->R (key); }

View File

@ -49,7 +49,7 @@ namespace triqs { namespace arrays {
struct internal_data { // implementing the pattern LazyPreCompute
matrix_type R;
internal_data(matmul_lazy const & P): R( P.a.dim0(), P.b.dim1()) { blas::gemm(1.0,P.a, P.b, 0.0, R); }
internal_data(matmul_lazy const & P): R( first_dim(P.a), second_dim(P.b)) { blas::gemm(1.0,P.a, P.b, 0.0, R); }
};
friend struct internal_data;
mutable std::shared_ptr<internal_data> _id;
@ -57,13 +57,11 @@ namespace triqs { namespace arrays {
public:
matmul_lazy( A const & a_, B const & b_):a(a_),b(b_){
if (a.dim1() != b.dim0()) TRIQS_RUNTIME_ERROR<< "Matrix product : dimension mismatch in A*B "<< a<<" "<< b;
if (second_dim(a) != first_dim(b)) TRIQS_RUNTIME_ERROR<< "Matrix product : dimension mismatch in A*B "<< a<<" "<< b;
}
domain_type domain() const { return mini_vector<size_t,2>(a.dim0(), b.dim1());}
//domain_type domain() const { return indexmaps::cuboid::domain_t<2>(mini_vector<size_t,2>(a.dim0(), b.dim1()));}
size_t dim0() const { return a.dim0();}
size_t dim1() const { return b.dim1();}
domain_type domain() const { return mini_vector<size_t,2>(first_dim(a), second_dim(b));}
//domain_type domain() const { return indexmaps::cuboid::domain_t<2>(mini_vector<size_t,2>(first_dim(a), second_dim(b)));}
template<typename K0, typename K1> value_type operator() (K0 const & k0, K1 const & k1) const { activate(); return _id->R(k0,k1); }
@ -72,7 +70,7 @@ namespace triqs { namespace arrays {
template<typename LHS>
friend void triqs_arrays_assign_delegation (LHS & lhs, matmul_lazy const & rhs) {
static_assert((is_matrix_or_view<LHS>::value), "LHS is not a matrix");
resize_or_check_if_view(lhs,make_shape(rhs.dim0(),rhs.dim1()));
resize_or_check_if_view(lhs,make_shape(first_dim(rhs),second_dim(rhs)));
blas::gemm(1.0,rhs.a, rhs.b, 0.0, lhs);
}
@ -84,10 +82,10 @@ namespace triqs { namespace arrays {
private:
template<typename LHS> void assign_comp_impl (LHS & lhs, double S) const {
static_assert((is_matrix_or_view<LHS>::value), "LHS is not a matrix");
if (lhs.dim0() != dim0())
TRIQS_RUNTIME_ERROR<< "Matmul : +=/-= operator : first dimension mismatch in A*B "<< lhs.dim0()<<" vs "<< dim0();
if (lhs.dim1() != dim1())
TRIQS_RUNTIME_ERROR<< "Matmul : +=/-= operator : first dimension mismatch in A*B "<< lhs.dim1()<<" vs "<< dim1();
if (first_dim(lhs) != first_dim(*this))
TRIQS_RUNTIME_ERROR<< "Matmul : +=/-= operator : first dimension mismatch in A*B "<< first_dim(lhs)<<" vs "<< first_dim(*this);
if (second_dim(lhs) != second_dim(*this))
TRIQS_RUNTIME_ERROR<< "Matmul : +=/-= operator : first dimension mismatch in A*B "<< second_dim(lhs)<<" vs "<< second_dim(*this);
blas::gemm(S,a, b, 1.0, lhs);
}

View File

@ -33,9 +33,7 @@ namespace triqs { namespace arrays {
// ---------------------- matrix --------------------------------
//
#define _IMPL_MATRIX_COMMON \
size_t dim0() const { return this->shape()[0];}\
size_t dim1() const { return this->shape()[1];}\
bool is_square() const { return dim0() == dim1();}\
bool is_square() const { return this->shape()[0] == this->shape()[1];}\
\
view_type transpose() const {\
typename indexmap_type::lengths_type l; l[0] = this->indexmap().lengths()[1];l[1] = this->indexmap().lengths()[0];\

View File

@ -45,8 +45,9 @@ namespace triqs { namespace arrays {
matrix_view<T> view(size_t i) const { return a(i,range(),range());}
size_t size() const { return a.shape(0);}
size_t dim0() const { return a.shape(1);}
size_t dim1() const { return a.shape(2);}
// BE CAREFUL to the shift : it is 1, 2, not 0,1, because of the stack !
friend size_t first_dim (matrix_stack_view const & m) { return second_dim(m.a);}
friend size_t second_dim (matrix_stack_view const & m) { return third_dim(m.a);}
matrix_stack_view & operator +=(matrix_stack_view const & arg) { a += arg.a; return *this; }
matrix_stack_view & operator -=(matrix_stack_view const & arg) { a -= arg.a; return *this; }
@ -65,13 +66,13 @@ namespace triqs { namespace arrays {
void invert() {for (size_t i=0; i<size(); ++i) { auto v = view(i); v = inverse(v);} }
friend matrix_stack_view matmul_L_R ( matrix_view<T> const & L, matrix_stack_view const & M, matrix_view<T> const & R) {
matrix_stack_view res (typename array_view_t::regular_type (M.size(), L.dim0(), R.dim1()));
matrix_stack_view res (typename array_view_t::regular_type (M.size(), first_dim(L), second_dim(R)));
for (size_t i=0; i<M.size(); ++i) { res.view(i) = L * M.view(i) * R; }
return res;
}
void onsite_matmul_L_R ( matrix_view<T> const & L, matrix_stack_view const & M, matrix_view<T> const & R) {
if ((dim0() != L.dim0()) || (dim1() != R.dim1()) || (L.dim1() != R.dim0()))
if ((first_dim(*this) != first_dim(L)) || (second_dim(*this) != second_dim(R)) || (second_dim(L) != first_dim(R)))
TRIQS_RUNTIME_ERROR << "dimensions do not match!";
for (size_t i=0; i<M.size(); ++i) { view(i) = L * M.view(i) * R; }
}

View File

@ -32,9 +32,7 @@ namespace triqs { namespace arrays {
template<typename ArrayType,int Pos > class const_matrix_view_proxy;
// to do : separate the array and the matrix case.
// generalize with preprocessor (draft below)
// write concept mutable down and clean it (dim0, dim1, shape(i), ...)
#ifdef DO_NOT_DEFINE_ME
// human version of the class, the preprocessor generalisation is next..
template<typename ArrayType > class const_matrix_view_proxy<ArrayType,2> : TRIQS_MODEL_CONCEPT(ImmutableMatrix) {
@ -47,9 +45,6 @@ namespace triqs { namespace arrays {
typedef typename indexmap_type::domain_type domain_type;
indexmap_type indexmap() const { return slicer_t::invoke(A->indexmap() , range() , range(),n, ellipsis()); }
domain_type domain() const { return indexmap().domain();}
size_t shape(int i) const { return A->shape(i);}
size_t dim0() const { return A->shape(0);}
size_t dim1() const { return A->shape(1);}
typename ArrayType::storage_type const & storage() const { return A->storage();}
TRIQS_DELETE_COMPOUND_OPERATORS(const_matrix_view_proxy);
template< typename A0 , typename A1 , typename ... Args> value_type const & operator() ( A0 &&a0 , A1 &&a1 , Args && ... args) const
@ -66,9 +61,6 @@ namespace triqs { namespace arrays {
typedef typename indexmap_type::domain_type domain_type;
indexmap_type indexmap() const { return slicer_t::invoke(A->indexmap() , range() , range(),n, ellipsis()); }
domain_type domain() const { return indexmap().domain();}
size_t shape(int i) const { return A->shape(i);}
size_t dim0() const { return A->shape(0);}
size_t dim1() const { return A->shape(1);}
typename ArrayType::storage_type const & storage() const { return A->storage();}
template<typename RHS> matrix_view_proxy & operator=(const RHS & X) {triqs_arrays_assign_delegation(*this,X); return *this; }
TRIQS_DEFINE_COMPOUND_OPERATORS(matrix_view_proxy);
@ -92,9 +84,9 @@ namespace triqs { namespace arrays {
typedef typename indexmap_type::domain_type domain_type;\
indexmap_type indexmap() const { return slicer_t::invoke(A->indexmap() BOOST_PP_ENUM_TRAILING(POS, TEXT, range()),n, ellipsis()); }\
domain_type domain() const { return indexmap().domain();}\
size_t shape(int i) const { return A->shape(i);}\
size_t dim0() const { return A->shape((POS+1)%3);}\
size_t dim1() const { return A->shape((POS+2)%3);}\
friend size_t get_shape (const_matrix_view_proxy const & x) { return get_shape(*x.A);}\
friend size_t first_dim (const_matrix_view_proxy const & x) { return get_shape(*x.A)[(POS+1)%3];}\
friend size_t second_dim(const_matrix_view_proxy const & x) { return get_shape(*x.A)[(POS+2)%3];}\
typename ArrayType::storage_type const & storage() const { return A->storage();}\
value_type const * restrict data_start() const { return &storage()[indexmap().start_shift()];}\
value_type * restrict data_start() { return &storage()[indexmap().start_shift()];}\
@ -117,9 +109,9 @@ namespace triqs { namespace arrays {
typedef typename indexmap_type::domain_type domain_type;\
indexmap_type indexmap() const { return slicer_t::invoke(A->indexmap() BOOST_PP_ENUM_TRAILING(POS, TEXT, range()),n, ellipsis()); }\
domain_type domain() const { return indexmap().domain();}\
size_t shape(int i) const { return A->shape(i);}\
size_t dim0() const { return A->shape((POS+1)%3);}\
size_t dim1() const { return A->shape((POS+2)%3);}\
friend size_t get_shape (matrix_view_proxy const & x) { return get_shape(*x.A);}\
friend size_t first_dim (matrix_view_proxy const & x) { return get_shape(*x.A)[(POS+1)%3];}\
friend size_t second_dim(matrix_view_proxy const & x) { return get_shape(*x.A)[(POS+2)%3];}\
typename ArrayType::storage_type const & storage() const { return A->storage();}\
value_type const * restrict data_start() const { return &storage()[indexmap().start_shift()];}\
value_type * restrict data_start() { return &storage()[indexmap().start_shift()];}\

View File

@ -106,7 +106,7 @@ namespace triqs { namespace det_manip {
auto gr = fg.open_group(subgroup_name);
h5_read(gr,"N",g.N);
h5_read(gr,"mat_inv",g.mat_inv);
g.Nmax = g.mat_inv.dim0(); // restore Nmax
g.Nmax = first_dim(g.mat_inv); // restore Nmax
g.last_try = 0;
h5_read(gr,"det",g.det);
h5_read(gr,"sign",g.sign);

View File

@ -55,8 +55,8 @@ namespace triqs { namespace lattice_tools {
std::vector<long> V(std::forward<VectorIntType>(v));
if (v.size() != bl_.dim()) TRIQS_RUNTIME_ERROR<<"tight_binding : displacement of incorrect size : got "<< v.size() << "instead of "<< bl_.dim();
matrix<dcomplex> M(std::forward<MatrixDComplexType>(m));
if (M.shape(0) != n_bands()) TRIQS_RUNTIME_ERROR<<"tight_binding : the first dim matrix is of size "<< M.shape(0) <<" instead of "<< n_bands();
if (M.shape(1) != n_bands()) TRIQS_RUNTIME_ERROR<<"tight_binding : the first dim matrix is of size "<< M.shape(1) <<" instead of "<< n_bands();
if (first_dim(M) != n_bands()) TRIQS_RUNTIME_ERROR<<"tight_binding : the first dim matrix is of size "<< first_dim(M) <<" instead of "<< n_bands();
if (second_dim(M) != n_bands()) TRIQS_RUNTIME_ERROR<<"tight_binding : the first dim matrix is of size "<< second_dim(M) <<" instead of "<< n_bands();
displ_value_stack.push_back(std::make_pair(std::move(V), std::move(M)));
}