/*******************************************************************************
*
* TRIQS: a Toolbox for Research in Interacting Quantum Systems
*
* Copyright (C) 2014 by O. Parcollet
*
* TRIQS is free software: you can redistribute it and/or modify it under the
* terms of the GNU General Public License as published by the Free Software
* Foundation, either version 3 of the License, or (at your option) any later
* version.
*
* TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*
* You should have received a copy of the GNU General Public License along with
* TRIQS. If not, see .
*
******************************************************************************/
#pragma once
#include
//#include
#include
namespace triqs {
namespace mpi {
/// Environment
struct environment {
environment(int argc, char *argv[]) { MPI_Init(&argc, &argv); }
~environment() { MPI_Finalize(); }
};
/// The communicator. Todo : add more constructors.
class communicator {
MPI_Comm _com = MPI_COMM_WORLD;
public:
communicator() = default;
MPI_Comm get() const { return _com; }
int rank() const {
int num;
MPI_Comm_rank(_com, &num);
return num;
}
int size() const {
int num;
MPI_Comm_size(_com, &num);
return num;
}
void barrier() const { MPI_Barrier(_com); }
};
/// a tag for each operation
namespace tag {
struct reduce {};
struct allreduce {};
struct scatter {};
struct gather {};
struct allgather {};
}
/// The implementation of mpi ops for each type
template struct mpi_impl;
// ----------------------------------------
// ------- top level functions -------
// ----------------------------------------
// ----- functions that can be lazy -------
template
AUTO_DECL reduce(T const &x, communicator c = {}, int root = 0) RETURN(mpi_impl::invoke(tag::reduce(), c, x, root));
template
AUTO_DECL scatter(T const &x, communicator c = {}, int root = 0) RETURN(mpi_impl::invoke(tag::scatter(), c, x, root));
template
AUTO_DECL gather(T const &x, communicator c = {}, int root = 0) RETURN(mpi_impl::invoke(tag::gather(), c, x, root));
template
AUTO_DECL allreduce(T const &x, communicator c = {}, int root = 0) RETURN(mpi_impl::invoke(tag::allreduce(), c, x, root));
template
AUTO_DECL allgather(T const &x, communicator c = {}, int root = 0) RETURN(mpi_impl::invoke(tag::allgather(), c, x, root));
// ----- functions that cannot be lazy -------
template void reduce_in_place(T &x, communicator c = {}, int root = 0) { mpi_impl::reduce_in_place(c, x, root); }
template void broadcast(T &x, communicator c = {}, int root = 0) { mpi_impl::broadcast(c, x, root); }
// transformation type -> mpi types
template struct mpi_datatype;
#define D(T, MPI_TY) \
template <> struct mpi_datatype { \
static MPI_Datatype invoke() { return MPI_TY; } \
};
D(int, MPI_INT) D(long, MPI_LONG) D(double, MPI_DOUBLE) D(float, MPI_FLOAT) D(std::complex, MPI_DOUBLE_COMPLEX);
D(unsigned long, MPI_UNSIGNED_LONG);
#undef D
/** ------------------------------------------------------------
* basic types
* ---------------------------------------------------------- **/
template struct mpi_impl_basic {
static MPI_Datatype D() { return mpi_datatype::invoke(); }
static T invoke(tag::reduce, communicator c, T a, int root) {
T b;
MPI_Reduce(&a, &b, 1, D(), MPI_SUM, root, c.get());
return b;
}
static T invoke(tag::allreduce, communicator c, T a, int root) {
T b;
MPI_Allreduce(&a, &b, 1, D(), MPI_SUM, c.get());
return b;
}
static void reduce_in_place(communicator c, T &a, int root) {
MPI_Reduce((c.rank() == root ? MPI_IN_PLACE : &a), &a, 1, D(), MPI_SUM, root, c.get());
}
static void allreduce_in_place(communicator c, T &a, int root) {
MPI_Allreduce(MPI_IN_PLACE, &a, 1, D(), MPI_SUM, root, c.get());
}
static void broadcast(communicator c, T &a, int root) { MPI_Bcast(&a, 1, D(), root, c.get()); }
};
// mpl_impl_basic is the mpi_impl is T is a number (including complex)
template
struct mpi_impl::value || triqs::is_complex::value>> : mpi_impl_basic {};
//------------ Some helper function
inline long slice_length(size_t imax, communicator c, int r) {
auto imin = 0;
long j = (imax - imin + 1) / c.size();
long i = imax - imin + 1 - c.size() * j;
auto r_min = (r <= i - 1 ? imin + r * (j + 1) : imin + r * j + i);
auto r_max = (r <= i - 1 ? imin + (r + 1) * (j + 1) - 1 : imin + (r + 1) * j + i - 1);
return r_max - r_min + 1;
};
}
}