diff --git a/test/triqs/utility/mpi.cpp b/test/triqs/utility/mpi.cpp new file mode 100644 index 00000000..61fb34e3 --- /dev/null +++ b/test/triqs/utility/mpi.cpp @@ -0,0 +1,86 @@ +/******************************************************************************* + * + * TRIQS: a Toolbox for Research in Interacting Quantum Systems + * + * Copyright (C) 2013 by O. Parcollet + * + * TRIQS is free software: you can redistribute it and/or modify it under the + * terms of the GNU General Public License as published by the Free Software + * Foundation, either version 3 of the License, or (at your option) any later + * version. + * + * TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License along with + * TRIQS. If not, see . + * + ******************************************************************************/ +#include +#include +#include +#include +#include +#include + +using namespace triqs; +using namespace triqs::arrays; +using namespace triqs::mpi; + +// a struct containing arrays. I only allow here boost serialization form, no mpi stuff here. +struct S { + array x; array y; + template friend void serialize( Ar & ar, S & s) { ar & s.x; ar & s.y;} +}; + +int main(int argc, char* argv[]) { + + mpi::environment env(argc, argv); + mpi::communicator world; + + array A {{1,2}, {3,4}}, C(2,2); + + // boost mpi + boost::mpi::reduce (world, A,C, std::c14::plus<>(),0); + int s= world.size(); + if (world.rank() ==0) std::cout<<" C = "<(s*A) <(s*A) <(s*A) <, 1 > { array{1,2}, {3,4}}; + auto cC = ca; + mpi::reduce_in_place (world, cC); + if (world.rank() ==0) std::cout<<" cC = "<. + * + ******************************************************************************/ +#ifndef TRIQS_UTILITY_MPI_H +#define TRIQS_UTILITY_MPI_H +#include +#include + +namespace triqs { namespace mpi { + + using boost::mpi::communicator; + using boost::mpi::environment; + + // transformation type -> mpi types + template struct mpi_datatype { static constexpr bool ok=false;}; +#define D(T,MPI_TY) template <> struct mpi_datatype { static MPI_Datatype invoke() { return MPI_TY;}; static constexpr bool ok=true;}; + D(int,MPI_INT) D(long,MPI_LONG) D(double,MPI_DOUBLE) D(float,MPI_FLOAT) D(std::complex, MPI_DOUBLE_COMPLEX); +#undef D + + // ok that is simple ... + void barrier(communicator _c) { MPI_Barrier(_c);} + + // a struct to specialize for the implementation for various types... + template struct mpi_impl; + + // ------------------------------ + // the final function for users + // ------------------------------ + + // reduce : first the in_place version + template void reduce_in_place(communicator _c, T & a, int root=0) { mpi_impl::reduce_in_place(_c,a,root); } + + void reduce_in_place_v(communicator _c) {} + + // try a variadic one. Does not cost much more to code... + template void reduce_in_place_v(communicator _c, T0 & a0, T& ... a) { + reduce_in_place(_c,a0,0); + reduce_in_place_v(_c, a...); + } + + // reduce : the regular version in term of the in place one (accept views on the fly from b). + template void reduce (communicator _c, T & a, U && b, int root =0) { + b = a; reduce_in_place(_c,b,root); + } + + // all_reduce : first the in_place version + template void all_reduce_in_place(communicator _c, T & a) { mpi_impl::reduce_in_place(_c,a,0); } + + // all_reduce : the regular version in term of the in place one (accept views on the fly from b). + template void all_reduce (communicator _c, T & a, U && b) { b = a; reduce_in_place(_c,b); } + + // BroadCast + template void broadcast(communicator _c, T & a, int root =0) { mpi_impl::broadcast(_c,a,root); } + + // ---------------------------------------------------------------------- + // the generic implementation : using serialization for recursive action + // ---------------------------------------------------------------------- + template struct mpi_impl { + +#define MAKE_ADAPTOR_AND_FNT(FNT)\ + struct adaptor_##FNT {\ + communicator _c; int root;\ + template adaptor_##FNT & operator & (RHS & rhs) { mpi_impl::FNT(_c, rhs, root); return *this; }\ + };\ + static void FNT (communicator _c, T & a, int root) {\ + auto ad = adaptor_##FNT{_c,root};\ + serialize(ad, a);\ + } + + MAKE_ADAPTOR_AND_FNT(reduce_in_place); + MAKE_ADAPTOR_AND_FNT(all_reduce_in_place); + MAKE_ADAPTOR_AND_FNT(broadcast); + +#undef MAKE_ADAPTOR_AND_FNT + }; + + // ------------------------------ + // overload for basic types + // ------------------------------ + template struct mpi_impl::value || boost::is_complex::value)> { + + static void reduce_in_place (communicator _c, A & a, int root) { + MPI_Reduce ((_c.rank()==root ? MPI_IN_PLACE:&a),&a,1, mpi_datatype::invoke(), MPI_SUM, root, _c); + } + + static void all_reduce_in_place (communicator _c, A & a, int root) { + MPI_Allreduce (&a,1, mpi_datatype::invoke(), MPI_SUM, _c); + } + + static void broadcast (communicator _c, A & a, int root) { MPI_Bcast (&a,1, mpi_datatype::invoke(), root, _c); } + }; + + // ------------------------------ + // a boost::mpi implementation + // ------------------------------ + template struct boost_mpi_impl { + + static void reduce_in_place (communicator _c, A & a, int root) { + boost::mpi::reduce(_c,a,a, std::c14::plus<>(), root); + } + + static void all_reduce_in_place (communicator _c, A & a, int root) { + boost::mpi::all_reduce(_c,a,a, std::c14::plus<>(), root); + } + + static void broadcast (communicator _c, A & a, int root) { boost::mpi::broadcast(_c,a,root);} + }; + + // ------------------------------ + // overload for arrays + // Stragey : if not contigous, we can i) revert to boost::mpi, ii) fail !?? + // ------------------------------ + // When value_type is a basic type, we can directly call the C API + template struct mpi_impl::ok && arrays::is_amv_value_or_view_class::value)> { + + typedef typename A::value_type a_t; + + static void reduce_in_place (communicator _c, A & a, int root) { + if (!has_contiguous_data(a)) TRIQS_RUNTIME_ERROR << "Non contiguous view in mpi_reduce_in_place"; + auto p = a.data_start(); + MPI_Reduce ((_c.rank()==root ? MPI_IN_PLACE:p),p,a.domain().number_of_elements(), mpi_datatype::invoke(), MPI_SUM, root, _c); + } + + static void all_reduce_in_place (communicator _c, A & a, int root) { + if (!has_contiguous_data(a)) TRIQS_RUNTIME_ERROR << "Non contiguous view in mpi_reduce_in_place"; + MPI_Allreduce (MPI_IN_PLACE, a.data_start(), a.domain().number_of_elements(), mpi_datatype::invoke(), MPI_SUM, _c); + } + + static void broadcast (communicator _c, A & a, int root) { + if (!has_contiguous_data(a)) TRIQS_RUNTIME_ERROR << "Non contiguous view in mpi_reduce_in_place"; + MPI_Bcast (a.data_start(),a.domain().number_of_elements(), mpi_datatype::invoke(), root, _c); + } + + }; + + // When value_type is NOT a basic type, we revert to boost::mpi + template struct mpi_impl::ok && arrays::is_amv_value_or_view_class::value)> : boost_mpi_impl{}; + + // overload for views rvalues (created on the fly) + template + void reduce_in_place( communicator _c, arrays::array_view && a, int root =0) { reduce_in_place(_c,a,root);} + + template + void reduce( communicator _c, A const & a, arrays::array_view && b, int root =0) { reduce(_c,a,b,root);} + + // to be implemented : scatter, gather for arrays + +}} +#endif +