diff --git a/test/triqs/mpi/mpi_array.cpp b/test/triqs/mpi/mpi_array.cpp index e02bf7e9..cea0b1b1 100644 --- a/test/triqs/mpi/mpi_array.cpp +++ b/test/triqs/mpi/mpi_array.cpp @@ -45,7 +45,10 @@ int main(int argc, char* argv[]) { A(i_, j_) << i_ + 10 * j_; + //std::cerr << "B0 "<< B < w_; + + auto g1 = gf{{beta, Fermion, Nfreq}, {1, 1}}; // using ARR = array; + g1(w_) << 1 / (w_ + 1); + + out << "g1.data" << g1.data() << std::endl; + + { + out<< "reduction "<< std::endl; + gf g2 = mpi::reduce(g1, world); + out << g2.data()< g2 = mpi::allreduce(g1, world); + out << g2.data()< g2 = mpi::scatter(g1); + g2(w_) << g2(w_) * (1 + world.rank()); + g1 = mpi::allgather(g2); + + out << g1.data() << std::endl; + } + + { + out << "Building directly scattered, and gather" << std::endl; + auto m = mpi_scatter(gf_mesh{beta, Fermion, Nfreq}, world, 0); + auto g3 = gf{m, {1, 1}}; + g3(w_) << 1 / (w_ + 1); + auto g4 = g3; + out<< "chunk ..."< #include #include +#include #include #include "./tools.hpp" #include "./data_proxies.hpp" @@ -424,6 +425,9 @@ namespace gfs { : B() { *this = x; } + + // mpi lazy + template gf(mpi::mpi_lazy x) : gf() { operator=(x); } gf(typename B::mesh_t m, typename B::data_t dat, typename B::singularity_view_t const &si, typename B::symmetry_t const &s, typename B::indices_t const &ind, std::string name = "") @@ -453,6 +457,13 @@ namespace gfs { return *this; } + friend struct mpi::mpi_impl_triqs_gfs; //allowed to modify mesh + + // + template void operator=(mpi::mpi_lazy x) { + mpi::mpi_impl_triqs_gfs::complete_operation(*this, x); + } + template void operator=(RHS &&rhs) { this->_mesh = rhs.mesh(); this->_data.resize(get_gf_data_shape(rhs)); @@ -841,6 +852,17 @@ namespace gfs { }; } // gfs_implementation } + +namespace mpi { + + template + struct mpi_impl, void> : mpi_impl_triqs_gfs> {}; + + template + struct mpi_impl, void> : mpi_impl_triqs_gfs> {}; + +} + } // same as for arrays : views cannot be swapped by the std::swap. Delete it diff --git a/triqs/gfs/local/tail.hpp b/triqs/gfs/local/tail.hpp index 68c2fcb8..5de8e8e4 100644 --- a/triqs/gfs/local/tail.hpp +++ b/triqs/gfs/local/tail.hpp @@ -18,11 +18,12 @@ * TRIQS. If not, see . * ******************************************************************************/ -#ifndef TRIQS_GF_LOCAL_TAIL_H -#define TRIQS_GF_LOCAL_TAIL_H +#pragma once #include #include #include +#include +#include namespace triqs { namespace gfs { namespace local { @@ -50,6 +51,7 @@ namespace triqs { namespace gfs { namespace local { /// A common implementation class. Idiom: ValueView template class tail_impl { public: + TRIQS_MPI_IMPLEMENTED_VIA_BOOST; typedef tail_view view_type; typedef tail regular_type; @@ -171,8 +173,9 @@ namespace triqs { namespace gfs { namespace local { } friend std::ostream & operator << (std::ostream & out, tail_impl const & x) { + if (x.data().is_empty()) return out << "empty tail"<(r.m_tuple) = mpi_scatter(std::get<0>(r.m_tuple), c, root); + return r; + } + + /// Opposite of scatter : rebuild the original mesh, without a window + friend matsubara_freq_mesh mpi_gather(matsubara_freq_mesh m, mpi::communicator c, int root) { + auto r = m; // same domain, but mesh with a window. Ok ? + std::get<0>(r.m_tuple) = mpi_gather(std::get<0>(r.m_tuple), c, root); + return r; + } + /// Conversions point <-> index <-> linear_index typename domain_t::point_t index_to_point(index_t const &ind) const { domain_pt_t res; diff --git a/triqs/mpi/arrays.hpp b/triqs/mpi/arrays.hpp index 17b79624..4dde987b 100644 --- a/triqs/mpi/arrays.hpp +++ b/triqs/mpi/arrays.hpp @@ -40,8 +40,12 @@ namespace mpi { auto dims = ref.shape(); long slow_size = first_dim(ref); + if (std::is_same::value) { + // optionally check all dims are the same ? + } + if (std::is_same::value) { - dims[0] = slice_length(slow_size - 1, c, c.rank()); + dims[0] = mpi::slice_length(slow_size - 1, c.size(), c.rank()); } if (std::is_same::value) { @@ -87,7 +91,7 @@ namespace mpi { static void allreduce_in_place(communicator c, A &a, int root) { check_is_contiguous(a); // assume arrays have the same size on all nodes... - MPI_Allreduce(MPI_IN_PLACE, a.data_start(), a.domain().number_of_elements(), D(), MPI_SUM, root, c.get()); + MPI_Allreduce(MPI_IN_PLACE, a.data_start(), a.domain().number_of_elements(), D(), MPI_SUM, c.get()); } //--------- @@ -137,6 +141,18 @@ namespace arrays { private: static MPI_Datatype D() { return mpi::mpi_datatype::invoke(); } + //--------------------------------- + void _invoke(triqs::mpi::tag::reduce) { + lhs.resize(laz.domain()); + MPI_Reduce((void *)laz.ref.data_start(), (void *)lhs.data_start(), laz.ref.domain().number_of_elements(), D(), MPI_SUM, laz.root, laz.c.get()); + } + + //--------------------------------- + void _invoke(triqs::mpi::tag::allreduce) { + lhs.resize(laz.domain()); + MPI_Allreduce((void *)laz.ref.data_start(), (void *)lhs.data_start(), laz.ref.domain().number_of_elements(), D(), MPI_SUM, laz.c.get()); + } + //--------------------------------- void _invoke(triqs::mpi::tag::scatter) { lhs.resize(laz.domain()); @@ -146,10 +162,10 @@ namespace arrays { auto slow_stride = laz.ref.indexmap().strides()[0]; auto sendcounts = std::vector(c.size()); auto displs = std::vector(c.size() + 1, 0); - int recvcount = slice_length(slow_size - 1, c, c.rank()) * slow_stride; + int recvcount = mpi::slice_length(slow_size - 1, c.size(), c.rank()) * slow_stride; for (int r = 0; r < c.size(); ++r) { - sendcounts[r] = slice_length(slow_size - 1, c, r) * slow_stride; + sendcounts[r] = mpi::slice_length(slow_size - 1, c.size(), r) * slow_stride; displs[r + 1] = sendcounts[r] + displs[r]; } diff --git a/triqs/mpi/base.hpp b/triqs/mpi/base.hpp index 862f9537..40e81d4f 100644 --- a/triqs/mpi/base.hpp +++ b/triqs/mpi/base.hpp @@ -23,6 +23,12 @@ //#include #include +namespace boost { // forward declare in case we do not include boost. +namespace mpi { + class communicator; +} +} + namespace triqs { namespace mpi { @@ -41,6 +47,11 @@ namespace mpi { MPI_Comm get() const { return _com; } + inline communicator(boost::mpi::communicator); + + /// Cast to the boost mpi communicator + inline operator boost::mpi::communicator () const; + int rank() const { int num; MPI_Comm_rank(_com, &num); @@ -67,6 +78,13 @@ namespace mpi { /// The implementation of mpi ops for each type template struct mpi_impl; + + /// A small lazy tagged class + template struct mpi_lazy { + T const &ref; + int root; + communicator c; + }; // ---------------------------------------- // ------- top level functions ------- @@ -136,6 +154,26 @@ namespace mpi { struct mpi_impl::value || triqs::is_complex::value>> : mpi_impl_basic {}; //------------ Some helper function + + // Given a range [first, last], slice it regularly for a node of rank 'rank' among n_nodes. + // If the range is not dividable in n_nodes equal parts, + // the first nodes have one more elements than the last ones. + inline std::pair slice_range(long first, long last, int n_nodes, int rank) { + long chunk = (last - first + 1) / n_nodes; + long n_large_nodes = (last - first + 1) - n_nodes * chunk; + if (rank <= n_large_nodes - 1) // first, larger nodes, use chunk + 1 + return {first + rank * (chunk + 1), first + (rank + 1) * (chunk + 1) - 1}; + else // others nodes : shift the first by 1*n_large_nodes, used chunk + return {first + n_large_nodes + rank * chunk, first + n_large_nodes + (rank + 1) * chunk - 1}; + } + + // TODO RECHECK TEST + inline long slice_length(long imax, int n_nodes, int rank) { + auto r = slice_range(0, imax, n_nodes, rank); + return r.second - r.first + 1; + } + + /* inline long slice_length(size_t imax, communicator c, int r) { auto imin = 0; long j = (imax - imin + 1) / c.size(); @@ -143,6 +181,7 @@ namespace mpi { auto r_min = (r <= i - 1 ? imin + r * (j + 1) : imin + r * j + i); auto r_max = (r <= i - 1 ? imin + (r + 1) * (j + 1) - 1 : imin + (r + 1) * j + i - 1); return r_max - r_min + 1; - }; + } + */ } } diff --git a/triqs/mpi/boost.hpp b/triqs/mpi/boost.hpp index d8a83676..74a76e6a 100644 --- a/triqs/mpi/boost.hpp +++ b/triqs/mpi/boost.hpp @@ -22,9 +22,20 @@ #include "./base.hpp" #include +#define TRIQS_MPI_IMPLEMENTED_VIA_BOOST using triqs_mpi_via_boost = void; + namespace triqs { namespace mpi { + // implement the communicator cast + inline communicator::operator boost::mpi::communicator() const { + return boost::mpi::communicator(_com, boost::mpi::comm_duplicate); + // duplicate policy : cf http://www.boost.org/doc/libs/1_56_0/doc/html/boost/mpi/comm_create_kind.html + } + + // reverse : construct (implicit) the communicator from the boost one. + inline communicator::communicator(boost::mpi::communicator c) :_com(c) {} + /** ------------------------------------------------------------ * Type which we use boost::mpi * ---------------------------------------------------------- **/ @@ -39,7 +50,7 @@ namespace mpi { static T invoke(tag::allreduce, communicator c, T const &a, int root) { T b; - boost::mpi::all_reduce(c, a, b, std::c14::plus<>(), root); + boost::mpi::all_reduce(c, a, b, std::c14::plus<>()); return b; } @@ -51,8 +62,8 @@ namespace mpi { static void allgather(communicator c, T const &, int root) = delete; }; - // default - //template struct mpi_impl : mpi_impl_boost_mpi {}; + // If type T has a mpi_implementation nested struct, then it is mpi_impl. + template struct mpi_impl : mpi_impl_boost_mpi {}; }}//namespace diff --git a/triqs/mpi/generic.hpp b/triqs/mpi/generic.hpp index fb3523f8..c64764a6 100644 --- a/triqs/mpi/generic.hpp +++ b/triqs/mpi/generic.hpp @@ -23,25 +23,30 @@ #include #define TRIQS_MPI_IMPLEMENTED_AS_TUPLEVIEW using triqs_mpi_as_tuple = void; +#define TRIQS_MPI_IMPLEMENTED_AS_TUPLEVIEW_NO_LAZY using triqs_mpi_as_tuple_no_lazy = void; namespace triqs { namespace mpi { - template struct mpi_lazy { - T const &ref; - int root; - communicator c; - }; - /** ------------------------------------------------------------ * Type which are recursively treated by reducing them to a tuple * of smaller objects. * ---------------------------------------------------------- **/ - template struct mpi_impl_tuple { + template struct mpi_impl_tuple { mpi_impl_tuple() = default; - template static mpi_lazy invoke(Tag, communicator c, T const &a, int root) { + + /// invoke + template static mpi_lazy invoke_impl(std::true_type, Tag, communicator c, T const &a, int root) { return {a, root, c}; } + + template static T &invoke_impl(std::false_type, Tag, communicator c, T const &a, int root) { + return complete_operation(a, {a, root, c}); + } + + template static mpi_lazy invoke(Tag, communicator c, T const &a, int root) { + return invoke_impl(std::integral_constant(), Tag(), c, a, root); + } #ifdef __cpp_generic_lambdas static void reduce_in_place(communicator c, T &a, int root) { @@ -57,6 +62,7 @@ namespace mpi { triqs::tuple::for_each_zip(l, view_as_tuple(target), view_as_tuple(laz.ref)); return target; } + #else struct aux1 { @@ -89,15 +95,17 @@ namespace mpi { } }; - template static void complete_operation(T &target, mpi_lazy laz) { + template static T& complete_operation(T &target, mpi_lazy laz) { auto l = aux3{laz}; triqs::tuple::for_each_zip(l, view_as_tuple(target), view_as_tuple(laz.ref)); + return target; } #endif }; // If type T has a mpi_implementation nested struct, then it is mpi_impl. - template struct mpi_impl : mpi_impl_tuple {}; + template struct mpi_impl : mpi_impl_tuple {}; + template struct mpi_impl : mpi_impl_tuple {}; } } // namespace diff --git a/triqs/mpi/gf.hpp b/triqs/mpi/gf.hpp new file mode 100644 index 00000000..f11a8c1e --- /dev/null +++ b/triqs/mpi/gf.hpp @@ -0,0 +1,98 @@ +/******************************************************************************* + * + * TRIQS: a Toolbox for Research in Interacting Quantum Systems + * + * Copyright (C) 2014 by O. Parcollet + * + * TRIQS is free software: you can redistribute it and/or modify it under the + * terms of the GNU General Public License as published by the Free Software + * Foundation, either version 3 of the License, or (at your option) any later + * version. + * + * TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License along with + * TRIQS. If not, see . + * + ******************************************************************************/ +#pragma once +#include "./base.hpp" +#include + +namespace triqs { +namespace mpi { + + //-------------------------------------------------------------------------------------------------------- + + // When value_type is a basic type, we can directly call the C API + template struct mpi_impl_triqs_gfs { + + //--------- + static void reduce_in_place(communicator c, G &g, int root) { + triqs::mpi::reduce_in_place(c, g.data(), root); + triqs::mpi::reduce_in_place(c, g.singularity(), root); + } + + //--------- + /*static void allreduce_in_place(communicator c, G &g, int root) { + triqs::mpi::allreduce_in_place(c, g.data(), root); + triqs::mpi::allreduce_in_place(c, g.singularity(), root); + } +*/ + + //--------- + static void broadcast(communicator c, G &g, int root) { + triqs::mpi::broadcast(c, g.data(), root); + triqs::mpi::broadcast(c, g.singularity(), root); + } + + //--------- + template static mpi_lazy invoke(Tag, communicator c, G const &g, int root) { + return {g, root, c}; + } + + //---- reduce ---- + static G &complete_operation(G &target, mpi_lazy laz) { + target._data = mpi::reduce(laz.ref.data(), laz.c, laz.root); + target._singularity = mpi::reduce(laz.ref.singularity(), laz.c, laz.root); + return target; + } + + //---- allreduce ---- + static G &complete_operation(G &target, mpi_lazy laz) { + target._data = mpi::allreduce(laz.ref.data(), laz.c, laz.root); + target._singularity = mpi::allreduce(laz.ref.singularity(), laz.c, laz.root); + return target; + } + + //---- scatter ---- + static G &complete_operation(G &target, mpi_lazy laz) { + target._mesh = mpi_scatter(laz.ref.mesh(), laz.c, laz.root); + target._data = mpi::scatter(laz.ref.data(), laz.c, laz.root); // HERE ADD OPTION FOR CHUNCK + target._singularity = laz.ref.singularity(); + //mpi::broadcast(target._singularity, laz.c, laz.root); + return target; + } + + //---- gather ---- + static G &complete_operation(G &target, mpi_lazy laz) { + target._mesh = mpi_gather(laz.ref.mesh(), laz.c, laz.root); + target._data = mpi::gather(laz.ref.data(), laz.c, laz.root); // HERE ADD OPTION FOR CHUNCK + // do nothing for singularity + return target; + } + + //---- allgather ---- + static G &complete_operation(G &target, mpi_lazy laz) { + target._data = mpi::allgather(laz.ref.data(), laz.c, laz.root); // HERE ADD OPTION FOR CHUNCK + // do nothing for singularity + return target; + } + + }; + +} // mpi namespace +} // namespace triqs diff --git a/triqs/mpi/vector.hpp b/triqs/mpi/vector.hpp index 3f1b1b9c..e1ae5007 100644 --- a/triqs/mpi/vector.hpp +++ b/triqs/mpi/vector.hpp @@ -64,11 +64,11 @@ namespace mpi { auto slow_size = a.size(); auto sendcounts = std::vector(c.size()); auto displs = std::vector(c.size() + 1, 0); - int recvcount = slice_length(slow_size - 1, c, c.rank()); + int recvcount = slice_length(slow_size - 1, c.size(), c.rank()); std::vector b(recvcount); for (int r = 0; r < c.size(); ++r) { - sendcounts[r] = slice_length(slow_size - 1, c, r); + sendcounts[r] = slice_length(slow_size - 1, c.size(), r); displs[r + 1] = sendcounts[r] + displs[r]; }