Skip to content

Commit

Permalink
Fix gmm initialization from weighted samples (#1038)
Browse files Browse the repository at this point in the history
* gmm initialization from weighted samples

* avoid indexing twice
  • Loading branch information
gerwang committed Aug 25, 2023
1 parent ee3c177 commit 3d9ef53
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions pomegranate/gmm.py
Expand Up @@ -156,11 +156,13 @@ def _initialize(self, X, sample_weight=None):
self.priors = _cast_as_parameter(torch.empty(self.k,
dtype=self.dtype, device=self.device))

sample_weight_sum = sample_weight.sum()
for i in range(self.k):
idx = y_hat == i

self.distributions[i].fit(X[idx], sample_weight=sample_weight[idx])
self.priors[i] = idx.type(torch.float32).mean()
sample_weight_idx = sample_weight[idx]
self.distributions[i].fit(X[idx], sample_weight=sample_weight_idx)
self.priors[i] = sample_weight_idx.sum() / sample_weight_sum

self._initialized = True
self._reset_cache()
Expand Down

0 comments on commit 3d9ef53

Please sign in to comment.