Skip to content

Commit

Permalink
Rewrite BBS::alltoall with std::vector (#2829)
Browse files Browse the repository at this point in the history
  • Loading branch information
alkino committed May 8, 2024
1 parent 2b040b7 commit 4a1e15a
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 58 deletions.
12 changes: 12 additions & 0 deletions src/ivoc/ivocvect.h
Expand Up @@ -31,6 +31,10 @@ class IvocVect {
return vec_;
}

inline double const* data() const {
return vec_.data();
}

inline double* data() {
return vec_.data();
}
Expand All @@ -57,6 +61,14 @@ class IvocVect {
return vec_.at(index);
}

inline auto begin() const -> std::vector<double>::const_iterator {
return vec_.begin();
}

inline auto end() const -> std::vector<double>::const_iterator {
return vec_.end();
}

inline auto begin() -> std::vector<double>::iterator {
return vec_.begin();
}
Expand Down
17 changes: 11 additions & 6 deletions src/nrniv/bbsavestate.cpp
Expand Up @@ -219,22 +219,27 @@ extern int nrn_gid_exists(int gid);

#if NRNMPI
extern void nrnmpi_barrier();
extern void nrnmpi_int_alltoallv(int*, int*, int*, int*, int*, int*);
extern void nrnmpi_dbl_alltoallv(double*, int*, int*, double*, int*, int*);
extern void nrnmpi_int_alltoallv(const int*, const int*, const int*, int*, int*, int*);
extern void nrnmpi_dbl_alltoallv(const double*, const int*, const int*, double*, int*, int*);
extern int nrnmpi_int_allmax(int);
extern void nrnmpi_int_allgather(int* s, int* r, int n);
extern void nrnmpi_int_allgatherv(int* s, int* r, int* n, int* dspl);
extern void nrnmpi_dbl_allgatherv(double* s, double* r, int* n, int* dspl);
#else
static void nrnmpi_barrier() {}
static void nrnmpi_int_alltoallv(int* s, int* scnt, int* sdispl, int* r, int* rcnt, int* rdispl) {
static void nrnmpi_int_alltoallv(const int* s,
const int* scnt,
const int* sdispl,
int* r,
int* rcnt,
int* rdispl) {
for (int i = 0; i < scnt[0]; ++i) {
r[i] = s[i];
}
}
static void nrnmpi_dbl_alltoallv(double* s,
int* scnt,
int* sdispl,
static void nrnmpi_dbl_alltoallv(const double* s,
const int* scnt,
const int* sdispl,
double* r,
int* rcnt,
int* rdispl) {
Expand Down
2 changes: 2 additions & 0 deletions src/nrnmpi/mkdynam.sh
Expand Up @@ -12,6 +12,8 @@ sed -n '
/extern [^v]/s/extern \([a-z*]*\) \(nrnmpi_[a-zA-Z0-9_]*\)\(.*\);/\1 \2\3 {@ return (*p_\2)\3;@}/p
' nrnmpidec.h | tr '@' '\n' | sed '
/p_nrnmpi/ {
s/, const [a-zA-Z0-9_:*&]* /, /g
s/)(const [a-zA-Z0-9_:*&]* /)(/
s/, [a-zA-Z0-9_:*&]* /, /g
s/)([a-zA-Z0-9_:*&]* /)(/
s/const& //g
Expand Down
13 changes: 9 additions & 4 deletions src/nrnmpi/mpispike.cpp
Expand Up @@ -404,7 +404,12 @@ extern void nrnmpi_int_alltoall(int* s, int* r, int n) {
MPI_Alltoall(s, n, MPI_INT, r, n, MPI_INT, nrnmpi_comm);
}

extern void nrnmpi_int_alltoallv(int* s, int* scnt, int* sdispl, int* r, int* rcnt, int* rdispl) {
extern void nrnmpi_int_alltoallv(const int* s,
const int* scnt,
const int* sdispl,
int* r,
int* rcnt,
int* rdispl) {
MPI_Alltoallv(s, scnt, sdispl, MPI_INT, r, rcnt, rdispl, MPI_INT, nrnmpi_comm);
}

Expand All @@ -417,9 +422,9 @@ extern void nrnmpi_long_alltoallv(int64_t* s,
MPI_Alltoallv(s, scnt, sdispl, MPI_INT64_T, r, rcnt, rdispl, MPI_INT64_T, nrnmpi_comm);
}

extern void nrnmpi_dbl_alltoallv(double* s,
int* scnt,
int* sdispl,
extern void nrnmpi_dbl_alltoallv(const double* s,
const int* scnt,
const int* sdispl,
double* r,
int* rcnt,
int* rdispl) {
Expand Down
4 changes: 2 additions & 2 deletions src/nrnmpi/nrnmpidec.h
Expand Up @@ -91,15 +91,15 @@ extern void nrnmpi_int_allgatherv_inplace(int* srcdest, int* n, int* dspl);
extern void nrnmpi_int_allgatherv(int* s, int* r, int* n, int* dspl);
extern void nrnmpi_char_allgatherv(char* s, char* r, int* n, int* dspl);
extern void nrnmpi_int_alltoall(int* s, int* r, int n);
extern void nrnmpi_int_alltoallv(int* s, int* scnt, int* sdispl, int* r, int* rcnt, int* rdispl);
extern void nrnmpi_int_alltoallv(const int* s, const int* scnt, const int* sdispl, int* r, int* rcnt, int* rdispl);
extern void nrnmpi_int_alltoallv_sparse(int* s, int* scnt, int* sdispl, int* r, int* rcnt, int* rdispl);
extern void nrnmpi_long_allgatherv(int64_t* s, int64_t* r, int* n, int* dspl);
extern void nrnmpi_long_allgatherv_inplace(long* srcdest, int* n, int* dspl);
extern void nrnmpi_long_alltoallv(int64_t* s, int* scnt, int* sdispl, int64_t* r, int* rcnt, int* rdispl);
extern void nrnmpi_long_alltoallv_sparse(int64_t* s, int* scnt, int* sdispl, int64_t* r, int* rcnt, int* rdispl);
extern void nrnmpi_dbl_allgatherv(double* s, double* r, int* n, int* dspl);
extern void nrnmpi_dbl_allgatherv_inplace(double* srcdest, int* n, int* dspl);
extern void nrnmpi_dbl_alltoallv(double* s, int* scnt, int* sdispl, double* r, int* rcnt, int* rdispl);
extern void nrnmpi_dbl_alltoallv(const double* s, const int* scnt, const int* sdispl, double* r, int* rcnt, int* rdispl);
extern void nrnmpi_dbl_alltoallv_sparse(double* s, int* scnt, int* sdispl, double* r, int* rcnt, int* rdispl);
extern void nrnmpi_char_alltoallv(char* s, int* scnt, int* sdispl, char* r, int* rcnt, int* rdispl);
extern void nrnmpi_dbl_broadcast(double* buf, int cnt, int root);
Expand Down
77 changes: 31 additions & 46 deletions src/parallel/ocbbs.cpp
Expand Up @@ -61,23 +61,18 @@ extern int nrncore_is_enabled();
extern int nrncore_is_file_mode();
extern int nrncore_psolve(double tstop, int file_mode);

class OcBBS: public BBS, public Resource {
class OcBBS final: public BBS, public Resource {
public:
OcBBS(int nhost_request);
virtual ~OcBBS();

public:
double retval_;
int userid_;
int next_local_;
double retval_ = 0.;
int userid_ = 0;
int next_local_ = 0;
};

OcBBS::OcBBS(int n)
: BBS(n) {
next_local_ = 0;
}

OcBBS::~OcBBS() {}
: BBS(n) {}

static bool posting_ = false;
static void pack_help(int, OcBBS*);
Expand Down Expand Up @@ -790,57 +785,47 @@ static double allgather(void*) {
return 0.;
}

// This function takes 3 arguments:
// - vsrc (In)
// - vscnt (In)
// - vdest (Out)
static double alltoall(void*) {
int i, ns, np = nrnmpi_numprocs;
Vect* vsrc = vector_arg(1);
Vect* vscnt = vector_arg(2);
ns = vector_capacity(vsrc);
double* s = vector_vec(vsrc);
if (vector_capacity(vscnt) != np) {
int np = nrnmpi_numprocs;
const Vect* vsrc = vector_arg(1);
const Vect* vscnt = vector_arg(2);
Vect* vdest = vector_arg(3);
std::size_t ns = vsrc->size();
if (vscnt->size() != np) {
hoc_execerror("size of source counts vector is not nhost", nullptr);
}
double* x = vector_vec(vscnt);
int* scnt = new int[np];
int* sdispl = new int[np + 1];
sdispl[0] = 0;
for (i = 0; i < np; ++i) {
scnt[i] = int(x[i]);
const std::vector<int> scnt(vscnt->begin(), vscnt->end()); // cast from double to int
std::vector<int> sdispl(np + 1);
for (int i = 0; i < np; ++i) {
sdispl[i + 1] = sdispl[i] + scnt[i];
}
if (ns != sdispl[np]) {
hoc_execerror("sum of source counts is not the size of the src vector", nullptr);
}
Vect* vdest = vector_arg(3);
if (nrnmpi_numprocs > 1) {
#if NRNMPI
int* rcnt = new int[np];
int* rdispl = new int[np + 1];
int* c = new int[np];
rdispl[0] = 0;
for (i = 0; i < np; ++i) {
c[i] = 1;
rdispl[i + 1] = i + 1;
}
nrnmpi_int_alltoallv(scnt, c, rdispl, rcnt, c, rdispl);
delete[] c;
for (i = 0; i < np; ++i) {
std::vector<int> rcnt(np);
std::vector<int> c(np, 1);
std::vector<int> rdispl(np + 1);
std::iota(rdispl.begin(), rdispl.end(), 0);

nrnmpi_int_alltoallv(
scnt.data(), c.data(), rdispl.data(), rcnt.data(), c.data(), rdispl.data());
for (int i = 0; i < np; ++i) {
rdispl[i + 1] = rdispl[i] + rcnt[i];
}
vector_resize(vdest, rdispl[np]);
double* r = vector_vec(vdest);
nrnmpi_dbl_alltoallv(s, scnt, sdispl, r, rcnt, rdispl);
delete[] rcnt;
delete[] rdispl;
vdest->resize(rdispl[np]);
nrnmpi_dbl_alltoallv(
vsrc->data(), scnt.data(), sdispl.data(), vdest->data(), rcnt.data(), rdispl.data());
#endif
} else {
vector_resize(vdest, ns);
double* r = vector_vec(vdest);
for (i = 0; i < ns; ++i) {
r[i] = s[i];
}
vdest->resize(ns);
std::copy(vsrc->begin(), vsrc->end(), vdest->begin());
}
delete[] scnt;
delete[] sdispl;
return 0.;
}

Expand Down

0 comments on commit 4a1e15a

Please sign in to comment.