/*******************************************************************************
*
* TRIQS: a Toolbox for Research in Interacting Quantum Systems
*
* Copyright (C) 2013 by O. Parcollet
*
* TRIQS is free software: you can redistribute it and/or modify it under the
* terms of the GNU General Public License as published by the Free Software
* Foundation, either version 3 of the License, or (at your option) any later
* version.
*
* TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*
* You should have received a copy of the GNU General Public License along with
* TRIQS. If not, see .
*
******************************************************************************/
#ifndef TRIQS_UTILITY_MPI_H
#define TRIQS_UTILITY_MPI_H
#include
#include
namespace triqs { namespace mpi {
using boost::mpi::communicator;
using boost::mpi::environment;
// transformation type -> mpi types
template struct mpi_datatype { static constexpr bool ok=false;};
#define D(T,MPI_TY) template <> struct mpi_datatype { static MPI_Datatype invoke() { return MPI_TY;}; static constexpr bool ok=true;};
D(int,MPI_INT) D(long,MPI_LONG) D(double,MPI_DOUBLE) D(float,MPI_FLOAT) D(std::complex, MPI_DOUBLE_COMPLEX);
#undef D
// ok that is simple ...
void barrier(communicator _c) { MPI_Barrier(_c);}
// a struct to specialize for the implementation for various types...
template struct mpi_impl;
// ------------------------------
// the final function for users
// ------------------------------
// reduce : first the in_place version
template void reduce_in_place(communicator _c, T & a, int root=0) { mpi_impl::reduce_in_place(_c,a,root); }
void reduce_in_place_v(communicator _c) {}
// try a variadic one. Does not cost much more to code...
template void reduce_in_place_v(communicator _c, T0 & a0, T& ... a) {
reduce_in_place(_c,a0,0);
reduce_in_place_v(_c, a...);
}
// reduce : the regular version in term of the in place one (accept views on the fly from b).
template void reduce (communicator _c, T & a, U && b, int root =0) {
b = a; reduce_in_place(_c,b,root);
}
// all_reduce : first the in_place version
template void all_reduce_in_place(communicator _c, T & a) { mpi_impl::reduce_in_place(_c,a,0); }
// all_reduce : the regular version in term of the in place one (accept views on the fly from b).
template void all_reduce (communicator _c, T & a, U && b) { b = a; reduce_in_place(_c,b); }
// BroadCast
template void broadcast(communicator _c, T & a, int root =0) { mpi_impl::broadcast(_c,a,root); }
// ----------------------------------------------------------------------
// the generic implementation : using serialization for recursive action
// ----------------------------------------------------------------------
template struct mpi_impl {
#define MAKE_ADAPTOR_AND_FNT(FNT)\
struct adaptor_##FNT {\
communicator _c; int root;\
template adaptor_##FNT & operator & (RHS & rhs) { mpi_impl::FNT(_c, rhs, root); return *this; }\
};\
static void FNT (communicator _c, T & a, int root) {\
auto ad = adaptor_##FNT{_c,root};\
serialize(ad, a);\
}
MAKE_ADAPTOR_AND_FNT(reduce_in_place);
MAKE_ADAPTOR_AND_FNT(all_reduce_in_place);
MAKE_ADAPTOR_AND_FNT(broadcast);
#undef MAKE_ADAPTOR_AND_FNT
};
// ------------------------------
// overload for basic types
// ------------------------------
template struct mpi_impl::value || boost::is_complex::value)> {
static void reduce_in_place (communicator _c, A & a, int root) {
MPI_Reduce ((_c.rank()==root ? MPI_IN_PLACE:&a),&a,1, mpi_datatype::invoke(), MPI_SUM, root, _c);
}
static void all_reduce_in_place (communicator _c, A & a, int root) {
MPI_Allreduce (&a,1, mpi_datatype::invoke(), MPI_SUM, _c);
}
static void broadcast (communicator _c, A & a, int root) { MPI_Bcast (&a,1, mpi_datatype::invoke(), root, _c); }
};
// ------------------------------
// a boost::mpi implementation
// ------------------------------
template struct boost_mpi_impl {
static void reduce_in_place (communicator _c, A & a, int root) {
boost::mpi::reduce(_c,a,a, std::c14::plus<>(), root);
}
static void all_reduce_in_place (communicator _c, A & a, int root) {
boost::mpi::all_reduce(_c,a,a, std::c14::plus<>(), root);
}
static void broadcast (communicator _c, A & a, int root) { boost::mpi::broadcast(_c,a,root);}
};
// ------------------------------
// overload for arrays
// Stragey : if not contigous, we can i) revert to boost::mpi, ii) fail !??
// ------------------------------
// When value_type is a basic type, we can directly call the C API
template struct mpi_impl::ok && arrays::is_amv_value_or_view_class::value)> {
typedef typename A::value_type a_t;
static void reduce_in_place (communicator _c, A & a, int root) {
if (!has_contiguous_data(a)) TRIQS_RUNTIME_ERROR << "Non contiguous view in mpi_reduce_in_place";
auto p = a.data_start();
MPI_Reduce ((_c.rank()==root ? MPI_IN_PLACE:p),p,a.domain().number_of_elements(), mpi_datatype::invoke(), MPI_SUM, root, _c);
}
static void all_reduce_in_place (communicator _c, A & a, int root) {
if (!has_contiguous_data(a)) TRIQS_RUNTIME_ERROR << "Non contiguous view in mpi_reduce_in_place";
MPI_Allreduce (MPI_IN_PLACE, a.data_start(), a.domain().number_of_elements(), mpi_datatype::invoke(), MPI_SUM, _c);
}
static void broadcast (communicator _c, A & a, int root) {
if (!has_contiguous_data(a)) TRIQS_RUNTIME_ERROR << "Non contiguous view in mpi_reduce_in_place";
MPI_Bcast (a.data_start(),a.domain().number_of_elements(), mpi_datatype::invoke(), root, _c);
}
};
// When value_type is NOT a basic type, we revert to boost::mpi
template struct mpi_impl::ok && arrays::is_amv_value_or_view_class::value)> : boost_mpi_impl{};
// overload for views rvalues (created on the fly)
template
void reduce_in_place( communicator _c, arrays::array_view && a, int root =0) { reduce_in_place(_c,a,root);}
template
void reduce( communicator _c, A const & a, arrays::array_view && b, int root =0) { reduce(_c,a,b,root);}
// to be implemented : scatter, gather for arrays
}}
#endif