Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a gather() function to gather std::map into the root process; #74

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
58 changes: 58 additions & 0 deletions include/diy/mpi/collectives.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#include <vector>
#include <map>
#include <numeric>

#include "../constants.h" // for DIY_UNUSED.
#include "../serialization.hpp" // for serialization
#include "operations.hpp"

namespace diy
Expand Down Expand Up @@ -363,6 +366,61 @@ namespace mpi
Collectives<T,void*>::gather(comm, in, root);
}

//! Gathering std::map into the root process; K and V are serialized into string
// and thus no datatype needs to be created for K and V
template <typename K, typename V> // merging an std::map to root using gather
inline
void gather(const communicator& comm, const std::map<K, V>& in, std::map<K, V> &out, int root)
{
#ifndef DIY_NO_MPI
// serialize input map
std::string serialized_in;
if (comm.rank() != root) // avoid serializing data from the root proc
serializeToString(in, serialized_in);

// gathering length of serialized data
int length_serialized_in = serialized_in.size();
std::vector<int> all_length_serialized_in(comm.size(), 0);
MPI_Gather(&length_serialized_in, 1, MPI_INT,
&all_length_serialized_in[0], 1, MPI_INT,
root, comm);

// preparing buffer and displacements for MPI_Gatherv
std::string buffer;
std::vector<int> displs;
if (comm.rank() == root) {
buffer.resize(std::accumulate( // prepare buffer
all_length_serialized_in.begin(),
all_length_serialized_in.end(), 0));

displs.resize(all_length_serialized_in.size(), 0); // prepare displs
for (int i = 1; i < comm.size(); i ++)
displs[i] = displs[i-1] + all_length_serialized_in[i-1];
}

// call MPI_Gatherv
MPI_Gatherv(serialized_in.data(), serialized_in.size(), MPI_CHAR,
&buffer[0], all_length_serialized_in.data(), displs.data(), MPI_CHAR,
root, comm);

// unserailization (root proc only)
if (comm.rank() == root) {
out = in;
StringBuffer sb(buffer);
while (sb) {
std::map<K, V> map;
load(sb, map);
for (const auto &kv : map)
out.insert(kv);
}
}
#else
DIY_UNUSED(comm);
DIY_UNUSED(root);
out = in;
#endif
}

//! all_gather from all processes in `comm`.
//! `out` is resized to `comm.size()` and filled with
//! elements from the respective ranks.
Expand Down
72 changes: 72 additions & 0 deletions include/diy/serialization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,38 @@ namespace diy
size_t position;
std::vector<char> buffer;
};


//! An alternative memory buffer based on std::string
struct StringBuffer : public BinaryBuffer {
std::string &str;
size_t pos;

explicit StringBuffer(std::string& str_, size_t pos_=0) : str(str_), pos(pos_) {}
void clear() {str.clear(); pos = 0;}
void reset() {pos = 0;}

operator bool() const {return pos < str.size();}

inline void save_binary(const char *x, size_t count) {
if (pos + count > str.size()) str.resize(pos + count);
memcpy((char*)(str.data()+pos), x, count);
pos += count;
}

inline void append_binary(const char *x, size_t count) {
str.append(x, count);
}

inline void load_binary(char *x, size_t count) {
memcpy(x, str.data()+pos, count);
pos += count;
}

inline void load_binary_back(char *x, size_t count) {
memcpy(x, str.data()+str.size()-count, count);
}
};

namespace detail
{
Expand Down Expand Up @@ -126,6 +158,23 @@ namespace diy

//@}

//! serialize a data structure into string
template <typename T>
void serializeToString(const T& obj, std::string& buf)
{
buf.clear();
diy::StringBuffer bb(buf);
diy::save(bb, obj);
}

//! unserialize a data structure from string
template <typename T>
void unserializeFromString(const std::string& buf, T& obj)
{
std::string buf1(buf);
diy::StringBuffer bb(buf1);
diy::load(bb, obj);
}

namespace detail
{
Expand Down Expand Up @@ -262,6 +311,29 @@ namespace diy
}
};

// save/load for std::deque
template<class T>
struct Serialization< std::deque<T> > {
typedef std::deque<T> deque;

static void save(BinaryBuffer& bb, const deque& q) {
size_t s = q.size();
diy::save(bb, q);
for (typename std::deque<T>::const_iterator it = q.begin(); it != q.end(); it ++)
diy::save(bb, *it);
}

static void load(BinaryBuffer& bb, deque& q) {
size_t s;
diy::load(bb, s);
for (int i = 0; i < s; i ++) {
T p;
diy::load(bb, p);
q.emplace_back(p);
}
}
};

// save/load for std::pair<X,Y>
template<class X, class Y>
struct Serialization< std::pair<X,Y> >
Expand Down