From 3d9ef53ed35d051a62efdd375b98165f96e45ebd Mon Sep 17 00:00:00 2001 From: gerwang Date: Sat, 26 Aug 2023 05:19:40 +0800 Subject: [PATCH] Fix gmm initialization from weighted samples (#1038) * gmm initialization from weighted samples * avoid indexing twice --- pomegranate/gmm.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pomegranate/gmm.py b/pomegranate/gmm.py index 65d7b483..38608994 100644 --- a/pomegranate/gmm.py +++ b/pomegranate/gmm.py @@ -156,11 +156,13 @@ def _initialize(self, X, sample_weight=None): self.priors = _cast_as_parameter(torch.empty(self.k, dtype=self.dtype, device=self.device)) + sample_weight_sum = sample_weight.sum() for i in range(self.k): idx = y_hat == i - self.distributions[i].fit(X[idx], sample_weight=sample_weight[idx]) - self.priors[i] = idx.type(torch.float32).mean() + sample_weight_idx = sample_weight[idx] + self.distributions[i].fit(X[idx], sample_weight=sample_weight_idx) + self.priors[i] = sample_weight_idx.sum() / sample_weight_sum self._initialized = True self._reset_cache()