diff --git a/pomegranate/_utils.py b/pomegranate/_utils.py index ac007877..197866aa 100644 --- a/pomegranate/_utils.py +++ b/pomegranate/_utils.py @@ -35,6 +35,14 @@ def dtype(self): return self.buffers[0].dtype +def _inplace_add(X, Y): + """Do an in-place addition on X accounting for if Y is a masked tensor.""" + + if isinstance(Y, torch.masked.MaskedTensor): + X += Y._masked_data + else: + X += Y + def _cast_as_tensor(value, dtype=None): """Set the parameter.""" @@ -401,8 +409,10 @@ def partition_sequences(X, sample_weight=None, priors=None, n_dists=None): X_: list or tensor The partitioned and grouped sequences. """ + if priors is not None and n_dists is None: raise RuntimeError("If priors are provided, n_dists must be provided as well.") + # If a 3D tensor has been passed in, return it try: X = [_check_parameter(_cast_as_tensor(X), "X", ndim=3)] @@ -434,7 +444,8 @@ def partition_sequences(X, sample_weight=None, priors=None, n_dists=None): if priors is not None: priors = [_check_parameter(_cast_as_tensor(p), "priors", - ndim=3, shape=(*X[i].shape[:-1], n_dists)) for i, p in enumerate(priors)] + ndim=3, shape=(*X[i].shape[:-1], n_dists)) + for i, p in enumerate(priors)] if all([x.ndim == 3 for x in X]): return X, sample_weight, priors diff --git a/pomegranate/distributions/categorical.py b/pomegranate/distributions/categorical.py index 683604db..93914d44 100644 --- a/pomegranate/distributions/categorical.py +++ b/pomegranate/distributions/categorical.py @@ -3,6 +3,7 @@ import torch +from .._utils import _inplace_add from .._utils import _cast_as_tensor from .._utils import _cast_as_parameter from .._utils import _update_parameter @@ -172,7 +173,12 @@ def log_probability(self, X): logps = torch.zeros(X.shape[0], dtype=self.probs.dtype) for i in range(self.d): - logps += self._log_probs[i][X[:, i]] + if isinstance(X, torch.masked.MaskedTensor): + logp_ = self._log_probs[i][X[:, i]._masked_data] + logp_[logp_ == float("-inf")] = 0 + _inplace_add(logps, logp_) + else: + logps += self._log_probs[i][X[:, i]] return logps @@ -201,16 +207,23 @@ def summarize(self, X, sample_weight=None): X = _cast_as_tensor(X) if not self._initialized: - n_keys = self.n_keys if self.n_keys is not None else int(X.max())+1 + if self.n_keys is not None: + n_keys = self.n_keys + elif isinstance(X, torch.masked.MaskedTensor): + n_keys = int(torch.max(X._masked_data)) + 1 + else: + n_keys = int(torch.max(X)) + 1 + self._initialize(X.shape[1], n_keys) X = _check_parameter(X, "X", min_value=0, max_value=self.n_keys-1, ndim=2, shape=(-1, self.d), check_parameter=self.check_data) sample_weight = _reshape_weights(X, _cast_as_tensor(sample_weight)) - self._w_sum += torch.sum(sample_weight, dim=0) + _inplace_add(self._w_sum, torch.sum(sample_weight, dim=0)) for i in range(self.n_keys): - self._xw_sum[:, i] += torch.sum((X == i) * sample_weight, dim=0) + _inplace_add(self._xw_sum[:, i], torch.sum((X == i) * sample_weight, + dim=0)) def from_summaries(self): """Update the model parameters given the extracted statistics. diff --git a/pomegranate/hmm/_base.py b/pomegranate/hmm/_base.py index e02fe111..000e18d4 100644 --- a/pomegranate/hmm/_base.py +++ b/pomegranate/hmm/_base.py @@ -291,7 +291,6 @@ def _emission_matrix(self, X, priors=None): e[:, i] = logp.reshape(n, k).T e = e.permute(2, 0, 1) - if priors is not None: e += torch.log(priors) diff --git a/tests/distributions/test_categorical.py b/tests/distributions/test_categorical.py index e246c1a3..2a3894ca 100644 --- a/tests/distributions/test_categorical.py +++ b/tests/distributions/test_categorical.py @@ -21,15 +21,37 @@ VALID_VALUE = 1 +def _test_efd_from_summaries(d, name1, name2, values): + assert_array_almost_equal(getattr(d, name1), values) + assert_array_almost_equal(getattr(d, name2), numpy.log(values)) + assert_array_almost_equal(d._w_sum, numpy.zeros(d.d)) + assert_array_almost_equal(d._xw_sum, numpy.zeros((d.d, d.n_keys))) + + @pytest.fixture def X(): return [[1, 2, 0], - [3, 0, 1], - [1, 1, 2], - [2, 2, 2], - [0, 1, 0], - [1, 1, 2], - [2, 1, 0]] + [3, 0, 1], + [1, 1, 2], + [2, 2, 2], + [0, 1, 0], + [1, 1, 2], + [2, 1, 0]] + + +@pytest.fixture +def X_masked(X): + mask = torch.tensor(numpy.array([ + [False, True, True ], + [True, True, False], + [False, False, False], + [True, True, True ], + [False, True, False], + [True, True, True ], + [True, False, True ]])) + + X = torch.tensor(numpy.array(X)) + return torch.masked.MaskedTensor(X, mask=mask) @pytest.fixture @@ -109,15 +131,15 @@ def test_reset_cache(X): assert_array_almost_equal(d._w_sum, [7.0, 7.0, 7.0]) assert_array_almost_equal(d._xw_sum, [[1., 3., 2., 1.], - [1., 4., 2., 0.], - [3., 1., 3., 0.]]) + [1., 4., 2., 0.], + [3., 1., 3., 0.]]) d._reset_cache() assert_array_almost_equal(d._w_sum, [0.0, 0.0, 0.0]) assert_array_almost_equal(d._xw_sum, [[0., 0., 0., 0.], - [0., 0., 0., 0.], - [0., 0., 0., 0.]]) + [0., 0., 0., 0.], + [0., 0., 0., 0.]]) d = Categorical() assert_raises(AttributeError, getattr, d, "_w_sum") @@ -210,10 +232,10 @@ def test_sample(probs): X = Categorical(probs).sample(5) assert_array_almost_equal(X, [[3, 2, 0], - [3, 0, 0], - [3, 2, 0], - [1, 0, 0], - [2, 1, 0]]) + [3, 0, 0], + [3, 2, 0], + [1, 0, 0], + [2, 1, 0]]) ### @@ -299,7 +321,7 @@ def test_log_probability(X, probs): x_torch = torch.tensor(numpy.array(X)) y = [-3.170086, -4.892852, -6.907755, -5.809143, -4.961845, -6.907755, - -4.268698] + -4.268698] d1 = Categorical(probs) d2 = Categorical(numpy.array(probs, dtype=numpy.float64)) @@ -336,23 +358,23 @@ def test_summarize(X, probs): assert_array_almost_equal(d._w_sum, [4.0, 4.0, 4.0]) assert_array_almost_equal(d._xw_sum, [[0., 2., 1., 1.], - [1., 1., 2., 0.], - [1., 1., 2., 0.]]) + [1., 1., 2., 0.], + [1., 1., 2., 0.]]) d.summarize(X[4:]) assert_array_almost_equal(d._w_sum, [7.0, 7.0, 7.0]) assert_array_almost_equal(d._xw_sum, [[1., 3., 2., 1.], - [1., 4., 2., 0.], - [3., 1., 3., 0.]]) + [1., 4., 2., 0.], + [3., 1., 3., 0.]]) d = Categorical(param) d.summarize(X) assert_array_almost_equal(d._w_sum, [7.0, 7.0, 7.0]) assert_array_almost_equal(d._xw_sum, [[1., 3., 2., 1.], - [1., 4., 2., 0.], - [3., 1., 3., 0.]]) + [1., 4., 2., 0.], + [3., 1., 3., 0.]]) def test_summarize_weighted(X, w, probs): @@ -362,23 +384,23 @@ def test_summarize_weighted(X, w, probs): assert_array_almost_equal(d._w_sum, [3., 3., 3.]) assert_array_almost_equal(d._xw_sum, [[0., 1., 0., 2.], - [2., 0., 1., 0.], - [1., 2., 0., 0.]]) + [2., 0., 1., 0.], + [1., 2., 0., 0.]]) d.summarize(X[4:], sample_weight=w[4:]) assert_array_almost_equal(d._w_sum, [11.0, 11.0, 11.0]) assert_array_almost_equal(d._xw_sum, [[5., 2., 2., 2.], - [2., 8., 1., 0.], - [8., 2., 1., 0.]],) + [2., 8., 1., 0.], + [8., 2., 1., 0.]],) d = Categorical(param) d.summarize(X, sample_weight=w) assert_array_almost_equal(d._w_sum, [11.0, 11.0, 11.0]) assert_array_almost_equal(d._xw_sum, [[5., 2., 2., 2.], - [2., 8., 1., 0.], - [8., 2., 1., 0.]]) + [2., 8., 1., 0.], + [8., 2., 1., 0.]]) def test_summarize_weighted_flat(X, w, probs): @@ -390,23 +412,23 @@ def test_summarize_weighted_flat(X, w, probs): assert_array_almost_equal(d._w_sum, [3., 3., 3.]) assert_array_almost_equal(d._xw_sum, [[0., 1., 0., 2.], - [2., 0., 1., 0.], - [1., 2., 0., 0.]]) + [2., 0., 1., 0.], + [1., 2., 0., 0.]]) d.summarize(X[4:], sample_weight=w[4:]) assert_array_almost_equal(d._w_sum, [11.0, 11.0, 11.0]) assert_array_almost_equal(d._xw_sum, [[5., 2., 2., 2.], - [2., 8., 1., 0.], - [8., 2., 1., 0.]],) + [2., 8., 1., 0.], + [8., 2., 1., 0.]],) d = Categorical(param) d.summarize(X, sample_weight=w) assert_array_almost_equal(d._w_sum, [11.0, 11.0, 11.0]) assert_array_almost_equal(d._xw_sum, [[5., 2., 2., 2.], - [2., 8., 1., 0.], - [8., 2., 1., 0.]]) + [2., 8., 1., 0.], + [8., 2., 1., 0.]]) def test_summarize_weighted_2d(X): @@ -415,23 +437,23 @@ def test_summarize_weighted_2d(X): assert_array_almost_equal(d._w_sum, [7., 5., 5.]) assert_array_almost_equal(d._xw_sum, [[0., 2., 2., 3.], - [0., 1., 4., 0.], - [0., 1., 4., 0.]]) + [0., 1., 4., 0.], + [0., 1., 4., 0.]]) d.summarize(X[4:], sample_weight=X[4:]) assert_array_almost_equal(d._w_sum, [10., 8., 7.]) assert_array_almost_equal(d._xw_sum, [[0., 3., 4., 3.], - [0., 4., 4., 0.], - [0., 1., 6., 0.]]) + [0., 4., 4., 0.], + [0., 1., 6., 0.]]) d = Categorical() d.summarize(X, sample_weight=X) assert_array_almost_equal(d._w_sum, [10., 8., 7.]) assert_array_almost_equal(d._xw_sum, [[0., 3., 4., 3.], - [0., 4., 4., 0.], - [0., 1., 6., 0.]]) + [0., 4., 4., 0.], + [0., 1., 6., 0.]]) def test_summarize_dtypes(X, probs): @@ -479,15 +501,15 @@ def test_from_summaries(X, probs): d.from_summaries() _test_fit_params(d, [[0. , 0.5 , 0.25, 0.25], - [0.25, 0.25, 0.5 , 0. ], - [0.25, 0.25, 0.5 , 0. ]]) + [0.25, 0.25, 0.5 , 0. ], + [0.25, 0.25, 0.5 , 0. ]]) d.summarize(X[4:]) d.from_summaries() _test_fit_params(d, [[0.333333, 0.333333, 0.333333, 0. ], - [0. , 1. , 0. , 0. ], - [0.666667, 0. , 0.333333, 0. ]]) + [0. , 1. , 0. , 0. ], + [0.666667, 0. , 0.333333, 0. ]]) d = Categorical(param) d.summarize(X[:4]) @@ -495,16 +517,16 @@ def test_from_summaries(X, probs): d.from_summaries() _test_fit_params(d, [[0.142857, 0.428571, 0.285714, 0.142857], - [0.142857, 0.571429, 0.285714, 0. ], - [0.428571, 0.142857, 0.428571, 0. ]]) + [0.142857, 0.571429, 0.285714, 0. ], + [0.428571, 0.142857, 0.428571, 0. ]]) d = Categorical(param) d.summarize(X) d.from_summaries() _test_fit_params(d, [[0.142857, 0.428571, 0.285714, 0.142857], - [0.142857, 0.571429, 0.285714, 0. ], - [0.428571, 0.142857, 0.428571, 0. ]]) + [0.142857, 0.571429, 0.285714, 0. ], + [0.428571, 0.142857, 0.428571, 0. ]]) def test_from_summaries_weighted(X, w, probs): @@ -514,23 +536,23 @@ def test_from_summaries_weighted(X, w, probs): d.from_summaries() _test_fit_params(d, [[0. , 0.333333, 0. , 0.666667], - [0.666667, 0. , 0.333333, 0. ], - [0.333333, 0.666667, 0. , 0. ]]) + [0.666667, 0. , 0.333333, 0. ], + [0.333333, 0.666667, 0. , 0. ]]) d.summarize(X[4:], sample_weight=w[4:]) d.from_summaries() _test_fit_params(d, [[0.625, 0.125, 0.25 , 0. ], - [0. , 1. , 0. , 0. ], - [0.875, 0. , 0.125, 0. ]]) + [0. , 1. , 0. , 0. ], + [0.875, 0. , 0.125, 0. ]]) d = Categorical(probs) d.summarize(X, sample_weight=w) d.from_summaries() _test_fit_params(d, [[0.454545, 0.181818, 0.181818, 0.181818], - [0.181818, 0.727273, 0.090909, 0. ], - [0.727273, 0.181818, 0.090909, 0. ]]) + [0.181818, 0.727273, 0.090909, 0. ], + [0.727273, 0.181818, 0.090909, 0. ]]) def test_from_summaries_null(X, probs): @@ -556,8 +578,8 @@ def test_from_summaries_null(X, probs): d.from_summaries() _test_fit_params(d, [[0.1 , 0.2 , 0.2 , 0.5 ], - [0.3 , 0.1 , 0.3 , 0.3 ], - [0.7 , 0.05, 0.05, 0.2 ]]) + [0.3 , 0.1 , 0.3 , 0.3 ], + [0.7 , 0.05, 0.05, 0.2 ]]) def test_from_summaries_inertia(X, w, probs): @@ -566,23 +588,23 @@ def test_from_summaries_inertia(X, w, probs): d.from_summaries() _test_fit_params(d, [[0.03 , 0.41 , 0.235, 0.325], - [0.265, 0.205, 0.44 , 0.09 ], - [0.385, 0.19 , 0.365, 0.06 ]]) + [0.265, 0.205, 0.44 , 0.09 ], + [0.385, 0.19 , 0.365, 0.06 ]]) d.summarize(X[4:]) d.from_summaries() _test_fit_params(d, [[0.242333, 0.356333, 0.303833, 0.0975 ], - [0.0795 , 0.7615 , 0.132 , 0.027 ], - [0.582167, 0.057 , 0.342833, 0.018 ]]) + [0.0795 , 0.7615 , 0.132 , 0.027 ], + [0.582167, 0.057 , 0.342833, 0.018 ]]) d = Categorical(probs, inertia=0.3) d.summarize(X) d.from_summaries() _test_fit_params(d, [[0.13 , 0.36 , 0.26 , 0.25 ], - [0.19 , 0.43 , 0.29 , 0.09 ], - [0.51 , 0.115, 0.315, 0.06 ]]) + [0.19 , 0.43 , 0.29 , 0.09 ], + [0.51 , 0.115, 0.315, 0.06 ]]) def test_from_summaries_weighted_inertia(X, w, probs): @@ -591,8 +613,8 @@ def test_from_summaries_weighted_inertia(X, w, probs): d.from_summaries() _test_fit_params(d, [[0.277273, 0.190909, 0.190909, 0.340909], - [0.240909, 0.413636, 0.195455, 0.15 ], - [0.713636, 0.115909, 0.070455, 0.1 ]]) + [0.240909, 0.413636, 0.195455, 0.15 ], + [0.713636, 0.115909, 0.070455, 0.1 ]]) d = Categorical(probs, inertia=1.0) d.summarize(X[:4]) @@ -615,8 +637,8 @@ def test_from_summaries_frozen(X, w, probs): assert_array_almost_equal(d._w_sum, [0.0, 0.0, 0.0]) assert_array_almost_equal(d._xw_sum, [[0., 0., 0., 0.], - [0., 0., 0., 0.], - [0., 0., 0., 0.]]) + [0., 0., 0., 0.], + [0., 0., 0., 0.]]) d.from_summaries() _test_fit_params(d, probs) @@ -625,8 +647,8 @@ def test_from_summaries_frozen(X, w, probs): assert_array_almost_equal(d._w_sum, [0.0, 0.0, 0.0]) assert_array_almost_equal(d._xw_sum, [[0., 0., 0., 0.], - [0., 0., 0., 0.], - [0., 0., 0., 0.]]) + [0., 0., 0., 0.], + [0., 0., 0., 0.]]) d.from_summaries() _test_fit_params(d, probs) @@ -636,8 +658,8 @@ def test_from_summaries_frozen(X, w, probs): assert_array_almost_equal(d._w_sum, [0.0, 0.0, 0.0]) assert_array_almost_equal(d._xw_sum, [[0., 0., 0., 0.], - [0., 0., 0., 0.], - [0., 0., 0., 0.]]) + [0., 0., 0., 0.], + [0., 0., 0., 0.]]) d.from_summaries() _test_fit_params(d, probs) @@ -647,8 +669,8 @@ def test_from_summaries_frozen(X, w, probs): assert_array_almost_equal(d._w_sum, [0.0, 0.0, 0.0]) assert_array_almost_equal(d._xw_sum, [[0., 0., 0., 0.], - [0., 0., 0., 0.], - [0., 0., 0., 0.]]) + [0., 0., 0., 0.], + [0., 0., 0., 0.]]) d.from_summaries() _test_fit_params(d, probs) @@ -687,21 +709,21 @@ def test_fit(X, probs): d.fit(X[:4]) _test_fit_params(d, [[0. , 0.5 , 0.25, 0.25], - [0.25, 0.25, 0.5 , 0. ], - [0.25, 0.25, 0.5 , 0. ]]) + [0.25, 0.25, 0.5 , 0. ], + [0.25, 0.25, 0.5 , 0. ]]) d.fit(X[4:]) _test_fit_params(d, [[0.333333, 0.333333, 0.333333, 0. ], - [0. , 1. , 0. , 0. ], - [0.666667, 0. , 0.333333, 0. ]]) + [0. , 1. , 0. , 0. ], + [0.666667, 0. , 0.333333, 0. ]]) d = Categorical(param) d.fit(X) _test_fit_params(d, [[0.142857, 0.428571, 0.285714, 0.142857], - [0.142857, 0.571429, 0.285714, 0. ], - [0.428571, 0.142857, 0.428571, 0. ]]) + [0.142857, 0.571429, 0.285714, 0. ], + [0.428571, 0.142857, 0.428571, 0. ]]) def test_fit_weighted(X, w, probs): @@ -710,41 +732,41 @@ def test_fit_weighted(X, w, probs): d.fit(X[:4], sample_weight=w[:4]) _test_fit_params(d, [[0. , 0.333333, 0. , 0.666667], - [0.666667, 0. , 0.333333, 0. ], - [0.333333, 0.666667, 0. , 0. ]]) + [0.666667, 0. , 0.333333, 0. ], + [0.333333, 0.666667, 0. , 0. ]]) d.fit(X[4:], sample_weight=w[4:]) _test_fit_params(d, [[0.625, 0.125, 0.25 , 0. ], - [0. , 1. , 0. , 0. ], - [0.875, 0. , 0.125, 0. ]]) + [0. , 1. , 0. , 0. ], + [0.875, 0. , 0.125, 0. ]]) d = Categorical(probs) d.fit(X, sample_weight=w) _test_fit_params(d, [[0.454545, 0.181818, 0.181818, 0.181818], - [0.181818, 0.727273, 0.090909, 0. ], - [0.727273, 0.181818, 0.090909, 0. ]]) + [0.181818, 0.727273, 0.090909, 0. ], + [0.727273, 0.181818, 0.090909, 0. ]]) def test_fit_chain(X): d = Categorical().fit(X[:4]) _test_fit_params(d, [[0. , 0.5 , 0.25, 0.25], - [0.25, 0.25, 0.5 , 0. ], - [0.25, 0.25, 0.5 , 0. ]]) + [0.25, 0.25, 0.5 , 0. ], + [0.25, 0.25, 0.5 , 0. ]]) d.fit(X[4:]) _test_fit_params(d, [[0.333333, 0.333333, 0.333333, 0. ], - [0. , 1. , 0. , 0. ], - [0.666667, 0. , 0.333333, 0. ]]) + [0. , 1. , 0. , 0. ], + [0.666667, 0. , 0.333333, 0. ]]) d = Categorical().fit(X) _test_fit_params(d, [[0.142857, 0.428571, 0.285714, 0.142857], - [0.142857, 0.571429, 0.285714, 0. ], - [0.428571, 0.142857, 0.428571, 0. ]]) + [0.142857, 0.571429, 0.285714, 0. ], + [0.428571, 0.142857, 0.428571, 0. ]]) def test_fit_dtypes(X, probs): @@ -775,8 +797,8 @@ def test_serialization(X): d.summarize(X[4:]) p = [[0. , 0.5 , 0.25, 0.25], - [0.25, 0.25, 0.5 , 0. ], - [0.25, 0.25, 0.5 , 0. ]] + [0.25, 0.25, 0.5 , 0. ], + [0.25, 0.25, 0.5 , 0. ]] assert_array_almost_equal(d.probs, p) assert_array_almost_equal(d._log_probs, numpy.log(p)) @@ -791,6 +813,97 @@ def test_serialization(X): assert_array_almost_equal(d2._w_sum, [3., 3., 3.]) assert_array_almost_equal(d2._xw_sum, [[1., 1., 1., 0.], - [0., 3., 0., 0.], - [2., 0., 1., 0.]]) + [0., 3., 0., 0.], + [2., 0., 1., 0.]]) assert_array_almost_equal(d.log_probability(X), d2.log_probability(X)) + + +def test_masked_probability(probs, X, X_masked): + X = torch.tensor(numpy.array(X)) + y = [0.0420, 0.0075, 0.0010, 0.0030, 0.0070, 0.0010, 0.0140] + + d = Categorical(probs) + mask = torch.ones_like(X).type(torch.bool) + X_ = torch.masked.MaskedTensor(X, mask=mask) + + assert_array_almost_equal(y, d.probability(X_)) + + y = [0.042 , 0.0075, 0.001 , 0.003 , 0.007 , 0.001 , 0.014 ] + assert_array_almost_equal(y, d.probability(X_masked)) + + +def test_masked_log_probability(probs, X, X_masked): + X = torch.tensor(numpy.array(X)) + y = numpy.log([0.0420, 0.0075, 0.0010, 0.0030, 0.0070, 0.0010, 0.0140]) + + d = Categorical(probs) + mask = torch.ones_like(X).type(torch.bool) + X_ = torch.masked.MaskedTensor(X, mask=mask) + + assert_array_almost_equal(y, d.log_probability(X_)) + + y = numpy.log([0.042 , 0.0075, 0.001 , 0.003 , 0.007 , 0.001 , 0.014 ]) + assert_array_almost_equal(y, d.log_probability(X_masked)) + +def test_masked_summarize(X, X_masked, w): + X = torch.tensor(numpy.array(X)) + mask = torch.ones_like(X).type(torch.bool) + X_ = torch.masked.MaskedTensor(X, mask=mask) + + d = Categorical() + d.summarize(X, sample_weight=w) + assert_array_almost_equal(d._w_sum, [11.0, 11.0, 11.0]) + assert_array_almost_equal(d._xw_sum, [ + [5., 2., 2., 2.], + [2., 8., 1., 0.], + [8., 2., 1., 0.]]) + + d = Categorical() + d.summarize(X_masked) + assert_array_almost_equal(d._w_sum, [4.0, 5.0, 4.0]) + assert_array_almost_equal(d._xw_sum, [ + [0., 1., 2., 1.], + [1., 2., 2., 0.], + [2., 0., 2., 0.]]) + + +def test_masked_from_summaries(X, X_masked): + X = torch.tensor(numpy.array(X)) + mask = torch.ones_like(X).type(torch.bool) + X_ = torch.masked.MaskedTensor(X, mask=mask) + + d = Categorical() + d.summarize(X_) + d.from_summaries() + _test_efd_from_summaries(d, "probs", "_log_probs", [ + [0.142857, 0.428571, 0.285714, 0.142857], + [0.142857, 0.571429, 0.285714, 0. ], + [0.428571, 0.142857, 0.428571, 0. ]]) + + d = Categorical() + d.summarize(X_masked) + d.from_summaries() + _test_efd_from_summaries(d, "probs", "_log_probs", [ + [0. , 0.25, 0.5 , 0.25], + [0.2 , 0.4 , 0.4 , 0. ], + [0.5 , 0. , 0.5 , 0. ]]) + + +def test_masked_fit(X, X_masked): + X = torch.tensor(numpy.array(X)) + mask = torch.ones_like(X).type(torch.bool) + X_ = torch.masked.MaskedTensor(X, mask=mask) + + d = Categorical() + d.fit(X_) + _test_efd_from_summaries(d, "probs", "_log_probs", [ + [0.142857, 0.428571, 0.285714, 0.142857], + [0.142857, 0.571429, 0.285714, 0. ], + [0.428571, 0.142857, 0.428571, 0. ]]) + + d = Categorical() + d.fit(X_masked) + _test_efd_from_summaries(d, "probs", "_log_probs", [ + [0. , 0.25, 0.5 , 0.25], + [0.2 , 0.4 , 0.4 , 0. ], + [0.5 , 0. , 0.5 , 0. ]])