3
0
mirror of https://github.com/triqs/dft_tools synced 2024-10-31 11:13:46 +01:00

mc_tools: simplify measure_aux

- for pieces that need to be precomputed for several
measures.
- put them under shared_ptr, and register then with add_measure_aux.
- they must be callable, as void ().
- TODO : add this in the doc when tested
This commit is contained in:
Olivier Parcollet 2014-01-28 21:45:32 +01:00
parent 35a15f0f35
commit ff3de6c5e7
3 changed files with 28 additions and 126 deletions

View File

@ -30,7 +30,7 @@ SpaceAfterControlStatementKeyword: true
SpaceBeforeAssignmentOperators: true
SpaceInEmptyParentheses: false
SpacesInParentheses: false
SpacesInAngles:false
#SpacesInAngles:false
Standard: Cpp11
TabWidth: 1
UseTab: Never

View File

@ -18,8 +18,7 @@
* TRIQS. If not, see <http://www.gnu.org/licenses/>.
*
******************************************************************************/
#ifndef TRIQS_TOOLS_MC_GENERIC_H
#define TRIQS_TOOLS_MC_GENERIC_H
#pragma once
#include <triqs/utility/first_include.hpp>
#include <math.h>
#include <triqs/utility/timer.hpp>
@ -96,13 +95,9 @@ namespace triqs { namespace mc_tools {
/**
* Register the precomputation
*/
/*template<typename MeasureAuxType>
MeasureAuxType * add_measure_aux(MeasureAuxType && M, std::string name) {
static_assert( !std::is_pointer<MeasureAuxType>::value, "add_measure_aux in mc_generic takes ONLY values !");
AllMeasuresAux.insert(std::forward<MeasureAuxType>(M), name);
}
*/
/// get the average sign (to be called after collect_results)
template <typename MeasureAuxType> void add_measure_aux(std::shared_ptr<MeasureAuxType> p) { AllMeasuresAux.emplace_back(p); }
/// get the average sign (to be called after collect_results)
MCSignType average_sign() const { return sign_av; }
/// get the current percents done
@ -126,7 +121,7 @@ namespace triqs { namespace mc_tools {
if (thermalized()) {
nmeasures++;
sum_sign += sign;
AllMeasuresAux.compute_all();
for (auto &x : AllMeasuresAux) x();
AllMeasures.accumulate(sign);
}
// recompute fraction done
@ -168,9 +163,6 @@ namespace triqs { namespace mc_tools {
template<typename MeasureType> MeasureType & get_measure(std::string const & name) { return AllMeasures.template get_measure<MeasureType> (name); }
template<typename MeasureType> MeasureType const & get_measure(std::string const & name) const { return AllMeasures.template get_measure<MeasureType> (name); }
template<typename MeasureAuxType> MeasureAuxType * get_measure_aux(std::string const & name) { return AllMeasuresAux.template get_measure_aux<MeasureAuxType> (name); }
template<typename MeasureAuxType> MeasureAuxType const * get_measure_aux(std::string const & name) const { return AllMeasuresAux.template get_measure_aux<MeasureAuxType> (name); }
template<typename MoveType> MoveType & get_move (std::string const & name) { return AllMoves.template get_move<MoveType> (name); }
template<typename MoveType> MoveType const & get_move (std::string const & name) const { return AllMoves.template get_move<MoveType> (name); }
@ -206,7 +198,7 @@ namespace triqs { namespace mc_tools {
random_generator RandomGenerator;
move_set<MCSignType> AllMoves;
measure_set<MCSignType> AllMeasures;
measure_aux_set AllMeasuresAux;
std::vector<measure_aux> AllMeasuresAux;
utility::report_stream report;
uint64_t Length_MC_Cycle;/// Length of one Monte-Carlo cycle between 2 measures
uint64_t NWarmIterations, NCycles;
@ -235,5 +227,4 @@ namespace triqs { namespace mc_tools {
template<typename M,typename T1, typename T2> M const & get_move(mc_generic<T1,T2> const & s, std::string const & name) { return s.template get_move<M> (name); }
}}// end namespace
#endif

View File

@ -2,7 +2,7 @@
*
* TRIQS: a Toolbox for Research in Interacting Quantum Systems
*
* Copyright (C) 2011-2013 by M. Ferrero, O. Parcollet
* Copyright (C) 2014 by O. Parcollet
*
* TRIQS is free software: you can redistribute it and/or modify it under the
* terms of the GNU General Public License as published by the Free Software
@ -18,124 +18,35 @@
* TRIQS. If not, see <http://www.gnu.org/licenses/>.
*
******************************************************************************/
#ifndef TRIQS_TOOLS_MC_PRECOMPUTATION_H
#define TRIQS_TOOLS_MC_PRECOMPUTATION_H
#pragma once
#include <functional>
#include <boost/mpi.hpp>
#include <memory>
#include <map>
#include <triqs/utility/exceptions.hpp>
#include "./impl_tools.hpp"
namespace triqs { namespace mc_tools {
// mini concept checking
template<typename T, typename Enable=void> struct is_callable: std::false_type {};
template<typename T> struct is_callable <T, decltype(std::declval<T>()())> : std::true_type {};
namespace triqs {
namespace mc_tools {
// mini concept checking : does T have void operator()() ?
template <typename T, typename Enable = void> struct is_callable : std::false_type {};
template <typename T> struct is_callable<T, decltype(std::declval<T>()())> : std::true_type {};
//--------------------------------------------------------------------
class measure_aux {
class measure_aux {
std::shared_ptr<void> impl_;
std::function<measure_aux()> clone_;
size_t hash_;
std::string type_name_;
std::shared_ptr<void> impl_;
std::function<void()> call_;
std::function<void ( ) > call_;
public:
template <typename MeasureAuxType> measure_aux(std::shared_ptr<MeasureAuxType> p) {
static_assert(is_callable<MeasureAuxType>::value, "This measure_aux is not callable");
impl_ = p;
call_ = [p]() mutable { (*p)(); }; // without the mutable, the operator () of the lambda is const, hence p ...
}
public :
void operator()() { call_(); }
};
template<typename MeasureAuxType>
measure_aux (MeasureAuxType && p_in, bool) {
static_assert( is_callable<MeasureAuxType>::value, "This measure_aux is not callable");
MeasureAuxType *p = new typename std::remove_reference<MeasureAuxType>::type(std::forward<MeasureAuxType>(p_in));
impl_= std::shared_ptr<void> (p);
clone_ = [p]() { return MeasureAuxType(*p);} ;
hash_ = typeid(MeasureAuxType).hash_code();
type_name_ = typeid(MeasureAuxType).name();
call_ = [p]() { (*p)();};
}
// Value semantics. Everyone at the end call move = ...
measure_aux(measure_aux const &rhs) = default; //{*this = rhs;}
//measure_aux(measure_aux &rhs) {*this = rhs;} // or it will use the template = bug
measure_aux(measure_aux && rhs) = default ; //{ *this = std::move(rhs);}
measure_aux & operator = (measure_aux const & rhs) { *this = rhs.clone_(); return *this;}
measure_aux & operator = (measure_aux && rhs) =default;
void operator()(){ call_();}
template<typename MeasureAuxType> bool has_type() const { return (typeid(MeasureAuxType).hash_code() == hash_); };
template<typename MeasureAuxType> void check_type() const {
if (!(has_type<MeasureAuxType>()))
TRIQS_RUNTIME_ERROR << "Trying to retrieve a measure_aux of type "<< typeid(MeasureAuxType).name() << " from a measure_aux of type "<< type_name_;
};
template<typename MeasureAuxType> MeasureAuxType * get() { check_type<MeasureAuxType>(); return (static_cast<MeasureAuxType *>(impl_.get())); }
template<typename MeasureAuxType> MeasureAuxType const * get() const { check_type<MeasureAuxType>(); return (static_cast<MeasureAuxType const *>(impl_.get())); }
};
//--------------------------------------------------------------------
class measure_aux_set {
typedef measure_aux measure_aux_type;
std::map<std::string, measure_aux> m_map;
public :
measure_aux_set(){}
measure_aux_set(measure_aux_set const &) = default;
measure_aux_set(measure_aux_set &&) = default;
measure_aux_set& operator = (measure_aux_set const &) = default;
measure_aux_set& operator = (measure_aux_set &&) = default;
/**
* Register the auxiliary M with a name
*/
/*template<typename MeasureAuxType>
MeasureAuxType * insert (MeasureAuxType && M, std::string const & name) {
if (has(name)) TRIQS_RUNTIME_ERROR << "Auxiliary measure "<<name<<" already exists";
auto r = m_map.insert(std::make_pair(name, measure_aux_type (std::forward<MeasureAuxType>(M))));
return r.first;
}
template<typename MeasureAuxType>
MeasureAuxType * retrieve (std::string const & name) {
if (!has(name)) TRIQS_RUNTIME_ERROR << "Auxiliary measure "<<name<<" does not exist";
return (&get_measure_aux<MeasureAuxType>(name));
}
*/
bool has(std::string const & name) const { return m_map.find(name) != m_map.end(); }
void compute_all () { for (auto & nmp : m_map) nmp.second(); }
std::vector<std::string> names() const {
std::vector<std::string> res;
for (auto const & nmp : m_map) res.push_back(nmp.first);
return res;
}
// access to the measure_aux, given its type, with dynamical type check
template<typename MeasureAuxType>
MeasureAuxType * get_measure_aux(std::string const & name) {
auto it = m_map.find (name);
if (it == m_map.end()) { auto r = m_map.insert(std::make_pair(name, measure_aux_type(MeasureAuxType()))); return get_measure_aux<MeasureAuxType>(name);}
return it->second.template get<MeasureAuxType>();
}
/*template<typename MeasureAuxType>
MeasureAuxType const & get_measure_aux(std::string const & name) const {
auto it = m_map.find (name);
if (it == m_map.end()) TRIQS_RUNTIME_ERROR << " Measure " << name << " unknown";
return it->template get<MeasureAuxType>();
}
*/
};
}}// end namespace
#endif
}
} // end namespace