3
0
mirror of https://github.com/triqs/dft_tools synced 2024-12-25 05:43:40 +01:00

Implement mpi lib (1). Array, generic, base, vector

- Implement the basic structure of the mpi lib
  and specialization for arrays, basic types, std::vector
- adapted the array class for the lazy mpi mechanism
- pass tests on arrays :
   - scatter, gather on array<long,2> array<complex,2>, etc...
   - broadcast
- several files for readibility
- the std::vector coded but not tested.
- generic mecanism implemented and tested (mpi_generic test)
- added several tests for the mpi lib.
- TODO : more tests, doc...
This commit is contained in:
Olivier Parcollet 2014-06-03 14:54:11 +02:00
parent 56820a9493
commit 38cfef4e9f
13 changed files with 914 additions and 2 deletions

View File

@ -0,0 +1,2 @@
all_tests()

View File

@ -0,0 +1,70 @@
/*******************************************************************************
*
* TRIQS: a Toolbox for Research in Interacting Quantum Systems
*
* Copyright (C) 2013 by O. Parcollet
*
* TRIQS is free software: you can redistribute it and/or modify it under the
* terms of the GNU General Public License as published by the Free Software
* Foundation, either version 3 of the License, or (at your option) any later
* version.
*
* TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*
* You should have received a copy of the GNU General Public License along with
* TRIQS. If not, see <http://www.gnu.org/licenses/>.
*
******************************************************************************/
#include <iostream>
#include <type_traits>
#include <triqs/arrays.hpp>
#include <triqs/mpi.hpp>
#include <iostream>
#include <fstream>
#include <sstream>
using namespace triqs;
using namespace triqs::arrays;
using namespace triqs::mpi;
int main(int argc, char* argv[]) {
mpi::environment env(argc, argv);
mpi::communicator world;
// using ARR = array<double,2>;
using ARR = array<std::complex<double>, 2>;
ARR A(7, 3), B, AA;
clef::placeholder<0> i_;
clef::placeholder<1> j_;
A(i_, j_) << i_ + 10 * j_;
B = mpi::scatter(A, world);
ARR C = mpi::scatter(A, world);
std::ofstream out("node" + std::to_string(world.rank()));
out << " A = " << A << std::endl;
out << " B = " << B << std::endl;
out << " C = " << C << std::endl;
B *= -1;
AA() = 0;
AA = mpi::gather(B, world);
out << " AA = " << AA << std::endl;
mpi::broadcast(AA, world);
out << " cast AA = " << AA << std::endl;
AA() = 0;
AA = mpi::allgather(B, world);
out << " AA = " << AA << std::endl;
}

View File

@ -0,0 +1,97 @@
/*******************************************************************************
*
* TRIQS: a Toolbox for Research in Interacting Quantum Systems
*
* Copyright (C) 2013 by O. Parcollet
*
* TRIQS is free software: you can redistribute it and/or modify it under the
* terms of the GNU General Public License as published by the Free Software
* Foundation, either version 3 of the License, or (at your option) any later
* version.
*
* TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*
* You should have received a copy of the GNU General Public License along with
* TRIQS. If not, see <http://www.gnu.org/licenses/>.
*
******************************************************************************/
#include <iostream>
#include <type_traits>
#include <triqs/arrays.hpp>
#include <triqs/mpi.hpp>
#include <iostream>
#include <fstream>
#include <sstream>
using namespace triqs;
using namespace triqs::arrays;
using namespace triqs::mpi;
struct my_object {
array<double, 1> a, b;
TRIQS_MPI_IMPLEMENTED_AS_TUPLEVIEW;
my_object() = default;
my_object(int s) : a(s), b(s) {
clef::placeholder<0> i_;
a(i_) << i_;
b(i_) << -i_;
}
// construction from the lazy is delegated to =
template <typename Tag> my_object(mpi_lazy<Tag, my_object> x) : my_object() { operator=(x); }
// assigment is almost done already...
template <typename Tag> my_object &operator=(mpi_lazy<Tag, my_object> x) {
return mpi_impl_tuple<my_object>::complete_operation(*this, x);
}
};
// non intrusive
auto view_as_tuple(my_object const &x) RETURN(std::tie(x.a, x.b));
auto view_as_tuple(my_object &x) RETURN(std::tie(x.a, x.b));
// --------------------------------------
int main(int argc, char *argv[]) {
mpi::environment env(argc, argv);
mpi::communicator world;
std::ofstream out("t2_node" + std::to_string(world.rank()));
auto ob = my_object(10);
mpi::broadcast(ob);
out << " a = " << ob.a << std::endl;
out << " b = " << ob.b << std::endl;
auto ob2 = ob;
// ok scatter all components
ob2 = mpi::scatter(ob);
out << " scattered a = " << ob2.a << std::endl;
out << " scattered b = " << ob2.b << std::endl;
ob2.a *= world.rank()+1; // change it a bit
// now regroup...
ob = mpi::gather(ob2);
out << " gather a = " << ob.a << std::endl;
out << " gather b = " << ob.b << std::endl;
// allgather
ob = mpi::allgather(ob2);
out << " allgather a = " << ob.a << std::endl;
out << " allgather b = " << ob.b << std::endl;
out << "----------------------------"<< std::endl;
}

View File

@ -0,0 +1,86 @@
/*******************************************************************************
*
* TRIQS: a Toolbox for Research in Interacting Quantum Systems
*
* Copyright (C) 2014 by O. Parcollet
*
* TRIQS is free software: you can redistribute it and/or modify it under the
* terms of the GNU General Public License as published by the Free Software
* Foundation, either version 3 of the License, or (at your option) any later
* version.
*
* TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*
* You should have received a copy of the GNU General Public License along with
* TRIQS. If not, see <http://www.gnu.org/licenses/>.
*
******************************************************************************/
#include <iostream>
#include <type_traits>
#include <triqs/arrays.hpp>
#include <triqs/mpi.hpp>
#include <iostream>
#include <fstream>
#include <sstream>
using namespace triqs;
using namespace triqs::arrays;
using namespace triqs::mpi;
int main(int argc, char *argv[]) {
mpi::environment env(argc, argv);
mpi::communicator world;
// TO BE RETESTED:
/*
mpi::environment env(argc, argv);
mpi::communicator world;
array<long,2> A {{1,2}, {3,4}}, C(2,2);
// boost mpi
boost::mpi::reduce (world, A,C, std::c14::plus<>(),0);
int s= world.size();
if (world.rank() ==0) std::cout<<" C = "<<C<< " should be "<< std::endl << array<long,2>(s*A) <<std::endl;
// triqs mpi
C = A;
mpi::reduce_in_place (world, C);
if (world.rank() ==0) std::cout<<" C = "<<C<< " should be "<< std::endl << array<long,2>(s*A) <<std::endl;
// test rvalue views
C = A;
mpi::reduce_in_place (world, C());
if (world.rank() ==0) std::cout<<" C = "<<C<< " should be "<< std::endl << array<long,2>(s*A) <<std::endl;
// more complex class
auto x = S { { {1,2},{3,4}}, { 1,2,3,4}};
mpi::reduce_in_place (world, x);
if (world.rank() ==0) std::cout<<" S.x = "<<x.x<<" S.y = "<<x.y<<std::endl;
// a simple number
double y = 1+world.rank(), z=0;
mpi::reduce(world,y,z);
if (world.rank() ==0) std::cout<<" y = "<<y<< " should be "<< 1+world.rank()<<std::endl;
if (world.rank() ==0) std::cout<<" z = "<<z<< " should be "<< s*(s+1)/2 <<std::endl;
mpi::reduce_in_place(world,y);
if (world.rank() ==0) std::cout<<" y = "<<y<< " should be "<< s*(s+1)/2 <<std::endl;
mpi::broadcast(world,C);
// reduced x,y,C, .... a variadic form
mpi::reduce_in_place_v (world, x,y,C);
// more complex object
auto ca = array< array<int,1>, 1 > { array<int,1>{1,2}, array<int,1>{3,4}};
auto cC = ca;
mpi::reduce_in_place (world, cC);
if (world.rank() ==0) std::cout<<" cC = "<<cC<< std::endl;
*/
return 0;
}

View File

@ -21,7 +21,7 @@
#include <iostream>
#include <type_traits>
#include <triqs/arrays.hpp>
#include <triqs/utility/mpi.hpp>
#include <triqs/utility/mpi1.hpp>
#include <iostream>
#include <sstream>

View File

@ -72,6 +72,9 @@ namespace triqs { namespace arrays {
template<class RHS,class LHS> struct is_isp :
std::integral_constant<bool, std::is_base_of<Tag::indexmap_storage_pair,RHS>::value && (!is_scalar_for<RHS,LHS>::value) > {};
/// RHS is special type that defines its own specialization of assign
template<class RHS,class LHS> struct is_special : std::false_type {};
#define TRIQS_REJECT_ASSIGN_TO_CONST \
static_assert( (!std::is_const<typename LHS::value_type>::value ), "Assignment : The value type of the LHS is const and cannot be assigned to !");
#define TRIQS_REJECT_MATRIX_COMPOUND_MUL_DIV_NON_SCALAR\
@ -106,7 +109,7 @@ namespace triqs { namespace arrays {
// ----------------- assignment for expressions RHS --------------------------------------------------
template<typename LHS, typename RHS, char OP>
struct impl<LHS,RHS,OP, ENABLE_IFC( ImmutableCuboidArray<RHS>::value && (!is_scalar_for<RHS,LHS>::value) && (!is_isp<RHS,LHS>::value)) > {
struct impl<LHS,RHS,OP, ENABLE_IFC( ImmutableCuboidArray<RHS>::value && (!is_scalar_for<RHS,LHS>::value) && (!is_isp<RHS,LHS>::value)&& (!is_special<LHS,RHS>::value)) > {
TRIQS_REJECT_ASSIGN_TO_CONST;
TRIQS_REJECT_MATRIX_COMPOUND_MUL_DIV_NON_SCALAR;
typedef typename LHS::value_type value_type;

29
triqs/mpi.hpp Normal file
View File

@ -0,0 +1,29 @@
/*******************************************************************************
*
* TRIQS: a Toolbox for Research in Interacting Quantum Systems
*
* Copyright (C) 2014 by O. Parcollet
*
* TRIQS is free software: you can redistribute it and/or modify it under the
* terms of the GNU General Public License as published by the Free Software
* Foundation, either version 3 of the License, or (at your option) any later
* version.
*
* TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*
* You should have received a copy of the GNU General Public License along with
* TRIQS. If not, see <http://www.gnu.org/licenses/>.
*
******************************************************************************/
#ifndef TRIQS_MPI_H
#define TRIQS_MPI_H
#include "./mpi/arrays.hpp"
#include "./mpi/vector.hpp"
#include "./mpi/generic.hpp"
#endif

197
triqs/mpi/arrays.hpp Normal file
View File

@ -0,0 +1,197 @@
/*******************************************************************************
*
* TRIQS: a Toolbox for Research in Interacting Quantum Systems
*
* Copyright (C) 2014 by O. Parcollet
*
* TRIQS is free software: you can redistribute it and/or modify it under the
* terms of the GNU General Public License as published by the Free Software
* Foundation, either version 3 of the License, or (at your option) any later
* version.
*
* TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*
* You should have received a copy of the GNU General Public License along with
* TRIQS. If not, see <http://www.gnu.org/licenses/>.
*
******************************************************************************/
#pragma once
#include <triqs/arrays.hpp>
#include "./base.hpp"
namespace triqs {
namespace mpi {
//--------------------------------------------------------------------------------------------------------
// The lazy ref made by scatter and co.
// Differs from the generic one in that it can make a domain of the (target) array
template <typename Tag, typename A> struct mpi_lazy_array {
A const &ref;
int root;
communicator c;
using domain_type = typename A::domain_type;
/// compute the array domain of the target array
domain_type domain() const {
auto dims = ref.shape();
long slow_size = first_dim(ref);
if (std::is_same<Tag, tag::scatter>::value) {
dims[0] = slice_length(slow_size - 1, c, c.rank());
}
if (std::is_same<Tag, tag::gather>::value) {
long slow_size_total = 0;
MPI_Reduce(&slow_size, &slow_size_total, 1, mpi_datatype<long>::invoke(), MPI_SUM, root, c.get());
dims[0] = slow_size_total;
// valid only on root
}
if (std::is_same<Tag, tag::allgather>::value) {
long slow_size_total = 0;
MPI_Allreduce(&slow_size, &slow_size_total, 1, mpi_datatype<long>::invoke(), MPI_SUM, c.get());
dims[0] = slow_size_total;
// in this case, it is valid on all nodes
}
return domain_type{dims};
}
};
//--------------------------------------------------------------------------------------------------------
// When value_type is a basic type, we can directly call the C API
template <typename A> class mpi_impl_triqs_arrays {
static MPI_Datatype D() { return mpi_datatype<typename A::value_type>::invoke(); }
static void check_is_contiguous(A const &a) {
if (!has_contiguous_data(a)) TRIQS_RUNTIME_ERROR << "Non contiguous view in mpi_reduce_in_place";
}
public:
//---------
static void reduce_in_place(communicator c, A &a, int root) {
check_is_contiguous(a);
// assume arrays have the same size on all nodes...
MPI_Reduce((c.rank() == root ? MPI_IN_PLACE : a), a.data_start(), a.domain().number_of_elements(), D(), MPI_SUM, root, c.get());
}
//---------
static void allreduce_in_place(communicator c, A &a, int root) {
check_is_contiguous(a);
// assume arrays have the same size on all nodes...
MPI_Allreduce(MPI_IN_PLACE, a.data_start(), a.domain().number_of_elements(), D(), MPI_SUM, root, c.get());
}
//---------
static void broadcast(communicator c, A &a, int root) {
check_is_contiguous(a);
auto sh = a.shape();
MPI_Bcast(&sh[0], sh.size(), mpi_datatype<typename decltype(sh)::value_type>::invoke(), root, c.get());
if (c.rank() != root) a.resize(sh);
MPI_Bcast(a.data_start(), a.domain().number_of_elements(), D(), root, c.get());
}
//---------
template <typename Tag> static mpi_lazy_array<Tag, A> invoke(Tag, communicator c, A const &a, int root) {
check_is_contiguous(a);
return {a, root, c};
}
};
template <typename A>
struct mpi_impl<A, std14::enable_if_t<triqs::arrays::is_amv_value_or_view_class<A>::value>> : mpi_impl_triqs_arrays<A> {};
} // mpi namespace
//------------------------------- Delegation of the assign operator of the array class -------------
namespace arrays {
// mpi_lazy_array model ImmutableCuboidArray
template <typename Tag, typename A> struct ImmutableCuboidArray<mpi::mpi_lazy_array<Tag, A>> : ImmutableCuboidArray<A> {};
namespace assignment {
template <typename LHS, typename Tag, typename A> struct is_special<LHS, mpi::mpi_lazy_array<Tag, A>> : std::true_type {};
// assignment delegation
template <typename LHS, typename A, typename Tag> struct impl<LHS, mpi::mpi_lazy_array<Tag, A>, 'E', void> {
using laz_t = mpi::mpi_lazy_array<Tag, A>;
LHS &lhs;
laz_t laz;
impl(LHS &lhs_, laz_t laz_) : lhs(lhs_), laz(laz_) {}
void invoke() { _invoke(Tag()); }
private:
static MPI_Datatype D() { return mpi::mpi_datatype<typename A::value_type>::invoke(); }
//---------------------------------
void _invoke(triqs::mpi::tag::scatter) {
lhs.resize(laz.domain());
auto c = laz.c;
auto slow_size = first_dim(laz.ref);
auto slow_stride = laz.ref.indexmap().strides()[0];
auto sendcounts = std::vector<int>(c.size());
auto displs = std::vector<int>(c.size() + 1, 0);
int recvcount = slice_length(slow_size - 1, c, c.rank()) * slow_stride;
for (int r = 0; r < c.size(); ++r) {
sendcounts[r] = slice_length(slow_size - 1, c, r) * slow_stride;
displs[r + 1] = sendcounts[r] + displs[r];
}
MPI_Scatterv((void *)laz.ref.data_start(), &sendcounts[0], &displs[0], D(), (void *)lhs.data_start(), recvcount, D(),
laz.root, c.get());
}
//---------------------------------
void _invoke(triqs::mpi::tag::gather) {
lhs.resize(laz.domain());
auto c = laz.c;
auto recvcounts = std::vector<int>(c.size());
auto displs = std::vector<int>(c.size() + 1, 0);
int sendcount = laz.ref.domain().number_of_elements();
auto mpi_ty = mpi::mpi_datatype<int>::invoke();
MPI_Gather(&sendcount, 1, mpi_ty, &recvcounts[0], 1, mpi_ty, laz.root, c.get());
for (int r = 0; r < c.size(); ++r) displs[r + 1] = recvcounts[r] + displs[r];
MPI_Gatherv((void *)laz.ref.data_start(), sendcount, D(), (void *)lhs.data_start(), &recvcounts[0], &displs[0], D(), laz.root,
c.get());
}
//---------------------------------
void _invoke(triqs::mpi::tag::allgather) {
lhs.resize(laz.domain());
// almost the same preparation as gather, except that the recvcounts are ALL gathered...
auto c = laz.c;
auto recvcounts = std::vector<int>(c.size());
auto displs = std::vector<int>(c.size() + 1, 0);
int sendcount = laz.ref.domain().number_of_elements();
auto mpi_ty = mpi::mpi_datatype<int>::invoke();
MPI_Allgather(&sendcount, 1, mpi_ty, &recvcounts[0], 1, mpi_ty, c.get());
for (int r = 0; r < c.size(); ++r) displs[r + 1] = recvcounts[r] + displs[r];
MPI_Allgatherv((void *)laz.ref.data_start(), sendcount, D(), (void *)lhs.data_start(), &recvcounts[0], &displs[0], D(),
c.get());
}
};
}
} //namespace arrays
} // namespace triqs

148
triqs/mpi/base.hpp Normal file
View File

@ -0,0 +1,148 @@
/*******************************************************************************
*
* TRIQS: a Toolbox for Research in Interacting Quantum Systems
*
* Copyright (C) 2014 by O. Parcollet
*
* TRIQS is free software: you can redistribute it and/or modify it under the
* terms of the GNU General Public License as published by the Free Software
* Foundation, either version 3 of the License, or (at your option) any later
* version.
*
* TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*
* You should have received a copy of the GNU General Public License along with
* TRIQS. If not, see <http://www.gnu.org/licenses/>.
*
******************************************************************************/
#pragma once
#include <triqs/utility/c14.hpp>
//#include <triqs/utility/tuple_tools.hpp>
#include <mpi.h>
namespace triqs {
namespace mpi {
/// Environment
struct environment {
environment(int argc, char *argv[]) { MPI_Init(&argc, &argv); }
~environment() { MPI_Finalize(); }
};
/// The communicator. Todo : add more constructors.
class communicator {
MPI_Comm _com = MPI_COMM_WORLD;
public:
communicator() = default;
MPI_Comm get() const { return _com; }
int rank() const {
int num;
MPI_Comm_rank(_com, &num);
return num;
}
int size() const {
int num;
MPI_Comm_size(_com, &num);
return num;
}
void barrier() const { MPI_Barrier(_com); }
};
/// a tag for each operation
namespace tag {
struct reduce {};
struct allreduce {};
struct scatter {};
struct gather {};
struct allgather {};
}
/// The implementation of mpi ops for each type
template <typename T, typename Enable = void> struct mpi_impl;
// ----------------------------------------
// ------- top level functions -------
// ----------------------------------------
// ----- functions that can be lazy -------
template <typename T>
AUTO_DECL reduce(T const &x, communicator c = {}, int root = 0) RETURN(mpi_impl<T>::invoke(tag::reduce(), c, x, root));
template <typename T>
AUTO_DECL scatter(T const &x, communicator c = {}, int root = 0) RETURN(mpi_impl<T>::invoke(tag::scatter(), c, x, root));
template <typename T>
AUTO_DECL gather(T const &x, communicator c = {}, int root = 0) RETURN(mpi_impl<T>::invoke(tag::gather(), c, x, root));
template <typename T>
AUTO_DECL allreduce(T const &x, communicator c = {}, int root = 0) RETURN(mpi_impl<T>::invoke(tag::allreduce(), c, x, root));
template <typename T>
AUTO_DECL allgather(T const &x, communicator c = {}, int root = 0) RETURN(mpi_impl<T>::invoke(tag::allgather(), c, x, root));
// ----- functions that cannot be lazy -------
template <typename T> void reduce_in_place(T &x, communicator c = {}, int root = 0) { mpi_impl<T>::reduce_in_place(c, x, root); }
template <typename T> void broadcast(T &x, communicator c = {}, int root = 0) { mpi_impl<T>::broadcast(c, x, root); }
// transformation type -> mpi types
template <class T> struct mpi_datatype;
#define D(T, MPI_TY) \
template <> struct mpi_datatype<T> { \
static MPI_Datatype invoke() { return MPI_TY; } \
};
D(int, MPI_INT) D(long, MPI_LONG) D(double, MPI_DOUBLE) D(float, MPI_FLOAT) D(std::complex<double>, MPI_DOUBLE_COMPLEX);
D(unsigned long, MPI_UNSIGNED_LONG);
#undef D
/** ------------------------------------------------------------
* basic types
* ---------------------------------------------------------- **/
template <typename T> struct mpi_impl_basic {
static MPI_Datatype D() { return mpi_datatype<T>::invoke(); }
static T invoke(tag::reduce, communicator c, T a, int root) {
T b;
MPI_Reduce(&a, &b, 1, D(), MPI_SUM, root, c.get());
return b;
}
static T invoke(tag::allreduce, communicator c, T a, int root) {
T b;
MPI_Allreduce(&a, &b, 1, D(), MPI_SUM, root, c.get());
return b;
}
static void reduce_in_place(communicator c, T &a, int root) {
MPI_Reduce((c.rank() == root ? MPI_IN_PLACE : &a), &a, 1, D(), MPI_SUM, root, c.get());
}
static void allreduce_in_place(communicator c, T &a, int root) {
MPI_Allreduce(MPI_IN_PLACE, &a, 1, D(), MPI_SUM, root, c.get());
}
static void broadcast(communicator c, T &a, int root) { MPI_Bcast(&a, 1, D(), root, c.get()); }
};
// mpl_impl_basic is the mpi_impl<T> is T is a number (including complex)
template <typename T>
struct mpi_impl<T, std14::enable_if_t<std::is_arithmetic<T>::value || triqs::is_complex<T>::value>> : mpi_impl_basic<T> {};
//------------ Some helper function
inline long slice_length(size_t imax, communicator c, int r) {
auto imin = 0;
long j = (imax - imin + 1) / c.size();
long i = imax - imin + 1 - c.size() * j;
auto r_min = (r <= i - 1 ? imin + r * (j + 1) : imin + r * j + i);
auto r_max = (r <= i - 1 ? imin + (r + 1) * (j + 1) - 1 : imin + (r + 1) * j + i - 1);
return r_max - r_min + 1;
};
}
}

61
triqs/mpi/boost.hpp Normal file
View File

@ -0,0 +1,61 @@
/*******************************************************************************
*
* TRIQS: a Toolbox for Research in Interacting Quantum Systems
*
* Copyright (C) 2014 by O. Parcollet
*
* TRIQS is free software: you can redistribute it and/or modify it under the
* terms of the GNU General Public License as published by the Free Software
* Foundation, either version 3 of the License, or (at your option) any later
* version.
*
* TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*
* You should have received a copy of the GNU General Public License along with
* TRIQS. If not, see <http://www.gnu.org/licenses/>.
*
******************************************************************************/
#pragma once
#include "./base.hpp"
#include <boost/mpi.hpp>
namespace triqs {
namespace mpi {
/** ------------------------------------------------------------
* Type which we use boost::mpi
* ---------------------------------------------------------- **/
template <typename T> struct mpi_impl_boost_mpi {
static T invoke(tag::reduce, communicator c, T const &a, int root) {
T b;
boost::mpi::reduce(c, a, b, std::c14::plus<>(), root);
return b;
}
static T invoke(tag::allreduce, communicator c, T const &a, int root) {
T b;
boost::mpi::all_reduce(c, a, b, std::c14::plus<>(), root);
return b;
}
static void reduce_in_place(communicator c, T &a, int root) { boost::mpi::reduce(c, a, a, std::c14::plus<>(), root); }
static void broadcast(communicator c, T &a, int root) { boost::mpi::broadcast(c, a, root); }
static void scatter(communicator c, T const &, int root) = delete;
static void gather(communicator c, T const &, int root) = delete;
static void allgather(communicator c, T const &, int root) = delete;
};
// default
//template <typename T> struct mpi_impl<T> : mpi_impl_boost_mpi<T> {};
}}//namespace

103
triqs/mpi/generic.hpp Normal file
View File

@ -0,0 +1,103 @@
/*******************************************************************************
*
* TRIQS: a Toolbox for Research in Interacting Quantum Systems
*
* Copyright (C) 2014 by O. Parcollet
*
* TRIQS is free software: you can redistribute it and/or modify it under the
* terms of the GNU General Public License as published by the Free Software
* Foundation, either version 3 of the License, or (at your option) any later
* version.
*
* TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*
* You should have received a copy of the GNU General Public License along with
* TRIQS. If not, see <http://www.gnu.org/licenses/>.
*
******************************************************************************/
#pragma once
#include "./base.hpp"
#include <triqs/utility/tuple_tools.hpp>
#define TRIQS_MPI_IMPLEMENTED_AS_TUPLEVIEW using triqs_mpi_as_tuple = void;
namespace triqs {
namespace mpi {
template <typename Tag, typename T> struct mpi_lazy {
T const &ref;
int root;
communicator c;
};
/** ------------------------------------------------------------
* Type which are recursively treated by reducing them to a tuple
* of smaller objects.
* ---------------------------------------------------------- **/
template <typename T> struct mpi_impl_tuple {
mpi_impl_tuple() = default;
template <typename Tag> static mpi_lazy<Tag, T> invoke(Tag, communicator c, T const &a, int root) {
return {a, root, c};
}
#ifdef __cpp_generic_lambdas
static void reduce_in_place(communicator c, T &a, int root) {
tuple::for_each(view_as_tuple(a), [c, root](auto &x) { triqs::mpi::reduce_in_place(x, c, root); });
}
static void broadcast(communicator c, T &a, int root) {
tuple::for_each(view_as_tuple(a), [c, root](auto &x) { triqs::mpi::broadcast(x, c, root); });
}
template <typename Tag> static T &complete_operation(T &target, mpi_lazy<Tag, T> laz) {
auto l = [laz](auto &t, auto &s) { t = triqs::mpi::mpi_impl<std::decay_t<decltype(s)>>::invoke(Tag(), laz.c, s, laz.root); };
triqs::tuple::for_each_zip(l, view_as_tuple(target), view_as_tuple(laz.ref));
return target;
}
#else
struct aux1 {
communicator c;
int root;
template <typename T1> void operator()(T1 &x) const { triqs::mpi::reduce_in_place(c, x, root); }
};
static void reduce_in_place(communicator c, T &a, int root) {
tuple::for_each(aux1{c, root}, view_as_tuple(a));
}
struct aux2 {
communicator c;
int root;
template <typename T2> void operator()(T2 &x) const { triqs::mpi::broadcast(c, x, root); }
};
static void broadcast(communicator c, T &a, int root) {
tuple::for_each(aux2{c, root}, view_as_tuple(a));
}
template <typename Tag> struct aux3 {
mpi_lazy<Tag, T> laz;
template <typename T1, typename T2> void operator()(T1 &t, T2 &s) const {
t = triqs::mpi::mpi_impl<T2>::invoke(Tag(), laz.c, laz.s);
}
};
template <typename Tag> static void complete_operation(T &target, mpi_lazy<Tag, T> laz) {
auto l = aux3<Tag>{laz};
triqs::tuple::for_each_zip(l, view_as_tuple(target), view_as_tuple(laz.ref));
}
#endif
};
// If type T has a mpi_implementation nested struct, then it is mpi_impl<T>.
template <typename T> struct mpi_impl<T, typename T::triqs_mpi_as_tuple> : mpi_impl_tuple<T> {};
}
} // namespace

116
triqs/mpi/vector.hpp Normal file
View File

@ -0,0 +1,116 @@
/*******************************************************************************
*
* TRIQS: a Toolbox for Research in Interacting Quantum Systems
*
* Copyright (C) 2014 by O. Parcollet
*
* TRIQS is free software: you can redistribute it and/or modify it under the
* terms of the GNU General Public License as published by the Free Software
* Foundation, either version 3 of the License, or (at your option) any later
* version.
*
* TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*
* You should have received a copy of the GNU General Public License along with
* TRIQS. If not, see <http://www.gnu.org/licenses/>.
*
******************************************************************************/
#pragma once
namespace triqs {
namespace mpi {
// When value_type is a basic type, we can directly call the C API
template <typename T> struct mpi_impl_std_vector_basic {
static MPI_Datatype D() { return mpi_datatype<T>::invoke(); }
// -----------
static void reduce_in_place(communicator c, std::vector<T> &a, int root) {
MPI_Reduce((c.rank() == root ? MPI_IN_PLACE : a.data()), a.data(), a.size(), D(), MPI_SUM, root, c.get());
}
static void allreduce_in_place(communicator c, std::vector<T> &a, int root) {
MPI_Allreduce(MPI_IN_PLACE, a.data(), a.size(), D(), MPI_SUM, root, c.get());
}
// -----------
static void broadcast(communicator c, std::vector<T> &a, int root) { MPI_Bcast(a.data(), a.size(), D(), root, c.get()); }
// -----------
static std::vector<T> invoke(tag::reduce, communicator c, T const &a, int root) {
std::vector<T> b(a.size());
MPI_Reduce(a.data(), b.data(), a.size(), D(), MPI_SUM, root, c.get());
return b;
}
// -----------
static std::vector<T> invoke(tag::allreduce, communicator c, std::vector<T> const &a, int root) {
std::vector<T> b(a.size());
MPI_Allreduce(a.data(), b.data(), a.size(), D(), MPI_SUM, root, c.get());
return b;
}
// -----------
static std::vector<T> invoke(tag::scatter, communicator c, std::vector<T> const &a, int root) {
auto slow_size = a.size();
auto sendcounts = std::vector<int>(c.size());
auto displs = std::vector<int>(c.size() + 1, 0);
int recvcount = slice_length(slow_size - 1, c, c.rank());
std::vector<T> b(recvcount);
for (int r = 0; r < c.size(); ++r) {
sendcounts[r] = slice_length(slow_size - 1, c, r);
displs[r + 1] = sendcounts[r] + displs[r];
}
MPI_Scatterv((void *)a.data(), &sendcounts[0], &displs[0], D(), (void *)b.data(), recvcount, D(), root, c.get());
return b;
}
// -----------
static std::vector<T> invoke(tag::gather, communicator c, std::vector<T> const &a, int root) {
long size = reduce(a.size(), c, root);
std::vector<T> b(size);
auto recvcounts = std::vector<int>(c.size());
auto displs = std::vector<int>(c.size() + 1, 0);
int sendcount = a.size();
auto mpi_ty = mpi::mpi_datatype<int>::invoke();
MPI_Gather(&sendcount, 1, mpi_ty, &recvcounts[0], 1, mpi_ty, root, c.get());
for (int r = 0; r < c.size(); ++r) displs[r + 1] = recvcounts[r] + displs[r];
MPI_Gatherv((void *)a.data(), sendcount, D(), (void *)b.data(), &recvcounts[0], &displs[0], D(), root, c.get());
return b;
}
// -----------
static std::vector<T> invoke(tag::allgather, communicator c, std::vector<T> const &a, int root) {
long size = reduce(a.size(), c, root);
std::vector<T> b(size);
auto recvcounts = std::vector<int>(c.size());
auto displs = std::vector<int>(c.size() + 1, 0);
int sendcount = a.size();
auto mpi_ty = mpi::mpi_datatype<int>::invoke();
MPI_Allgather(&sendcount, 1, mpi_ty, &recvcounts[0], 1, mpi_ty, c.get());
for (int r = 0; r < c.size(); ++r) displs[r + 1] = recvcounts[r] + displs[r];
MPI_Allgatherv((void *)a.data(), sendcount, D(), (void *)b.data(), &recvcounts[0], &displs[0], D(), c.get());
return b;
}
};
template <typename T>
struct mpi_impl<std::vector<T>, std14::enable_if_t<std::is_arithmetic<T>::value ||
triqs::is_complex<T>::value>> : mpi_impl_std_vector_basic<T> {};
// vector <T> for T non basic
}
} // namespace