3
0
mirror of https://github.com/triqs/dft_tools synced 2024-12-25 05:43:40 +01:00

[mpi] draft of gf support

- done Matsubara freq for testing and rereading.
- TODO: generalize to other meshes.
- draft for multi var gf
This commit is contained in:
Olivier Parcollet 2014-09-03 14:03:31 +02:00
parent ebbb2f0b25
commit 9c129cb224
13 changed files with 403 additions and 46 deletions

View File

@ -45,7 +45,10 @@ int main(int argc, char* argv[]) {
A(i_, j_) << i_ + 10 * j_;
//std::cerr << "B0 "<< B <<std::endl;
B = mpi::scatter(A, world);
std::cerr << "B "<< B <<std::endl;
ARR C = mpi::scatter(A, world);
std::ofstream out("node" + std::to_string(world.rank()));
@ -66,5 +69,14 @@ int main(int argc, char* argv[]) {
AA = mpi::allgather(B, world);
out << " AA = " << AA << std::endl;
ARR r1 = mpi::reduce(A, world);
out <<" Reduce "<< std::endl;
out << " r1 = " << r1 << std::endl;
ARR r2 = mpi::allreduce(A, world);
out <<" AllReduce "<< std::endl;
out << " r2 = " << r2 << std::endl;
}

View File

@ -49,7 +49,7 @@ struct my_object {
// assigment is almost done already...
template <typename Tag> my_object &operator=(mpi_lazy<Tag, my_object> x) {
return mpi_impl_tuple<my_object>::complete_operation(*this, x);
return mpi_impl<my_object>::complete_operation(*this, x);
}
};

104
test/triqs/mpi/mpi_gf.cpp Normal file
View File

@ -0,0 +1,104 @@
/*******************************************************************************
*
* TRIQS: a Toolbox for Research in Interacting Quantum Systems
*
* Copyright (C) 2013 by O. Parcollet
*
* TRIQS is free software: you can redistribute it and/or modify it under the
* terms of the GNU General Public License as published by the Free Software
* Foundation, either version 3 of the License, or (at your option) any later
* version.
*
* TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*
* You should have received a copy of the GNU General Public License along with
* TRIQS. If not, see <http://www.gnu.org/licenses/>.
*
******************************************************************************/
#define TRIQS_ARRAYS_ENFORCE_BOUNDCHECK
#include <iostream>
#include <type_traits>
#include <triqs/gfs.hpp>
#include <triqs/mpi.hpp>
#include <iostream>
#include <fstream>
#include <sstream>
using namespace triqs;
using namespace triqs::arrays;
using namespace triqs::gfs;
using namespace triqs::clef;
int main(int argc, char* argv[]) {
mpi::environment env(argc, argv);
mpi::communicator world;
std::ofstream out("node" + std::to_string(world.rank()));
double beta = 10;
int Nfreq = 8;
placeholder<0> w_;
auto g1 = gf<imfreq>{{beta, Fermion, Nfreq}, {1, 1}}; // using ARR = array<double,2>;
g1(w_) << 1 / (w_ + 1);
out << "g1.data" << g1.data() << std::endl;
{
out<< "reduction "<< std::endl;
gf<imfreq> g2 = mpi::reduce(g1, world);
out << g2.data()<<std::endl;
out << g2.singularity() << std::endl;
}
{
out<< "all reduction "<< std::endl;
gf<imfreq> g2 = mpi::allreduce(g1, world);
out << g2.data()<<std::endl;
out << g2.singularity() << std::endl;
}
{
out << "scatter-gather test with =" << std::endl;
auto g2 = g1;
auto g2b = g1;
g2 = mpi::scatter(g1);
g2(w_) << g2(w_) * (1 + world.rank());
g2b = mpi::gather(g2);
out << g2b.data() << std::endl;
}
{
out << "scatter-allgather test with construction" << std::endl;
gf<imfreq> g2 = mpi::scatter(g1);
g2(w_) << g2(w_) * (1 + world.rank());
g1 = mpi::allgather(g2);
out << g1.data() << std::endl;
}
{
out << "Building directly scattered, and gather" << std::endl;
auto m = mpi_scatter(gf_mesh<imfreq>{beta, Fermion, Nfreq}, world, 0);
auto g3 = gf<imfreq>{m, {1, 1}};
g3(w_) << 1 / (w_ + 1);
auto g4 = g3;
out<< "chunk ..."<<std::endl;
out << g3.data() << std::endl;
out<< "gather"<<std::endl;
g4 = mpi::gather(g3);
out << g4.data() << std::endl;
out<< "allgather"<<std::endl;
g4 = mpi::allgather(g3);
out << g4.data() << std::endl;
}
}

View File

@ -26,6 +26,7 @@
#include <triqs/utility/tuple_tools.hpp>
#include <triqs/utility/c14.hpp>
#include <triqs/arrays/h5.hpp>
#include <triqs/mpi/gf.hpp>
#include <vector>
#include "./tools.hpp"
#include "./data_proxies.hpp"
@ -425,6 +426,9 @@ namespace gfs {
*this = x;
}
// mpi lazy
template <typename Tag> gf(mpi::mpi_lazy<Tag, gf> x) : gf() { operator=(x); }
gf(typename B::mesh_t m, typename B::data_t dat, typename B::singularity_view_t const &si, typename B::symmetry_t const &s,
typename B::indices_t const &ind, std::string name = "")
: B(std::move(m), std::move(dat), si, s, ind, name, typename B::evaluator_t{}) {}
@ -453,6 +457,13 @@ namespace gfs {
return *this;
}
friend struct mpi::mpi_impl_triqs_gfs<gf>; //allowed to modify mesh
//
template <typename Tag> void operator=(mpi::mpi_lazy<Tag, gf> x) {
mpi::mpi_impl_triqs_gfs<gf>::complete_operation(*this, x);
}
template <typename RHS> void operator=(RHS &&rhs) {
this->_mesh = rhs.mesh();
this->_data.resize(get_gf_data_shape(rhs));
@ -841,6 +852,17 @@ namespace gfs {
};
} // gfs_implementation
}
namespace mpi {
template <typename Variable, typename Target, typename Opt>
struct mpi_impl<gfs::gf<Variable, Target, Opt>, void> : mpi_impl_triqs_gfs<gfs::gf<Variable, Target, Opt>> {};
template <typename Variable, typename Target, typename Opt, bool IsConst>
struct mpi_impl<gfs::gf_view<Variable, Target, Opt, IsConst>, void> : mpi_impl_triqs_gfs<gfs::gf_view<Variable, Target, Opt, IsConst>> {};
}
}
// same as for arrays : views cannot be swapped by the std::swap. Delete it

View File

@ -18,11 +18,12 @@
* TRIQS. If not, see <http://www.gnu.org/licenses/>.
*
******************************************************************************/
#ifndef TRIQS_GF_LOCAL_TAIL_H
#define TRIQS_GF_LOCAL_TAIL_H
#pragma once
#include <triqs/arrays.hpp>
#include <triqs/arrays/algorithms.hpp>
#include <triqs/gfs/tools.hpp>
#include <triqs/mpi/boost.hpp>
#include <boost/serialization/complex.hpp>
namespace triqs { namespace gfs { namespace local {
@ -50,6 +51,7 @@ namespace triqs { namespace gfs { namespace local {
/// A common implementation class. Idiom: ValueView
template<bool IsView> class tail_impl {
public:
TRIQS_MPI_IMPLEMENTED_VIA_BOOST;
typedef tail_view view_type;
typedef tail regular_type;
@ -171,6 +173,7 @@ namespace triqs { namespace gfs { namespace local {
}
friend std::ostream & operator << (std::ostream & out, tail_impl const & x) {
if (x.data().is_empty()) return out << "empty tail"<<std::endl;
out <<"tail/tail_view: min/smallest/max = "<< x.order_min() << " " << x.smallest_nonzero() << " "<< x.order_max();
for (long u = x.order_min(); u <= x.order_max(); ++u) out <<"\n ... Order "<<u << " = " << x(u);
return out;
@ -283,7 +286,7 @@ namespace triqs { namespace gfs { namespace local {
inline tail transpose(tail_view t) { return {transposed_view(t.data(),0,2,1), transposed_view(t.mask_view(),1,0),t.order_min()};}
/// Slice in orbital space
//template<bool V> tail_view slice_target(tail_impl<V> const & t, tqa::range R1, tqa::range R2) {
//template<bool V> tail_view slice_target(tail_impl<V> const & t, tqa::range R1, tqa::range R2)
inline tail_view slice_target(tail_view t, tqa::range R1, tqa::range R2) {
return tail_view(t.data()(tqa::range(),R1,R2), t.mask_view()(R1,R2), t.order_min());
}
@ -407,7 +410,5 @@ namespace triqs { namespace gfs { namespace local {
#undef DEFINE_OPERATOR
}}}
#endif
}}
}

View File

@ -35,11 +35,21 @@ namespace gfs {
using domain_pt_t = typename domain_t::point_t;
/// Constructor
matsubara_freq_mesh() : _dom(), _n_pts(0), _positive_only(true) {}
matsubara_freq_mesh(domain_t dom, long n_pts = 1025, bool positive_only = true)
: _dom(std::move(dom)), _n_pts(n_pts), _positive_only(positive_only) {
if (_positive_only) {
_first_index = 0;
_last_index = n_pts - 1; // CORRECTION
} else {
_last_index = (_n_pts - (_dom.statistic == Boson ? 1 : 2)) / 2;
_first_index = -(_last_index + (_dom.statistic == Fermion));
}
_first_index_window = _first_index;
_last_index_window = _last_index;
}
/// Constructor
matsubara_freq_mesh(domain_t dom, int n_pts=1025, bool positive_only = true)
: _dom(std::move(dom)), _n_pts(n_pts), _positive_only(positive_only) {}
matsubara_freq_mesh() : matsubara_freq_mesh(domain_t(), 0, true){}
/// Constructor
matsubara_freq_mesh(double beta, statistic_enum S, int n_pts = 1025, bool positive_only = true)
@ -48,6 +58,17 @@ namespace gfs {
/// Copy constructor
matsubara_freq_mesh(matsubara_freq_mesh const &) = default;
/// Scatter a mesh over the communicator c
friend matsubara_freq_mesh mpi_scatter(matsubara_freq_mesh m, mpi::communicator c, int root) {
auto m2 = matsubara_freq_mesh{m.domain(), m.size(), m.positive_only()};
std::tie(m2._first_index_window, m2._last_index_window) = mpi::slice_range(m2._first_index, m2._last_index, c.size(), c.rank());
return m2;
}
friend matsubara_freq_mesh mpi_gather(matsubara_freq_mesh m, mpi::communicator c, int root) {
return matsubara_freq_mesh{m.domain(), m.size(), m.positive_only()};
}
/// The corresponding domain
domain_t const &domain() const { return _dom; }
@ -60,20 +81,29 @@ namespace gfs {
**/
/// last Matsubara index
int last_index() const { return (_positive_only ? _n_pts : (_n_pts - (_dom.statistic == Boson ? 1 : 2))/2);}
int last_index() const { return _last_index;}
/// first Matsubara index
int first_index() const { return -(_positive_only ? 0 : last_index() + (_dom.statistic == Fermion)); }
int first_index() const { return _first_index;}
/// last Matsubara index of the window
int last_index_window() const { return _last_index_window;}
/// first Matsubara index of the window
int first_index_window() const { return _first_index_window;}
/// Size (linear) of the mesh
long size() const { return _n_pts;}
//long size() const { return _n_pts;}
/// Size (linear) of the mesh of the window
long size() const { return _last_index_window - _first_index_window + 1; }
/// From an index of a point in the mesh, returns the corresponding point in the domain
domain_pt_t index_to_point(index_t ind) const { return 1_j * M_PI * (2 * ind + (_dom.statistic == Fermion)) / _dom.beta; }
/// Flatten the index in the positive linear index for memory storage (almost trivial here).
long index_to_linear(index_t ind) const { return ind - first_index(); }
index_t linear_to_index(long lind) const { return lind + first_index(); }
long index_to_linear(index_t ind) const { return ind - first_index_window(); }
index_t linear_to_index(long lind) const { return lind + first_index_window(); }
/// Is the mesh only for positive omega_n (G(tau) real))
bool positive_only() const { return _positive_only;}
@ -86,18 +116,18 @@ namespace gfs {
struct mesh_point_t : tag::mesh_point, matsubara_freq {
mesh_point_t() = default;
mesh_point_t(matsubara_freq_mesh const &mesh, index_t const &index_)
: matsubara_freq(index_, mesh.domain().beta, mesh.domain().statistic),
first_index(mesh.first_index()),
index_stop(mesh.first_index() + mesh.size() - 1) {}
mesh_point_t(matsubara_freq_mesh const &mesh) : mesh_point_t(mesh, mesh.first_index()) {}
: matsubara_freq(index_, mesh.domain().beta, mesh.domain().statistic)
, first_index_window(mesh.first_index_window())
, last_index_window(mesh.last_index_window()) {}
mesh_point_t(matsubara_freq_mesh const &mesh) : mesh_point_t(mesh, mesh.first_index_window()) {}
void advance() { ++n; }
long linear_index() const { return n - first_index; }
long linear_index() const { return n - first_index_window; }
long index() const { return n; }
bool at_end() const { return (n == index_stop + 1); } // at_end means " one after the last one", as in STL
void reset() { n = first_index; }
bool at_end() const { return (n == last_index_window + 1); } // at_end means " one after the last one", as in STL
void reset() { n = first_index_window; }
private:
index_t first_index, index_stop;
index_t first_index_window, last_index_window;
};
/// Accessing a point of the mesh from its index
@ -164,6 +194,7 @@ namespace gfs {
domain_t _dom;
int _n_pts;
bool _positive_only;
long _first_index, _last_index, _first_index_window, _last_index_window;
};
//-------------------------------------------------------

View File

@ -39,6 +39,7 @@ namespace gfs {
mesh_product() {}
mesh_product(Meshes const &... meshes) : m_tuple(meshes...), _dom(meshes.domain()...) {}
mesh_product(mesh_product const &) = default;
domain_t const &domain() const { return _dom; }
m_tuple_t const &components() const { return m_tuple; }
@ -49,6 +50,20 @@ namespace gfs {
return triqs::tuple::fold([](auto const &m, size_t R) { return R * m.size(); }, m_tuple, 1);
}
/// Scatter the first mesh over the communicator c
friend mesh_product mpi_scatter(mesh_product const &m, mpi::communicator c, int root) {
auto r = m; // same domain, but mesh with a window. Ok ?
std::get<0>(r.m_tuple) = mpi_scatter(std::get<0>(r.m_tuple), c, root);
return r;
}
/// Opposite of scatter : rebuild the original mesh, without a window
friend matsubara_freq_mesh mpi_gather(matsubara_freq_mesh m, mpi::communicator c, int root) {
auto r = m; // same domain, but mesh with a window. Ok ?
std::get<0>(r.m_tuple) = mpi_gather(std::get<0>(r.m_tuple), c, root);
return r;
}
/// Conversions point <-> index <-> linear_index
typename domain_t::point_t index_to_point(index_t const &ind) const {
domain_pt_t res;

View File

@ -40,8 +40,12 @@ namespace mpi {
auto dims = ref.shape();
long slow_size = first_dim(ref);
if (std::is_same<Tag, tag::reduce>::value) {
// optionally check all dims are the same ?
}
if (std::is_same<Tag, tag::scatter>::value) {
dims[0] = slice_length(slow_size - 1, c, c.rank());
dims[0] = mpi::slice_length(slow_size - 1, c.size(), c.rank());
}
if (std::is_same<Tag, tag::gather>::value) {
@ -87,7 +91,7 @@ namespace mpi {
static void allreduce_in_place(communicator c, A &a, int root) {
check_is_contiguous(a);
// assume arrays have the same size on all nodes...
MPI_Allreduce(MPI_IN_PLACE, a.data_start(), a.domain().number_of_elements(), D(), MPI_SUM, root, c.get());
MPI_Allreduce(MPI_IN_PLACE, a.data_start(), a.domain().number_of_elements(), D(), MPI_SUM, c.get());
}
//---------
@ -137,6 +141,18 @@ namespace arrays {
private:
static MPI_Datatype D() { return mpi::mpi_datatype<typename A::value_type>::invoke(); }
//---------------------------------
void _invoke(triqs::mpi::tag::reduce) {
lhs.resize(laz.domain());
MPI_Reduce((void *)laz.ref.data_start(), (void *)lhs.data_start(), laz.ref.domain().number_of_elements(), D(), MPI_SUM, laz.root, laz.c.get());
}
//---------------------------------
void _invoke(triqs::mpi::tag::allreduce) {
lhs.resize(laz.domain());
MPI_Allreduce((void *)laz.ref.data_start(), (void *)lhs.data_start(), laz.ref.domain().number_of_elements(), D(), MPI_SUM, laz.c.get());
}
//---------------------------------
void _invoke(triqs::mpi::tag::scatter) {
lhs.resize(laz.domain());
@ -146,10 +162,10 @@ namespace arrays {
auto slow_stride = laz.ref.indexmap().strides()[0];
auto sendcounts = std::vector<int>(c.size());
auto displs = std::vector<int>(c.size() + 1, 0);
int recvcount = slice_length(slow_size - 1, c, c.rank()) * slow_stride;
int recvcount = mpi::slice_length(slow_size - 1, c.size(), c.rank()) * slow_stride;
for (int r = 0; r < c.size(); ++r) {
sendcounts[r] = slice_length(slow_size - 1, c, r) * slow_stride;
sendcounts[r] = mpi::slice_length(slow_size - 1, c.size(), r) * slow_stride;
displs[r + 1] = sendcounts[r] + displs[r];
}

View File

@ -23,6 +23,12 @@
//#include <triqs/utility/tuple_tools.hpp>
#include <mpi.h>
namespace boost { // forward declare in case we do not include boost.
namespace mpi {
class communicator;
}
}
namespace triqs {
namespace mpi {
@ -41,6 +47,11 @@ namespace mpi {
MPI_Comm get() const { return _com; }
inline communicator(boost::mpi::communicator);
/// Cast to the boost mpi communicator
inline operator boost::mpi::communicator () const;
int rank() const {
int num;
MPI_Comm_rank(_com, &num);
@ -68,6 +79,13 @@ namespace mpi {
/// The implementation of mpi ops for each type
template <typename T, typename Enable = void> struct mpi_impl;
/// A small lazy tagged class
template <typename Tag, typename T> struct mpi_lazy {
T const &ref;
int root;
communicator c;
};
// ----------------------------------------
// ------- top level functions -------
// ----------------------------------------
@ -136,6 +154,26 @@ namespace mpi {
struct mpi_impl<T, std14::enable_if_t<std::is_arithmetic<T>::value || triqs::is_complex<T>::value>> : mpi_impl_basic<T> {};
//------------ Some helper function
// Given a range [first, last], slice it regularly for a node of rank 'rank' among n_nodes.
// If the range is not dividable in n_nodes equal parts,
// the first nodes have one more elements than the last ones.
inline std::pair<long, long> slice_range(long first, long last, int n_nodes, int rank) {
long chunk = (last - first + 1) / n_nodes;
long n_large_nodes = (last - first + 1) - n_nodes * chunk;
if (rank <= n_large_nodes - 1) // first, larger nodes, use chunk + 1
return {first + rank * (chunk + 1), first + (rank + 1) * (chunk + 1) - 1};
else // others nodes : shift the first by 1*n_large_nodes, used chunk
return {first + n_large_nodes + rank * chunk, first + n_large_nodes + (rank + 1) * chunk - 1};
}
// TODO RECHECK TEST
inline long slice_length(long imax, int n_nodes, int rank) {
auto r = slice_range(0, imax, n_nodes, rank);
return r.second - r.first + 1;
}
/*
inline long slice_length(size_t imax, communicator c, int r) {
auto imin = 0;
long j = (imax - imin + 1) / c.size();
@ -143,6 +181,7 @@ namespace mpi {
auto r_min = (r <= i - 1 ? imin + r * (j + 1) : imin + r * j + i);
auto r_max = (r <= i - 1 ? imin + (r + 1) * (j + 1) - 1 : imin + (r + 1) * j + i - 1);
return r_max - r_min + 1;
};
}
*/
}
}

View File

@ -22,9 +22,20 @@
#include "./base.hpp"
#include <boost/mpi.hpp>
#define TRIQS_MPI_IMPLEMENTED_VIA_BOOST using triqs_mpi_via_boost = void;
namespace triqs {
namespace mpi {
// implement the communicator cast
inline communicator::operator boost::mpi::communicator() const {
return boost::mpi::communicator(_com, boost::mpi::comm_duplicate);
// duplicate policy : cf http://www.boost.org/doc/libs/1_56_0/doc/html/boost/mpi/comm_create_kind.html
}
// reverse : construct (implicit) the communicator from the boost one.
inline communicator::communicator(boost::mpi::communicator c) :_com(c) {}
/** ------------------------------------------------------------
* Type which we use boost::mpi
* ---------------------------------------------------------- **/
@ -39,7 +50,7 @@ namespace mpi {
static T invoke(tag::allreduce, communicator c, T const &a, int root) {
T b;
boost::mpi::all_reduce(c, a, b, std::c14::plus<>(), root);
boost::mpi::all_reduce(c, a, b, std::c14::plus<>());
return b;
}
@ -51,8 +62,8 @@ namespace mpi {
static void allgather(communicator c, T const &, int root) = delete;
};
// default
//template <typename T> struct mpi_impl<T> : mpi_impl_boost_mpi<T> {};
// If type T has a mpi_implementation nested struct, then it is mpi_impl<T>.
template <typename T> struct mpi_impl<T, typename T::triqs_mpi_via_boost> : mpi_impl_boost_mpi<T> {};
}}//namespace

View File

@ -23,26 +23,31 @@
#include <triqs/utility/tuple_tools.hpp>
#define TRIQS_MPI_IMPLEMENTED_AS_TUPLEVIEW using triqs_mpi_as_tuple = void;
#define TRIQS_MPI_IMPLEMENTED_AS_TUPLEVIEW_NO_LAZY using triqs_mpi_as_tuple_no_lazy = void;
namespace triqs {
namespace mpi {
template <typename Tag, typename T> struct mpi_lazy {
T const &ref;
int root;
communicator c;
};
/** ------------------------------------------------------------
* Type which are recursively treated by reducing them to a tuple
* of smaller objects.
* ---------------------------------------------------------- **/
template <typename T> struct mpi_impl_tuple {
template <typename T, bool with_lazy> struct mpi_impl_tuple {
mpi_impl_tuple() = default;
template <typename Tag> static mpi_lazy<Tag, T> invoke(Tag, communicator c, T const &a, int root) {
/// invoke
template <typename Tag> static mpi_lazy<Tag, T> invoke_impl(std::true_type, Tag, communicator c, T const &a, int root) {
return {a, root, c};
}
template <typename Tag> static T &invoke_impl(std::false_type, Tag, communicator c, T const &a, int root) {
return complete_operation(a, {a, root, c});
}
template <typename Tag> static mpi_lazy<Tag, T> invoke(Tag, communicator c, T const &a, int root) {
return invoke_impl(std::integral_constant<bool, with_lazy>(), Tag(), c, a, root);
}
#ifdef __cpp_generic_lambdas
static void reduce_in_place(communicator c, T &a, int root) {
tuple::for_each(view_as_tuple(a), [c, root](auto &x) { triqs::mpi::reduce_in_place(x, c, root); });
@ -57,6 +62,7 @@ namespace mpi {
triqs::tuple::for_each_zip(l, view_as_tuple(target), view_as_tuple(laz.ref));
return target;
}
#else
struct aux1 {
@ -89,15 +95,17 @@ namespace mpi {
}
};
template <typename Tag> static void complete_operation(T &target, mpi_lazy<Tag, T> laz) {
template <typename Tag> static T& complete_operation(T &target, mpi_lazy<Tag, T> laz) {
auto l = aux3<Tag>{laz};
triqs::tuple::for_each_zip(l, view_as_tuple(target), view_as_tuple(laz.ref));
return target;
}
#endif
};
// If type T has a mpi_implementation nested struct, then it is mpi_impl<T>.
template <typename T> struct mpi_impl<T, typename T::triqs_mpi_as_tuple> : mpi_impl_tuple<T> {};
template <typename T> struct mpi_impl<T, typename T::triqs_mpi_as_tuple> : mpi_impl_tuple<T, true> {};
template <typename T> struct mpi_impl<T, typename T::triqs_mpi_as_tuple_no_lazy> : mpi_impl_tuple<T, false> {};
}
} // namespace

98
triqs/mpi/gf.hpp Normal file
View File

@ -0,0 +1,98 @@
/*******************************************************************************
*
* TRIQS: a Toolbox for Research in Interacting Quantum Systems
*
* Copyright (C) 2014 by O. Parcollet
*
* TRIQS is free software: you can redistribute it and/or modify it under the
* terms of the GNU General Public License as published by the Free Software
* Foundation, either version 3 of the License, or (at your option) any later
* version.
*
* TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*
* You should have received a copy of the GNU General Public License along with
* TRIQS. If not, see <http://www.gnu.org/licenses/>.
*
******************************************************************************/
#pragma once
#include "./base.hpp"
#include <triqs/mpi/generic.hpp>
namespace triqs {
namespace mpi {
//--------------------------------------------------------------------------------------------------------
// When value_type is a basic type, we can directly call the C API
template <typename G> struct mpi_impl_triqs_gfs {
//---------
static void reduce_in_place(communicator c, G &g, int root) {
triqs::mpi::reduce_in_place(c, g.data(), root);
triqs::mpi::reduce_in_place(c, g.singularity(), root);
}
//---------
/*static void allreduce_in_place(communicator c, G &g, int root) {
triqs::mpi::allreduce_in_place(c, g.data(), root);
triqs::mpi::allreduce_in_place(c, g.singularity(), root);
}
*/
//---------
static void broadcast(communicator c, G &g, int root) {
triqs::mpi::broadcast(c, g.data(), root);
triqs::mpi::broadcast(c, g.singularity(), root);
}
//---------
template <typename Tag> static mpi_lazy<Tag, G> invoke(Tag, communicator c, G const &g, int root) {
return {g, root, c};
}
//---- reduce ----
static G &complete_operation(G &target, mpi_lazy<tag::reduce, G> laz) {
target._data = mpi::reduce(laz.ref.data(), laz.c, laz.root);
target._singularity = mpi::reduce(laz.ref.singularity(), laz.c, laz.root);
return target;
}
//---- allreduce ----
static G &complete_operation(G &target, mpi_lazy<tag::allreduce, G> laz) {
target._data = mpi::allreduce(laz.ref.data(), laz.c, laz.root);
target._singularity = mpi::allreduce(laz.ref.singularity(), laz.c, laz.root);
return target;
}
//---- scatter ----
static G &complete_operation(G &target, mpi_lazy<tag::scatter, G> laz) {
target._mesh = mpi_scatter(laz.ref.mesh(), laz.c, laz.root);
target._data = mpi::scatter(laz.ref.data(), laz.c, laz.root); // HERE ADD OPTION FOR CHUNCK
target._singularity = laz.ref.singularity();
//mpi::broadcast(target._singularity, laz.c, laz.root);
return target;
}
//---- gather ----
static G &complete_operation(G &target, mpi_lazy<tag::gather, G> laz) {
target._mesh = mpi_gather(laz.ref.mesh(), laz.c, laz.root);
target._data = mpi::gather(laz.ref.data(), laz.c, laz.root); // HERE ADD OPTION FOR CHUNCK
// do nothing for singularity
return target;
}
//---- allgather ----
static G &complete_operation(G &target, mpi_lazy<tag::allgather, G> laz) {
target._data = mpi::allgather(laz.ref.data(), laz.c, laz.root); // HERE ADD OPTION FOR CHUNCK
// do nothing for singularity
return target;
}
};
} // mpi namespace
} // namespace triqs

View File

@ -64,11 +64,11 @@ namespace mpi {
auto slow_size = a.size();
auto sendcounts = std::vector<int>(c.size());
auto displs = std::vector<int>(c.size() + 1, 0);
int recvcount = slice_length(slow_size - 1, c, c.rank());
int recvcount = slice_length(slow_size - 1, c.size(), c.rank());
std::vector<T> b(recvcount);
for (int r = 0; r < c.size(); ++r) {
sendcounts[r] = slice_length(slow_size - 1, c, r);
sendcounts[r] = slice_length(slow_size - 1, c.size(), r);
displs[r + 1] = sendcounts[r] + displs[r];
}