Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/jmschrei/pomegranate
Browse files Browse the repository at this point in the history
  • Loading branch information
jmschrei committed May 17, 2023
2 parents 1a1118f + f0b966e commit 0b3978a
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 40 deletions.
3 changes: 3 additions & 0 deletions .github/ISSUE_TEMPLATE/bug_report.md
Expand Up @@ -12,3 +12,6 @@ A clear and concise description of what the bug is, including what you were expe

**To Reproduce**
Please provide a snippet of code that can reproduce this error. It is much easier for us to track down bugs and fix them if we have an example script that fails until we're successful.

**Response time**
Although I will likely respond during weekdays if I am not on vacation, I am not likely to be able to merge PRs or write code until the weekend.
Expand Up @@ -210,7 +210,7 @@
"id": "66fb5e21",
"metadata": {},
"source": [
"It looks like it successfully identified a CG island in the middle (the long stretch of 0's) and another shorter one at the end. More importantly, the model wasn't tricked into thinking that every CG or even pair of CGs was an island. It required many C's and G's to be part of a longer stretch to identify that region as an island. Naturally, the balance of the transition and emission probabilities will heavily influence what regions are detected.\n",
"It looks like it successfully identified a CG island in the middle (the long stretch of 1's) and another shorter one at the end. More importantly, the model wasn't tricked into thinking that every CG or even pair of CGs was an island. It required many C's and G's to be part of a longer stretch to identify that region as an island. Naturally, the balance of the transition and emission probabilities will heavily influence what regions are detected.\n",
"\n",
"Let's say, though, that we want to get rid of that CG island prediction at the end because we don't believe that real islands can occur at the end of the sequence. We can take care of this by adding in an explicit end state that only the non-island hidden state can get to. We enforce that the model has to end in the end state, and if only the non-island state gets there, the sequence of hidden states must end in the non-island state. Here's how:"
]
Expand Down
27 changes: 16 additions & 11 deletions pomegranate/_utils.py
Expand Up @@ -355,7 +355,7 @@ def _initialize_centroids(X, k, algorithm='first-k', random_state=None):
return selector.fit_transform(X)


def partition_sequences(X, sample_weight=None, priors=None):
def partition_sequences(X, sample_weight=None, priors=None, n_dists=None):
"""Partition a set of sequences into blobs of equal length.
This function will take in a list of sequences, where each sequence is
Expand Down Expand Up @@ -391,48 +391,53 @@ def partition_sequences(X, sample_weight=None, priors=None):
The input sequence priors for the sequences or None. If None, return
None. Default is None.
n_dists: int or None
The expected last dimension of the priors tensor. Must be provided if
`priors` is provided. Default is None.
Returns
-------
X_: list or tensor
The partitioned and grouped sequences.
"""

if priors is not None and n_dists is None:
raise RuntimeError("If priors are provided, n_dists must be provided as well.")
# If a 3D tensor has been passed in, return it
try:
X = [_check_parameter(_cast_as_tensor(X), "X", ndim=3)]

except:
pass
else:
if sample_weight is not None:
sample_weight = [_check_parameter(_cast_as_tensor(sample_weight),
"sample_weight", min_value=0.0)]

if priors is not None:
priors = [_check_parameter(_cast_as_tensor(priors), "priors",
ndim=3, shape=X[0].shape)]
ndim=3, shape=(*X[0].shape[:-1], n_dists))]

return X, sample_weight, priors
except:
pass

# Otherwise, cast all elements in the list as a tensor
X = [_cast_as_tensor(x) for x in X]

# If a list of 3D tensors has been passed in, return it
try:
X = [_check_parameter(x, "X", ndim=3) for x in X]

except:
pass
else:
if sample_weight is not None:
sample_weight = [_check_parameter(_cast_as_tensor(w_),
"sample_weight", min_value=0.0) for w_ in sample_weight]

if priors is not None:
priors = [_check_parameter(_cast_as_tensor(p), "priors",
ndim=3, shape=X[i].shape) for i, p in enumerate(priors)]
ndim=3, shape=(*X[i].shape[:-1], n_dists)) for i, p in enumerate(priors)]

if all([x.ndim == 3 for x in X]):
return X, sample_weight, priors
except:
pass

# Otherwise, group together same-sized examples
X_dict = collections.defaultdict(list)
Expand All @@ -453,7 +458,7 @@ def partition_sequences(X, sample_weight=None, priors=None):

if priors is not None:
p = _check_parameter(_cast_as_tensor(priors[i]), "priors",
ndim=2, shape=x.shape)
ndim=2, shape=(*x.shape[:-1], n_dists))

priors_dict[n].append(p)

Expand Down
4 changes: 2 additions & 2 deletions pomegranate/hmm/_base.py
Expand Up @@ -565,7 +565,7 @@ def fit(self, X, sample_weight=None, priors=None):
Prior probabilities of assigning each symbol to each node. If not
provided, do not include in the calculations (conceptually
equivalent to a uniform probability, but without scaling the
probabilities). This can be used to assign labels to observatons
probabilities). This can be used to assign labels to observations
by setting one of the probabilities for an observation to 1.0.
Note that this can be used to assign hard labels, but does not
have the same semantics for soft labels, in that it only
Expand All @@ -580,7 +580,7 @@ def fit(self, X, sample_weight=None, priors=None):
"""

X, sample_weight, priors = partition_sequences(X,
sample_weight=sample_weight, priors=priors)
sample_weight=sample_weight, priors=priors, n_dists=self.k)

# Initialize by concatenating across sequences
if not self._initialized:
Expand Down
19 changes: 19 additions & 0 deletions tests/hmm/test_dense_hmm.py
Expand Up @@ -1626,3 +1626,22 @@ def test_masked_fit(X, X_masked):
assert_array_almost_equal(d2.scales, [2.8777, 2.3498, 2.1939], 4)
assert_array_almost_equal(d2._w_sum, [0., 0., 0.])
assert_array_almost_equal(d2._xw_sum, [0., 0., 0.])


@pytest.mark.parametrize("n_states", [2, 3, 4])
def test_priors_in_fit_valid(n_states):
funcs = (lambda x: x, tuple, numpy.array,
lambda x: torch.from_numpy(numpy.array(x)))

# Data here is a 3D array of lists
_data = [[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
[[10, 11, 12], [13, 14, 15], [16, 17, 18]],
[[19, 20, 21], [22, 23, 24], [25, 26, 27]]]

for func in funcs:
model = DenseHMM(distributions=[Exponential() for _ in range(n_states)])
data = func(_data)

priors = torch.ones((*numpy.array(data).shape[:-1], n_states)) / n_states

model.fit(data, priors=priors)
45 changes: 19 additions & 26 deletions tests/test_utils.py
Expand Up @@ -641,7 +641,7 @@ def test_partition_3d_Xp(X1, p1):
lambda x: torch.from_numpy(numpy.array(x)))

for func in funcs:
y, w, p = partition_sequences(func(X1), priors=func(p1))
y, w, p = partition_sequences(func(X1), priors=func(p1), n_dists=2)

assert isinstance(y, list)
assert len(y) == 1
Expand All @@ -667,8 +667,8 @@ def test_partition_3d_Xwp(X1, w1, p1):
lambda x: torch.from_numpy(numpy.array(x)))

for func in funcs:
y, w, p = partition_sequences(func(X1), sample_weight=func(w1),
priors=func(p1))
y, w, p = partition_sequences(func(X1), sample_weight=func(w1),
priors=func(p1), n_dists=2)

assert isinstance(y, list)
assert len(y) == 1
Expand Down Expand Up @@ -753,25 +753,6 @@ def test_partition_3ds_X(X2):
assert_array_almost_equal(y_, X2[i])


def test_partition_3ds_X(X2):
funcs = (lambda x: x, tuple, numpy.array,
lambda x: torch.from_numpy(numpy.array(x)))

for func in funcs:
X2_ = [func(x) for x in X2]

y, _, _ = partition_sequences(X2_)

assert isinstance(y, list)
assert len(y) == 2

for i, y_ in enumerate(y):
assert isinstance(y_, torch.Tensor)
assert y_.ndim == 3
assert y_.shape == (2, i+2, 2)
assert_array_almost_equal(y_, X2[i])


def test_partition_3ds_Xw(X2, w2):
funcs = (lambda x: x, tuple, numpy.array,
lambda x: torch.from_numpy(numpy.array(x)))
Expand Down Expand Up @@ -811,7 +792,7 @@ def test_partition_3ds_Xp(X2, p2):
X2_ = [func(x) for x in X2]
p2_ = [func(p) for p in p2]

y, w, p = partition_sequences(X2_, priors=p2_)
y, w, p = partition_sequences(X2_, priors=p2_, n_dists=2)

assert isinstance(y, list)
assert len(y) == 2
Expand Down Expand Up @@ -843,7 +824,8 @@ def test_partition_3ds_Xwp(X2, w2, p2):
w2_ = [func(w) for w in w2]
p2_ = [func(p) for p in p2]

y, w, p = partition_sequences(X2_, sample_weight=w2_, priors=p2_)
y, w, p = partition_sequences(X2_,
sample_weight=w2_, priors=p2_, n_dists=2)

assert isinstance(y, list)
assert len(y) == 2
Expand Down Expand Up @@ -1025,7 +1007,7 @@ def test_partition_2ds_Xp(X3, p3):
X3_ = [func(x) for x in X3]
p3_ = [func(p) for p in p3]

y, w, p = partition_sequences(X3_, priors=p3_)
y, w, p = partition_sequences(X3_, priors=p3_, n_dists=2)

assert isinstance(y, list)
assert len(y) == 3
Expand Down Expand Up @@ -1056,7 +1038,7 @@ def test_partition_2ds_Xwp(X3, w3, p3):
w3_ = [func(w) for w in w3]
p3_ = [func(p) for p in p3]

y, w, p = partition_sequences(X3_, sample_weight=w3_, priors=p3_)
y, w, p = partition_sequences(X3_, sample_weight=w3_, priors=p3_, n_dists=2)

assert isinstance(y, list)
assert len(y) == 3
Expand All @@ -1081,3 +1063,14 @@ def test_partition_2ds_Xwp(X3, w3, p3):
assert isinstance(p_, torch.Tensor)
assert p_.ndim == 3
assert p_.shape == ([1, 3, 2][i], i+1, 2)


@pytest.mark.parametrize("X", [torch.ones((1, 10, 2)), [torch.ones((1, 10, 2)), torch.ones((2, 10, 2))]])
@pytest.mark.parametrize("invalid", ["sample_weight", "priors"])
def test_dont_hide_errors_for_priors_and_sample_weight(X, invalid):
"""Test that we get the correct error message when we don't pass data in case 3."""

with pytest.raises(ValueError) as excinfo:
partition_sequences(X, **{invalid: numpy.zeros((1, 1)) - 1}, n_dists=10)

assert invalid in str(excinfo.value)

0 comments on commit 0b3978a

Please sign in to comment.