mirror of
https://github.com/triqs/dft_tools
synced 2024-12-25 05:43:40 +01:00
[mpi] draft of gf support
- done Matsubara freq for testing and rereading. - TODO: generalize to other meshes. - draft for multi var gf
This commit is contained in:
parent
ebbb2f0b25
commit
9c129cb224
@ -45,7 +45,10 @@ int main(int argc, char* argv[]) {
|
||||
|
||||
A(i_, j_) << i_ + 10 * j_;
|
||||
|
||||
//std::cerr << "B0 "<< B <<std::endl;
|
||||
B = mpi::scatter(A, world);
|
||||
std::cerr << "B "<< B <<std::endl;
|
||||
|
||||
ARR C = mpi::scatter(A, world);
|
||||
|
||||
std::ofstream out("node" + std::to_string(world.rank()));
|
||||
@ -66,5 +69,14 @@ int main(int argc, char* argv[]) {
|
||||
|
||||
AA = mpi::allgather(B, world);
|
||||
out << " AA = " << AA << std::endl;
|
||||
|
||||
ARR r1 = mpi::reduce(A, world);
|
||||
out <<" Reduce "<< std::endl;
|
||||
out << " r1 = " << r1 << std::endl;
|
||||
|
||||
ARR r2 = mpi::allreduce(A, world);
|
||||
out <<" AllReduce "<< std::endl;
|
||||
out << " r2 = " << r2 << std::endl;
|
||||
|
||||
}
|
||||
|
||||
|
@ -49,7 +49,7 @@ struct my_object {
|
||||
|
||||
// assigment is almost done already...
|
||||
template <typename Tag> my_object &operator=(mpi_lazy<Tag, my_object> x) {
|
||||
return mpi_impl_tuple<my_object>::complete_operation(*this, x);
|
||||
return mpi_impl<my_object>::complete_operation(*this, x);
|
||||
}
|
||||
};
|
||||
|
||||
|
104
test/triqs/mpi/mpi_gf.cpp
Normal file
104
test/triqs/mpi/mpi_gf.cpp
Normal file
@ -0,0 +1,104 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* TRIQS: a Toolbox for Research in Interacting Quantum Systems
|
||||
*
|
||||
* Copyright (C) 2013 by O. Parcollet
|
||||
*
|
||||
* TRIQS is free software: you can redistribute it and/or modify it under the
|
||||
* terms of the GNU General Public License as published by the Free Software
|
||||
* Foundation, either version 3 of the License, or (at your option) any later
|
||||
* version.
|
||||
*
|
||||
* TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY
|
||||
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
|
||||
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
|
||||
* details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License along with
|
||||
* TRIQS. If not, see <http://www.gnu.org/licenses/>.
|
||||
*
|
||||
******************************************************************************/
|
||||
#define TRIQS_ARRAYS_ENFORCE_BOUNDCHECK
|
||||
#include <iostream>
|
||||
#include <type_traits>
|
||||
#include <triqs/gfs.hpp>
|
||||
#include <triqs/mpi.hpp>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
using namespace triqs;
|
||||
using namespace triqs::arrays;
|
||||
using namespace triqs::gfs;
|
||||
using namespace triqs::clef;
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
|
||||
mpi::environment env(argc, argv);
|
||||
mpi::communicator world;
|
||||
|
||||
std::ofstream out("node" + std::to_string(world.rank()));
|
||||
|
||||
double beta = 10;
|
||||
int Nfreq = 8;
|
||||
placeholder<0> w_;
|
||||
|
||||
auto g1 = gf<imfreq>{{beta, Fermion, Nfreq}, {1, 1}}; // using ARR = array<double,2>;
|
||||
g1(w_) << 1 / (w_ + 1);
|
||||
|
||||
out << "g1.data" << g1.data() << std::endl;
|
||||
|
||||
{
|
||||
out<< "reduction "<< std::endl;
|
||||
gf<imfreq> g2 = mpi::reduce(g1, world);
|
||||
out << g2.data()<<std::endl;
|
||||
out << g2.singularity() << std::endl;
|
||||
}
|
||||
|
||||
{
|
||||
out<< "all reduction "<< std::endl;
|
||||
gf<imfreq> g2 = mpi::allreduce(g1, world);
|
||||
out << g2.data()<<std::endl;
|
||||
out << g2.singularity() << std::endl;
|
||||
}
|
||||
|
||||
{
|
||||
out << "scatter-gather test with =" << std::endl;
|
||||
auto g2 = g1;
|
||||
auto g2b = g1;
|
||||
|
||||
g2 = mpi::scatter(g1);
|
||||
g2(w_) << g2(w_) * (1 + world.rank());
|
||||
g2b = mpi::gather(g2);
|
||||
|
||||
out << g2b.data() << std::endl;
|
||||
}
|
||||
|
||||
{
|
||||
out << "scatter-allgather test with construction" << std::endl;
|
||||
|
||||
gf<imfreq> g2 = mpi::scatter(g1);
|
||||
g2(w_) << g2(w_) * (1 + world.rank());
|
||||
g1 = mpi::allgather(g2);
|
||||
|
||||
out << g1.data() << std::endl;
|
||||
}
|
||||
|
||||
{
|
||||
out << "Building directly scattered, and gather" << std::endl;
|
||||
auto m = mpi_scatter(gf_mesh<imfreq>{beta, Fermion, Nfreq}, world, 0);
|
||||
auto g3 = gf<imfreq>{m, {1, 1}};
|
||||
g3(w_) << 1 / (w_ + 1);
|
||||
auto g4 = g3;
|
||||
out<< "chunk ..."<<std::endl;
|
||||
out << g3.data() << std::endl;
|
||||
out<< "gather"<<std::endl;
|
||||
g4 = mpi::gather(g3);
|
||||
out << g4.data() << std::endl;
|
||||
out<< "allgather"<<std::endl;
|
||||
g4 = mpi::allgather(g3);
|
||||
out << g4.data() << std::endl;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -26,6 +26,7 @@
|
||||
#include <triqs/utility/tuple_tools.hpp>
|
||||
#include <triqs/utility/c14.hpp>
|
||||
#include <triqs/arrays/h5.hpp>
|
||||
#include <triqs/mpi/gf.hpp>
|
||||
#include <vector>
|
||||
#include "./tools.hpp"
|
||||
#include "./data_proxies.hpp"
|
||||
@ -425,6 +426,9 @@ namespace gfs {
|
||||
*this = x;
|
||||
}
|
||||
|
||||
// mpi lazy
|
||||
template <typename Tag> gf(mpi::mpi_lazy<Tag, gf> x) : gf() { operator=(x); }
|
||||
|
||||
gf(typename B::mesh_t m, typename B::data_t dat, typename B::singularity_view_t const &si, typename B::symmetry_t const &s,
|
||||
typename B::indices_t const &ind, std::string name = "")
|
||||
: B(std::move(m), std::move(dat), si, s, ind, name, typename B::evaluator_t{}) {}
|
||||
@ -453,6 +457,13 @@ namespace gfs {
|
||||
return *this;
|
||||
}
|
||||
|
||||
friend struct mpi::mpi_impl_triqs_gfs<gf>; //allowed to modify mesh
|
||||
|
||||
//
|
||||
template <typename Tag> void operator=(mpi::mpi_lazy<Tag, gf> x) {
|
||||
mpi::mpi_impl_triqs_gfs<gf>::complete_operation(*this, x);
|
||||
}
|
||||
|
||||
template <typename RHS> void operator=(RHS &&rhs) {
|
||||
this->_mesh = rhs.mesh();
|
||||
this->_data.resize(get_gf_data_shape(rhs));
|
||||
@ -841,6 +852,17 @@ namespace gfs {
|
||||
};
|
||||
} // gfs_implementation
|
||||
}
|
||||
|
||||
namespace mpi {
|
||||
|
||||
template <typename Variable, typename Target, typename Opt>
|
||||
struct mpi_impl<gfs::gf<Variable, Target, Opt>, void> : mpi_impl_triqs_gfs<gfs::gf<Variable, Target, Opt>> {};
|
||||
|
||||
template <typename Variable, typename Target, typename Opt, bool IsConst>
|
||||
struct mpi_impl<gfs::gf_view<Variable, Target, Opt, IsConst>, void> : mpi_impl_triqs_gfs<gfs::gf_view<Variable, Target, Opt, IsConst>> {};
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// same as for arrays : views cannot be swapped by the std::swap. Delete it
|
||||
|
@ -18,11 +18,12 @@
|
||||
* TRIQS. If not, see <http://www.gnu.org/licenses/>.
|
||||
*
|
||||
******************************************************************************/
|
||||
#ifndef TRIQS_GF_LOCAL_TAIL_H
|
||||
#define TRIQS_GF_LOCAL_TAIL_H
|
||||
#pragma once
|
||||
#include <triqs/arrays.hpp>
|
||||
#include <triqs/arrays/algorithms.hpp>
|
||||
#include <triqs/gfs/tools.hpp>
|
||||
#include <triqs/mpi/boost.hpp>
|
||||
#include <boost/serialization/complex.hpp>
|
||||
|
||||
namespace triqs { namespace gfs { namespace local {
|
||||
|
||||
@ -50,6 +51,7 @@ namespace triqs { namespace gfs { namespace local {
|
||||
/// A common implementation class. Idiom: ValueView
|
||||
template<bool IsView> class tail_impl {
|
||||
public:
|
||||
TRIQS_MPI_IMPLEMENTED_VIA_BOOST;
|
||||
typedef tail_view view_type;
|
||||
typedef tail regular_type;
|
||||
|
||||
@ -171,8 +173,9 @@ namespace triqs { namespace gfs { namespace local {
|
||||
}
|
||||
|
||||
friend std::ostream & operator << (std::ostream & out, tail_impl const & x) {
|
||||
if (x.data().is_empty()) return out << "empty tail"<<std::endl;
|
||||
out <<"tail/tail_view: min/smallest/max = "<< x.order_min() << " " << x.smallest_nonzero() << " "<< x.order_max();
|
||||
for (long u = x.order_min(); u <= x.order_max(); ++u) out <<"\n ... Order "<<u << " = " << x(u);
|
||||
for (long u = x.order_min(); u <= x.order_max(); ++u) out <<"\n ... Order "<<u << " = " << x(u);
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -283,7 +286,7 @@ namespace triqs { namespace gfs { namespace local {
|
||||
inline tail transpose(tail_view t) { return {transposed_view(t.data(),0,2,1), transposed_view(t.mask_view(),1,0),t.order_min()};}
|
||||
|
||||
/// Slice in orbital space
|
||||
//template<bool V> tail_view slice_target(tail_impl<V> const & t, tqa::range R1, tqa::range R2) {
|
||||
//template<bool V> tail_view slice_target(tail_impl<V> const & t, tqa::range R1, tqa::range R2)
|
||||
inline tail_view slice_target(tail_view t, tqa::range R1, tqa::range R2) {
|
||||
return tail_view(t.data()(tqa::range(),R1,R2), t.mask_view()(R1,R2), t.order_min());
|
||||
}
|
||||
@ -407,7 +410,5 @@ namespace triqs { namespace gfs { namespace local {
|
||||
|
||||
#undef DEFINE_OPERATOR
|
||||
|
||||
|
||||
|
||||
}}}
|
||||
#endif
|
||||
}}
|
||||
}
|
||||
|
@ -35,11 +35,21 @@ namespace gfs {
|
||||
using domain_pt_t = typename domain_t::point_t;
|
||||
|
||||
/// Constructor
|
||||
matsubara_freq_mesh() : _dom(), _n_pts(0), _positive_only(true) {}
|
||||
matsubara_freq_mesh(domain_t dom, long n_pts = 1025, bool positive_only = true)
|
||||
: _dom(std::move(dom)), _n_pts(n_pts), _positive_only(positive_only) {
|
||||
if (_positive_only) {
|
||||
_first_index = 0;
|
||||
_last_index = n_pts - 1; // CORRECTION
|
||||
} else {
|
||||
_last_index = (_n_pts - (_dom.statistic == Boson ? 1 : 2)) / 2;
|
||||
_first_index = -(_last_index + (_dom.statistic == Fermion));
|
||||
}
|
||||
_first_index_window = _first_index;
|
||||
_last_index_window = _last_index;
|
||||
}
|
||||
|
||||
/// Constructor
|
||||
matsubara_freq_mesh(domain_t dom, int n_pts=1025, bool positive_only = true)
|
||||
: _dom(std::move(dom)), _n_pts(n_pts), _positive_only(positive_only) {}
|
||||
matsubara_freq_mesh() : matsubara_freq_mesh(domain_t(), 0, true){}
|
||||
|
||||
/// Constructor
|
||||
matsubara_freq_mesh(double beta, statistic_enum S, int n_pts = 1025, bool positive_only = true)
|
||||
@ -48,6 +58,17 @@ namespace gfs {
|
||||
/// Copy constructor
|
||||
matsubara_freq_mesh(matsubara_freq_mesh const &) = default;
|
||||
|
||||
/// Scatter a mesh over the communicator c
|
||||
friend matsubara_freq_mesh mpi_scatter(matsubara_freq_mesh m, mpi::communicator c, int root) {
|
||||
auto m2 = matsubara_freq_mesh{m.domain(), m.size(), m.positive_only()};
|
||||
std::tie(m2._first_index_window, m2._last_index_window) = mpi::slice_range(m2._first_index, m2._last_index, c.size(), c.rank());
|
||||
return m2;
|
||||
}
|
||||
|
||||
friend matsubara_freq_mesh mpi_gather(matsubara_freq_mesh m, mpi::communicator c, int root) {
|
||||
return matsubara_freq_mesh{m.domain(), m.size(), m.positive_only()};
|
||||
}
|
||||
|
||||
/// The corresponding domain
|
||||
domain_t const &domain() const { return _dom; }
|
||||
|
||||
@ -60,20 +81,29 @@ namespace gfs {
|
||||
**/
|
||||
|
||||
/// last Matsubara index
|
||||
int last_index() const { return (_positive_only ? _n_pts : (_n_pts - (_dom.statistic == Boson ? 1 : 2))/2);}
|
||||
int last_index() const { return _last_index;}
|
||||
|
||||
/// first Matsubara index
|
||||
int first_index() const { return -(_positive_only ? 0 : last_index() + (_dom.statistic == Fermion)); }
|
||||
int first_index() const { return _first_index;}
|
||||
|
||||
/// last Matsubara index of the window
|
||||
int last_index_window() const { return _last_index_window;}
|
||||
|
||||
/// first Matsubara index of the window
|
||||
int first_index_window() const { return _first_index_window;}
|
||||
|
||||
/// Size (linear) of the mesh
|
||||
long size() const { return _n_pts;}
|
||||
//long size() const { return _n_pts;}
|
||||
|
||||
/// Size (linear) of the mesh of the window
|
||||
long size() const { return _last_index_window - _first_index_window + 1; }
|
||||
|
||||
/// From an index of a point in the mesh, returns the corresponding point in the domain
|
||||
domain_pt_t index_to_point(index_t ind) const { return 1_j * M_PI * (2 * ind + (_dom.statistic == Fermion)) / _dom.beta; }
|
||||
|
||||
/// Flatten the index in the positive linear index for memory storage (almost trivial here).
|
||||
long index_to_linear(index_t ind) const { return ind - first_index(); }
|
||||
index_t linear_to_index(long lind) const { return lind + first_index(); }
|
||||
long index_to_linear(index_t ind) const { return ind - first_index_window(); }
|
||||
index_t linear_to_index(long lind) const { return lind + first_index_window(); }
|
||||
|
||||
/// Is the mesh only for positive omega_n (G(tau) real))
|
||||
bool positive_only() const { return _positive_only;}
|
||||
@ -86,18 +116,18 @@ namespace gfs {
|
||||
struct mesh_point_t : tag::mesh_point, matsubara_freq {
|
||||
mesh_point_t() = default;
|
||||
mesh_point_t(matsubara_freq_mesh const &mesh, index_t const &index_)
|
||||
: matsubara_freq(index_, mesh.domain().beta, mesh.domain().statistic),
|
||||
first_index(mesh.first_index()),
|
||||
index_stop(mesh.first_index() + mesh.size() - 1) {}
|
||||
mesh_point_t(matsubara_freq_mesh const &mesh) : mesh_point_t(mesh, mesh.first_index()) {}
|
||||
: matsubara_freq(index_, mesh.domain().beta, mesh.domain().statistic)
|
||||
, first_index_window(mesh.first_index_window())
|
||||
, last_index_window(mesh.last_index_window()) {}
|
||||
mesh_point_t(matsubara_freq_mesh const &mesh) : mesh_point_t(mesh, mesh.first_index_window()) {}
|
||||
void advance() { ++n; }
|
||||
long linear_index() const { return n - first_index; }
|
||||
long linear_index() const { return n - first_index_window; }
|
||||
long index() const { return n; }
|
||||
bool at_end() const { return (n == index_stop + 1); } // at_end means " one after the last one", as in STL
|
||||
void reset() { n = first_index; }
|
||||
bool at_end() const { return (n == last_index_window + 1); } // at_end means " one after the last one", as in STL
|
||||
void reset() { n = first_index_window; }
|
||||
|
||||
private:
|
||||
index_t first_index, index_stop;
|
||||
index_t first_index_window, last_index_window;
|
||||
};
|
||||
|
||||
/// Accessing a point of the mesh from its index
|
||||
@ -164,6 +194,7 @@ namespace gfs {
|
||||
domain_t _dom;
|
||||
int _n_pts;
|
||||
bool _positive_only;
|
||||
long _first_index, _last_index, _first_index_window, _last_index_window;
|
||||
};
|
||||
|
||||
//-------------------------------------------------------
|
||||
|
@ -39,6 +39,7 @@ namespace gfs {
|
||||
|
||||
mesh_product() {}
|
||||
mesh_product(Meshes const &... meshes) : m_tuple(meshes...), _dom(meshes.domain()...) {}
|
||||
mesh_product(mesh_product const &) = default;
|
||||
|
||||
domain_t const &domain() const { return _dom; }
|
||||
m_tuple_t const &components() const { return m_tuple; }
|
||||
@ -49,6 +50,20 @@ namespace gfs {
|
||||
return triqs::tuple::fold([](auto const &m, size_t R) { return R * m.size(); }, m_tuple, 1);
|
||||
}
|
||||
|
||||
/// Scatter the first mesh over the communicator c
|
||||
friend mesh_product mpi_scatter(mesh_product const &m, mpi::communicator c, int root) {
|
||||
auto r = m; // same domain, but mesh with a window. Ok ?
|
||||
std::get<0>(r.m_tuple) = mpi_scatter(std::get<0>(r.m_tuple), c, root);
|
||||
return r;
|
||||
}
|
||||
|
||||
/// Opposite of scatter : rebuild the original mesh, without a window
|
||||
friend matsubara_freq_mesh mpi_gather(matsubara_freq_mesh m, mpi::communicator c, int root) {
|
||||
auto r = m; // same domain, but mesh with a window. Ok ?
|
||||
std::get<0>(r.m_tuple) = mpi_gather(std::get<0>(r.m_tuple), c, root);
|
||||
return r;
|
||||
}
|
||||
|
||||
/// Conversions point <-> index <-> linear_index
|
||||
typename domain_t::point_t index_to_point(index_t const &ind) const {
|
||||
domain_pt_t res;
|
||||
|
@ -40,8 +40,12 @@ namespace mpi {
|
||||
auto dims = ref.shape();
|
||||
long slow_size = first_dim(ref);
|
||||
|
||||
if (std::is_same<Tag, tag::reduce>::value) {
|
||||
// optionally check all dims are the same ?
|
||||
}
|
||||
|
||||
if (std::is_same<Tag, tag::scatter>::value) {
|
||||
dims[0] = slice_length(slow_size - 1, c, c.rank());
|
||||
dims[0] = mpi::slice_length(slow_size - 1, c.size(), c.rank());
|
||||
}
|
||||
|
||||
if (std::is_same<Tag, tag::gather>::value) {
|
||||
@ -87,7 +91,7 @@ namespace mpi {
|
||||
static void allreduce_in_place(communicator c, A &a, int root) {
|
||||
check_is_contiguous(a);
|
||||
// assume arrays have the same size on all nodes...
|
||||
MPI_Allreduce(MPI_IN_PLACE, a.data_start(), a.domain().number_of_elements(), D(), MPI_SUM, root, c.get());
|
||||
MPI_Allreduce(MPI_IN_PLACE, a.data_start(), a.domain().number_of_elements(), D(), MPI_SUM, c.get());
|
||||
}
|
||||
|
||||
//---------
|
||||
@ -137,6 +141,18 @@ namespace arrays {
|
||||
private:
|
||||
static MPI_Datatype D() { return mpi::mpi_datatype<typename A::value_type>::invoke(); }
|
||||
|
||||
//---------------------------------
|
||||
void _invoke(triqs::mpi::tag::reduce) {
|
||||
lhs.resize(laz.domain());
|
||||
MPI_Reduce((void *)laz.ref.data_start(), (void *)lhs.data_start(), laz.ref.domain().number_of_elements(), D(), MPI_SUM, laz.root, laz.c.get());
|
||||
}
|
||||
|
||||
//---------------------------------
|
||||
void _invoke(triqs::mpi::tag::allreduce) {
|
||||
lhs.resize(laz.domain());
|
||||
MPI_Allreduce((void *)laz.ref.data_start(), (void *)lhs.data_start(), laz.ref.domain().number_of_elements(), D(), MPI_SUM, laz.c.get());
|
||||
}
|
||||
|
||||
//---------------------------------
|
||||
void _invoke(triqs::mpi::tag::scatter) {
|
||||
lhs.resize(laz.domain());
|
||||
@ -146,10 +162,10 @@ namespace arrays {
|
||||
auto slow_stride = laz.ref.indexmap().strides()[0];
|
||||
auto sendcounts = std::vector<int>(c.size());
|
||||
auto displs = std::vector<int>(c.size() + 1, 0);
|
||||
int recvcount = slice_length(slow_size - 1, c, c.rank()) * slow_stride;
|
||||
int recvcount = mpi::slice_length(slow_size - 1, c.size(), c.rank()) * slow_stride;
|
||||
|
||||
for (int r = 0; r < c.size(); ++r) {
|
||||
sendcounts[r] = slice_length(slow_size - 1, c, r) * slow_stride;
|
||||
sendcounts[r] = mpi::slice_length(slow_size - 1, c.size(), r) * slow_stride;
|
||||
displs[r + 1] = sendcounts[r] + displs[r];
|
||||
}
|
||||
|
||||
|
@ -23,6 +23,12 @@
|
||||
//#include <triqs/utility/tuple_tools.hpp>
|
||||
#include <mpi.h>
|
||||
|
||||
namespace boost { // forward declare in case we do not include boost.
|
||||
namespace mpi {
|
||||
class communicator;
|
||||
}
|
||||
}
|
||||
|
||||
namespace triqs {
|
||||
namespace mpi {
|
||||
|
||||
@ -41,6 +47,11 @@ namespace mpi {
|
||||
|
||||
MPI_Comm get() const { return _com; }
|
||||
|
||||
inline communicator(boost::mpi::communicator);
|
||||
|
||||
/// Cast to the boost mpi communicator
|
||||
inline operator boost::mpi::communicator () const;
|
||||
|
||||
int rank() const {
|
||||
int num;
|
||||
MPI_Comm_rank(_com, &num);
|
||||
@ -68,6 +79,13 @@ namespace mpi {
|
||||
/// The implementation of mpi ops for each type
|
||||
template <typename T, typename Enable = void> struct mpi_impl;
|
||||
|
||||
/// A small lazy tagged class
|
||||
template <typename Tag, typename T> struct mpi_lazy {
|
||||
T const &ref;
|
||||
int root;
|
||||
communicator c;
|
||||
};
|
||||
|
||||
// ----------------------------------------
|
||||
// ------- top level functions -------
|
||||
// ----------------------------------------
|
||||
@ -136,6 +154,26 @@ namespace mpi {
|
||||
struct mpi_impl<T, std14::enable_if_t<std::is_arithmetic<T>::value || triqs::is_complex<T>::value>> : mpi_impl_basic<T> {};
|
||||
|
||||
//------------ Some helper function
|
||||
|
||||
// Given a range [first, last], slice it regularly for a node of rank 'rank' among n_nodes.
|
||||
// If the range is not dividable in n_nodes equal parts,
|
||||
// the first nodes have one more elements than the last ones.
|
||||
inline std::pair<long, long> slice_range(long first, long last, int n_nodes, int rank) {
|
||||
long chunk = (last - first + 1) / n_nodes;
|
||||
long n_large_nodes = (last - first + 1) - n_nodes * chunk;
|
||||
if (rank <= n_large_nodes - 1) // first, larger nodes, use chunk + 1
|
||||
return {first + rank * (chunk + 1), first + (rank + 1) * (chunk + 1) - 1};
|
||||
else // others nodes : shift the first by 1*n_large_nodes, used chunk
|
||||
return {first + n_large_nodes + rank * chunk, first + n_large_nodes + (rank + 1) * chunk - 1};
|
||||
}
|
||||
|
||||
// TODO RECHECK TEST
|
||||
inline long slice_length(long imax, int n_nodes, int rank) {
|
||||
auto r = slice_range(0, imax, n_nodes, rank);
|
||||
return r.second - r.first + 1;
|
||||
}
|
||||
|
||||
/*
|
||||
inline long slice_length(size_t imax, communicator c, int r) {
|
||||
auto imin = 0;
|
||||
long j = (imax - imin + 1) / c.size();
|
||||
@ -143,6 +181,7 @@ namespace mpi {
|
||||
auto r_min = (r <= i - 1 ? imin + r * (j + 1) : imin + r * j + i);
|
||||
auto r_max = (r <= i - 1 ? imin + (r + 1) * (j + 1) - 1 : imin + (r + 1) * j + i - 1);
|
||||
return r_max - r_min + 1;
|
||||
};
|
||||
}
|
||||
*/
|
||||
}
|
||||
}
|
||||
|
@ -22,9 +22,20 @@
|
||||
#include "./base.hpp"
|
||||
#include <boost/mpi.hpp>
|
||||
|
||||
#define TRIQS_MPI_IMPLEMENTED_VIA_BOOST using triqs_mpi_via_boost = void;
|
||||
|
||||
namespace triqs {
|
||||
namespace mpi {
|
||||
|
||||
// implement the communicator cast
|
||||
inline communicator::operator boost::mpi::communicator() const {
|
||||
return boost::mpi::communicator(_com, boost::mpi::comm_duplicate);
|
||||
// duplicate policy : cf http://www.boost.org/doc/libs/1_56_0/doc/html/boost/mpi/comm_create_kind.html
|
||||
}
|
||||
|
||||
// reverse : construct (implicit) the communicator from the boost one.
|
||||
inline communicator::communicator(boost::mpi::communicator c) :_com(c) {}
|
||||
|
||||
/** ------------------------------------------------------------
|
||||
* Type which we use boost::mpi
|
||||
* ---------------------------------------------------------- **/
|
||||
@ -39,7 +50,7 @@ namespace mpi {
|
||||
|
||||
static T invoke(tag::allreduce, communicator c, T const &a, int root) {
|
||||
T b;
|
||||
boost::mpi::all_reduce(c, a, b, std::c14::plus<>(), root);
|
||||
boost::mpi::all_reduce(c, a, b, std::c14::plus<>());
|
||||
return b;
|
||||
}
|
||||
|
||||
@ -51,8 +62,8 @@ namespace mpi {
|
||||
static void allgather(communicator c, T const &, int root) = delete;
|
||||
};
|
||||
|
||||
// default
|
||||
//template <typename T> struct mpi_impl<T> : mpi_impl_boost_mpi<T> {};
|
||||
// If type T has a mpi_implementation nested struct, then it is mpi_impl<T>.
|
||||
template <typename T> struct mpi_impl<T, typename T::triqs_mpi_via_boost> : mpi_impl_boost_mpi<T> {};
|
||||
|
||||
}}//namespace
|
||||
|
||||
|
@ -23,26 +23,31 @@
|
||||
#include <triqs/utility/tuple_tools.hpp>
|
||||
|
||||
#define TRIQS_MPI_IMPLEMENTED_AS_TUPLEVIEW using triqs_mpi_as_tuple = void;
|
||||
#define TRIQS_MPI_IMPLEMENTED_AS_TUPLEVIEW_NO_LAZY using triqs_mpi_as_tuple_no_lazy = void;
|
||||
namespace triqs {
|
||||
namespace mpi {
|
||||
|
||||
template <typename Tag, typename T> struct mpi_lazy {
|
||||
T const &ref;
|
||||
int root;
|
||||
communicator c;
|
||||
};
|
||||
|
||||
/** ------------------------------------------------------------
|
||||
* Type which are recursively treated by reducing them to a tuple
|
||||
* of smaller objects.
|
||||
* ---------------------------------------------------------- **/
|
||||
template <typename T> struct mpi_impl_tuple {
|
||||
template <typename T, bool with_lazy> struct mpi_impl_tuple {
|
||||
|
||||
mpi_impl_tuple() = default;
|
||||
template <typename Tag> static mpi_lazy<Tag, T> invoke(Tag, communicator c, T const &a, int root) {
|
||||
|
||||
/// invoke
|
||||
template <typename Tag> static mpi_lazy<Tag, T> invoke_impl(std::true_type, Tag, communicator c, T const &a, int root) {
|
||||
return {a, root, c};
|
||||
}
|
||||
|
||||
template <typename Tag> static T &invoke_impl(std::false_type, Tag, communicator c, T const &a, int root) {
|
||||
return complete_operation(a, {a, root, c});
|
||||
}
|
||||
|
||||
template <typename Tag> static mpi_lazy<Tag, T> invoke(Tag, communicator c, T const &a, int root) {
|
||||
return invoke_impl(std::integral_constant<bool, with_lazy>(), Tag(), c, a, root);
|
||||
}
|
||||
|
||||
#ifdef __cpp_generic_lambdas
|
||||
static void reduce_in_place(communicator c, T &a, int root) {
|
||||
tuple::for_each(view_as_tuple(a), [c, root](auto &x) { triqs::mpi::reduce_in_place(x, c, root); });
|
||||
@ -57,6 +62,7 @@ namespace mpi {
|
||||
triqs::tuple::for_each_zip(l, view_as_tuple(target), view_as_tuple(laz.ref));
|
||||
return target;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
struct aux1 {
|
||||
@ -89,15 +95,17 @@ namespace mpi {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Tag> static void complete_operation(T &target, mpi_lazy<Tag, T> laz) {
|
||||
template <typename Tag> static T& complete_operation(T &target, mpi_lazy<Tag, T> laz) {
|
||||
auto l = aux3<Tag>{laz};
|
||||
triqs::tuple::for_each_zip(l, view_as_tuple(target), view_as_tuple(laz.ref));
|
||||
return target;
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
// If type T has a mpi_implementation nested struct, then it is mpi_impl<T>.
|
||||
template <typename T> struct mpi_impl<T, typename T::triqs_mpi_as_tuple> : mpi_impl_tuple<T> {};
|
||||
template <typename T> struct mpi_impl<T, typename T::triqs_mpi_as_tuple> : mpi_impl_tuple<T, true> {};
|
||||
template <typename T> struct mpi_impl<T, typename T::triqs_mpi_as_tuple_no_lazy> : mpi_impl_tuple<T, false> {};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
98
triqs/mpi/gf.hpp
Normal file
98
triqs/mpi/gf.hpp
Normal file
@ -0,0 +1,98 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* TRIQS: a Toolbox for Research in Interacting Quantum Systems
|
||||
*
|
||||
* Copyright (C) 2014 by O. Parcollet
|
||||
*
|
||||
* TRIQS is free software: you can redistribute it and/or modify it under the
|
||||
* terms of the GNU General Public License as published by the Free Software
|
||||
* Foundation, either version 3 of the License, or (at your option) any later
|
||||
* version.
|
||||
*
|
||||
* TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY
|
||||
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
|
||||
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
|
||||
* details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License along with
|
||||
* TRIQS. If not, see <http://www.gnu.org/licenses/>.
|
||||
*
|
||||
******************************************************************************/
|
||||
#pragma once
|
||||
#include "./base.hpp"
|
||||
#include <triqs/mpi/generic.hpp>
|
||||
|
||||
namespace triqs {
|
||||
namespace mpi {
|
||||
|
||||
//--------------------------------------------------------------------------------------------------------
|
||||
|
||||
// When value_type is a basic type, we can directly call the C API
|
||||
template <typename G> struct mpi_impl_triqs_gfs {
|
||||
|
||||
//---------
|
||||
static void reduce_in_place(communicator c, G &g, int root) {
|
||||
triqs::mpi::reduce_in_place(c, g.data(), root);
|
||||
triqs::mpi::reduce_in_place(c, g.singularity(), root);
|
||||
}
|
||||
|
||||
//---------
|
||||
/*static void allreduce_in_place(communicator c, G &g, int root) {
|
||||
triqs::mpi::allreduce_in_place(c, g.data(), root);
|
||||
triqs::mpi::allreduce_in_place(c, g.singularity(), root);
|
||||
}
|
||||
*/
|
||||
|
||||
//---------
|
||||
static void broadcast(communicator c, G &g, int root) {
|
||||
triqs::mpi::broadcast(c, g.data(), root);
|
||||
triqs::mpi::broadcast(c, g.singularity(), root);
|
||||
}
|
||||
|
||||
//---------
|
||||
template <typename Tag> static mpi_lazy<Tag, G> invoke(Tag, communicator c, G const &g, int root) {
|
||||
return {g, root, c};
|
||||
}
|
||||
|
||||
//---- reduce ----
|
||||
static G &complete_operation(G &target, mpi_lazy<tag::reduce, G> laz) {
|
||||
target._data = mpi::reduce(laz.ref.data(), laz.c, laz.root);
|
||||
target._singularity = mpi::reduce(laz.ref.singularity(), laz.c, laz.root);
|
||||
return target;
|
||||
}
|
||||
|
||||
//---- allreduce ----
|
||||
static G &complete_operation(G &target, mpi_lazy<tag::allreduce, G> laz) {
|
||||
target._data = mpi::allreduce(laz.ref.data(), laz.c, laz.root);
|
||||
target._singularity = mpi::allreduce(laz.ref.singularity(), laz.c, laz.root);
|
||||
return target;
|
||||
}
|
||||
|
||||
//---- scatter ----
|
||||
static G &complete_operation(G &target, mpi_lazy<tag::scatter, G> laz) {
|
||||
target._mesh = mpi_scatter(laz.ref.mesh(), laz.c, laz.root);
|
||||
target._data = mpi::scatter(laz.ref.data(), laz.c, laz.root); // HERE ADD OPTION FOR CHUNCK
|
||||
target._singularity = laz.ref.singularity();
|
||||
//mpi::broadcast(target._singularity, laz.c, laz.root);
|
||||
return target;
|
||||
}
|
||||
|
||||
//---- gather ----
|
||||
static G &complete_operation(G &target, mpi_lazy<tag::gather, G> laz) {
|
||||
target._mesh = mpi_gather(laz.ref.mesh(), laz.c, laz.root);
|
||||
target._data = mpi::gather(laz.ref.data(), laz.c, laz.root); // HERE ADD OPTION FOR CHUNCK
|
||||
// do nothing for singularity
|
||||
return target;
|
||||
}
|
||||
|
||||
//---- allgather ----
|
||||
static G &complete_operation(G &target, mpi_lazy<tag::allgather, G> laz) {
|
||||
target._data = mpi::allgather(laz.ref.data(), laz.c, laz.root); // HERE ADD OPTION FOR CHUNCK
|
||||
// do nothing for singularity
|
||||
return target;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
} // mpi namespace
|
||||
} // namespace triqs
|
@ -64,11 +64,11 @@ namespace mpi {
|
||||
auto slow_size = a.size();
|
||||
auto sendcounts = std::vector<int>(c.size());
|
||||
auto displs = std::vector<int>(c.size() + 1, 0);
|
||||
int recvcount = slice_length(slow_size - 1, c, c.rank());
|
||||
int recvcount = slice_length(slow_size - 1, c.size(), c.rank());
|
||||
std::vector<T> b(recvcount);
|
||||
|
||||
for (int r = 0; r < c.size(); ++r) {
|
||||
sendcounts[r] = slice_length(slow_size - 1, c, r);
|
||||
sendcounts[r] = slice_length(slow_size - 1, c.size(), r);
|
||||
displs[r + 1] = sendcounts[r] + displs[r];
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user