Skip to content

Commit

Permalink
Merge pull request #4016 from mfem/periodic-mesh-kdtree
Browse files Browse the repository at this point in the history
Faster periodic meshes with k-d trees
  • Loading branch information
v-dobrev committed Dec 27, 2023
2 parents 79b6098 + 7c9abd5 commit 148047f
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 100 deletions.
120 changes: 69 additions & 51 deletions general/kdtree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace KDTreeNorms
template <typename Tfloat, int ndim>
struct Norm_l1
{
Tfloat operator() (const Tfloat* xx)
Tfloat operator()(const Tfloat* xx) const
{
Tfloat tm=abs(xx[0]);
for (int i=1; i<ndim; i++)
Expand All @@ -45,7 +45,7 @@ struct Norm_l1
template<typename Tfloat,int ndim>
struct Norm_l2
{
Tfloat operator() (const Tfloat* xx)
Tfloat operator()(const Tfloat* xx) const
{
Tfloat tm;
tm=xx[0]*xx[0];
Expand All @@ -61,7 +61,7 @@ struct Norm_l2
template<typename Tfloat,int ndim>
struct Norm_li
{
Tfloat operator() (const Tfloat* xx)
Tfloat operator()(const Tfloat* xx) const
{
Tfloat tm;
if (xx[0]<Tfloat(0.0)) { tm=-xx[0];}
Expand All @@ -83,6 +83,23 @@ struct Norm_li

}

/// @brief Abstract base class for KDTree. Can be used when the dimension of the
/// space is known dynamically.
template <typename Tindex, typename Tfloat>
class KDTreeBase
{
public:
/// Adds a point to the tree. See KDTree::AddPoint().
virtual void AddPoint(const Tfloat *xx, Tindex ii) = 0;
/// @brief Sorts the tree. Should be performed after adding points and before
/// performing queries. See KDTree::Sort().
virtual void Sort() = 0;
/// Returns the index of the closest point to @a xx.
virtual Tindex FindClosestPoint(const Tfloat *xx) const = 0;
/// Virtual destructor.
virtual ~KDTreeBase() { }
};

/// Template class for build KDTree with template parameters Tindex
/// specifying the type utilized for indexing the points, Tfloat
/// specifying a float type for representing the coordinates of the
Expand All @@ -96,19 +113,20 @@ struct Norm_li
/// computed with n or 1 rank(s).
template <typename Tindex, typename Tfloat, size_t ndim=3,
typename Tnorm=KDTreeNorms::Norm_l2<Tfloat,ndim> >
class KDTree
class KDTree : public KDTreeBase<Tindex, Tfloat>
{
public:

/// Structure defining a geometric point in the ndim-dimensional
/// space. The coordinate type (Tfloat) can be any floating or
/// integer type. It can be even a character if necessary. For
/// such types users should redefine the norms.
/// Structure defining a geometric point in the ndim-dimensional space. The
/// coordinate type (Tfloat) can be any floating or integer type. It can be
/// even a character if necessary. For such types users should redefine the
/// norms.
struct PointND
{
/// Geometric point constructor
PointND() { std::fill(xx,xx+ndim,Tfloat(0.0));}

/// Default constructor: fill with zeros
PointND() { std::fill(xx,xx+ndim,Tfloat(0.0)); }
/// Copy coordinates from pointer/array @a xx_
PointND(const Tfloat *xx_) { std::copy(xx_,xx_+ndim,xx); }
/// Coordinates of the point
Tfloat xx[ndim];
};
Expand All @@ -118,16 +136,19 @@ class KDTree
{
/// Defines a point in the ndim-dimensional space
PointND pt;

/// Defines the attached index
Tindex ind;
Tindex ind = 0;
/// Default constructor: fill with zeros
NodeND() = default;
/// Create from given point and index
NodeND(PointND pt_, Tindex ind_ = 0) : pt(pt_), ind(ind_) { }
};

/// Default constructor
KDTree() = default;

/// Returns the spatial dimension of the points
int SpaceDimension()
int SpaceDimension() const
{
return ndim;
}
Expand All @@ -148,7 +169,7 @@ class KDTree
}

/// Returns the size of the point cloud
size_t size()
size_t size() const
{
return data.size();
}
Expand All @@ -161,34 +182,25 @@ class KDTree

/// Builds the KDTree. If the point cloud is modified the tree
/// needs to be rebuild by a new call to Sort().
void Sort()
void Sort() override
{
SortInPlace(data.begin(),data.end(),0);
}

/// Adds a new node to the point cloud
void AddPoint(PointND& pt, Tindex ii)
void AddPoint(const PointND &pt, Tindex ii)
{
NodeND nd;
nd.pt=pt;
nd.ind=ii;
data.push_back(nd);
data.emplace_back(pt, ii);
}

/// Adds a new node by coordinates and an associated index
void AddPoint(Tfloat* xx,Tindex ii)
void AddPoint(const Tfloat *xx,Tindex ii) override
{
NodeND nd;
for (size_t i=0; i<ndim; i++)
{
nd.pt.xx[i]=xx[i];
}
nd.ind=ii;
data.push_back(nd);
data.emplace_back(xx, ii);
}

/// Finds the nearest neighbour index
Tindex FindClosestPoint(PointND& pt)
Tindex FindClosestPoint(const PointND &pt) const
{
PointS best_candidate;
best_candidate.sp=pt;
Expand All @@ -200,8 +212,13 @@ class KDTree
return data[best_candidate.pos].ind;
}

/// Finds the nearest neighbour index and return the clossest poitn in clp
Tindex FindClosestPoint(PointND& pt, PointND& clp)
Tindex FindClosestPoint(const Tfloat *xx) const override
{
return FindClosestPoint(PointND(xx));
}

/// Finds the nearest neighbour index and return the clossest point in clp
Tindex FindClosestPoint(const PointND &pt, const PointND &clp) const
{
PointS best_candidate;
best_candidate.sp=pt;
Expand All @@ -216,15 +233,15 @@ class KDTree
}

/// Returns the closest point and the distance to the input point pt.
void FindClosestPoint(PointND& pt, Tindex& ind, Tfloat& dist)
void FindClosestPoint(const PointND &pt, Tindex &ind, Tfloat &dist) const
{
PointND clp;
FindClosestPoint(pt,ind,dist,clp);

}

/// Returns the closest point and the distance to the input point pt.
void FindClosestPoint(PointND& pt, Tindex& ind, Tfloat& dist, PointND& clp)
void FindClosestPoint(const PointND &pt, Tindex &ind, Tfloat &dist,
PointND &clp) const
{
PointS best_candidate;
best_candidate.sp=pt;
Expand All @@ -241,7 +258,7 @@ class KDTree


/// Brute force search - please, use it only for debuging purposes
void FindClosestPointSlow(PointND& pt, Tindex& ind, Tfloat& dist)
void FindClosestPointSlow(const PointND &pt, Tindex &ind, Tfloat &dist) const
{
PointS best_candidate;
best_candidate.sp=pt;
Expand All @@ -265,22 +282,23 @@ class KDTree

/// Finds all points within a distance R from point pt. The indices are
/// returned in the vector res and the correponding distances in vector dist.
void FindNeighborPoints(PointND& pt,Tfloat R, std::vector<Tindex> & res,
void FindNeighborPoints(const PointND &pt,Tfloat R, std::vector<Tindex> & res,
std::vector<Tfloat> & dist)
{
FindNeighborPoints(pt,R,data.begin(),data.end(),0,res,dist);
}

/// Finds all points within a distance R from point pt. The indices are
/// returned in the vector res and the correponding distances in vector dist.
void FindNeighborPoints(PointND& pt,Tfloat R, std::vector<Tindex> & res)
void FindNeighborPoints(const PointND &pt,Tfloat R, std::vector<Tindex> & res)
{
FindNeighborPoints(pt,R,data.begin(),data.end(),0,res);
}

/// Brute force search - please, use it only for debuging purposes
void FindNeighborPointsSlow(PointND& pt,Tfloat R, std::vector<Tindex> & res,
std::vector<Tfloat> & dist)
void FindNeighborPointsSlow(const PointND &pt,Tfloat R,
std::vector<Tindex> &res,
std::vector<Tfloat> &dist)
{
Tfloat dd;
for (auto iti=data.begin(); iti!=data.end(); iti++)
Expand All @@ -295,7 +313,8 @@ class KDTree
}

/// Brute force search - please, use it only for debuging purposes
void FindNeighborPointsSlow(PointND& pt,Tfloat R, std::vector<Tindex> & res)
void FindNeighborPointsSlow(const PointND &pt,Tfloat R,
std::vector<Tindex> &res)
{
Tfloat dd;
for (auto iti=data.begin(); iti!=data.end(); iti++)
Expand Down Expand Up @@ -333,12 +352,11 @@ class KDTree
}
};

/// Point for storing tmp data
PointND tp;
mutable PointND tp; ///< Point for storing tmp data
Tnorm fnorm;

/// Computes the distance between two nodes
Tfloat Dist(const PointND& pt1,const PointND& pt2)
Tfloat Dist(const PointND &pt1, const PointND &pt2) const
{
for (size_t i=0; i<ndim; i++)
{
Expand Down Expand Up @@ -388,13 +406,13 @@ class KDTree

/// Finds the closest point to bc.sp in the point cloud
/// bounded between [itb,ite).
void PSearch(typename std::vector<NodeND>::iterator itb,
typename std::vector<NodeND>::iterator ite,
size_t level, PointS& bc)
void PSearch(typename std::vector<NodeND>::const_iterator itb,
typename std::vector<NodeND>::const_iterator ite,
size_t level, PointS& bc) const
{
std::uint8_t dim=(std::uint8_t) (level%ndim);
size_t siz=ite-itb;
typename std::vector<NodeND>::iterator mtb=itb+siz/2;
typename std::vector<NodeND>::const_iterator mtb=itb+siz/2;
if (siz>2)
{
// median is at itb+siz/2
Expand Down Expand Up @@ -469,7 +487,7 @@ class KDTree
typename std::vector<NodeND>::iterator itb,
typename std::vector<NodeND>::iterator ite,
size_t level,
std::vector< std::tuple<Tfloat,Tindex> > & res)
std::vector< std::tuple<Tfloat,Tindex> > & res) const
{
std::uint8_t dim=(std::uint8_t) (level%ndim);
size_t siz=ite-itb;
Expand Down Expand Up @@ -521,7 +539,7 @@ class KDTree
typename std::vector<NodeND>::iterator itb,
typename std::vector<NodeND>::iterator ite,
size_t level,
std::vector<Tindex> & res)
std::vector<Tindex> & res) const
{
std::uint8_t dim=(std::uint8_t) (level%ndim);
size_t siz=ite-itb;
Expand Down Expand Up @@ -570,7 +588,7 @@ class KDTree
typename std::vector<NodeND>::iterator itb,
typename std::vector<NodeND>::iterator ite,
size_t level,
std::vector<Tindex> & res, std::vector<Tfloat> & dist)
std::vector<Tindex> & res, std::vector<Tfloat> & dist) const
{
std::uint8_t dim=(std::uint8_t) (level%ndim);
size_t siz=ite-itb;
Expand Down

0 comments on commit 148047f

Please sign in to comment.