3
0
mirror of https://github.com/triqs/dft_tools synced 2024-12-25 13:53:40 +01:00

Fix matrix * alias issue and adapt det_manip

- The previous version of the * operator for matrix was too clever.
It was giving a lazy object and then rewriting C = A *B into gemm (a,A,B,0,C).
The pb was in case of aliasing : when e.g. C = A, or is a part of A.
gemm is not correct that case, and as a result generic code like
a = a *b
may not be correct in matrix case, which is unacceptable.

- So we revert to a simple * operator for matrix
that does immediate computation.
Same thing for matrix* vector

- we also suppress a_x_ty class.

-> for M = a * b,
when M is a matrix, there is no overhead due to move assignment
-> however, when M is a view, there is an additionnal copy.

-Correctness comes first, hence the fix.
However, if one wants more speed and one can guarantee that
there is no aliasing possible, then one has to write a direct gemm call.

-> det_manip class was adapted, since in that case, we can show there
no alias, and we want the speed gain, so the * ops where replaced
by direct blas call (using the array blas interface).

-> also gemm, gemv, ger were overloaded in the case the return
matrix/vector (i.e. last parameter of the function) is not an lvalue,
but a temporary view created on the fly.
This commit is contained in:
Olivier Parcollet 2013-09-10 21:41:17 +02:00
parent 3c2a3c51dc
commit b534936589
23 changed files with 111 additions and 455 deletions

View File

@ -18,11 +18,7 @@
* TRIQS. If not, see <http://www.gnu.org/licenses/>.
*
******************************************************************************/
#include <triqs/arrays/matrix.hpp>
#include <triqs/arrays/array.hpp>
#include <triqs/arrays/linalg/matmul.hpp>
#include <triqs/arrays/linalg/mat_vec_mul.hpp>
#include <triqs/arrays/linalg/det_and_inverse.hpp>
#include <triqs/arrays.hpp>
#include <triqs/arrays/linalg/det_and_inverse.hpp>
using namespace triqs::arrays;

View File

@ -18,11 +18,7 @@
* TRIQS. If not, see <http://www.gnu.org/licenses/>.
*
******************************************************************************/
#include <triqs/arrays/matrix.hpp>
#include <triqs/arrays/array.hpp>
#include <triqs/arrays/linalg/matmul.hpp>
#include <triqs/arrays/linalg/mat_vec_mul.hpp>
#include <triqs/arrays/linalg/det_and_inverse.hpp>
#include <triqs/arrays.hpp>
#include <triqs/arrays/linalg/det_and_inverse.hpp>
using namespace triqs::arrays;

View File

@ -18,11 +18,7 @@
* TRIQS. If not, see <http://www.gnu.org/licenses/>.
*
******************************************************************************/
#include <triqs/arrays/vector.hpp>
#include <triqs/arrays/matrix.hpp>
#include <triqs/arrays/linalg/matmul.hpp>
#include <triqs/arrays/linalg/mat_vec_mul.hpp>
#include <triqs/arrays/blas_lapack/gemv.hpp>
#include <triqs/arrays.hpp>
using namespace std;
using namespace triqs;

View File

@ -1,73 +0,0 @@
/*******************************************************************************
*
* TRIQS: a Toolbox for Research in Interacting Quantum Systems
*
* Copyright (C) 2011 by O. Parcollet
*
* TRIQS is free software: you can redistribute it and/or modify it under the
* terms of the GNU General Public License as published by the Free Software
* Foundation, either version 3 of the License, or (at your option) any later
* version.
*
* TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*
* You should have received a copy of the GNU General Public License along with
* TRIQS. If not, see <http://www.gnu.org/licenses/>.
*
******************************************************************************/
#include "./common.hpp"
#include <iostream>
#include "./src/array.hpp"
#include "./src/vector.hpp"
#include "./src/matrix.hpp"
#include "./src/linalg/matmul.hpp"
#include "./src/linalg/mat_vec_mul.hpp"
#include "./src/linalg/det_and_inverse.hpp"
#include "./src/linalg/det_and_inverse.hpp"
#include "./src/linalg/a_x_ty.hpp"
#include "./src/blas_lapack/dot.hpp"
using namespace triqs::arrays;
int main(int argc, char **argv) {
triqs::arrays::matrix<double> A(5,5, FORTRAN_LAYOUT);
//triqs::arrays::matrix<double > A(5,5);
typedef triqs::arrays::vector<double> vector_type;
vector_type MC(5), MB(5);
A()= 0;
for (int i =0; i<5; ++i)
{ MC(i) = i;MB(i)=10*(i+1); }
range R(1,3);
std::cout<<" MC(R) = "<<MC(R)<< std::endl<<std::endl;
std::cout<<" MB(R) = "<<MB(R)<< std::endl<<std::endl;
A(R,R) += a_x_ty(1.0,MC(R),MB(R));
std::cout<<" A(R,R) = "<<A(R,R)<< std::endl<<std::endl;
A(R,R) += a_x_ty(1.0,MB(R),MC(R));
std::cout<<" A(R,R) = "<<A(R,R)<< std::endl<<std::endl;
A(R,R) = a_x_ty(1.0,MB(R),MC(R));
std::cout<<" A(R,R) = "<<A(R,R)<< std::endl<<std::endl;
std::cout<<" full A"<< A<<std::endl<<std::endl;
std::cout<< " MB, MC, dot "<< MB << MC << dot(MB,MC)<<std::endl;
std::cout<< " MC, MC, dot "<< MB << MC << dot(MC,MC)<<std::endl;
}

View File

@ -23,9 +23,6 @@
#include "./src/array.hpp"
#include "./src/vector.hpp"
#include "./src/matrix.hpp"
#include "./src/linalg/matmul.hpp"
#include "./src/linalg/mat_vec_mul.hpp"
#include "./src/linalg/det_and_inverse.hpp"
#include "./src/linalg/det_and_inverse.hpp"
#include "./src/blas_lapack/gemm.hpp"

View File

@ -21,7 +21,9 @@
******************************************************************************/
#ifndef COMMON_TEST_ARRAY_H
#define COMMON_TEST_ARRAY_H
#include<triqs/utility/first_include.hpp>
#include <iostream>
#include <triqs/arrays.hpp>
#include <triqs/arrays/asserts.hpp>
#include<sstream>
#define TEST(X) std::cout << BOOST_PP_STRINGIZE((X)) << " ---> "<< (X) <<std::endl;

View File

@ -24,14 +24,10 @@
#include "./src/vector.hpp"
#include "./src/matrix.hpp"
#include "./src/linalg/det_and_inverse.hpp"
#include "./src/linalg/matmul.hpp"
#include <iostream>
using std::cout; using std::endl;
using namespace triqs::arrays;
//using linalg::inverse;
//using linalg::inverse_and_compute_det;
//using linalg::determinant;
template<typename Expr >
matrix_view <typename Expr::value_type>
@ -39,8 +35,6 @@ eval_as_matrix( Expr const & e) { return matrix<typename Expr::value_type>(e);}
int main(int argc, char **argv) {
try {
triqs::arrays::matrix<double> W(3,3,FORTRAN_LAYOUT),Wi(3,3,FORTRAN_LAYOUT),Wkeep(3,3,FORTRAN_LAYOUT),A(FORTRAN_LAYOUT);

View File

@ -1,13 +1,6 @@
#include "./common.hpp"
#include "./src/array.hpp"
#include "./src/matrix.hpp"
#include "./src/linalg/matmul.hpp"
#include <iostream>
#include <triqs/arrays/asserts.hpp>
using std::cout; using std::endl;
using namespace triqs::arrays;
using namespace triqs::arrays;
// to be extended to more complex case
// calling lapack on view to test cache securities....

View File

@ -11,9 +11,7 @@ A(R,R) =
MC(R) = [1,1]
MB = mat_vec_mul(
[[4,6]
[5,7]],[1,1]) = [0,10,12,0,0]
MB = [10,12] = [0,10,12,0,0]
testing infix

View File

@ -19,9 +19,7 @@
*
******************************************************************************/
#include "./common.hpp"
#include "./src/matrix.hpp"
#include "./src/asserts.hpp"
#include "./src/linalg/matmul.hpp"
#include <triqs/arrays.hpp>
#include <iostream>
using namespace triqs::arrays;
@ -36,10 +34,7 @@ template<typename T, typename O1, typename O2, typename O3> void test(O1 o1, O2
for (int j=0; j<4; ++j)
{ M2(i,j) = 1 + i -j ; }
// The central instruction : note that matmul returns a lazy object
// that has ImmutableArray interface, and defines a specialized version assignment
// As a result this is equivalent to matmul_with_lapack(M1,M2,M3) : there is NO intermediate copy.
M3 = matmul(M1,M2);
M3 = M1 * M2; //matmul(M1,M2);
M4 = M3;
M4() = 0;
@ -57,15 +52,11 @@ template<typename T, typename O1, typename O2, typename O3> void test(O1 o1, O2
std::cerr<<"M1 = "<<M1<<std::endl;
std::cerr<<"M2 = "<<M2<<std::endl;
std::cerr<<"M3 = "<<M3<<std::endl;
std::cerr<<"M4 = "<< matrix<T>(matmul(M1,M2)) <<std::endl;
std::cerr<<"M5 = "<< matrix<T>(matmul(M1,M2)) <<std::endl;
std::cerr<<"M4 = "<< M1*M2 <<std::endl;
std::cerr<<"M5 = "<< M1*M2 <<std::endl;
for (int i =0; i<2; ++i)
for (int j=0; j<2; ++j)
M3(i,j) = matmul(M1,M2)(i,j); //[mini_vector<int,2>(i,j)];
}
std::cerr<<"M3 = "<<M3<<std::endl<<"----------------"<<std::endl;
}

View File

@ -19,9 +19,7 @@
*
******************************************************************************/
#include "./common.hpp"
#include "./src/matrix.hpp"
#include "./src/asserts.hpp"
#include "./src/linalg/matmul.hpp"
#include <triqs/arrays.hpp>
#include <iostream>
using namespace triqs::arrays;

View File

@ -19,11 +19,7 @@
*
******************************************************************************/
#include "./common.hpp"
#include "./src/array.hpp"
#include "./src/vector.hpp"
#include "./src/matrix.hpp"
#include "./src/linalg/matmul.hpp"
#include <triqs/arrays.hpp>
#include <iostream>
using std::cout; using std::endl;

View File

@ -19,13 +19,8 @@
*
******************************************************************************/
#include "./common.hpp"
#include "./src/matrix.hpp"
#include "./src/asserts.hpp"
#include "./src/linalg/matmul.hpp"
#include "./src/linalg/eigenelements.hpp"
#include "./src/blas_lapack/stev.hpp"
#include <triqs/arrays/asserts.hpp>
#include <iostream>
using namespace triqs::arrays;
using namespace triqs;

View File

@ -129,6 +129,12 @@ namespace triqs { namespace arrays { namespace blas {
gemm_generic(alpha,A,B,beta,C);
}
// to allow gemm (alpha, a, b, beta, M(..., ...)) i.e. a temporary view, which is not matched by previos templates
// which require an lvalue. This is the only version which takes an && as last argument
// indeed, in the routine, c is a *lvalue*, since it has a name, and hence we call *other* overload of the function
template<typename A, typename MT1, typename MT2, typename B, typename V, ull_t Opt, ull_t To, bool W>
void gemm (A alpha, MT1 const & a, MT2 const & b, B beta, matrix_view<V,Opt,To,W> && c) { gemm(alpha,a,b,beta,c);}
}}}// namespace

View File

@ -96,6 +96,11 @@ namespace triqs { namespace arrays { namespace blas {
gemv_generic(alpha,A,X,beta,Y);
}
// to allow gem (alpha, a, b, beta, M(..., ...)) i.e. a temporary view, which is not matched by previos templates
// which require an lvalue.
template<typename A, typename MT, typename VT, typename B, typename V, ull_t Opt, bool W>
void gemv (A alpha, MT const & a, VT const & b, B beta, vector_view<V,Opt,W> && c) { gemv(alpha,a,b,beta,c);}
}}}// namespace

View File

@ -72,8 +72,11 @@ namespace triqs { namespace arrays { namespace blas {
}
// to allow ger (alpha, x,y, M(..., ...)) i.e. a temporary view, which is not matched by previos templates
// which require an lvalue
template< typename A, typename VTX, typename VTY, typename V, ull_t Opt, ull_t To, bool W>
void ger (A alpha, VTX const & x, VTY const & y, matrix_view<V,Opt,To,W> && r) { ger(alpha,x,y,r);}
}}}// namespace
#endif

View File

@ -22,11 +22,48 @@
#define TRIQS_ARRAYS_EXPRESSION_MATRIX_ALGEBRA_H
#include "./vector_algebra.hpp"
#include "../matrix.hpp"
#include "../linalg/matmul.hpp"
#include "../linalg/mat_vec_mul.hpp"
#include "../linalg/det_and_inverse.hpp"
#include "../blas_lapack/gemv.hpp"
#include "../blas_lapack/gemm.hpp"
namespace triqs { namespace arrays {
// matrix * matrix
template<typename A, typename B, typename Enable = void> struct _matmul_rvalue {};
template<typename A, typename B> struct _matmul_rvalue<A,B, ENABLE_IFC(ImmutableMatrix<A>::value && ImmutableMatrix<B>::value)> {
typedef typename std::remove_const<typename A::value_type>::type V1;
typedef typename std::remove_const<typename B::value_type>::type V2;
typedef matrix<typename std::decay<decltype( V1{} * V2{})>::type> type;
};
template<typename A, typename B>
typename _matmul_rvalue<A,B>::type
operator * (A const & a, B const & b) {
if (second_dim(a) != first_dim(b)) TRIQS_RUNTIME_ERROR<< "Matrix product : dimension mismatch in A*B "<< a<<" "<< b;
auto R = typename _matmul_rvalue<A,B>::type( first_dim(a), second_dim(b));
blas::gemm(1.0,a, b, 0.0, R);
return R;
}
// matrix * vector
template<typename M, typename V, typename Enable = void> struct _mat_vec_mul_rvalue {};
template<typename M, typename V> struct _mat_vec_mul_rvalue<M,V, ENABLE_IFC(ImmutableMatrix<M>::value && ImmutableVector<V>::value)> {
typedef typename std::remove_const<typename M::value_type>::type V1;
typedef typename std::remove_const<typename V::value_type>::type V2;
typedef vector<typename std::decay<decltype(V1{} * V2{})>::type> type;
};
template<typename M, typename V>
typename _mat_vec_mul_rvalue<M,V>::type
operator * (M const & m, V const & v) {
if (second_dim(m) != v.size()) TRIQS_RUNTIME_ERROR<< "Matrix product : dimension mismatch in Matrix*Vector "<< m<<" "<< v;
auto R = typename _mat_vec_mul_rvalue<M,V>::type(first_dim(m));
blas::gemv(1.0,m,v,0.0,R);
return R;
}
// expression template
template<typename Tag, typename L, typename R, bool scalar_are_diagonal_matrices= false>
struct matrix_expr : TRIQS_CONCEPT_TAG_NAME(ImmutableMatrix) {
typedef typename keeper_type<L,scalar_are_diagonal_matrices>::type L_t;
@ -91,16 +128,6 @@ namespace triqs { namespace arrays {
template<typename A1> typename std::enable_if<ImmutableMatrix<A1>::value, matrix_unary_m_expr<A1>>::type
operator - (A1 const & a1) { return {a1};}
template<typename Expr > matrix <typename Expr::value_type>
make_matrix( Expr const & e) { return matrix<typename Expr::value_type>(e);}
template<typename M1, typename M2> // matrix * matrix
typename boost::enable_if< mpl::and_<ImmutableMatrix<M1>, ImmutableMatrix<M2> >, matmul_lazy<M1,M2> >::type
operator* (M1 const & a, M2 const & b) { return matmul_lazy<M1,M2>(a,b); }
template<typename M, typename V> // matrix * vector
typename boost::enable_if< mpl::and_<ImmutableMatrix<M>, ImmutableVector<V> >, mat_vec_mul_lazy<M,V> >::type
operator* (M const & m, V const & v) { return mat_vec_mul_lazy<M,V>(m,v); }
template<typename A, typename M> // anything / matrix ---> anything * inverse(matrix)
typename boost::lazy_enable_if< ImmutableMatrix<M>, type_of_mult<A, inverse_lazy <M> > >::type

View File

@ -1,86 +0,0 @@
/*******************************************************************************
*
* TRIQS: a Toolbox for Research in Interacting Quantum Systems
*
* Copyright (C) 2011 by O. Parcollet
*
* TRIQS is free software: you can redistribute it and/or modify it under the
* terms of the GNU General Public License as published by the Free Software
* Foundation, either version 3 of the License, or (at your option) any later
* version.
*
* TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*
* You should have received a copy of the GNU General Public License along with
* TRIQS. If not, see <http://www.gnu.org/licenses/>.
*
******************************************************************************/
#ifndef TRIQS_ARRAYS_EXPRESSION_A_X_TY_H
#define TRIQS_ARRAYS_EXPRESSION_A_X_TY_H
#include <boost/type_traits/is_same.hpp>
#include <boost/typeof/typeof.hpp>
#include "../matrix.hpp"
#include "../vector.hpp"
#include "../blas_lapack/ger.hpp"
namespace triqs { namespace arrays {
///
template<typename ScalarType, typename VectorType1, typename VectorType2> class a_x_ty_lazy;
///
template<typename ScalarType, typename VectorType1, typename VectorType2>
a_x_ty_lazy<ScalarType,VectorType1,VectorType2> a_x_ty (ScalarType a, VectorType1 const & x, VectorType2 const & y)
{ return a_x_ty_lazy<ScalarType,VectorType1,VectorType2>(a,x,y); }
//------------- IMPLEMENTATION -----------------------------------
template<typename ScalarType, typename VectorType1, typename VectorType2>
class a_x_ty_lazy : TRIQS_CONCEPT_TAG_NAME(ImmutableMatrix) {
typedef typename boost::remove_const<typename VectorType1::value_type>::type V1;
typedef typename boost::remove_const<typename VectorType2::value_type>::type V2;
static_assert((boost::is_same<V1,V2>::value),"Different values : not implemented");
public:
typedef BOOST_TYPEOF_TPL( V1() * V2() * ScalarType()) value_type;
typedef indexmaps::cuboid::domain_t<2> domain_type;
typedef typename const_view_type_if_exists_else_type<VectorType1>::type X_type;
typedef typename const_view_type_if_exists_else_type<VectorType2>::type Y_type;
const ScalarType a; const X_type x; const Y_type y;
public:
a_x_ty_lazy( ScalarType a_, VectorType1 const & x_, VectorType2 const & y_):a(a_),x(x_),y(y_){}
domain_type domain() const { return domain_type(mini_vector<size_t,2>(x.size(), y.size()));}
template<typename K0, typename K1> value_type operator() (K0 const & k0, K1 const & k1) const { return a * x(k0) * y(k1); }
// Optimized implementation of =
template<typename LHS>
friend void triqs_arrays_assign_delegation (LHS & lhs, a_x_ty_lazy const & rhs) {
resize_or_check_if_view(lhs,make_shape( first_dim(rhs),second_dim(rhs) ));
lhs()=0;
blas::ger(rhs.a,rhs.x, rhs.y, lhs);
}
//Optimized implementation of +=
template<typename LHS>
friend void triqs_arrays_compound_assign_delegation (LHS & lhs, a_x_ty_lazy const & rhs, mpl::char_<'A'>) {
static_assert((is_matrix_or_view<LHS>::value), "LHS is not a matrix or a matrix_view"); // check that the target is indeed a matrix.
blas::ger(rhs.a, rhs.x, rhs.y, lhs);
}
//Optimized implementation of -=
template<typename LHS>
friend void triqs_arrays_compound_assign_delegation (LHS & lhs, a_x_ty_lazy const & rhs, mpl::char_<'S'>) {
static_assert((is_matrix_or_view<LHS>::value), "LHS is not a matrix or a matrix_view"); // check that the target is indeed a matrix.
blas::ger(- rhs.a, rhs.x, rhs.y, lhs);
}
friend std::ostream & operator<<(std::ostream & out, a_x_ty_lazy const & a){ return out<<"a_x_ty("<<a.a<<","<<a.x<<","<<a.y<<")";}
};
}} // namespace triqs_arrays
#endif

View File

@ -1,4 +1,3 @@
/*******************************************************************************
*
* TRIQS: a Toolbox for Research in Interacting Quantum Systems
@ -19,7 +18,6 @@
* TRIQS. If not, see <http://www.gnu.org/licenses/>.
*
******************************************************************************/
#ifndef TRIQS_LINALG_CROSS_PRODUCT_H
#define TRIQS_LINALG_CROSS_PRODUCT_H
#include <triqs/utility/exceptions.hpp>
@ -28,14 +26,14 @@ namespace triqs { namespace arrays {
/** Cross product. Dim 3 only */
template<typename VectorType>
typename VectorType::view_type cross_product (VectorType const & A, VectorType const & B) {
vector<typename std::remove_const<typename VectorType::value_type >::type>
cross_product (VectorType const & A, VectorType const & B) {
if (A.shape()[0] !=3) TRIQS_RUNTIME_ERROR<<"arrays::linalg::cross_product : works only in d=3 while you gave a vector of size "<<A.shape()[0];
if (B.shape()[0] !=3) TRIQS_RUNTIME_ERROR<<"arrays::linalg::cross_product : works only in d=3 while you gave a vector of size "<<B.shape()[0];
vector<typename boost::remove_const<typename VectorType::value_type >::type > r(3);
vector<typename std::remove_const<typename VectorType::value_type >::type > r(3);
r(0) = A(1)* B(2) - B(1) * A(2);
r(1) = - A(0)* B(2) + B(0) * A(2);
r(2) = A(0)*B(1) - B(0) * A(1);
std::cout << "in cross product "<< A << B << r << std::endl;
return r;
}

View File

@ -1,97 +0,0 @@
/*******************************************************************************
*
* TRIQS: a Toolbox for Research in Interacting Quantum Systems
*
* Copyright (C) 2011 by O. Parcollet
*
* TRIQS is free software: you can redistribute it and/or modify it under the
* terms of the GNU General Public License as published by the Free Software
* Foundation, either version 3 of the License, or (at your option) any later
* version.
*
* TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*
* You should have received a copy of the GNU General Public License along with
* TRIQS. If not, see <http://www.gnu.org/licenses/>.
*
******************************************************************************/
#ifndef TRIQS_ARRAYS_EXPRESSION_MAT_VEC_MUL_H
#define TRIQS_ARRAYS_EXPRESSION_MAT_VEC_MUL_H
#include <boost/type_traits/is_same.hpp>
#include <boost/typeof/typeof.hpp>
#include "../matrix.hpp"
#include "../vector.hpp"
#include "../blas_lapack/gemv.hpp"
namespace triqs { namespace arrays {
///
template<typename MT, typename VT> class mat_vec_mul_lazy;
///
template<typename MT, typename VT> mat_vec_mul_lazy<MT,VT> mat_vec_mul (MT const & a, VT const & b) { return mat_vec_mul_lazy<MT,VT>(a,b); }
// ----------------- implementation -----------------------------------------
template<typename MT, typename VT>
class mat_vec_mul_lazy : TRIQS_CONCEPT_TAG_NAME(ImmutableVector) {
typedef typename MT::value_type V1;
typedef typename VT::value_type V2;
//static_assert((boost::is_same<V1,V2>::value),"Different values : not implemented");
public:
typedef BOOST_TYPEOF_TPL( V1() * V2()) value_type;
typedef typename VT::domain_type domain_type;
typedef typename const_view_type_if_exists_else_type<MT>::type M_type;
typedef typename const_view_type_if_exists_else_type<VT>::type V_type;
const M_type M; const V_type V;
private:
typedef vector<value_type> vector_type;
struct internal_data {
vector_type R;
internal_data(mat_vec_mul_lazy const & P): R( first_dim(P.M) ) { blas::gemv(1,P.M,P.V,0,R); }
};
friend struct internal_data;
mutable std::shared_ptr<internal_data> _id;
void activate() const { if (!_id) _id= std::make_shared<internal_data>(*this);}
public:
mat_vec_mul_lazy( MT const & M_, VT const & V_):M(M_),V(V_){
if (second_dim(M) != V.size()) TRIQS_RUNTIME_ERROR<< "Matrix product : dimension mismatch in Matrix*Vector "<< M<<" "<< V;
}
domain_type domain() const { return mini_vector<size_t,1>(size());}
//domain_type domain() const { return indexmaps::cuboid::domain_t<1>(mini_vector<size_t,1>(size()));}
size_t size() const { return first_dim(M);}
template<typename KeyType> value_type operator() (KeyType const & key) const { activate(); return _id->R (key); }
template<typename LHS> // Optimized implementation of =
friend void triqs_arrays_assign_delegation (LHS & lhs, mat_vec_mul_lazy const & rhs) {
static_assert((is_vector_or_view<LHS>::value), "LHS is not a vector or a vector_view");
resize_or_check_if_view(lhs,make_shape(rhs.size()));
blas::gemv(1,rhs.M,rhs.V,0,lhs);
}
template<typename LHS>
friend void triqs_arrays_compound_assign_delegation (LHS & lhs, mat_vec_mul_lazy const & rhs, mpl::char_<'A'>) { rhs.assign_comp_impl(lhs,1.0);}
template<typename LHS>
friend void triqs_arrays_compound_assign_delegation (LHS & lhs, mat_vec_mul_lazy const & rhs, mpl::char_<'S'>) { rhs.assign_comp_impl(lhs,-1.0);}
private:
template<typename LHS> void assign_comp_impl (LHS & lhs, double S) const {
static_assert((is_vector_or_view<LHS>::value), "LHS is not a vector or a vector_view");
if (lhs.size() != size()) TRIQS_RUNTIME_ERROR<< "mat_vec_mul : -=/-= operator : size mismatch in M*V "<< lhs.size()<<" vs "<< size();
blas::gemv(1,M,V,S,lhs);
}
friend std::ostream & operator<<(std::ostream & out, mat_vec_mul_lazy const & x){ return out<<"mat_vec_mul("<<x.M<<","<<x.V<<")";}
};
}}//namespace triqs::arrays
#endif

View File

@ -1,96 +0,0 @@
/*******************************************************************************
*
* TRIQS: a Toolbox for Research in Interacting Quantum Systems
*
* Copyright (C) 2011-2012 by O. Parcollet
*
* TRIQS is free software: you can redistribute it and/or modify it under the
* terms of the GNU General Public License as published by the Free Software
* Foundation, either version 3 of the License, or (at your option) any later
* version.
*
* TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*
* You should have received a copy of the GNU General Public License along with
* TRIQS. If not, see <http://www.gnu.org/licenses/>.
*
******************************************************************************/
#ifndef TRIQS_ARRAYS_EXPRESSION_MATMUL_H
#define TRIQS_ARRAYS_EXPRESSION_MATMUL_H
#include <boost/type_traits/is_same.hpp>
#include <boost/typeof/typeof.hpp>
#include "../blas_lapack/gemm.hpp"
namespace triqs { namespace arrays {
///
template<typename A, typename B> class matmul_lazy;
///
template<typename A, typename B> matmul_lazy<A,B> matmul (A const & a, B const & b) { return matmul_lazy<A,B>(a,b); }
// ----------------- implementation -----------------------------------------
template<typename A, typename B> class matmul_lazy : TRIQS_CONCEPT_TAG_NAME(ImmutableMatrix) {
typedef typename boost::remove_const<typename A::value_type>::type V1;
typedef typename boost::remove_const<typename B::value_type>::type V2;
//static_assert((boost::is_same<V1,V2>::value),"Different values : not implemented");
public:
typedef BOOST_TYPEOF_TPL( V1() * V2()) value_type; // what is the result of multiplying a V1 by a V2 ?
typedef typename A::domain_type domain_type;
typedef typename const_view_type_if_exists_else_type<A>::type A_type;
typedef typename const_view_type_if_exists_else_type<B>::type B_type;
const A_type a; const B_type b;
private:
typedef matrix<value_type> matrix_type;
struct internal_data { // implementing the pattern LazyPreCompute
matrix_type R;
internal_data(matmul_lazy const & P): R( first_dim(P.a), second_dim(P.b)) { blas::gemm(1.0,P.a, P.b, 0.0, R); }
};
friend struct internal_data;
mutable std::shared_ptr<internal_data> _id;
void activate() const { if (!_id) _id= std::make_shared<internal_data>(*this);}
public:
matmul_lazy( A const & a_, B const & b_):a(a_),b(b_){
if (second_dim(a) != first_dim(b)) TRIQS_RUNTIME_ERROR<< "Matrix product : dimension mismatch in A*B "<< a<<" "<< b;
}
domain_type domain() const { return mini_vector<size_t,2>(first_dim(a), second_dim(b));}
template<typename K0, typename K1> value_type operator() (K0 const & k0, K1 const & k1) const { activate(); return _id->R(k0,k1); }
// TO BE REMOVED because of the aliasing question
// Optimized implementation of =, +=, -=
template<typename LHS>
friend void triqs_arrays_assign_delegation (LHS & lhs, matmul_lazy const & rhs) {
static_assert((is_matrix_or_view<LHS>::value), "LHS is not a matrix");
resize_or_check_if_view(lhs,make_shape(first_dim(rhs),second_dim(rhs)));
blas::gemm(1.0,rhs.a, rhs.b, 0.0, lhs);
}
template<typename LHS>
friend void triqs_arrays_compound_assign_delegation (LHS & lhs, matmul_lazy const & rhs, mpl::char_<'A'>) { rhs.assign_comp_impl(lhs,1.0);}
template<typename LHS>
friend void triqs_arrays_compound_assign_delegation (LHS & lhs, matmul_lazy const & rhs, mpl::char_<'S'>) { rhs.assign_comp_impl(lhs,-1.0);}
private:
template<typename LHS> void assign_comp_impl (LHS & lhs, double S) const {
static_assert((is_matrix_or_view<LHS>::value), "LHS is not a matrix");
if (first_dim(lhs) != first_dim(*this))
TRIQS_RUNTIME_ERROR<< "Matmul : +=/-= operator : first dimension mismatch in A*B "<< first_dim(lhs)<<" vs "<< first_dim(*this);
if (second_dim(lhs) != second_dim(*this))
TRIQS_RUNTIME_ERROR<< "Matmul : +=/-= operator : first dimension mismatch in A*B "<< second_dim(lhs)<<" vs "<< second_dim(*this);
blas::gemm(S,a, b, 1.0, lhs);
}
friend std::ostream & operator<<(std::ostream & out, matmul_lazy<A,B> const & x){return out<<x.a<<" * "<<x.b;}
};// class matmul_lazy
}}//namespace triqs::arrays
#endif

View File

@ -194,17 +194,19 @@ namespace triqs { namespace arrays {
#undef IMPL_TYPE
/*
template<typename ArrayType>
matrix_view<typename ArrayType::value_type, ArrayType::opt_flags, ArrayType::traversal_order>
make_matrix_view(ArrayType const & a) {
static_assert(ArrayType::rank ==2, "make_matrix_view only works for array of rank 2");
return a;
}
*/
template<typename ArrayType>
matrix<typename ArrayType::value_type, ArrayType::opt_flags, ArrayType::traversal_order>
matrix<typename ArrayType::value_type> //, ArrayType::opt_flags, ArrayType::traversal_order>
make_matrix(ArrayType const & a) {
static_assert(ArrayType::rank ==2, "make_matrix_view only works for array of rank 2");
static_assert(ArrayType::domain_type::rank ==2, "make_matrix only works for array of rank 2");
return a;
}

View File

@ -24,17 +24,18 @@
#include <vector>
#include <iterator>
#include <triqs/arrays.hpp>
//#include <triqs/arrays/mapped_functions.hpp>
#include <triqs/arrays/algorithms.hpp>
#include <triqs/arrays/linalg/det_and_inverse.hpp>
#include <triqs/arrays/linalg/a_x_ty.hpp>
#include <triqs/arrays/linalg/matmul.hpp>
#include <triqs/arrays/linalg/mat_vec_mul.hpp>
#include <triqs/arrays/blas_lapack/dot.hpp>
#include <triqs/arrays/blas_lapack/ger.hpp>
#include <triqs/arrays/blas_lapack/gemm.hpp>
#include <triqs/arrays/blas_lapack/gemv.hpp>
#include <triqs/utility/function_arg_ret_type.hpp>
namespace triqs { namespace det_manip {
namespace blas = arrays::blas;
/**
* \brief Standard matrix/det manipulations used in several QMC.
*/
@ -320,7 +321,8 @@ namespace triqs { namespace det_manip {
w1.C(k) = f(x, y_values[k]);
}
range R(0,N);
w1.MB(R) = mat_inv(R,R) * w1.B(R);// CHANGE
//w1.MB(R) = mat_inv(R,R) * w1.B(R);// OPTIMIZE BELOW
blas::gemv(1.0, mat_inv(R,R), w1.B(R),0.0,w1.MB(R));
w1.ksi = f(x,y) - arrays::dot( w1.C(R) , w1.MB(R) );
newdet = det*w1.ksi;
newsign = ((i + j)%2==0 ? sign : -sign); // since N-i0 + N-j0 = i0+j0 [2]
@ -347,7 +349,8 @@ namespace triqs { namespace det_manip {
w1.C(k) = fy(y_values[k]);
}
range R(0,N);
w1.MB(R) = mat_inv(R,R) * w1.B(R);// CHANGE
//w1.MB(R) = mat_inv(R,R) * w1.B(R);// OPTIMIZE BELOW
blas::gemv(1.0, mat_inv(R,R), w1.B(R),0.0,w1.MB(R));
w1.ksi = ksi - arrays::dot( w1.C(R) , w1.MB(R) );
newdet = det*w1.ksi;
newsign = ((i + j)%2==0 ? sign : -sign); // since N-i0 + N-j0 = i0+j0 [2]
@ -366,7 +369,8 @@ namespace triqs { namespace det_manip {
if (N==0) { N=1; mat_inv(0,0) = 1/newdet; return; }
range R1(0,N);
w1.MC(R1) = mat_inv(R1,R1).transpose() * w1.C(R1); //CHANGE
//w1.MC(R1) = mat_inv(R1,R1).transpose() * w1.C(R1); //OPTIMIZE BELOW
blas::gemv(1.0, mat_inv(R1,R1).transpose(), w1.C(R1),0.0,w1.MC(R1));
w1.MC(N) = -1;
w1.MB(N) = -1;
@ -389,7 +393,8 @@ namespace triqs { namespace det_manip {
range R(0,N);
mat_inv(R,N-1) = 0;
mat_inv(N-1,R) = 0;
mat_inv(R,R) += triqs::arrays::a_x_ty(w1.ksi, w1.MB(R) ,w1.MC(R)) ;//mat_inv(R,R) += w1.ksi* w1.MB(R) * w1.MC(R)// CHANGE
//mat_inv(R,R) += w1.ksi* w1.MB(R) * w1.MC(R)// OPTIMIZE BELOW
blas::ger(w1.ksi, w1.MB(R) ,w1.MC(R),mat_inv(R,R));
}
public :
@ -440,8 +445,10 @@ namespace triqs { namespace det_manip {
w2.C(1,k) = f(x1, y_values[k]);
}
range R(0,N), R2(0,2);
w2.MB(R,R2) = mat_inv(R,R) * w2.B(R,R2); // CHANGE
w2.ksi -= w2.C (R2, R) * w2.MB(R, R2); // CHANGE
//w2.MB(R,R2) = mat_inv(R,R) * w2.B(R,R2); // OPTIMIZE BELOW
blas::gemm(1.0, mat_inv(R,R) , w2.B(R,R2),0.0,w2.MB(R,R2));
//w2.ksi -= w2.C (R2, R) * w2.MB(R, R2); // OPTIMIZE BELOW
blas::gemm(-1.0, w2.C(R2,R), w2.MB(R, R2),1.0,w2.ksi);
newdet = det * w2.det_ksi();
newsign = ((i0 + j0 + i1 + j1)%2==0 ? sign : -sign); // since N-i0 + N-j0 + N + 1 -i1 + N+1 -j1 = i0+j0 [2]
return (newdet/det)*(newsign*sign); // sign is unity, hence 1/sign == sign
@ -460,7 +467,8 @@ namespace triqs { namespace det_manip {
if (N==0) {N=2; mat_inv(R2,R2)=inverse(w2.ksi); row_num[w2.i[1]]=1; col_num[w2.j[1]]=1; return;}
range Ri(0,N);
w2.MC(R2,Ri) = w2.C(R2,Ri) * mat_inv(Ri,Ri);// CHANGE
//w2.MC(R2,Ri) = w2.C(R2,Ri) * mat_inv(Ri,Ri);// OPTIMIZE BELOW
blas::gemm(1.0, w2.C(R2,Ri), mat_inv(Ri,Ri),0.0,w2.MC(R2,Ri));
w2.MC(R2, range(N, N+2) ) = -1; // identity matrix
w2.MB(range(N,N+2), R2 ) = -1; // identity matrix !
@ -479,7 +487,8 @@ namespace triqs { namespace det_manip {
range R(0,N);
mat_inv(R,range(N-2,N)) = 0;
mat_inv(range(N-2,N),R) = 0;
mat_inv(R,R) += w2.MB(R,R2) * (w2.ksi * w2.MC(R2,R)); // CHANGE
//mat_inv(R,R) += w2.MB(R,R2) * (w2.ksi * w2.MC(R2,R)); // OPTIMIZE BELOW
blas::gemm(1.0, w2.MB(R,R2), (w2.ksi * w2.MC(R2,R)),1.0,mat_inv(R,R) );
}
public:
@ -531,7 +540,8 @@ namespace triqs { namespace det_manip {
w1.ksi = - 1/mat_inv(N,N);
range R(0,N);
mat_inv(R,R) += arrays::a_x_ty(w1.ksi,mat_inv(R,N),mat_inv(N,R));
//mat_inv(R,R) += w1.ksi, * mat_inv(R,N) * mat_inv(N,R);
blas::ger(w1.ksi,mat_inv(R,N),mat_inv(N,R), mat_inv(R,R));
// modify the permutations
for (size_t k =w1.i; k<N; k++) {row_num[k]= row_num[k+1];}
@ -620,7 +630,8 @@ namespace triqs { namespace det_manip {
w2.ksi = inverse( mat_inv(Rl,Rl));
// write explicitely the second product on ksi for speed ?
mat_inv(Rn,Rn) -= mat_inv(Rn,Rl) * (w2.ksi * mat_inv(Rl,Rn)); // CHANGE
//mat_inv(Rn,Rn) -= mat_inv(Rn,Rl) * (w2.ksi * mat_inv(Rl,Rn)); // OPTIMIZE BELOW
blas::gemm(-1.0, mat_inv(Rn,Rl), w2.ksi * mat_inv(Rl,Rn),1.0, mat_inv(Rn,Rn) );
// modify the permutations
for (size_t k =w2.i[0]; k<w2.i[1]-1; k++) row_num[k] = row_num[k+1];
@ -654,7 +665,8 @@ namespace triqs { namespace det_manip {
// Compute the col B.
for (size_t i= 0; i<N;i++) w1.MC(i) = f(x_values[i] , w1.y) - f(x_values[i], y_values[w1.jreal]);
range R(0,N);
w1.MB(R) = mat_inv(R,R) * w1.MC(R);// CHANGE
//w1.MB(R) = mat_inv(R,R) * w1.MC(R);// OPTIMIZE BELOW
blas::gemv(1.0, mat_inv(R,R), w1.MC(R) ,0.0, w1.MB(R) );
// compute the newdet
w1.ksi = (1+w1.MB(w1.jreal));
@ -674,8 +686,9 @@ namespace triqs { namespace det_manip {
// Cf notes : simply multiply by -w1.ksi
w1.ksi = - 1/(1+ w1.MB(w1.jreal));
w1.MB(w1.jreal) = 0;
mat_inv(R,R) += triqs::arrays::a_x_ty(w1.ksi,w1.MB(R), mat_inv(w1.jreal,R)); // CHANGE
mat_inv(w1.jreal,R)*= -w1.ksi; // CHANGE
//mat_inv(R,R) += w1.ksi * w1.MB(R) * mat_inv(w1.jreal,R)); // OPTIMIZE BELOW
blas::ger(w1.ksi,w1.MB(R), mat_inv(w1.jreal,R), mat_inv(R,R));
mat_inv(w1.jreal,R)*= -w1.ksi;
}
//------------------------------------------------------------------------------------------
@ -696,7 +709,8 @@ namespace triqs { namespace det_manip {
// Compute the col B.
for (size_t i= 0; i<N;i++) w1.MB(i) = f(w1.x, y_values[i] ) - f(x_values[w1.ireal], y_values[i] );
range R(0,N);
w1.MC(R) = mat_inv(R,R).transpose() * w1.MB(R); // CHANGE
//w1.MC(R) = mat_inv(R,R).transpose() * w1.MB(R); // OPTIMIZE BELOW
blas::gemv(1.0, mat_inv(R,R).transpose(), w1.MB(R),0.0, w1.MC(R));
// compute the newdet
w1.ksi = (1+w1.MC(w1.ireal));
@ -715,8 +729,9 @@ namespace triqs { namespace det_manip {
// impl. Cf case 3
w1.ksi = - 1/(1+ w1.MC(w1.ireal));
w1.MC(w1.ireal) = 0;
mat_inv(R,R) += triqs::arrays::a_x_ty(w1.ksi,mat_inv(R,w1.ireal),w1.MC(R));
mat_inv(R,w1.ireal) *= -w1.ksi; // CHANGE
//mat_inv(R,R) += w1.ksi * mat_inv(R,w1.ireal) * w1.MC(R);
blas::ger(w1.ksi,mat_inv(R,w1.ireal),w1.MC(R), mat_inv(R,R));
mat_inv(R,w1.ireal) *= -w1.ksi;
}
//------------------------------------------------------------------------------------------
private: