Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QUESTION] I would like to fix the transition matrix upon running. Can I have some help pointing to what update steps need to be commented? #1092

Open
gajarajv opened this issue Apr 9, 2024 · 13 comments

Comments

@gajarajv
Copy link

gajarajv commented Apr 9, 2024

Describe the bug
A clear and concise description of what the bug is, including what you were expecting to happen and what actually happened. Please report the version of pomegranate that you are using and the operating system. Also, please make sure that you have upgraded to the latest version of pomegranate before submitting the bug report.

To Reproduce
Please provide a snippet of code that can reproduce this error. It is much easier for us to track down bugs and fix them if we have an example script that fails until we're successful.

Response time
Although I will likely respond during weekdays if I am not on vacation, I am not likely to be able to merge PRs or write code until the weekend.

@gajarajv
Copy link
Author

gajarajv commented Apr 9, 2024

Specifically, when running baum welch for the fit method, I'd like to iterate through any given sample without updating parameter updates for the transition edges

@jmschrei
Copy link
Owner

If you pass in inertia=1.0 to your HMM you should be able to freeze the transition matrix while keeping the emissions unfrozen.

@gajarajv
Copy link
Author

gajarajv commented Apr 10, 2024

Thanks, I am trying to populate the SparseHMM object using sparse (CSR) matrix and am unsure if I am using the newer functionality correctly. Right now I have several steps to input each of the values individually as such (where E is a sparse emission matrix w/ shape (20000 ,1000), T is a sparse transition mat with shape (20000x20000):

dists = [Categorical([row]) for row in E.toarray()]
ends = E.max(axis=1).data
ends /= np.sum(ends)
P0 /= np.sum(P0)

rows, cols = T.nonzero()
values = T.data

# Get corresponding distributions with transition prob
index_value_pairs = [(dists[row], dists[col], T[row, col]) for row, col in zip(rows, cols)]

model = SparseHMM(dists,edges =index_value_pairs, starts=P0,ends = ends, max_iter=1, verbose=True, inertia=1.0)

I wasn't able to find anything detailing this scenario on the docs but does this follow your design?

@jmschrei
Copy link
Owner

I think that ends should be derived from T not E right? Not sure where you're getting the first call for P0 but, assuming it's right, I think what you have here is right.

@gajarajv
Copy link
Author

gajarajv commented Apr 10, 2024

Err, right the P0 is uniform. Now that I'm looking at it more carefully I am setting them as the stationary dist of the transition matrix, T.

However I noticed that for some cases problems where I do not initialize the ends variable, the EM algorithm breaks down and produces nan solutions. Any clue as to why the setting of the ends var is very important?

I haven't seen BW implementations that rely on the ending states, so I'm curious on the inspiration for this design.

@jmschrei
Copy link
Owner

I don't think BW here is dependent on the end states but having it can add constraints that make optimization easier. It's hard to diagnose why one might observe the behavior you describe without knowing more about your data or model.

@gajarajv
Copy link
Author

gajarajv commented Apr 10, 2024

From your understanding of your optimizer, can you provide some potential sources from where nans can come from in the code? I am trying to solve for a overcomplete HMM's emission matrix. Investigating, it looks like that some of the sources of the errors might be coming from way we initialize the problem. When the # of states and output possibilities are more reasonable (10 states, 5 observations) it looks like the solver can work well. However if we bring the number of states to 1000, with 20 possible observations (under genericly created conditions), I find that the solver output to look like this:
image

image
Where the prob values in the solver are all NaN.

For more context, here are the relative sizes of each of the objects:
image

I figure I am dividing by zero somewhere in the code where one of the forward backward computes runs into an inf. Is there a way of changing scaling or something along those lines to reduce the likelihood of errors like this?

Also, its not a # of samples issue because I get NaN with # of unique samples at 5 and 10000. Its not an init issue for the KMeans algo or a tolerance issue, because I tried out all the available combinations for those as well.
Thank you so much for the help!

@gajarajv
Copy link
Author

Upon closer inspection with a debugger, it appears the problem is the result of integer underflow during the forward backward algorithm when the emission probabilities are also sparse.

@jmschrei
Copy link
Owner

Can you elaborate a little bit more? Where are you getting integer underflow? The probabilities should all be floats, right?

@gajarajv
Copy link
Author

gajarajv commented Apr 11, 2024

Err, not integer underflow my mistake just general arithmetic underflow or computing an undefined log expression. I'll try and get you a script that can verifiably replicate this problem sometime this week, with inputs as pickled files.

But yes, the emission probabilities (whos rows initialize the categorical distribution class) have multiple zero values contained in them. I suspect this is what causes FB to fail. For example the t, f and b variables that get defined between ln 534-542 in sparse_hmm.py (the function forward_backward) all return matricies of nan values. For context this is for a model based on a 1000 states and 35 observations for randomly generated sparse initial transition and emission probabilities. Looking at the code naively without having debugged properly yet, some of the potential causes of this could be from some edge case in the torch float32 data type when we have small numbers ( line 383 of sparse_hmm.py), or possibly because sparse_hmm.py ln 542: t = torch.exp(torch.logsumexp(t, dim=1).T - logp).T will cause -inf values when we have zeros across dim 1 -- however i've yet to spend the time today to verify this and will get back to you. I'm suspecting this is a 'my data my problem' kind of issue.

@gajarajv
Copy link
Author

Totally forgot to update this ticket but I solved this issue by the following modifications to the forward and backwards:

  1. Initialization of f:
    Use torch.finfo(torch.float64).min to initialize f to the minimum finite value representable by float64 rather than -inf to avoid issues in subsequent calculations.

  2. Log-Sum-Exp Calculation:
    Subtract alpha before the exponential operation (torch.exp(p - alpha)) and clip the result of the exponential operation to avoid zeros.

  3. Zero Clamping in z:
    Clamp z before taking the logarithm to ensure that torch.log(z) does not attempt to compute the log of zero.
    Squeezing alpha:

  4. Corrected the handling of alpha to remove unnecessary dimensions before adding it back into the forward variable calculation.
    These changes should help stabilize the computation by managing extreme values and preventing -inf and NaN issues in the log-space calculations

import torch
def forward(self, X=None, emissions=None, priors=None):

    emissions = self._check_inputs(X, emissions, priors)
    n, l, _ = emissions.shape

    f = torch.full((l, n, self.n_distributions), torch.finfo(torch.float64).min, dtype=torch.float64, 
                   device=self.device)
    f[0] = self.starts + emissions[:, 0]

    for i in range(1, l):
        p = f[i-1, :, self._edge_idx_starts]
        p += self._edge_log_probs.expand(n, -1)

        alpha = torch.max(p, dim=1, keepdims=True).values
        p = p - alpha  # Stabilized values before exp to prevent underflow
        p = torch.exp(p).clamp(min=1e-10)  # Prevents exp(0) leading to exact 0s

        z = torch.zeros_like(f[i])
        z.scatter_add_(1, self._edge_idx_ends.expand(n, -1), p)

        z = z.clamp(min=1e-10)  # Prevents log(0)
        f[i] = alpha.squeeze(1) + torch.log(z) + emissions[:, i]  # Corrected alpha handling

    f = f.permute(1, 0, 2)
    return f

@gajarajv
Copy link
Author

def backward(self, X=None, emissions=None, priors=None):
    emissions = self._check_inputs(X, emissions, priors)
    n, l, _ = emissions.shape

    b = torch.full((l, n, self.n_distributions), torch.finfo(torch.float64).min, dtype=torch.float64,
                   device=self.device)
    b[-1] = self.ends + emissions[:, -1]

    for i in range(l-2, -1, -1):
        p = b[i+1, :, self._edge_idx_ends]
        p += emissions[:, i+1]
        p += self._edge_log_probs.expand(n, -1)

        alpha = torch.max(p, dim=1, keepdims=True).values
        p = p - alpha  # Stabilized values before exp to prevent underflow
        p = torch.exp(p).clamp(min=1e-10)  # Prevents exp(0) leading to exact 0s

        z = torch.zeros_like(b[i])
        z.scatter_add_(1, self._edge_idx_starts.expand(n, -1), p)

        z = z.clamp(min=1e-10)  # Prevents log(0)
        b[i] = alpha.squeeze(1) + torch.log(z)  # Corrected alpha handling

    b = b.permute(1, 0, 2)
    return b

@gajarajv
Copy link
Author

If you think this is worth a PR, I can set that up

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants