You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When I run the example on flowtorch.ai with a 1D distribution:
import torch
import torch.distributions as dist
import flowtorch
import flowtorch.bijectors as bijectors
# Lazily instantiated flow plus base and target distributions
flow = bijectors.AffineAutoregressive(
flowtorch.params.DenseAutoregressive()
)
base_dist = dist.Normal(torch.zeros(1), torch.ones(1))
target_dist = dist.Normal(torch.zeros(1)+5, torch.ones(1))
# Instantiate transformed distribution and parameters
new_dist, params = flow(base_dist)
# Training loop
opt = torch.optim.Adam(params.parameters(), lr=1e-3)
for idx in range(501):
opt.zero_grad()
# Minimize KL(p || q)
y = target_dist.sample((1000,))
loss = -new_dist.log_prob(y).mean()
if idx % 100 == 0:
print('epoch', idx, 'loss', loss)
loss.backward()
opt.step()
sns.relplot(
data=pd.DataFrame(new_dist.sample((100,)).detach().numpy()),
x=0, y=1
)
The loss goes to NaN unless the learning rate is set extremely low (1e-15 gives sensible results).
Removing the call to self._init_weights in DenseAutoregressiveresolves the issue and allows a more reasonable1e-3` learning rate.
The text was updated successfully, but these errors were encountered:
feynmanliang
changed the title
Extremely low learning rates required for 1D target distributions
DenseAutoregresive.init_weights causes unstable learning for 1D target distributions
Feb 18, 2021
When I run the example on flowtorch.ai with a 1D distribution:
The loss goes to NaN unless the learning rate is set extremely low (
1e-15
gives sensible results).Removing the call to
self._init_weights
in DenseAutoregressiveresolves the issue and allows a more reasonable
1e-3` learning rate.The text was updated successfully, but these errors were encountered: