From 1faa346c9d0fbbb9ca9268a44f33ea50a2214e0a Mon Sep 17 00:00:00 2001 From: Jacob Schreiber Date: Sat, 26 Aug 2023 16:23:06 -0700 Subject: [PATCH] FIX more categorical --- pomegranate/distributions/categorical.py | 4 ++-- tests/distributions/test_categorical.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pomegranate/distributions/categorical.py b/pomegranate/distributions/categorical.py index 93914d44..e5d264e1 100644 --- a/pomegranate/distributions/categorical.py +++ b/pomegranate/distributions/categorical.py @@ -175,8 +175,8 @@ def log_probability(self, X): for i in range(self.d): if isinstance(X, torch.masked.MaskedTensor): logp_ = self._log_probs[i][X[:, i]._masked_data] - logp_[logp_ == float("-inf")] = 0 - _inplace_add(logps, logp_) + logp_[~X[:, i]._masked_mask] = 0 + logps += logp_ else: logps += self._log_probs[i][X[:, i]] diff --git a/tests/distributions/test_categorical.py b/tests/distributions/test_categorical.py index 2a3894ca..c9b6fa66 100644 --- a/tests/distributions/test_categorical.py +++ b/tests/distributions/test_categorical.py @@ -828,7 +828,8 @@ def test_masked_probability(probs, X, X_masked): assert_array_almost_equal(y, d.probability(X_)) - y = [0.042 , 0.0075, 0.001 , 0.003 , 0.007 , 0.001 , 0.014 ] + y = [2.1000e-01, 1.5000e-01, 1.0000e+00, 3.0000e-03, 1.0000e-01, 1.0000e-03, + 1.4000e-01] assert_array_almost_equal(y, d.probability(X_masked)) @@ -842,7 +843,8 @@ def test_masked_log_probability(probs, X, X_masked): assert_array_almost_equal(y, d.log_probability(X_)) - y = numpy.log([0.042 , 0.0075, 0.001 , 0.003 , 0.007 , 0.001 , 0.014 ]) + y = numpy.log([2.1000e-01, 1.5000e-01, 1.0000e+00, 3.0000e-03, 1.0000e-01, + 1.0000e-03, 1.4000e-01]) assert_array_almost_equal(y, d.log_probability(X_masked)) def test_masked_summarize(X, X_masked, w):