Skip to content

Commit

Permalink
Merge pull request #519 from jc-healy/master
Browse files Browse the repository at this point in the history
dbscan cluster extraction
  • Loading branch information
jc-healy committed Jan 12, 2022
2 parents 51050de + 249dddc commit 2179c24
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 13 deletions.
126 changes: 126 additions & 0 deletions docs/dbscan_from_hdbscan.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@

Extracting DBSCAN* clustering from HDBSCAN*
===========================================

There are a number of reasons that one might prefer `DBSCAN <https://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html>`__'s
clustering over that of HDBSCAN*. The biggest difficulty many folks have with
DBSCAN is that the epsilon distance parameter can be hard to determine and often
requires a great deal of trial and error to tune. If your data lived in a more
interpretable space and you had a good notion of distance in that space this problem
is certainly mitigated and a user might want to set a very specific epsilon distance
for their use case. Another viable use case might be that a user is interested in a
constant density clustering.
HDBSCAN* does variable density clustering by default, looking for the clusters that persist
over a wide range of epsilon distance parameters to find a 'natural' clustering. This might
not be the right result for your application. A DBSCAN clustering at a particular
epsilon value might work better for your particular task.

HDBSCAN returns a very natural clustering of your data which is often very useful in exploring
a new data set. That doesn't necessarily make it the right clustering algorithm or every
task.

HDBSCAN* can best be thought of as a DBSCAN* implementation which varies across
all epsilon values and extracts the clusters that persist over the widest range
of these parameter choices. It is therefore able to ignore the parameter and
only needs the minimum cluster size as single input parameter.
The 'eom' (Excess of Mass) cluster selection method then returns clusters with the
best stability over epsilon.

There are a number of alternative ways of extracting a flat clustering from
the HDBSCAN* hierarchical tree. If one is interested in finer resolution
clusters while still maintaining variable density one could set
``cluster_selection_method='leaf'`` to extract the leaves of the condensed
tree instead of the most persistent clusters. For more details on these
cluster selection methods see :ref:`leaf_clustering_label`.

If one wasn't interested in the variable density clustering that is the hallmark of
HDBSCAN* it is relatively easy to extract any DBSCAN* clustering from a
single run of HDBSCAN*. This has the advantage of allowing you to perform
a single computationally efficient HDBSCAN* run and then quickly search over
the DBSCAN* parameter space by extracting clustering results from our
pre-constructed tree. This can save significant computational time when
searching across multiple cluster parameter settings on large amounts of data.

Alternatively, one could make use of the ``cluster_selection_epsilon`` as a
post processing step with any ``cluster_selection_method`` in order to
return a hybrid clustering of DBSCAN* and HDBSCAN*. For more details on
this see :doc:`how_to_use_epsilon`.

In order to extract a DBSCAN* clustering from an HDBSCAN run we must first train
and HDBSCAN model on our data.

.. code:: python
import hdbscan
h_cluster = hdbscan.HDBSCAN(min_samples=5,match_reference_implementation=True).fit(X)
The ``min_cluster_size`` parameter is unimportant in this case in that it is
only used in the creation of our condensed tree which we won't be using here.
Now we choose a ``cut_distance`` which is just another name for the epsilon
threshold in DBSCAN and will be passed to our
:py:meth:`~hdbscan.hdbscan_.dbscan_clustering` method.

.. code:: python
eps = 0.2
labels = h_cluster.dbscan_clustering(cut_distance=eps, min_cluster_size=5)
sns.scatterplot(x=X[:,0], y=X[:,1], hue=labels.astype(str));
.. image:: images/dbscan_from_hdbscan_clustering.png
:align: center

It should be noted that a DBSCAN* clustering extracted from our HDBSCAN* tree will
not precisely match the clustering results from sklearn's DBSCAN implementation.
Our clustering results should better match DBSCAN* (which can be thought of as
DBSCAN without the border points). As such when comparing the two results one
should expect them to mostly differ in the points that DBSCAN considers boarder
points. We'll deal with
this by only looking at the comparison of our clustering results based on the points identified
by DBSCAN as core points. We can see below that the differences between these two
clusterings mostly occur in the boundaries of the clusters. This matches our
intuition of stability within the core points.

.. image:: images/dbscan_from_hdbscan_comparision.png
:align: center

For a slightly more empirical comparison we we make use of the `adjusted rand score <https://scikit-learn.org/stable/modules/generated/sklearn.metrics.adjusted_rand_score.html>`__
to compare the clustering of the core points between a DBSCAN cluster from sklearn and
a DBSCAN* clustering extracted from our HDBSCAN* object.

.. image:: images/dbscan_from_hdbscan_percentage_core.png
:align: center

.. image:: images/dbscan_from_hdbscan_number_of_clusters.png
:align: center

We see that for very small epsilon values our number of clusters tends to be quite
far apart, largely due to a large number of the points being considered boundary points
instead of core points. As the epsilon value increases, more and more points are
considered core and the number of clusters generated by each algorithm converge.

Additionally, the adjusted rand score between the core points of both algorithm
stays consistently high (mostly 1.0) for our entire range of epsilon. There may be
be some minor discrepancies between core point results largely due to implementation
details and optimizations with the code base.

Why might one just extract the DBSCAN* clustering results from a single HDBSCAN* run
instead of making use of sklearns DBSSCAN code? The short answer is efficiency.
If you aren't sure what epsilon parameter to select for DBSCAN then you may have to
run the algorithm many times on your data set. While those runs can be inexpensive for
very small epsilon values they can get quite expensive for large parameter values.

In this small benchmark case of 50,000 two dimensional data points we have broken even
after having only had to try two epsilon parameters from DBSCAN, or only a single
run with a large parameter selected. This trend is only exacerbated for larger
data sets in higher dimensional spaces. For more detailed scaling experiments see
`Accelearted Hierarchical Density Clustering <https://arxiv.org/abs/1705.07321>`__
by McInnes and Healy.

.. image:: images/dbscan_from_hdbscan_timing.png
:align: center






1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ User Guide / Tutorial
prediction_tutorial
soft_clustering
how_to_use_epsilon
dbscan_from_hdbscan
faq

Background on Clustering with HDBSCAN
Expand Down
66 changes: 53 additions & 13 deletions hdbscan/hdbscan_.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,9 @@ def _hdbscan_prims_kdtree(
dist_metric = DistanceMetric.get_metric(metric, **kwargs)

# Get distance to kth nearest neighbour
core_distances = tree.query(X, k=min_samples + 1, dualtree=True, breadth_first=True)[0][
:, -1
].copy(order="C")
core_distances = tree.query(
X, k=min_samples + 1, dualtree=True, breadth_first=True
)[0][:, -1].copy(order="C")

# Mutual reachability distance is implicit in mst_linkage_core_vector
min_spanning_tree = mst_linkage_core_vector(X, core_distances, dist_metric, alpha)
Expand Down Expand Up @@ -283,9 +283,9 @@ def _hdbscan_prims_balltree(
dist_metric = DistanceMetric.get_metric(metric, **kwargs)

# Get distance to kth nearest neighbour
core_distances = tree.query(X, k=min_samples + 1, dualtree=True, breadth_first=True)[0][
:, -1
].copy(order="C")
core_distances = tree.query(
X, k=min_samples + 1, dualtree=True, breadth_first=True
)[0][:, -1].copy(order="C")

# Mutual reachability distance is implicit in mst_linkage_core_vector
min_spanning_tree = mst_linkage_core_vector(X, core_distances, dist_metric, alpha)
Expand All @@ -299,6 +299,7 @@ def _hdbscan_prims_balltree(
else:
return single_linkage_tree, None


def _hdbscan_boruvka_kdtree(
X,
min_samples=5,
Expand Down Expand Up @@ -423,12 +424,15 @@ def remap_condensed_tree(tree, internal_to_raw, outliers):
for outlier in outliers:
outlier_list.append((root, outlier, 0, 1))

outlier_tree = np.array(outlier_list, dtype=[
("parent", np.intp),
("child", np.intp),
("lambda_val", float),
("child_size", np.intp),
],)
outlier_tree = np.array(
outlier_list,
dtype=[
("parent", np.intp),
("child", np.intp),
("lambda_val", float),
("child_size", np.intp),
],
)
tree = np.append(outlier_tree, tree)
return tree

Expand Down Expand Up @@ -471,13 +475,15 @@ def remap_single_linkage_tree(tree, internal_to_raw, outliers):
tree = np.vstack([tree, outlier_tree])
return tree


def is_finite(matrix):
"""Returns true only if all the values of a ndarray or sparse matrix are finite"""
if issparse(matrix):
return np.alltrue(np.isfinite(matrix.tocoo().data))
else:
return np.alltrue(np.isfinite(matrix))


def get_finite_row_indices(matrix):
"""Returns the indices of the purely finite rows of a sparse matrix or dense ndarray"""
if issparse(matrix):
Expand Down Expand Up @@ -1146,7 +1152,7 @@ def fit(self, X, y=None):
self._raw_data = X

self._all_finite = is_finite(X)
if(~self._all_finite):
if ~self._all_finite:
# Pass only the purely finite indices into hdbscan
# We will later assign all non-finite points to the background -1 cluster
finite_index = get_finite_row_indices(X)
Expand Down Expand Up @@ -1322,6 +1328,40 @@ def weighted_cluster_medoid(self, cluster_id):
medoid_index = np.argmin(dist_mat.sum(axis=1))
return cluster_data[medoid_index]

def dbscan_clustering(self, cut_distance, min_cluster_size=5):
"""Return clustering that would be equivalent to running DBSCAN* for a particular cut_distance (or epsilon)
DBSCAN* can be thought of as DBSCAN without the border points. As such these results may differ slightly
from sklearns implementation of dbscan in the non-core points.
This can also be thought of as a flat clustering derived from constant height cut through the single
linkage tree.
This represents the result of selecting a cut value for robust single linkage
clustering. The `min_cluster_size` allows the flat clustering to declare noise
points (and cluster smaller than `min_cluster_size`).
Parameters
----------
cut_distance : float
The mutual reachability distance cut value to use to generate a flat clustering.
min_cluster_size : int, optional
Clusters smaller than this value with be called 'noise' and remain unclustered
in the resulting flat clustering.
Returns
-------
labels : array [n_samples]
An array of cluster labels, one per datapoint. Unclustered points are assigned
the label -1.
"""
return self.single_linkage_tree_.get_clusters(
cut_distance=cut_distance,
min_cluster_size=min_cluster_size,
)

@property
def prediction_data_(self):
if self._prediction_data is None:
Expand Down
5 changes: 5 additions & 0 deletions hdbscan/tests/test_hdbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,11 @@ def test_hdbscan_generic():
n_clusters_2 = len(set(labels)) - int(-1 in labels)
assert n_clusters_2 == n_clusters

def test_hdbscan_dbscan_clustering():
clusterer = HDBSCAN().fit(X)
labels = clusterer.dbscan_clustering(0.3)
n_clusters_1 = len(set(labels)) - int(-1 in labels)
assert(n_clusters == n_clusters_1)

def test_hdbscan_high_dimensional():
H, y = make_blobs(n_samples=50, random_state=0, n_features=64)
Expand Down

0 comments on commit 2179c24

Please sign in to comment.