Skip to content

Commit

Permalink
Fix hmm fit shape check (#1019)
Browse files Browse the repository at this point in the history
* changed control flow so that errors caused by sample_weights or priors will be raised and will not result into falling back to the next input shape case

* Added a potential fix for #1018

* Updated all tests impacted by the change

* Added new test to test hmm-fit fix

* renamed priors_last_dim to n_dists
  • Loading branch information
AKuederle committed Apr 19, 2023
1 parent d03584f commit 514bf94
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 39 deletions.
27 changes: 16 additions & 11 deletions pomegranate/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def _initialize_centroids(X, k, algorithm='first-k', random_state=None):
return selector.fit_transform(X)


def partition_sequences(X, sample_weight=None, priors=None):
def partition_sequences(X, sample_weight=None, priors=None, n_dists=None):
"""Partition a set of sequences into blobs of equal length.
This function will take in a list of sequences, where each sequence is
Expand Down Expand Up @@ -391,48 +391,53 @@ def partition_sequences(X, sample_weight=None, priors=None):
The input sequence priors for the sequences or None. If None, return
None. Default is None.
n_dists: int or None
The expected last dimension of the priors tensor. Must be provided if
`priors` is provided. Default is None.
Returns
-------
X_: list or tensor
The partitioned and grouped sequences.
"""

if priors is not None and n_dists is None:
raise RuntimeError("If priors are provided, n_dists must be provided as well.")
# If a 3D tensor has been passed in, return it
try:
X = [_check_parameter(_cast_as_tensor(X), "X", ndim=3)]

except:
pass
else:
if sample_weight is not None:
sample_weight = [_check_parameter(_cast_as_tensor(sample_weight),
"sample_weight", min_value=0.0)]

if priors is not None:
priors = [_check_parameter(_cast_as_tensor(priors), "priors",
ndim=3, shape=X[0].shape)]
ndim=3, shape=(*X[0].shape[:-1], n_dists))]

return X, sample_weight, priors
except:
pass

# Otherwise, cast all elements in the list as a tensor
X = [_cast_as_tensor(x) for x in X]

# If a list of 3D tensors has been passed in, return it
try:
X = [_check_parameter(x, "X", ndim=3) for x in X]

except:
pass
else:
if sample_weight is not None:
sample_weight = [_check_parameter(_cast_as_tensor(w_),
"sample_weight", min_value=0.0) for w_ in sample_weight]

if priors is not None:
priors = [_check_parameter(_cast_as_tensor(p), "priors",
ndim=3, shape=X[i].shape) for i, p in enumerate(priors)]
ndim=3, shape=(*X[i].shape[:-1], n_dists)) for i, p in enumerate(priors)]

if all([x.ndim == 3 for x in X]):
return X, sample_weight, priors
except:
pass

# Otherwise, group together same-sized examples
X_dict = collections.defaultdict(list)
Expand All @@ -453,7 +458,7 @@ def partition_sequences(X, sample_weight=None, priors=None):

if priors is not None:
p = _check_parameter(_cast_as_tensor(priors[i]), "priors",
ndim=2, shape=x.shape)
ndim=2, shape=(*x.shape[:-1], n_dists))

priors_dict[n].append(p)

Expand Down
4 changes: 2 additions & 2 deletions pomegranate/hmm/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def fit(self, X, sample_weight=None, priors=None):
Prior probabilities of assigning each symbol to each node. If not
provided, do not include in the calculations (conceptually
equivalent to a uniform probability, but without scaling the
probabilities). This can be used to assign labels to observatons
probabilities). This can be used to assign labels to observations
by setting one of the probabilities for an observation to 1.0.
Note that this can be used to assign hard labels, but does not
have the same semantics for soft labels, in that it only
Expand All @@ -580,7 +580,7 @@ def fit(self, X, sample_weight=None, priors=None):
"""

X, sample_weight, priors = partition_sequences(X,
sample_weight=sample_weight, priors=priors)
sample_weight=sample_weight, priors=priors, n_dists=self.k)

# Initialize by concatenating across sequences
if not self._initialized:
Expand Down
19 changes: 19 additions & 0 deletions tests/hmm/test_dense_hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1626,3 +1626,22 @@ def test_masked_fit(X, X_masked):
assert_array_almost_equal(d2.scales, [2.8777, 2.3498, 2.1939], 4)
assert_array_almost_equal(d2._w_sum, [0., 0., 0.])
assert_array_almost_equal(d2._xw_sum, [0., 0., 0.])


@pytest.mark.parametrize("n_states", [2, 3, 4])
def test_priors_in_fit_valid(n_states):
funcs = (lambda x: x, tuple, numpy.array,
lambda x: torch.from_numpy(numpy.array(x)))

# Data here is a 3D array of lists
_data = [[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
[[10, 11, 12], [13, 14, 15], [16, 17, 18]],
[[19, 20, 21], [22, 23, 24], [25, 26, 27]]]

for func in funcs:
model = DenseHMM(distributions=[Exponential() for _ in range(n_states)])
data = func(_data)

priors = torch.ones((*numpy.array(data).shape[:-1], n_states)) / n_states

model.fit(data, priors=priors)
45 changes: 19 additions & 26 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ def test_partition_3d_Xp(X1, p1):
lambda x: torch.from_numpy(numpy.array(x)))

for func in funcs:
y, w, p = partition_sequences(func(X1), priors=func(p1))
y, w, p = partition_sequences(func(X1), priors=func(p1), n_dists=2)

assert isinstance(y, list)
assert len(y) == 1
Expand All @@ -667,8 +667,8 @@ def test_partition_3d_Xwp(X1, w1, p1):
lambda x: torch.from_numpy(numpy.array(x)))

for func in funcs:
y, w, p = partition_sequences(func(X1), sample_weight=func(w1),
priors=func(p1))
y, w, p = partition_sequences(func(X1), sample_weight=func(w1),
priors=func(p1), n_dists=2)

assert isinstance(y, list)
assert len(y) == 1
Expand Down Expand Up @@ -753,25 +753,6 @@ def test_partition_3ds_X(X2):
assert_array_almost_equal(y_, X2[i])


def test_partition_3ds_X(X2):
funcs = (lambda x: x, tuple, numpy.array,
lambda x: torch.from_numpy(numpy.array(x)))

for func in funcs:
X2_ = [func(x) for x in X2]

y, _, _ = partition_sequences(X2_)

assert isinstance(y, list)
assert len(y) == 2

for i, y_ in enumerate(y):
assert isinstance(y_, torch.Tensor)
assert y_.ndim == 3
assert y_.shape == (2, i+2, 2)
assert_array_almost_equal(y_, X2[i])


def test_partition_3ds_Xw(X2, w2):
funcs = (lambda x: x, tuple, numpy.array,
lambda x: torch.from_numpy(numpy.array(x)))
Expand Down Expand Up @@ -811,7 +792,7 @@ def test_partition_3ds_Xp(X2, p2):
X2_ = [func(x) for x in X2]
p2_ = [func(p) for p in p2]

y, w, p = partition_sequences(X2_, priors=p2_)
y, w, p = partition_sequences(X2_, priors=p2_, n_dists=2)

assert isinstance(y, list)
assert len(y) == 2
Expand Down Expand Up @@ -843,7 +824,8 @@ def test_partition_3ds_Xwp(X2, w2, p2):
w2_ = [func(w) for w in w2]
p2_ = [func(p) for p in p2]

y, w, p = partition_sequences(X2_, sample_weight=w2_, priors=p2_)
y, w, p = partition_sequences(X2_,
sample_weight=w2_, priors=p2_, n_dists=2)

assert isinstance(y, list)
assert len(y) == 2
Expand Down Expand Up @@ -1025,7 +1007,7 @@ def test_partition_2ds_Xp(X3, p3):
X3_ = [func(x) for x in X3]
p3_ = [func(p) for p in p3]

y, w, p = partition_sequences(X3_, priors=p3_)
y, w, p = partition_sequences(X3_, priors=p3_, n_dists=2)

assert isinstance(y, list)
assert len(y) == 3
Expand Down Expand Up @@ -1056,7 +1038,7 @@ def test_partition_2ds_Xwp(X3, w3, p3):
w3_ = [func(w) for w in w3]
p3_ = [func(p) for p in p3]

y, w, p = partition_sequences(X3_, sample_weight=w3_, priors=p3_)
y, w, p = partition_sequences(X3_, sample_weight=w3_, priors=p3_, n_dists=2)

assert isinstance(y, list)
assert len(y) == 3
Expand All @@ -1081,3 +1063,14 @@ def test_partition_2ds_Xwp(X3, w3, p3):
assert isinstance(p_, torch.Tensor)
assert p_.ndim == 3
assert p_.shape == ([1, 3, 2][i], i+1, 2)


@pytest.mark.parametrize("X", [torch.ones((1, 10, 2)), [torch.ones((1, 10, 2)), torch.ones((2, 10, 2))]])
@pytest.mark.parametrize("invalid", ["sample_weight", "priors"])
def test_dont_hide_errors_for_priors_and_sample_weight(X, invalid):
"""Test that we get the correct error message when we don't pass data in case 3."""

with pytest.raises(ValueError) as excinfo:
partition_sequences(X, **{invalid: numpy.zeros((1, 1)) - 1}, n_dists=10)

assert invalid in str(excinfo.value)

0 comments on commit 514bf94

Please sign in to comment.