/*******************************************************************************
*
* TRIQS: a Toolbox for Research in Interacting Quantum Systems
*
* Copyright (C) 2014 by O. Parcollet
*
* TRIQS is free software: you can redistribute it and/or modify it under the
* terms of the GNU General Public License as published by the Free Software
* Foundation, either version 3 of the License, or (at your option) any later
* version.
*
* TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*
* You should have received a copy of the GNU General Public License along with
* TRIQS. If not, see .
*
******************************************************************************/
#pragma once
#include
#include "./base.hpp"
namespace triqs {
namespace mpi {
//--------------------------------------------------------------------------------------------------------
// The lazy ref made by scatter and co.
// Differs from the generic one in that it can make a domain of the (target) array
template struct mpi_lazy_array {
A const &ref;
int root;
communicator c;
using domain_type = typename A::domain_type;
/// compute the array domain of the target array
domain_type domain() const {
auto dims = ref.shape();
long slow_size = first_dim(ref);
if (std::is_same::value) {
// optionally check all dims are the same ?
}
if (std::is_same::value) {
dims[0] = mpi::slice_length(slow_size - 1, c.size(), c.rank());
}
if (std::is_same::value) {
long slow_size_total = 0;
MPI_Reduce(&slow_size, &slow_size_total, 1, mpi_datatype::invoke(), MPI_SUM, root, c.get());
dims[0] = slow_size_total;
// valid only on root
}
if (std::is_same::value) {
long slow_size_total = 0;
MPI_Allreduce(&slow_size, &slow_size_total, 1, mpi_datatype::invoke(), MPI_SUM, c.get());
dims[0] = slow_size_total;
// in this case, it is valid on all nodes
}
return domain_type{dims};
}
};
//--------------------------------------------------------------------------------------------------------
// When value_type is a basic type, we can directly call the C API
template class mpi_impl_triqs_arrays {
static MPI_Datatype D() { return mpi_datatype::invoke(); }
static void check_is_contiguous(A const &a) {
if (!has_contiguous_data(a)) TRIQS_RUNTIME_ERROR << "Non contiguous view in mpi_reduce_in_place";
}
public:
//---------
static void reduce_in_place(communicator c, A &a, int root) {
check_is_contiguous(a);
// assume arrays have the same size on all nodes...
MPI_Reduce((c.rank() == root ? MPI_IN_PLACE : a), a.data_start(), a.domain().number_of_elements(), D(), MPI_SUM, root, c.get());
}
//---------
static void allreduce_in_place(communicator c, A &a, int root) {
check_is_contiguous(a);
// assume arrays have the same size on all nodes...
MPI_Allreduce(MPI_IN_PLACE, a.data_start(), a.domain().number_of_elements(), D(), MPI_SUM, c.get());
}
//---------
static void broadcast(communicator c, A &a, int root) {
check_is_contiguous(a);
auto sh = a.shape();
MPI_Bcast(&sh[0], sh.size(), mpi_datatype::invoke(), root, c.get());
if (c.rank() != root) a.resize(sh);
MPI_Bcast(a.data_start(), a.domain().number_of_elements(), D(), root, c.get());
}
//---------
template static mpi_lazy_array invoke(Tag, communicator c, A const &a, int root) {
check_is_contiguous(a);
return {a, root, c};
}
};
template
struct mpi_impl::value>> : mpi_impl_triqs_arrays {};
} // mpi namespace
//------------------------------- Delegation of the assign operator of the array class -------------
namespace arrays {
// mpi_lazy_array model ImmutableCuboidArray
template struct ImmutableCuboidArray> : ImmutableCuboidArray {};
namespace assignment {
template struct is_special> : std::true_type {};
// assignment delegation
template struct impl, 'E', void> {
using laz_t = mpi::mpi_lazy_array;
LHS &lhs;
laz_t laz;
impl(LHS &lhs_, laz_t laz_) : lhs(lhs_), laz(laz_) {}
void invoke() { _invoke(Tag()); }
private:
static MPI_Datatype D() { return mpi::mpi_datatype::invoke(); }
//---------------------------------
void _invoke(triqs::mpi::tag::reduce) {
lhs.resize(laz.domain());
MPI_Reduce((void *)laz.ref.data_start(), (void *)lhs.data_start(), laz.ref.domain().number_of_elements(), D(), MPI_SUM, laz.root, laz.c.get());
}
//---------------------------------
void _invoke(triqs::mpi::tag::allreduce) {
lhs.resize(laz.domain());
MPI_Allreduce((void *)laz.ref.data_start(), (void *)lhs.data_start(), laz.ref.domain().number_of_elements(), D(), MPI_SUM, laz.c.get());
}
//---------------------------------
void _invoke(triqs::mpi::tag::scatter) {
lhs.resize(laz.domain());
auto c = laz.c;
auto slow_size = first_dim(laz.ref);
auto slow_stride = laz.ref.indexmap().strides()[0];
auto sendcounts = std::vector(c.size());
auto displs = std::vector(c.size() + 1, 0);
int recvcount = mpi::slice_length(slow_size - 1, c.size(), c.rank()) * slow_stride;
for (int r = 0; r < c.size(); ++r) {
sendcounts[r] = mpi::slice_length(slow_size - 1, c.size(), r) * slow_stride;
displs[r + 1] = sendcounts[r] + displs[r];
}
MPI_Scatterv((void *)laz.ref.data_start(), &sendcounts[0], &displs[0], D(), (void *)lhs.data_start(), recvcount, D(),
laz.root, c.get());
}
//---------------------------------
void _invoke(triqs::mpi::tag::gather) {
lhs.resize(laz.domain());
auto c = laz.c;
auto recvcounts = std::vector(c.size());
auto displs = std::vector(c.size() + 1, 0);
int sendcount = laz.ref.domain().number_of_elements();
auto mpi_ty = mpi::mpi_datatype::invoke();
MPI_Gather(&sendcount, 1, mpi_ty, &recvcounts[0], 1, mpi_ty, laz.root, c.get());
for (int r = 0; r < c.size(); ++r) displs[r + 1] = recvcounts[r] + displs[r];
MPI_Gatherv((void *)laz.ref.data_start(), sendcount, D(), (void *)lhs.data_start(), &recvcounts[0], &displs[0], D(), laz.root,
c.get());
}
//---------------------------------
void _invoke(triqs::mpi::tag::allgather) {
lhs.resize(laz.domain());
// almost the same preparation as gather, except that the recvcounts are ALL gathered...
auto c = laz.c;
auto recvcounts = std::vector(c.size());
auto displs = std::vector(c.size() + 1, 0);
int sendcount = laz.ref.domain().number_of_elements();
auto mpi_ty = mpi::mpi_datatype::invoke();
MPI_Allgather(&sendcount, 1, mpi_ty, &recvcounts[0], 1, mpi_ty, c.get());
for (int r = 0; r < c.size(); ++r) displs[r + 1] = recvcounts[r] + displs[r];
MPI_Allgatherv((void *)laz.ref.data_start(), sendcount, D(), (void *)lhs.data_start(), &recvcounts[0], &displs[0], D(),
c.get());
}
};
}
} //namespace arrays
} // namespace triqs