Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Class to enable gather-scatter type operator with gslib #4213

Merged
merged 10 commits into from
May 17, 2024
72 changes: 72 additions & 0 deletions fem/gslib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1352,6 +1352,78 @@ void OversetFindPointsGSLIB::Interpolate(const Vector &point_pos,
Interpolate(field_in, field_out);
}

GSOPGSLIB::GSOPGSLIB(Array<long long> &ids)
{
gsl_comm = new gslib::comm;
cr = new gslib::crystal;
#ifdef MFEM_USE_MPI
int initialized;
MPI_Initialized(&initialized);
if (!initialized) { MPI_Init(NULL, NULL); }
MPI_Comm comm = MPI_COMM_WORLD;
comm_init(gsl_comm, comm);
#else
comm_init(gsl_comm, 0);
#endif
crystal_init(cr, gsl_comm);
UpdateIdentifiers(ids);
}

#ifdef MFEM_USE_MPI
GSOPGSLIB::GSOPGSLIB(MPI_Comm comm_, Array<long long> &ids)
: cr(NULL), gsl_comm(NULL)
{
gsl_comm = new gslib::comm;
cr = new gslib::crystal;
comm_init(gsl_comm, comm_);
crystal_init(cr, gsl_comm);
UpdateIdentifiers(ids);
}
#endif

GSOPGSLIB::~GSOPGSLIB()
{
crystal_free(cr);
gslib_gs_free(gsl_data);
comm_free(gsl_comm);
delete gsl_comm;
delete cr;
}

void GSOPGSLIB::UpdateIdentifiers(Array<long long> &ids)
{
if (gsl_data != NULL) { gslib_gs_free(gsl_data); }
num_ids = ids.Size();
gsl_data = gslib_gs_setup(ids.GetData(),
ids.Size(),
gsl_comm, 0,
gslib::gs_crystal_router, 0);
}

void GSOPGSLIB::GS(Vector &senddata, GSOp op)
{
MFEM_VERIFY(senddata.Size() == num_ids,"Incompatible setup and GOP operation.");
if (op == GSOp::ADD)
{
gslib_gs(senddata.GetData(), gslib::gs_double, gslib::gs_add, 0, gsl_data, 0);
}
else if (op == GSOp::MUL)
{
gslib_gs(senddata.GetData(), gslib::gs_double, gslib::gs_mul, 0, gsl_data, 0);
vladotomov marked this conversation as resolved.
Show resolved Hide resolved
}
else if (op == GSOp::MAX)
{
gslib_gs(senddata.GetData(), gslib::gs_double, gslib::gs_max, 0, gsl_data, 0);
}
else if (op == GSOp::MIN)
{
gslib_gs(senddata.GetData(), gslib::gs_double, gslib::gs_min, 0, gsl_data, 0);
}
else
{
MFEM_ABORT("Invalid GSOp operation.");
}
}

} // namespace mfem

Expand Down
59 changes: 58 additions & 1 deletion fem/gslib.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@ struct comm;
struct findpts_data_2;
struct findpts_data_3;
struct crystal;
struct gs_data;
}

namespace mfem
{

/** \brief FindPointsGSLIB can robustly evaluate a GridFunction on an arbitrary
* collection of points. There are three key functions in FindPointsGSLIB:
* collection of points.
*
* There are three key functions in FindPointsGSLIB:
*
* 1. Setup - constructs the internal data structures of gslib.
*
Expand Down Expand Up @@ -290,6 +293,60 @@ class OversetFindPointsGSLIB : public FindPointsGSLIB
using FindPointsGSLIB::Interpolate;
};

/** \brief Class for gather-scatter (gs) operations on Vectors based on
corresponding global identifiers.
This functionality is useful for gs-ops on
DOF values across processor boundary, where the global identifier would be
the corresponding true DOF index. Operations currently supported are
min, max, sum, and multiplication. Note: identifier 0 does not participate
in the gather-scatter operation.
For example, consider a vector, v:
- v = [0.3, 0.4, 0.25] on rank1,
- v = [0.6, 0.1] on rank 2,
- v = [-0.2, 0.3, 0.7, 0.] on rank 3.
Consider a corresponding Array<int>, a:
- a = [1, 2, 3] on rank 1,
- a = [3, 2] on rank 2,
- a = [1, 2, 0, 3] on rank 3.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe mention that a can contain repeated indices on a given rank.

A gather-scatter "minimum" operation, done as follows:
GSOPGSLIB gs = GSOPGSLIB(MPI_COMM_WORLD, a);
gs.GS(v, GSOp::MIN);
would return into v:
- v = [-0.2, 0.1, 0.] on rank 1,
- v = [0., 0.1] on rank 2,
- v = [-0.2, 0.1, 0.7, 0] on rank 3,
where the values have been compared across all processors based on the
integer identifier. */
class GSOPGSLIB
vladotomov marked this conversation as resolved.
Show resolved Hide resolved
{
protected:
struct gslib::crystal *cr; // gslib's internal data
struct gslib::comm *gsl_comm; // gslib's internal data
struct gslib::gs_data *gsl_data = NULL;
int num_ids;

public:
GSOPGSLIB(Array<long long> &ids);
vladotomov marked this conversation as resolved.
Show resolved Hide resolved

#ifdef MFEM_USE_MPI
GSOPGSLIB(MPI_Comm comm_, Array<long long> &ids);
#endif

virtual ~GSOPGSLIB();

/// Supported operation types. See class description.
enum GSOp {ADD, MUL, MIN, MAX};

/// Update the identifiers used for the gather-scatter operator.
/// Same @a ids get grouped together and id == 0 does not participate.
/// See class description.
virtual void UpdateIdentifiers(Array<long long> &ids);
vladotomov marked this conversation as resolved.
Show resolved Hide resolved
vladotomov marked this conversation as resolved.
Show resolved Hide resolved

/// Gather-Scatter operation on senddata. Must match length of unique
/// identifiers used in the constructor. See class description.
virtual void GS(Vector &senddata, GSOp op);
};

} // namespace mfem

#endif // MFEM_USE_GSLIB
Expand Down
66 changes: 66 additions & 0 deletions tests/unit/fem/test_gslib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,5 +368,71 @@ TEST_CASE("GSLIBInterpolateL2ElementBoundary",
delete c_fec;
}

#ifdef MFEM_USE_MPI
TEST_CASE("GSLIBGSOP", "[GSLIBGSOP][Parallel][GSLIB]")
{
int myid;
MPI_Comm_rank(MPI_COMM_WORLD, &myid);

int nlen = 10;
Array<long long> ids(nlen);
Vector vals(nlen);
vladotomov marked this conversation as resolved.
Show resolved Hide resolved
vals.Randomize(myid+1);

// Force minimum values based on the identifier for deterministic behavior
// on rank 0 and randomize the identifier on other ranks.
if (myid == 0)
{
for (int i = 0; i < nlen; i++)
{
ids[i] = i+1;
vals(i) = -ids[i];
}
}
else
{
for (int i = 0; i < nlen; i++)
{
int num = rand() % nlen + 1;
ids[i] = num;
}
}

// Test GSOp::MIN
GSOPGSLIB gs = GSOPGSLIB(MPI_COMM_WORLD, ids);
gs.GS(vals, GSOPGSLIB::GSOp::MIN);

// Check for minimum value
for (int i = 0; i < nlen; i++)
{
int id = ids[i];
REQUIRE(vals(i) == -1.0*id);
}

// Test GSOp::ADD
// Set all values to 0 except on rank 0, and then add them.
if (myid != 0) { vals = 0.0; }
gs.GS(vals, GSOPGSLIB::GSOp::ADD);

// Check for added value to match what was originally set on rank 0.
for (int i = 0; i < nlen; i++)
{
int id = ids[i];
REQUIRE(vals(i) == -1.0*id);
}

// Test GSOp::MUL
// Randomize values on all ranks except rank 0 such that they are positive.
if (myid != 0) { vals.Randomize(); }
gs.GS(vals, GSOPGSLIB::GSOp::MUL);

// Check for multipled values to be negative
for (int i = 0; i < nlen; i++)
{
REQUIRE(vals(i) < 0);
}
}
#endif

} //namespace_gslib
#endif