diff --git a/pomegranate/gmm.py b/pomegranate/gmm.py index 65d7b483..38608994 100644 --- a/pomegranate/gmm.py +++ b/pomegranate/gmm.py @@ -156,11 +156,13 @@ def _initialize(self, X, sample_weight=None): self.priors = _cast_as_parameter(torch.empty(self.k, dtype=self.dtype, device=self.device)) + sample_weight_sum = sample_weight.sum() for i in range(self.k): idx = y_hat == i - self.distributions[i].fit(X[idx], sample_weight=sample_weight[idx]) - self.priors[i] = idx.type(torch.float32).mean() + sample_weight_idx = sample_weight[idx] + self.distributions[i].fit(X[idx], sample_weight=sample_weight_idx) + self.priors[i] = sample_weight_idx.sum() / sample_weight_sum self._initialized = True self._reset_cache()