Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] GMM (and normal distribution) fitting doesn't respect frozen parameters #1054

Open
NicholasClark opened this issue Jul 27, 2023 · 2 comments

Comments

@NicholasClark
Copy link

I am trying to fit a mixture model of two normal distributions where I freeze the means at 4 and 1.5 and only fit the variances.
When I use GeneralMixtureModel, it changes the means (fitted means are 3.98 and 1.47) when it fits to the data anyway.
I notice the same issue if I try to fit just one Normal distribution and freeze the mean.
I may be doing something wrong, but I've tried it a number of different ways at this point.

Any help would be highly appreciated!

Here is code to reproduce the issue:

import seaborn
import torch
from pomegranate.gmm import GeneralMixtureModel
from pomegranate.distributions import *
import numpy as np
import matplotlib.pyplot as plt

### Generate data for mixture model
np.random.seed(0)
X = np.concatenate([np.random.normal(4, 0.5, size=400),
                    np.random.normal(1.5, 0.5, size=600)])
XX = np.array(X).reshape(-1,1)
XX = torch.tensor(XX).float()
### Fit mixture model and freeze the mean of each distribution
m1 = torch.tensor([4]) ### mean = 4
m2 = torch.tensor([1.5]) ### mean = 1.5
m1.frozen=True
m2.frozen=True
d1 = Normal(means=m1)
d2 = Normal(means=m2)
model = GeneralMixtureModel([d1, d2], verbose=False).fit(XX)
### plot results
x = np.arange(np.min(X), np.max(X), 0.1)
y1 = model.distributions[0].probability(x.reshape(-1, 1))
y2 = model.distributions[1].probability(x.reshape(-1, 1))
y3 = model.probability(x.reshape(-1, 1))
plt.figure(figsize=(6, 3))
plt.hist(X, density=True, bins=30)
plt.plot(x, y1, color = "green", label="Normal1")
plt.plot(x, y2, color = "red", label="Normal2")
plt.plot(x, y3, color = "purple", label="Mixture")
plt.legend(loc=(1.05, 0.4))
plt.tight_layout()
print("mean of Normal1: " + str(round(model.distributions[0].means.item(), 2)))
print("mean of Normal2: " + str(round(model.distributions[1].means.item(), 2)))

histogram_mixture_means

@ShaolinXU
Copy link

I think the finest control is the distribution that you define.

I managed to modify the Normal.py from the source code to frozen the means as follows:

remove _update_parameter from def from_summaries(self):

Please point me out if I did it wrong

@jmschrei
Copy link
Owner

Hi @NicholasClark.

Sorry for the late reply. You are correct that you can freeze individual parameters but you have to do it in a specific way to get it to stick.

First, you added the frozen attribute to the underlying tensor and when this gets taken into the Normal object it gets wrapped into a torch.nn.Parameter, so the frozen attribute is still attached to the tensor but not d.means (since it's a parameter object). You can solve this by adding frozen to the parameters you want frozen after creating the object.

Second, pomegranate does not allow you to incompletely specify distributions as starting points. This should probably raise a warning when it happens. So, what happened is that the distribution did not register as being initialized and so was overwritten in the first step of fitting a GMM. You can get around this by putting some value into covs to completely specify it.

This code works for me and keeps the means frozen. I took out the plotting stuff just because it wasn't relevant for me.

from pomegranate.gmm import GeneralMixtureModel
from pomegranate.distributions import *
import numpy as np
import matplotlib.pyplot as plt

### Generate data for mixture model
np.random.seed(0)
X = np.concatenate([np.random.normal(4, 0.5, size=400),
                    np.random.normal(1.5, 0.5, size=600)])
XX = np.array(X).reshape(-1,1)
XX = torch.tensor(XX).float()
### Fit mixture model and freeze the mean of each distribution
m1 = torch.tensor([4]) ### mean = 4
m2 = torch.tensor([1.5]) ### mean = 1.5

d1 = Normal(means=m1, covs=[1], covariance_type='diag')
d2 = Normal(means=m2, covs=[1], covariance_type='diag')

d1.means.frozen = True
d2.means.frozen = True

model = GeneralMixtureModel([d1, d2], verbose=True).fit(XX)

print(model.distributions[0].means.frozen)
print("mean of Normal1: " + str(round(model.distributions[0].means.item(), 2)))
print("mean of Normal2: " + str(round(model.distributions[1].means.item(), 2)))

When I run this it gives me:

[1] Improvement: 35.6845703125, Time: 0.0005453s
[2] Improvement: 1.496337890625, Time: 0.0005322s
[3] Improvement: 0.07568359375, Time: 0.0005288s
True
mean of Normal1: 4
mean of Normal2: 1.5

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants