Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite BBS::alltoall with std::vector #2829

Merged
merged 8 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/ivoc/ivocvect.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
return vec_;
}

inline double const* data() const {
return vec_.data();

Check warning on line 35 in src/ivoc/ivocvect.h

View check run for this annotation

Codecov / codecov/patch

src/ivoc/ivocvect.h#L34-L35

Added lines #L34 - L35 were not covered by tests
}

inline double* data() {
return vec_.data();
}
Expand All @@ -57,6 +61,14 @@
return vec_.at(index);
}

inline auto begin() const -> std::vector<double>::const_iterator {
return vec_.begin();

Check warning on line 65 in src/ivoc/ivocvect.h

View check run for this annotation

Codecov / codecov/patch

src/ivoc/ivocvect.h#L64-L65

Added lines #L64 - L65 were not covered by tests
}

inline auto end() const -> std::vector<double>::const_iterator {
return vec_.end();

Check warning on line 69 in src/ivoc/ivocvect.h

View check run for this annotation

Codecov / codecov/patch

src/ivoc/ivocvect.h#L68-L69

Added lines #L68 - L69 were not covered by tests
}

inline auto begin() -> std::vector<double>::iterator {
return vec_.begin();
}
Expand Down
17 changes: 11 additions & 6 deletions src/nrniv/bbsavestate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,22 +219,27 @@ extern int nrn_gid_exists(int gid);

#if NRNMPI
extern void nrnmpi_barrier();
extern void nrnmpi_int_alltoallv(int*, int*, int*, int*, int*, int*);
extern void nrnmpi_dbl_alltoallv(double*, int*, int*, double*, int*, int*);
extern void nrnmpi_int_alltoallv(const int*, const int*, const int*, int*, int*, int*);
extern void nrnmpi_dbl_alltoallv(const double*, const int*, const int*, double*, int*, int*);
extern int nrnmpi_int_allmax(int);
extern void nrnmpi_int_allgather(int* s, int* r, int n);
extern void nrnmpi_int_allgatherv(int* s, int* r, int* n, int* dspl);
extern void nrnmpi_dbl_allgatherv(double* s, double* r, int* n, int* dspl);
#else
static void nrnmpi_barrier() {}
static void nrnmpi_int_alltoallv(int* s, int* scnt, int* sdispl, int* r, int* rcnt, int* rdispl) {
static void nrnmpi_int_alltoallv(const int* s,
const int* scnt,
const int* sdispl,
int* r,
int* rcnt,
int* rdispl) {
for (int i = 0; i < scnt[0]; ++i) {
r[i] = s[i];
}
}
static void nrnmpi_dbl_alltoallv(double* s,
int* scnt,
int* sdispl,
static void nrnmpi_dbl_alltoallv(const double* s,
const int* scnt,
const int* sdispl,
double* r,
int* rcnt,
int* rdispl) {
Expand Down
2 changes: 2 additions & 0 deletions src/nrnmpi/mkdynam.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ sed -n '
/extern [^v]/s/extern \([a-z*]*\) \(nrnmpi_[a-zA-Z0-9_]*\)\(.*\);/\1 \2\3 {@ return (*p_\2)\3;@}/p
' nrnmpidec.h | tr '@' '\n' | sed '
/p_nrnmpi/ {
s/, const [a-zA-Z0-9_:*&]* /, /g
s/)(const [a-zA-Z0-9_:*&]* /)(/
s/, [a-zA-Z0-9_:*&]* /, /g
s/)([a-zA-Z0-9_:*&]* /)(/
s/const& //g
Expand Down
13 changes: 9 additions & 4 deletions src/nrnmpi/mpispike.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,12 @@ extern void nrnmpi_int_alltoall(int* s, int* r, int n) {
MPI_Alltoall(s, n, MPI_INT, r, n, MPI_INT, nrnmpi_comm);
}

extern void nrnmpi_int_alltoallv(int* s, int* scnt, int* sdispl, int* r, int* rcnt, int* rdispl) {
extern void nrnmpi_int_alltoallv(const int* s,
const int* scnt,
const int* sdispl,
int* r,
int* rcnt,
int* rdispl) {
MPI_Alltoallv(s, scnt, sdispl, MPI_INT, r, rcnt, rdispl, MPI_INT, nrnmpi_comm);
}

Expand All @@ -417,9 +422,9 @@ extern void nrnmpi_long_alltoallv(int64_t* s,
MPI_Alltoallv(s, scnt, sdispl, MPI_INT64_T, r, rcnt, rdispl, MPI_INT64_T, nrnmpi_comm);
}

extern void nrnmpi_dbl_alltoallv(double* s,
int* scnt,
int* sdispl,
extern void nrnmpi_dbl_alltoallv(const double* s,
const int* scnt,
const int* sdispl,
double* r,
int* rcnt,
int* rdispl) {
Expand Down
4 changes: 2 additions & 2 deletions src/nrnmpi/nrnmpidec.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,15 @@ extern void nrnmpi_int_allgatherv_inplace(int* srcdest, int* n, int* dspl);
extern void nrnmpi_int_allgatherv(int* s, int* r, int* n, int* dspl);
extern void nrnmpi_char_allgatherv(char* s, char* r, int* n, int* dspl);
extern void nrnmpi_int_alltoall(int* s, int* r, int n);
extern void nrnmpi_int_alltoallv(int* s, int* scnt, int* sdispl, int* r, int* rcnt, int* rdispl);
extern void nrnmpi_int_alltoallv(const int* s, const int* scnt, const int* sdispl, int* r, int* rcnt, int* rdispl);
extern void nrnmpi_int_alltoallv_sparse(int* s, int* scnt, int* sdispl, int* r, int* rcnt, int* rdispl);
extern void nrnmpi_long_allgatherv(int64_t* s, int64_t* r, int* n, int* dspl);
extern void nrnmpi_long_allgatherv_inplace(long* srcdest, int* n, int* dspl);
extern void nrnmpi_long_alltoallv(int64_t* s, int* scnt, int* sdispl, int64_t* r, int* rcnt, int* rdispl);
extern void nrnmpi_long_alltoallv_sparse(int64_t* s, int* scnt, int* sdispl, int64_t* r, int* rcnt, int* rdispl);
extern void nrnmpi_dbl_allgatherv(double* s, double* r, int* n, int* dspl);
extern void nrnmpi_dbl_allgatherv_inplace(double* srcdest, int* n, int* dspl);
extern void nrnmpi_dbl_alltoallv(double* s, int* scnt, int* sdispl, double* r, int* rcnt, int* rdispl);
extern void nrnmpi_dbl_alltoallv(const double* s, const int* scnt, const int* sdispl, double* r, int* rcnt, int* rdispl);
extern void nrnmpi_dbl_alltoallv_sparse(double* s, int* scnt, int* sdispl, double* r, int* rcnt, int* rdispl);
extern void nrnmpi_char_alltoallv(char* s, int* scnt, int* sdispl, char* r, int* rcnt, int* rdispl);
extern void nrnmpi_dbl_broadcast(double* buf, int cnt, int root);
Expand Down
77 changes: 31 additions & 46 deletions src/parallel/ocbbs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,23 +61,18 @@
extern int nrncore_is_file_mode();
extern int nrncore_psolve(double tstop, int file_mode);

class OcBBS: public BBS, public Resource {
class OcBBS final: public BBS, public Resource {
public:
OcBBS(int nhost_request);
virtual ~OcBBS();

public:
double retval_;
int userid_;
int next_local_;
double retval_ = 0.;
int userid_ = 0;
int next_local_ = 0;
};

OcBBS::OcBBS(int n)
: BBS(n) {
next_local_ = 0;
}

OcBBS::~OcBBS() {}
: BBS(n) {}

static bool posting_ = false;
static void pack_help(int, OcBBS*);
Expand Down Expand Up @@ -790,57 +785,47 @@
return 0.;
}

// This function takes 3 arguments:
// - vsrc (In)
// - vscnt (In)
// - vdest (Out)
static double alltoall(void*) {
int i, ns, np = nrnmpi_numprocs;
Vect* vsrc = vector_arg(1);
Vect* vscnt = vector_arg(2);
ns = vector_capacity(vsrc);
double* s = vector_vec(vsrc);
if (vector_capacity(vscnt) != np) {
int np = nrnmpi_numprocs;
const Vect* vsrc = vector_arg(1);
const Vect* vscnt = vector_arg(2);
Vect* vdest = vector_arg(3);
std::size_t ns = vsrc->size();
if (vscnt->size() != np) {

Check warning on line 798 in src/parallel/ocbbs.cpp

View check run for this annotation

Codecov / codecov/patch

src/parallel/ocbbs.cpp#L793-L798

Added lines #L793 - L798 were not covered by tests
hoc_execerror("size of source counts vector is not nhost", nullptr);
}
double* x = vector_vec(vscnt);
int* scnt = new int[np];
int* sdispl = new int[np + 1];
sdispl[0] = 0;
for (i = 0; i < np; ++i) {
scnt[i] = int(x[i]);
const std::vector<int> scnt(vscnt->begin(), vscnt->end()); // cast from double to int
std::vector<int> sdispl(np + 1);
for (int i = 0; i < np; ++i) {

Check warning on line 803 in src/parallel/ocbbs.cpp

View check run for this annotation

Codecov / codecov/patch

src/parallel/ocbbs.cpp#L801-L803

Added lines #L801 - L803 were not covered by tests
sdispl[i + 1] = sdispl[i] + scnt[i];
}
if (ns != sdispl[np]) {
hoc_execerror("sum of source counts is not the size of the src vector", nullptr);
}
Vect* vdest = vector_arg(3);
if (nrnmpi_numprocs > 1) {
#if NRNMPI
int* rcnt = new int[np];
int* rdispl = new int[np + 1];
int* c = new int[np];
rdispl[0] = 0;
for (i = 0; i < np; ++i) {
c[i] = 1;
rdispl[i + 1] = i + 1;
}
nrnmpi_int_alltoallv(scnt, c, rdispl, rcnt, c, rdispl);
delete[] c;
for (i = 0; i < np; ++i) {
std::vector<int> rcnt(np);
std::vector<int> c(np, 1);
std::vector<int> rdispl(np + 1);
std::iota(rdispl.begin(), rdispl.end(), 0);

Check warning on line 814 in src/parallel/ocbbs.cpp

View check run for this annotation

Codecov / codecov/patch

src/parallel/ocbbs.cpp#L811-L814

Added lines #L811 - L814 were not covered by tests

nrnmpi_int_alltoallv(
scnt.data(), c.data(), rdispl.data(), rcnt.data(), c.data(), rdispl.data());
for (int i = 0; i < np; ++i) {

Check warning on line 818 in src/parallel/ocbbs.cpp

View check run for this annotation

Codecov / codecov/patch

src/parallel/ocbbs.cpp#L816-L818

Added lines #L816 - L818 were not covered by tests
rdispl[i + 1] = rdispl[i] + rcnt[i];
}
vector_resize(vdest, rdispl[np]);
double* r = vector_vec(vdest);
nrnmpi_dbl_alltoallv(s, scnt, sdispl, r, rcnt, rdispl);
delete[] rcnt;
delete[] rdispl;
vdest->resize(rdispl[np]);
nrnmpi_dbl_alltoallv(
vsrc->data(), scnt.data(), sdispl.data(), vdest->data(), rcnt.data(), rdispl.data());

Check warning on line 823 in src/parallel/ocbbs.cpp

View check run for this annotation

Codecov / codecov/patch

src/parallel/ocbbs.cpp#L821-L823

Added lines #L821 - L823 were not covered by tests
#endif
} else {
vector_resize(vdest, ns);
double* r = vector_vec(vdest);
for (i = 0; i < ns; ++i) {
r[i] = s[i];
}
vdest->resize(ns);
std::copy(vsrc->begin(), vsrc->end(), vdest->begin());

Check warning on line 827 in src/parallel/ocbbs.cpp

View check run for this annotation

Codecov / codecov/patch

src/parallel/ocbbs.cpp#L826-L827

Added lines #L826 - L827 were not covered by tests
}
delete[] scnt;
delete[] sdispl;
return 0.;
}

Expand Down