Skip to content

Commit

Permalink
ADD fixes for categorical and masked
Browse files Browse the repository at this point in the history
  • Loading branch information
jmschrei committed Aug 26, 2023
1 parent 24c89e2 commit 1ca070f
Show file tree
Hide file tree
Showing 4 changed files with 235 additions and 99 deletions.
13 changes: 12 additions & 1 deletion pomegranate/_utils.py
Expand Up @@ -35,6 +35,14 @@ def dtype(self):
return self.buffers[0].dtype


def _inplace_add(X, Y):
"""Do an in-place addition on X accounting for if Y is a masked tensor."""

if isinstance(Y, torch.masked.MaskedTensor):
X += Y._masked_data
else:
X += Y

def _cast_as_tensor(value, dtype=None):
"""Set the parameter."""

Expand Down Expand Up @@ -401,8 +409,10 @@ def partition_sequences(X, sample_weight=None, priors=None, n_dists=None):
X_: list or tensor
The partitioned and grouped sequences.
"""

if priors is not None and n_dists is None:
raise RuntimeError("If priors are provided, n_dists must be provided as well.")

# If a 3D tensor has been passed in, return it
try:
X = [_check_parameter(_cast_as_tensor(X), "X", ndim=3)]
Expand Down Expand Up @@ -434,7 +444,8 @@ def partition_sequences(X, sample_weight=None, priors=None, n_dists=None):

if priors is not None:
priors = [_check_parameter(_cast_as_tensor(p), "priors",
ndim=3, shape=(*X[i].shape[:-1], n_dists)) for i, p in enumerate(priors)]
ndim=3, shape=(*X[i].shape[:-1], n_dists))
for i, p in enumerate(priors)]

if all([x.ndim == 3 for x in X]):
return X, sample_weight, priors
Expand Down
21 changes: 17 additions & 4 deletions pomegranate/distributions/categorical.py
Expand Up @@ -3,6 +3,7 @@

import torch

from .._utils import _inplace_add
from .._utils import _cast_as_tensor
from .._utils import _cast_as_parameter
from .._utils import _update_parameter
Expand Down Expand Up @@ -172,7 +173,12 @@ def log_probability(self, X):

logps = torch.zeros(X.shape[0], dtype=self.probs.dtype)
for i in range(self.d):
logps += self._log_probs[i][X[:, i]]
if isinstance(X, torch.masked.MaskedTensor):
logp_ = self._log_probs[i][X[:, i]._masked_data]
logp_[logp_ == float("-inf")] = 0
_inplace_add(logps, logp_)
else:
logps += self._log_probs[i][X[:, i]]

return logps

Expand Down Expand Up @@ -201,16 +207,23 @@ def summarize(self, X, sample_weight=None):

X = _cast_as_tensor(X)
if not self._initialized:
n_keys = self.n_keys if self.n_keys is not None else int(X.max())+1
if self.n_keys is not None:
n_keys = self.n_keys
elif isinstance(X, torch.masked.MaskedTensor):
n_keys = int(torch.max(X._masked_data)) + 1
else:
n_keys = int(torch.max(X)) + 1

self._initialize(X.shape[1], n_keys)

X = _check_parameter(X, "X", min_value=0, max_value=self.n_keys-1,
ndim=2, shape=(-1, self.d), check_parameter=self.check_data)
sample_weight = _reshape_weights(X, _cast_as_tensor(sample_weight))

self._w_sum += torch.sum(sample_weight, dim=0)
_inplace_add(self._w_sum, torch.sum(sample_weight, dim=0))
for i in range(self.n_keys):
self._xw_sum[:, i] += torch.sum((X == i) * sample_weight, dim=0)
_inplace_add(self._xw_sum[:, i], torch.sum((X == i) * sample_weight,
dim=0))

def from_summaries(self):
"""Update the model parameters given the extracted statistics.
Expand Down
1 change: 0 additions & 1 deletion pomegranate/hmm/_base.py
Expand Up @@ -291,7 +291,6 @@ def _emission_matrix(self, X, priors=None):
e[:, i] = logp.reshape(n, k).T

e = e.permute(2, 0, 1)

if priors is not None:
e += torch.log(priors)

Expand Down

0 comments on commit 1ca070f

Please sign in to comment.