diff --git a/pomegranate/distributions/categorical.py b/pomegranate/distributions/categorical.py index 93914d44..e5d264e1 100644 --- a/pomegranate/distributions/categorical.py +++ b/pomegranate/distributions/categorical.py @@ -175,8 +175,8 @@ def log_probability(self, X): for i in range(self.d): if isinstance(X, torch.masked.MaskedTensor): logp_ = self._log_probs[i][X[:, i]._masked_data] - logp_[logp_ == float("-inf")] = 0 - _inplace_add(logps, logp_) + logp_[~X[:, i]._masked_mask] = 0 + logps += logp_ else: logps += self._log_probs[i][X[:, i]] diff --git a/tests/distributions/test_categorical.py b/tests/distributions/test_categorical.py index 2a3894ca..c9b6fa66 100644 --- a/tests/distributions/test_categorical.py +++ b/tests/distributions/test_categorical.py @@ -828,7 +828,8 @@ def test_masked_probability(probs, X, X_masked): assert_array_almost_equal(y, d.probability(X_)) - y = [0.042 , 0.0075, 0.001 , 0.003 , 0.007 , 0.001 , 0.014 ] + y = [2.1000e-01, 1.5000e-01, 1.0000e+00, 3.0000e-03, 1.0000e-01, 1.0000e-03, + 1.4000e-01] assert_array_almost_equal(y, d.probability(X_masked)) @@ -842,7 +843,8 @@ def test_masked_log_probability(probs, X, X_masked): assert_array_almost_equal(y, d.log_probability(X_)) - y = numpy.log([0.042 , 0.0075, 0.001 , 0.003 , 0.007 , 0.001 , 0.014 ]) + y = numpy.log([2.1000e-01, 1.5000e-01, 1.0000e+00, 3.0000e-03, 1.0000e-01, + 1.0000e-03, 1.4000e-01]) assert_array_almost_equal(y, d.log_probability(X_masked)) def test_masked_summarize(X, X_masked, w):