Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Class to enable gather-scatter type operator with gslib #4213

Merged
merged 10 commits into from
May 17, 2024
2 changes: 2 additions & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ Miscellaneous

- Various other simplifications, extensions, and bugfixes in the code.

- Added GSLIB-based gather-scatter operator.


Version 4.6, released on September 27, 2023
===========================================
Expand Down
77 changes: 77 additions & 0 deletions fem/gslib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1352,6 +1352,83 @@ void OversetFindPointsGSLIB::Interpolate(const Vector &point_pos,
Interpolate(field_in, field_out);
}

GSOPGSLIB::GSOPGSLIB(Array<long long> &ids)
{
gsl_comm = new gslib::comm;
cr = new gslib::crystal;
#ifdef MFEM_USE_MPI
int initialized;
MPI_Initialized(&initialized);
if (!initialized) { MPI_Init(NULL, NULL); }
MPI_Comm comm = MPI_COMM_WORLD;
comm_init(gsl_comm, comm);
#else
comm_init(gsl_comm, 0);
#endif
crystal_init(cr, gsl_comm);
UpdateIdentifiers(ids);
}

#ifdef MFEM_USE_MPI
GSOPGSLIB::GSOPGSLIB(MPI_Comm comm_, Array<long long> &ids)
: cr(NULL), gsl_comm(NULL)
{
gsl_comm = new gslib::comm;
cr = new gslib::crystal;
comm_init(gsl_comm, comm_);
crystal_init(cr, gsl_comm);
UpdateIdentifiers(ids);
}
#endif

GSOPGSLIB::~GSOPGSLIB()
{
crystal_free(cr);
gslib_gs_free(gsl_data);
comm_free(gsl_comm);
delete gsl_comm;
delete cr;
}

void GSOPGSLIB::UpdateIdentifiers(const Array<long long> &ids)
{
long long minval = ids.Min();
MPI_Allreduce(MPI_IN_PLACE, &minval, 1, MPI_LONG_LONG_INT,
MPI_MIN, gsl_comm->c);
MFEM_VERIFY(minval >= 0, "Unique identifier cannot be negative.");
if (gsl_data != NULL) { gslib_gs_free(gsl_data); }
num_ids = ids.Size();
gsl_data = gslib_gs_setup(ids.GetData(),
ids.Size(),
gsl_comm, 0,
gslib::gs_crystal_router, 0);
}

void GSOPGSLIB::GS(Vector &senddata, GSOp op)
{
MFEM_VERIFY(senddata.Size() == num_ids,
"Incompatible setup and GOP operation.");
if (op == GSOp::ADD)
{
gslib_gs(senddata.GetData(),gslib::gs_double,gslib::gs_add,0,gsl_data,0);
}
else if (op == GSOp::MUL)
{
gslib_gs(senddata.GetData(),gslib::gs_double,gslib::gs_mul,0,gsl_data,0);
}
else if (op == GSOp::MAX)
{
gslib_gs(senddata.GetData(),gslib::gs_double,gslib::gs_max,0,gsl_data,0);
}
else if (op == GSOp::MIN)
{
gslib_gs(senddata.GetData(),gslib::gs_double,gslib::gs_min,0,gsl_data,0);
}
else
{
MFEM_ABORT("Invalid GSOp operation.");
}
}

} // namespace mfem

Expand Down
63 changes: 62 additions & 1 deletion fem/gslib.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@ struct comm;
struct findpts_data_2;
struct findpts_data_3;
struct crystal;
struct gs_data;
}

namespace mfem
{

/** \brief FindPointsGSLIB can robustly evaluate a GridFunction on an arbitrary
* collection of points. There are three key functions in FindPointsGSLIB:
* collection of points.
*
* There are three key functions in FindPointsGSLIB:
*
* 1. Setup - constructs the internal data structures of gslib.
*
Expand Down Expand Up @@ -226,6 +229,7 @@ class FindPointsGSLIB

/** \brief OversetFindPointsGSLIB enables use of findpts for arbitrary number of
overlapping grids.

The parameters in this class are the same as FindPointsGSLIB with the
difference of additional inputs required to account for more than 1 mesh. */
class OversetFindPointsGSLIB : public FindPointsGSLIB
Expand Down Expand Up @@ -290,6 +294,63 @@ class OversetFindPointsGSLIB : public FindPointsGSLIB
using FindPointsGSLIB::Interpolate;
};

/** \brief Class for gather-scatter (gs) operations on Vectors based on
corresponding global identifiers.

This functionality is useful for gs-ops on DOF values across processor
boundary, where the global identifier would be the corresponding true DOF
index. Operations currently supported are min, max, sum, and multiplication.
Note: identifier 0 does not participate in the gather-scatter operation and
a given identifier can be included multiple times on a given rank.
For example, consider a vector, v:
- v = [0.3, 0.4, 0.25, 0.7] on rank1,
- v = [0.6, 0.1] on rank 2,
- v = [-0.2, 0.3, 0.7, 0.] on rank 3.

Consider a corresponding Array<int>, a:
- a = [1, 2, 3, 1] on rank 1,
- a = [3, 2] on rank 2,
- a = [1, 2, 0, 3] on rank 3.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe mention that a can contain repeated indices on a given rank.


A gather-scatter "minimum" operation, done as follows:
GSOPGSLIB gs = GSOPGSLIB(MPI_COMM_WORLD, a);
gs.GS(v, GSOp::MIN);
would return into v:
- v = [-0.2, 0.1, 0., -0.2] on rank 1,
- v = [0., 0.1] on rank 2,
- v = [-0.2, 0.1, 0.7, 0.] on rank 3,
where the values have been compared across all processors based on the
integer identifier. */
class GSOPGSLIB
vladotomov marked this conversation as resolved.
Show resolved Hide resolved
{
protected:
struct gslib::crystal *cr; // gslib's internal data
struct gslib::comm *gsl_comm; // gslib's internal data
struct gslib::gs_data *gsl_data = NULL;
int num_ids;

public:
GSOPGSLIB(Array<long long> &ids);
vladotomov marked this conversation as resolved.
Show resolved Hide resolved

#ifdef MFEM_USE_MPI
GSOPGSLIB(MPI_Comm comm_, Array<long long> &ids);
#endif

virtual ~GSOPGSLIB();

/// Supported operation types. See class description.
enum GSOp {ADD, MUL, MIN, MAX};

/// Update the identifiers used for the gather-scatter operator.
/// Same @a ids get grouped together and id == 0 does not participate.
/// See class description.
void UpdateIdentifiers(const Array<long long> &ids);

/// Gather-Scatter operation on senddata. Must match length of unique
/// identifiers used in the constructor. See class description.
void GS(Vector &senddata, GSOp op);
};

} // namespace mfem

#endif // MFEM_USE_GSLIB
Expand Down
68 changes: 68 additions & 0 deletions tests/unit/fem/test_gslib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,5 +368,73 @@ TEST_CASE("GSLIBInterpolateL2ElementBoundary",
delete c_fec;
}

#ifdef MFEM_USE_MPI
TEST_CASE("GSLIBGSOP", "[GSLIBGSOP][Parallel][GSLIB]")
{
int myid;
MPI_Comm_rank(MPI_COMM_WORLD, &myid);

int nlen = 5 + rand() % 1000;
MPI_Allreduce(MPI_IN_PLACE, &nlen, 1, MPI_INT, MPI_MAX, MPI_COMM_WORLD);

Array<long long> ids(nlen);
Vector vals(nlen);
vladotomov marked this conversation as resolved.
Show resolved Hide resolved
vals.Randomize(myid+1);

// Force minimum values based on the identifier for deterministic behavior
// on rank 0 and randomize the identifier on other ranks.
if (myid == 0)
{
for (int i = 0; i < nlen; i++)
{
ids[i] = i+1;
vals(i) = -ids[i];
}
}
else
{
for (int i = 0; i < nlen; i++)
{
int num = rand() % nlen + 1;
ids[i] = num;
}
}

// Test GSOp::MIN
GSOPGSLIB gs = GSOPGSLIB(MPI_COMM_WORLD, ids);
gs.GS(vals, GSOPGSLIB::GSOp::MIN);

// Check for minimum value
for (int i = 0; i < nlen; i++)
{
int id = ids[i];
REQUIRE(vals(i) == -1.0*id);
}

// Test GSOp::ADD
// Set all values to 0 except on rank 0, and then add them.
if (myid != 0) { vals = 0.0; }
gs.GS(vals, GSOPGSLIB::GSOp::ADD);

// Check for added value to match what was originally set on rank 0.
for (int i = 0; i < nlen; i++)
{
int id = ids[i];
REQUIRE(vals(i) == -1.0*id);
}

// Test GSOp::MUL
// Randomize values on all ranks except rank 0 such that they are positive.
if (myid != 0) { vals.Randomize(); }
gs.GS(vals, GSOPGSLIB::GSOp::MUL);

// Check for multipled values to be negative
for (int i = 0; i < nlen; i++)
{
REQUIRE(vals(i) < 0);
}
}
#endif

} //namespace_gslib
#endif