-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds failing test case for batched base distribution #18
base: master
Are you sure you want to change the base?
Conversation
|
||
|
||
def test_batched_base_distribution(affine_ar_bijector): | ||
base_dist = dist.Normal(torch.zeros(1), torch.ones(1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This causes base_dist.batch_shape == torch.Size([1,])
so you get a [100, 1, 1]
shaped sample from base_dist
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@feynmanliang hmm, I get a different result here.. I get a [100, 1]
shaped sample from the transformed distribution... Could you share some code to reproduce?
On the other hand, for base_dist = dist.Normal(0, 1)
I get an error since base_dist.batch_shape == torch.Size([])
and it tries to input a torch.Size([100])
into the MLP from the base distribution.
After thinking about this a while, I think we have a solution. Just to summarize what we discussed, there will be a |
There's been some upstream work to PyTorch/Pyro building on our discussion on the Transform/Bijector interface (thanks @fritzo @fehiepsi et al.): pytorch/pytorch#50581 We should integrate these changes before we proceed with shape tests |
In pyro-ppl/pyro#2753 I am updating @stefanwebb's original flow implementations in Pyro; hopefully you will be able to adapt some of those changes here. |
Closes #28