diff --git a/test/triqs/mpi/CMakeLists.txt b/test/triqs/mpi/CMakeLists.txt new file mode 100644 index 00000000..917a5baf --- /dev/null +++ b/test/triqs/mpi/CMakeLists.txt @@ -0,0 +1,2 @@ +all_tests() + diff --git a/test/triqs/mpi/mpi_array.cpp b/test/triqs/mpi/mpi_array.cpp new file mode 100644 index 00000000..e02bf7e9 --- /dev/null +++ b/test/triqs/mpi/mpi_array.cpp @@ -0,0 +1,70 @@ +/******************************************************************************* + * + * TRIQS: a Toolbox for Research in Interacting Quantum Systems + * + * Copyright (C) 2013 by O. Parcollet + * + * TRIQS is free software: you can redistribute it and/or modify it under the + * terms of the GNU General Public License as published by the Free Software + * Foundation, either version 3 of the License, or (at your option) any later + * version. + * + * TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License along with + * TRIQS. If not, see . + * + ******************************************************************************/ +#include +#include +#include +#include +#include +#include +#include + +using namespace triqs; +using namespace triqs::arrays; +using namespace triqs::mpi; + +int main(int argc, char* argv[]) { + + mpi::environment env(argc, argv); + mpi::communicator world; + + // using ARR = array; + using ARR = array, 2>; + + ARR A(7, 3), B, AA; + + clef::placeholder<0> i_; + clef::placeholder<1> j_; + + A(i_, j_) << i_ + 10 * j_; + + B = mpi::scatter(A, world); + ARR C = mpi::scatter(A, world); + + std::ofstream out("node" + std::to_string(world.rank())); + out << " A = " << A << std::endl; + out << " B = " << B << std::endl; + out << " C = " << C << std::endl; + + B *= -1; + AA() = 0; + + AA = mpi::gather(B, world); + out << " AA = " << AA << std::endl; + + mpi::broadcast(AA, world); + out << " cast AA = " << AA << std::endl; + + AA() = 0; + + AA = mpi::allgather(B, world); + out << " AA = " << AA << std::endl; +} + diff --git a/test/triqs/mpi/mpi_generic.cpp b/test/triqs/mpi/mpi_generic.cpp new file mode 100644 index 00000000..2aae06c5 --- /dev/null +++ b/test/triqs/mpi/mpi_generic.cpp @@ -0,0 +1,97 @@ +/******************************************************************************* + * + * TRIQS: a Toolbox for Research in Interacting Quantum Systems + * + * Copyright (C) 2013 by O. Parcollet + * + * TRIQS is free software: you can redistribute it and/or modify it under the + * terms of the GNU General Public License as published by the Free Software + * Foundation, either version 3 of the License, or (at your option) any later + * version. + * + * TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License along with + * TRIQS. If not, see . + * + ******************************************************************************/ +#include +#include +#include +#include +#include +#include +#include + +using namespace triqs; +using namespace triqs::arrays; +using namespace triqs::mpi; + +struct my_object { + + array a, b; + + TRIQS_MPI_IMPLEMENTED_AS_TUPLEVIEW; + + my_object() = default; + + my_object(int s) : a(s), b(s) { + clef::placeholder<0> i_; + a(i_) << i_; + b(i_) << -i_; + } + + // construction from the lazy is delegated to = + template my_object(mpi_lazy x) : my_object() { operator=(x); } + + // assigment is almost done already... + template my_object &operator=(mpi_lazy x) { + return mpi_impl_tuple::complete_operation(*this, x); + } +}; + +// non intrusive +auto view_as_tuple(my_object const &x) RETURN(std::tie(x.a, x.b)); +auto view_as_tuple(my_object &x) RETURN(std::tie(x.a, x.b)); + +// -------------------------------------- + +int main(int argc, char *argv[]) { + + mpi::environment env(argc, argv); + mpi::communicator world; + + std::ofstream out("t2_node" + std::to_string(world.rank())); + + auto ob = my_object(10); + mpi::broadcast(ob); + + out << " a = " << ob.a << std::endl; + out << " b = " << ob.b << std::endl; + + auto ob2 = ob; + + // ok scatter all components + ob2 = mpi::scatter(ob); + + out << " scattered a = " << ob2.a << std::endl; + out << " scattered b = " << ob2.b << std::endl; + + ob2.a *= world.rank()+1; // change it a bit + + // now regroup... + ob = mpi::gather(ob2); + out << " gather a = " << ob.a << std::endl; + out << " gather b = " << ob.b << std::endl; + + // allgather + ob = mpi::allgather(ob2); + out << " allgather a = " << ob.a << std::endl; + out << " allgather b = " << ob.b << std::endl; + + out << "----------------------------"<< std::endl; +} + diff --git a/test/triqs/mpi/mpi_old_test.cpp b/test/triqs/mpi/mpi_old_test.cpp new file mode 100644 index 00000000..489352cd --- /dev/null +++ b/test/triqs/mpi/mpi_old_test.cpp @@ -0,0 +1,86 @@ +/******************************************************************************* + * + * TRIQS: a Toolbox for Research in Interacting Quantum Systems + * + * Copyright (C) 2014 by O. Parcollet + * + * TRIQS is free software: you can redistribute it and/or modify it under the + * terms of the GNU General Public License as published by the Free Software + * Foundation, either version 3 of the License, or (at your option) any later + * version. + * + * TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License along with + * TRIQS. If not, see . + * + ******************************************************************************/ +#include +#include +#include +#include +#include +#include +#include + +using namespace triqs; +using namespace triqs::arrays; +using namespace triqs::mpi; + +int main(int argc, char *argv[]) { + + mpi::environment env(argc, argv); + mpi::communicator world; + + // TO BE RETESTED: + /* + mpi::environment env(argc, argv); + mpi::communicator world; + + array A {{1,2}, {3,4}}, C(2,2); + + // boost mpi + boost::mpi::reduce (world, A,C, std::c14::plus<>(),0); + int s= world.size(); + if (world.rank() ==0) std::cout<<" C = "<(s*A) <(s*A) <(s*A) <, 1 > { array{1,2}, array{3,4}}; + auto cC = ca; + mpi::reduce_in_place (world, cC); + if (world.rank() ==0) std::cout<<" cC = "< #include #include -#include +#include #include #include diff --git a/triqs/arrays/impl/assignment.hpp b/triqs/arrays/impl/assignment.hpp index 4a5e15b2..dfaafd7f 100644 --- a/triqs/arrays/impl/assignment.hpp +++ b/triqs/arrays/impl/assignment.hpp @@ -72,6 +72,9 @@ namespace triqs { namespace arrays { template struct is_isp : std::integral_constant::value && (!is_scalar_for::value) > {}; + /// RHS is special type that defines its own specialization of assign + template struct is_special : std::false_type {}; + #define TRIQS_REJECT_ASSIGN_TO_CONST \ static_assert( (!std::is_const::value ), "Assignment : The value type of the LHS is const and cannot be assigned to !"); #define TRIQS_REJECT_MATRIX_COMPOUND_MUL_DIV_NON_SCALAR\ @@ -106,7 +109,7 @@ namespace triqs { namespace arrays { // ----------------- assignment for expressions RHS -------------------------------------------------- template - struct impl::value && (!is_scalar_for::value) && (!is_isp::value)) > { + struct impl::value && (!is_scalar_for::value) && (!is_isp::value)&& (!is_special::value)) > { TRIQS_REJECT_ASSIGN_TO_CONST; TRIQS_REJECT_MATRIX_COMPOUND_MUL_DIV_NON_SCALAR; typedef typename LHS::value_type value_type; diff --git a/triqs/mpi.hpp b/triqs/mpi.hpp new file mode 100644 index 00000000..43625d92 --- /dev/null +++ b/triqs/mpi.hpp @@ -0,0 +1,29 @@ +/******************************************************************************* + * + * TRIQS: a Toolbox for Research in Interacting Quantum Systems + * + * Copyright (C) 2014 by O. Parcollet + * + * TRIQS is free software: you can redistribute it and/or modify it under the + * terms of the GNU General Public License as published by the Free Software + * Foundation, either version 3 of the License, or (at your option) any later + * version. + * + * TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License along with + * TRIQS. If not, see . + * + ******************************************************************************/ +#ifndef TRIQS_MPI_H +#define TRIQS_MPI_H + +#include "./mpi/arrays.hpp" +#include "./mpi/vector.hpp" +#include "./mpi/generic.hpp" + +#endif + diff --git a/triqs/mpi/arrays.hpp b/triqs/mpi/arrays.hpp new file mode 100644 index 00000000..17b79624 --- /dev/null +++ b/triqs/mpi/arrays.hpp @@ -0,0 +1,197 @@ +/******************************************************************************* + * + * TRIQS: a Toolbox for Research in Interacting Quantum Systems + * + * Copyright (C) 2014 by O. Parcollet + * + * TRIQS is free software: you can redistribute it and/or modify it under the + * terms of the GNU General Public License as published by the Free Software + * Foundation, either version 3 of the License, or (at your option) any later + * version. + * + * TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License along with + * TRIQS. If not, see . + * + ******************************************************************************/ +#pragma once +#include +#include "./base.hpp" + +namespace triqs { +namespace mpi { + + //-------------------------------------------------------------------------------------------------------- + // The lazy ref made by scatter and co. + // Differs from the generic one in that it can make a domain of the (target) array + template struct mpi_lazy_array { + A const &ref; + int root; + communicator c; + + using domain_type = typename A::domain_type; + + /// compute the array domain of the target array + domain_type domain() const { + auto dims = ref.shape(); + long slow_size = first_dim(ref); + + if (std::is_same::value) { + dims[0] = slice_length(slow_size - 1, c, c.rank()); + } + + if (std::is_same::value) { + long slow_size_total = 0; + MPI_Reduce(&slow_size, &slow_size_total, 1, mpi_datatype::invoke(), MPI_SUM, root, c.get()); + dims[0] = slow_size_total; + // valid only on root + } + + if (std::is_same::value) { + long slow_size_total = 0; + MPI_Allreduce(&slow_size, &slow_size_total, 1, mpi_datatype::invoke(), MPI_SUM, c.get()); + dims[0] = slow_size_total; + // in this case, it is valid on all nodes + } + + return domain_type{dims}; + } + + }; + + //-------------------------------------------------------------------------------------------------------- + + // When value_type is a basic type, we can directly call the C API + template class mpi_impl_triqs_arrays { + + static MPI_Datatype D() { return mpi_datatype::invoke(); } + + static void check_is_contiguous(A const &a) { + if (!has_contiguous_data(a)) TRIQS_RUNTIME_ERROR << "Non contiguous view in mpi_reduce_in_place"; + } + + public: + + //--------- + static void reduce_in_place(communicator c, A &a, int root) { + check_is_contiguous(a); + // assume arrays have the same size on all nodes... + MPI_Reduce((c.rank() == root ? MPI_IN_PLACE : a), a.data_start(), a.domain().number_of_elements(), D(), MPI_SUM, root, c.get()); + } + + //--------- + static void allreduce_in_place(communicator c, A &a, int root) { + check_is_contiguous(a); + // assume arrays have the same size on all nodes... + MPI_Allreduce(MPI_IN_PLACE, a.data_start(), a.domain().number_of_elements(), D(), MPI_SUM, root, c.get()); + } + + //--------- + static void broadcast(communicator c, A &a, int root) { + check_is_contiguous(a); + auto sh = a.shape(); + MPI_Bcast(&sh[0], sh.size(), mpi_datatype::invoke(), root, c.get()); + if (c.rank() != root) a.resize(sh); + MPI_Bcast(a.data_start(), a.domain().number_of_elements(), D(), root, c.get()); + } + + //--------- + template static mpi_lazy_array invoke(Tag, communicator c, A const &a, int root) { + check_is_contiguous(a); + return {a, root, c}; + } + + }; + + template + struct mpi_impl::value>> : mpi_impl_triqs_arrays {}; + +} // mpi namespace + +//------------------------------- Delegation of the assign operator of the array class ------------- + +namespace arrays { + + // mpi_lazy_array model ImmutableCuboidArray + template struct ImmutableCuboidArray> : ImmutableCuboidArray {}; + + namespace assignment { + + template struct is_special> : std::true_type {}; + + // assignment delegation + template struct impl, 'E', void> { + + using laz_t = mpi::mpi_lazy_array; + LHS &lhs; + laz_t laz; + + impl(LHS &lhs_, laz_t laz_) : lhs(lhs_), laz(laz_) {} + + void invoke() { _invoke(Tag()); } + + private: + static MPI_Datatype D() { return mpi::mpi_datatype::invoke(); } + + //--------------------------------- + void _invoke(triqs::mpi::tag::scatter) { + lhs.resize(laz.domain()); + + auto c = laz.c; + auto slow_size = first_dim(laz.ref); + auto slow_stride = laz.ref.indexmap().strides()[0]; + auto sendcounts = std::vector(c.size()); + auto displs = std::vector(c.size() + 1, 0); + int recvcount = slice_length(slow_size - 1, c, c.rank()) * slow_stride; + + for (int r = 0; r < c.size(); ++r) { + sendcounts[r] = slice_length(slow_size - 1, c, r) * slow_stride; + displs[r + 1] = sendcounts[r] + displs[r]; + } + + MPI_Scatterv((void *)laz.ref.data_start(), &sendcounts[0], &displs[0], D(), (void *)lhs.data_start(), recvcount, D(), + laz.root, c.get()); + } + + //--------------------------------- + void _invoke(triqs::mpi::tag::gather) { + lhs.resize(laz.domain()); + + auto c = laz.c; + auto recvcounts = std::vector(c.size()); + auto displs = std::vector(c.size() + 1, 0); + int sendcount = laz.ref.domain().number_of_elements(); + + auto mpi_ty = mpi::mpi_datatype::invoke(); + MPI_Gather(&sendcount, 1, mpi_ty, &recvcounts[0], 1, mpi_ty, laz.root, c.get()); + for (int r = 0; r < c.size(); ++r) displs[r + 1] = recvcounts[r] + displs[r]; + + MPI_Gatherv((void *)laz.ref.data_start(), sendcount, D(), (void *)lhs.data_start(), &recvcounts[0], &displs[0], D(), laz.root, + c.get()); + } + + //--------------------------------- + void _invoke(triqs::mpi::tag::allgather) { + lhs.resize(laz.domain()); + + // almost the same preparation as gather, except that the recvcounts are ALL gathered... + auto c = laz.c; + auto recvcounts = std::vector(c.size()); + auto displs = std::vector(c.size() + 1, 0); + int sendcount = laz.ref.domain().number_of_elements(); + + auto mpi_ty = mpi::mpi_datatype::invoke(); + MPI_Allgather(&sendcount, 1, mpi_ty, &recvcounts[0], 1, mpi_ty, c.get()); + for (int r = 0; r < c.size(); ++r) displs[r + 1] = recvcounts[r] + displs[r]; + + MPI_Allgatherv((void *)laz.ref.data_start(), sendcount, D(), (void *)lhs.data_start(), &recvcounts[0], &displs[0], D(), + c.get()); + } + }; + } +} //namespace arrays +} // namespace triqs diff --git a/triqs/mpi/base.hpp b/triqs/mpi/base.hpp new file mode 100644 index 00000000..7a5c8237 --- /dev/null +++ b/triqs/mpi/base.hpp @@ -0,0 +1,148 @@ +/******************************************************************************* + * + * TRIQS: a Toolbox for Research in Interacting Quantum Systems + * + * Copyright (C) 2014 by O. Parcollet + * + * TRIQS is free software: you can redistribute it and/or modify it under the + * terms of the GNU General Public License as published by the Free Software + * Foundation, either version 3 of the License, or (at your option) any later + * version. + * + * TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License along with + * TRIQS. If not, see . + * + ******************************************************************************/ +#pragma once +#include +//#include +#include + +namespace triqs { +namespace mpi { + + /// Environment + struct environment { + environment(int argc, char *argv[]) { MPI_Init(&argc, &argv); } + ~environment() { MPI_Finalize(); } + }; + + /// The communicator. Todo : add more constructors. + class communicator { + MPI_Comm _com = MPI_COMM_WORLD; + + public: + communicator() = default; + + MPI_Comm get() const { return _com; } + + int rank() const { + int num; + MPI_Comm_rank(_com, &num); + return num; + } + + int size() const { + int num; + MPI_Comm_size(_com, &num); + return num; + } + + void barrier() const { MPI_Barrier(_com); } + }; + + /// a tag for each operation + namespace tag { + struct reduce {}; + struct allreduce {}; + struct scatter {}; + struct gather {}; + struct allgather {}; + } + + /// The implementation of mpi ops for each type + template struct mpi_impl; + + // ---------------------------------------- + // ------- top level functions ------- + // ---------------------------------------- + + // ----- functions that can be lazy ------- + + template + AUTO_DECL reduce(T const &x, communicator c = {}, int root = 0) RETURN(mpi_impl::invoke(tag::reduce(), c, x, root)); + template + AUTO_DECL scatter(T const &x, communicator c = {}, int root = 0) RETURN(mpi_impl::invoke(tag::scatter(), c, x, root)); + template + AUTO_DECL gather(T const &x, communicator c = {}, int root = 0) RETURN(mpi_impl::invoke(tag::gather(), c, x, root)); + template + AUTO_DECL allreduce(T const &x, communicator c = {}, int root = 0) RETURN(mpi_impl::invoke(tag::allreduce(), c, x, root)); + template + AUTO_DECL allgather(T const &x, communicator c = {}, int root = 0) RETURN(mpi_impl::invoke(tag::allgather(), c, x, root)); + + // ----- functions that cannot be lazy ------- + + template void reduce_in_place(T &x, communicator c = {}, int root = 0) { mpi_impl::reduce_in_place(c, x, root); } + template void broadcast(T &x, communicator c = {}, int root = 0) { mpi_impl::broadcast(c, x, root); } + + // transformation type -> mpi types + template struct mpi_datatype; +#define D(T, MPI_TY) \ + template <> struct mpi_datatype { \ + static MPI_Datatype invoke() { return MPI_TY; } \ + }; + D(int, MPI_INT) D(long, MPI_LONG) D(double, MPI_DOUBLE) D(float, MPI_FLOAT) D(std::complex, MPI_DOUBLE_COMPLEX); + D(unsigned long, MPI_UNSIGNED_LONG); +#undef D + + /** ------------------------------------------------------------ + * basic types + * ---------------------------------------------------------- **/ + + template struct mpi_impl_basic { + + static MPI_Datatype D() { return mpi_datatype::invoke(); } + + static T invoke(tag::reduce, communicator c, T a, int root) { + T b; + MPI_Reduce(&a, &b, 1, D(), MPI_SUM, root, c.get()); + return b; + } + + static T invoke(tag::allreduce, communicator c, T a, int root) { + T b; + MPI_Allreduce(&a, &b, 1, D(), MPI_SUM, root, c.get()); + return b; + } + + static void reduce_in_place(communicator c, T &a, int root) { + MPI_Reduce((c.rank() == root ? MPI_IN_PLACE : &a), &a, 1, D(), MPI_SUM, root, c.get()); + } + + static void allreduce_in_place(communicator c, T &a, int root) { + MPI_Allreduce(MPI_IN_PLACE, &a, 1, D(), MPI_SUM, root, c.get()); + } + + static void broadcast(communicator c, T &a, int root) { MPI_Bcast(&a, 1, D(), root, c.get()); } + }; + + // mpl_impl_basic is the mpi_impl is T is a number (including complex) + template + struct mpi_impl::value || triqs::is_complex::value>> : mpi_impl_basic {}; + + //------------ Some helper function + inline long slice_length(size_t imax, communicator c, int r) { + auto imin = 0; + long j = (imax - imin + 1) / c.size(); + long i = imax - imin + 1 - c.size() * j; + auto r_min = (r <= i - 1 ? imin + r * (j + 1) : imin + r * j + i); + auto r_max = (r <= i - 1 ? imin + (r + 1) * (j + 1) - 1 : imin + (r + 1) * j + i - 1); + return r_max - r_min + 1; + }; +} +} diff --git a/triqs/mpi/boost.hpp b/triqs/mpi/boost.hpp new file mode 100644 index 00000000..d8a83676 --- /dev/null +++ b/triqs/mpi/boost.hpp @@ -0,0 +1,61 @@ +/******************************************************************************* + * + * TRIQS: a Toolbox for Research in Interacting Quantum Systems + * + * Copyright (C) 2014 by O. Parcollet + * + * TRIQS is free software: you can redistribute it and/or modify it under the + * terms of the GNU General Public License as published by the Free Software + * Foundation, either version 3 of the License, or (at your option) any later + * version. + * + * TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License along with + * TRIQS. If not, see . + * + ******************************************************************************/ +#pragma once +#include "./base.hpp" +#include + +namespace triqs { +namespace mpi { + + /** ------------------------------------------------------------ + * Type which we use boost::mpi + * ---------------------------------------------------------- **/ + + template struct mpi_impl_boost_mpi { + + static T invoke(tag::reduce, communicator c, T const &a, int root) { + T b; + boost::mpi::reduce(c, a, b, std::c14::plus<>(), root); + return b; + } + + static T invoke(tag::allreduce, communicator c, T const &a, int root) { + T b; + boost::mpi::all_reduce(c, a, b, std::c14::plus<>(), root); + return b; + } + + static void reduce_in_place(communicator c, T &a, int root) { boost::mpi::reduce(c, a, a, std::c14::plus<>(), root); } + static void broadcast(communicator c, T &a, int root) { boost::mpi::broadcast(c, a, root); } + + static void scatter(communicator c, T const &, int root) = delete; + static void gather(communicator c, T const &, int root) = delete; + static void allgather(communicator c, T const &, int root) = delete; + }; + + // default + //template struct mpi_impl : mpi_impl_boost_mpi {}; + +}}//namespace + + + + diff --git a/triqs/mpi/generic.hpp b/triqs/mpi/generic.hpp new file mode 100644 index 00000000..fb3523f8 --- /dev/null +++ b/triqs/mpi/generic.hpp @@ -0,0 +1,103 @@ +/******************************************************************************* + * + * TRIQS: a Toolbox for Research in Interacting Quantum Systems + * + * Copyright (C) 2014 by O. Parcollet + * + * TRIQS is free software: you can redistribute it and/or modify it under the + * terms of the GNU General Public License as published by the Free Software + * Foundation, either version 3 of the License, or (at your option) any later + * version. + * + * TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License along with + * TRIQS. If not, see . + * + ******************************************************************************/ +#pragma once +#include "./base.hpp" +#include + +#define TRIQS_MPI_IMPLEMENTED_AS_TUPLEVIEW using triqs_mpi_as_tuple = void; +namespace triqs { +namespace mpi { + + template struct mpi_lazy { + T const &ref; + int root; + communicator c; + }; + + /** ------------------------------------------------------------ + * Type which are recursively treated by reducing them to a tuple + * of smaller objects. + * ---------------------------------------------------------- **/ + template struct mpi_impl_tuple { + + mpi_impl_tuple() = default; + template static mpi_lazy invoke(Tag, communicator c, T const &a, int root) { + return {a, root, c}; + } + +#ifdef __cpp_generic_lambdas + static void reduce_in_place(communicator c, T &a, int root) { + tuple::for_each(view_as_tuple(a), [c, root](auto &x) { triqs::mpi::reduce_in_place(x, c, root); }); + } + + static void broadcast(communicator c, T &a, int root) { + tuple::for_each(view_as_tuple(a), [c, root](auto &x) { triqs::mpi::broadcast(x, c, root); }); + } + + template static T &complete_operation(T &target, mpi_lazy laz) { + auto l = [laz](auto &t, auto &s) { t = triqs::mpi::mpi_impl>::invoke(Tag(), laz.c, s, laz.root); }; + triqs::tuple::for_each_zip(l, view_as_tuple(target), view_as_tuple(laz.ref)); + return target; + } +#else + + struct aux1 { + communicator c; + int root; + + template void operator()(T1 &x) const { triqs::mpi::reduce_in_place(c, x, root); } + }; + + static void reduce_in_place(communicator c, T &a, int root) { + tuple::for_each(aux1{c, root}, view_as_tuple(a)); + } + + struct aux2 { + communicator c; + int root; + + template void operator()(T2 &x) const { triqs::mpi::broadcast(c, x, root); } + }; + + static void broadcast(communicator c, T &a, int root) { + tuple::for_each(aux2{c, root}, view_as_tuple(a)); + } + + template struct aux3 { + mpi_lazy laz; + + template void operator()(T1 &t, T2 &s) const { + t = triqs::mpi::mpi_impl::invoke(Tag(), laz.c, laz.s); + } + }; + + template static void complete_operation(T &target, mpi_lazy laz) { + auto l = aux3{laz}; + triqs::tuple::for_each_zip(l, view_as_tuple(target), view_as_tuple(laz.ref)); + } +#endif + }; + + // If type T has a mpi_implementation nested struct, then it is mpi_impl. + template struct mpi_impl : mpi_impl_tuple {}; +} +} // namespace + diff --git a/triqs/mpi/vector.hpp b/triqs/mpi/vector.hpp new file mode 100644 index 00000000..b406e99b --- /dev/null +++ b/triqs/mpi/vector.hpp @@ -0,0 +1,116 @@ +/******************************************************************************* + * + * TRIQS: a Toolbox for Research in Interacting Quantum Systems + * + * Copyright (C) 2014 by O. Parcollet + * + * TRIQS is free software: you can redistribute it and/or modify it under the + * terms of the GNU General Public License as published by the Free Software + * Foundation, either version 3 of the License, or (at your option) any later + * version. + * + * TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY + * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + * details. + * + * You should have received a copy of the GNU General Public License along with + * TRIQS. If not, see . + * + ******************************************************************************/ +#pragma once + +namespace triqs { +namespace mpi { + + // When value_type is a basic type, we can directly call the C API + template struct mpi_impl_std_vector_basic { + + static MPI_Datatype D() { return mpi_datatype::invoke(); } + + // ----------- + static void reduce_in_place(communicator c, std::vector &a, int root) { + MPI_Reduce((c.rank() == root ? MPI_IN_PLACE : a.data()), a.data(), a.size(), D(), MPI_SUM, root, c.get()); + } + + static void allreduce_in_place(communicator c, std::vector &a, int root) { + MPI_Allreduce(MPI_IN_PLACE, a.data(), a.size(), D(), MPI_SUM, root, c.get()); + } + + // ----------- + static void broadcast(communicator c, std::vector &a, int root) { MPI_Bcast(a.data(), a.size(), D(), root, c.get()); } + + // ----------- + static std::vector invoke(tag::reduce, communicator c, T const &a, int root) { + std::vector b(a.size()); + MPI_Reduce(a.data(), b.data(), a.size(), D(), MPI_SUM, root, c.get()); + return b; + } + + // ----------- + static std::vector invoke(tag::allreduce, communicator c, std::vector const &a, int root) { + std::vector b(a.size()); + MPI_Allreduce(a.data(), b.data(), a.size(), D(), MPI_SUM, root, c.get()); + return b; + } + + // ----------- + static std::vector invoke(tag::scatter, communicator c, std::vector const &a, int root) { + + auto slow_size = a.size(); + auto sendcounts = std::vector(c.size()); + auto displs = std::vector(c.size() + 1, 0); + int recvcount = slice_length(slow_size - 1, c, c.rank()); + std::vector b(recvcount); + + for (int r = 0; r < c.size(); ++r) { + sendcounts[r] = slice_length(slow_size - 1, c, r); + displs[r + 1] = sendcounts[r] + displs[r]; + } + + MPI_Scatterv((void *)a.data(), &sendcounts[0], &displs[0], D(), (void *)b.data(), recvcount, D(), root, c.get()); + return b; + } + + // ----------- + static std::vector invoke(tag::gather, communicator c, std::vector const &a, int root) { + long size = reduce(a.size(), c, root); + std::vector b(size); + + auto recvcounts = std::vector(c.size()); + auto displs = std::vector(c.size() + 1, 0); + int sendcount = a.size(); + auto mpi_ty = mpi::mpi_datatype::invoke(); + MPI_Gather(&sendcount, 1, mpi_ty, &recvcounts[0], 1, mpi_ty, root, c.get()); + for (int r = 0; r < c.size(); ++r) displs[r + 1] = recvcounts[r] + displs[r]; + + MPI_Gatherv((void *)a.data(), sendcount, D(), (void *)b.data(), &recvcounts[0], &displs[0], D(), root, c.get()); + return b; + } + + // ----------- + + static std::vector invoke(tag::allgather, communicator c, std::vector const &a, int root) { + long size = reduce(a.size(), c, root); + std::vector b(size); + + auto recvcounts = std::vector(c.size()); + auto displs = std::vector(c.size() + 1, 0); + int sendcount = a.size(); + auto mpi_ty = mpi::mpi_datatype::invoke(); + MPI_Allgather(&sendcount, 1, mpi_ty, &recvcounts[0], 1, mpi_ty, c.get()); + for (int r = 0; r < c.size(); ++r) displs[r + 1] = recvcounts[r] + displs[r]; + + MPI_Allgatherv((void *)a.data(), sendcount, D(), (void *)b.data(), &recvcounts[0], &displs[0], D(), c.get()); + return b; + } + }; + + template + struct mpi_impl, std14::enable_if_t::value || + triqs::is_complex::value>> : mpi_impl_std_vector_basic {}; + + // vector for T non basic +} +} // namespace + diff --git a/triqs/utility/mpi.hpp b/triqs/utility/mpi1.hpp similarity index 100% rename from triqs/utility/mpi.hpp rename to triqs/utility/mpi1.hpp