diff --git a/pomegranate/_utils.py b/pomegranate/_utils.py index 06c58505..ac007877 100644 --- a/pomegranate/_utils.py +++ b/pomegranate/_utils.py @@ -74,7 +74,7 @@ def _update_parameter(value, new_value, inertia=0.0, frozen=None): return if inertia == 0.0: - value[:] = _cast_as_parameter(new_value) + value[...] = _cast_as_parameter(new_value) elif inertia < 1.0: value_ = inertia*value + (1-inertia)*new_value diff --git a/pomegranate/distributions/normal.py b/pomegranate/distributions/normal.py index 0a4e08b0..93ed4590 100644 --- a/pomegranate/distributions/normal.py +++ b/pomegranate/distributions/normal.py @@ -165,8 +165,8 @@ def _reset_cache(self): self.register_buffer("_log_sigma_sqrt_2pi", _log_sigma_sqrt_2pi) self.register_buffer("_inv_two_sigma", _inv_two_sigma) - - if any(self.covs < 0): + + if torch.any(self.covs < 0): raise ValueError("Variances must be positive.") def sample(self, n): @@ -289,9 +289,11 @@ def from_summaries(self): v = self._xw_sum.unsqueeze(0) * self._xw_sum.unsqueeze(1) covs = self._xxw_sum / self._w_sum - v / self._w_sum ** 2.0 - elif self.covariance_type == 'diag': + elif self.covariance_type in ['diag', 'sphere']: covs = self._xxw_sum / self._w_sum - \ self._xw_sum ** 2.0 / self._w_sum ** 2.0 + if self.covariance_type == 'sphere': + covs = covs.mean(dim=-1) _update_parameter(self.means, means, self.inertia) _update_parameter(self.covs, covs, self.inertia)