diff --git a/triqs/mc_tools/mc_generic.hpp b/triqs/mc_tools/mc_generic.hpp index a2c0b316..6c2fad48 100644 --- a/triqs/mc_tools/mc_generic.hpp +++ b/triqs/mc_tools/mc_generic.hpp @@ -24,6 +24,7 @@ #include #include #include +#include "./mc_measure_aux_set.hpp" #include "./mc_measure_set.hpp" #include "./mc_move_set.hpp" #include "./mc_basic_step.hpp" @@ -46,7 +47,7 @@ namespace triqs { namespace mc_tools { std::function AfterCycleDuty = std::function() ) : RandomGenerator(Random_Name, Random_Seed), AllMoves(RandomGenerator), - AllMeasures(), + AllMeasures(),AllMeasuresAux(), report(&std::cout, Verbosity), Length_MC_Cycle(Length_Cycle), NWarmIterations(N_Warmup_Cycles), @@ -64,7 +65,7 @@ namespace triqs { namespace mc_tools { //RandomGenerator(P["Random_Generator_Name"]), P.value_or_default("Random_Seed",1)), report(&std::cout,int(P["Verbosity"])), AllMoves(RandomGenerator), - AllMeasures(), + AllMeasures(),AllMeasuresAux(), Length_MC_Cycle(long(P["Length_Cycle"])), /// NOT NICE THIS EXPLICIT CAST : no unsigned in parameters, really ?? NWarmIterations(long(P["N_Warmup_Cycles"])), NCycles(long(P["N_Cycles"])), @@ -92,7 +93,16 @@ namespace triqs { namespace mc_tools { AllMeasures.insert(std::forward(M), name); } - /// get the average sign (to be called after collect_results) + /** + * Register the precomputation + */ + /*template + MeasureAuxType * add_measure_aux(MeasureAuxType && M, std::string name) { + static_assert( !std::is_pointer::value, "add_measure_aux in mc_generic takes ONLY values !"); + AllMeasuresAux.insert(std::forward(M), name); + } +*/ + /// get the average sign (to be called after collect_results) MCSignType average_sign() const { return sign_av; } /// get the current percents done @@ -116,6 +126,7 @@ namespace triqs { namespace mc_tools { if (thermalized()) { nmeasures++; sum_sign += sign; + AllMeasuresAux.compute_all(); AllMeasures.accumulate(sign); } // recompute fraction done @@ -156,11 +167,14 @@ namespace triqs { namespace mc_tools { } // do not use direcly, use the free function it is simpler to call... - template MeasureType & get_measure(std::string const & name) { return AllMeasures.template get (name); } - template MeasureType const & get_measure(std::string const & name) const { return AllMeasures.template get (name); } + template MeasureType & get_measure(std::string const & name) { return AllMeasures.template get_measure (name); } + template MeasureType const & get_measure(std::string const & name) const { return AllMeasures.template get_measure (name); } + + template MeasureAuxType * get_measure_aux(std::string const & name) { return AllMeasuresAux.template get_measure_aux (name); } + template MeasureAuxType const * get_measure_aux(std::string const & name) const { return AllMeasuresAux.template get_measure_aux (name); } - template MoveType & get_move (std::string const & name) { return AllMoves.template get (name); } - template MoveType const & get_move (std::string const & name) const { return AllMoves.template get (name); } + template MoveType & get_move (std::string const & name) { return AllMoves.template get_move (name); } + template MoveType const & get_move (std::string const & name) const { return AllMoves.template get_move (name); } /// HDF5 interface friend void h5_write (h5::group g, std::string const & name, mc_generic const & mc){ @@ -194,6 +208,7 @@ namespace triqs { namespace mc_tools { random_generator RandomGenerator; move_set AllMoves; measure_set AllMeasures; + measure_aux_set AllMeasuresAux; utility::report_stream report; uint64_t Length_MC_Cycle;/// Length of one Monte-Carlo cycle between 2 measures uint64_t NWarmIterations, NCycles; @@ -212,6 +227,10 @@ namespace triqs { namespace mc_tools { /// Retrieve a Measure given name and type. NB : the type is checked at runtime template M & get_measure(mc_generic & s, std::string const & name) { return s.template get_measure (name); } template M const & get_measure(mc_generic const & s, std::string const & name) { return s.template get_measure (name); } + + /// Retrieve a Measure given name and type. NB : the type is checked at runtime + template M * get_measure_aux(mc_generic & s, std::string const & name) { return s.template get_measure_aux (name); } + template M const * get_measure_aux(mc_generic const & s, std::string const & name) { return s.template get_measure_aux (name); } /// Retrieve a Move given name and type. NB : the type is checked at runtime template M & get_move(mc_generic & s, std::string const & name) { return s.template get_move (name); } diff --git a/triqs/mc_tools/mc_measure_aux_set.hpp b/triqs/mc_tools/mc_measure_aux_set.hpp new file mode 100644 index 00000000..d2147323 --- /dev/null +++ b/triqs/mc_tools/mc_measure_aux_set.hpp @@ -0,0 +1,160 @@ +/******************************************************************************* + * + * TRIQS: a Toolbox for Research in Interacting Quantum Systems + * + * Copyright (C) 2011-2013 by M. Ferrero, O. Parcollet + * + * TRIQS is free software: you can redistribute it and/or modify it under the + * terms of the GNU General Public License as published by the Free Software + * Foundation, either version 3 of the License, or (at your option) any later + * version. + * + * TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License along with + * TRIQS. If not, see . + * + ******************************************************************************/ +#ifndef TRIQS_TOOLS_MC_PRECOMPUTATION_H +#define TRIQS_TOOLS_MC_PRECOMPUTATION_H + +#include +#include +#include +#include +#include "./impl_tools.hpp" + +namespace triqs { namespace mc_tools { + + // mini concept checking + template struct is_callable: std::false_type {}; + template struct is_callable ()())> : std::true_type {}; + + + //-------------------------------------------------------------------- + + class measure_aux { + + std::shared_ptr impl_; + std::function clone_; + size_t hash_; + std::string type_name_; + + std::function call_; + + public : + + template + measure_aux (MeasureAuxType && p_in, bool) { + static_assert( is_callable::value, "This measure_aux is not callable"); + MeasureAuxType *p = new typename std::remove_reference::type(std::forward(p_in)); + impl_= std::shared_ptr (p); + clone_ = [p]() { return MeasureAuxType(*p);} ; + hash_ = typeid(MeasureAuxType).hash_code(); + type_name_ = typeid(MeasureAuxType).name(); + call_ = [p]() { (*p)();}; + } + + // Value semantics. Everyone at the end call move = ... + measure_aux(measure_aux const &rhs) = default; //{*this = rhs;} + //measure_aux(measure_aux &rhs) {*this = rhs;} // or it will use the template = bug + measure_aux(measure_aux && rhs) = default ; //{ *this = std::move(rhs);} + measure_aux & operator = (measure_aux const & rhs) { *this = rhs.clone_(); return *this;} +#ifndef TRIQS_WORKAROUND_INTEL_COMPILER_BUGS + measure_aux & operator = (measure_aux && rhs) =default; +#else + measure_aux & operator = (measure_aux && rhs) noexcept { + using std::swap; +#define SW(X) swap(X,rhs.X) + SW(impl_); SW(hash_); SW(type_name_); SW(clone_); + SW(call_); +#undef SW + return *this; + } +#endif + + void operator()(){ call_();} + + template bool has_type() const { return (typeid(MeasureAuxType).hash_code() == hash_); }; + template void check_type() const { + if (!(has_type())) + TRIQS_RUNTIME_ERROR << "Trying to retrieve a measure_aux of type "<< typeid(MeasureAuxType).name() << " from a measure_aux of type "<< type_name_; + }; + + template MeasureAuxType * get() { check_type(); return (static_cast(impl_.get())); } + template MeasureAuxType const * get() const { check_type(); return (static_cast(impl_.get())); } + + }; + + //-------------------------------------------------------------------- + + class measure_aux_set { + typedef measure_aux measure_aux_type; + std::map m_map; + public : + + measure_aux_set(){} + + measure_aux_set(measure_aux_set const &) = default; + measure_aux_set(measure_aux_set &&) = default; + + measure_aux_set& operator = (measure_aux_set const &) = default; +#ifndef TRIQS_WORKAROUND_INTEL_COMPILER_BUGS + measure_aux_set& operator = (measure_aux_set &&) = default; +#else + measure_aux_set& operator = (measure_aux_set && rhs) { + using std::swap; + swap(m_map,rhs.m_map); + return *this; + } +#endif + + /** + * Register the auxiliary M with a name + */ + /*template + MeasureAuxType * insert (MeasureAuxType && M, std::string const & name) { + if (has(name)) TRIQS_RUNTIME_ERROR << "Auxiliary measure "<(M)))); + return r.first; + } + + template + MeasureAuxType * retrieve (std::string const & name) { + if (!has(name)) TRIQS_RUNTIME_ERROR << "Auxiliary measure "<(name)); + } +*/ + bool has(std::string const & name) const { return m_map.find(name) != m_map.end(); } + + void compute_all () { for (auto & nmp : m_map) nmp.second(); } + + std::vector names() const { + std::vector res; + for (auto const & nmp : m_map) res.push_back(nmp.first); + return res; + } + + // access to the measure_aux, given its type, with dynamical type check + template + MeasureAuxType * get_measure_aux(std::string const & name) { + auto it = m_map.find (name); + if (it == m_map.end()) { auto r = m_map.insert(std::make_pair(name, measure_aux_type(MeasureAuxType()))); return get_measure_aux(name);} + return it->second.template get(); + } + + /*template + MeasureAuxType const & get_measure_aux(std::string const & name) const { + auto it = m_map.find (name); + if (it == m_map.end()) TRIQS_RUNTIME_ERROR << " Measure " << name << " unknown"; + return it->template get(); + } + */ + }; + +}}// end namespace +#endif +