Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TransformedDistribution.log_prob return shape incompatible with batch dimension semantic #11

Open
feynmanliang opened this issue Dec 17, 2020 · 1 comment

Comments

@feynmanliang
Copy link
Collaborator

feynmanliang commented Dec 17, 2020

When I run

import simplex.bijectors
import simplex.params

d, params = simplex.bijectors.AffineAutoregressive(
    simplex.params.DenseAutoregressive())(dist.Normal(0, 1))
d.log_prob(torch.zeros((10,1))).shape

I expect the shape of the log_prob of a (sample_shape=[10], batch_shape=[], event_shape=[1]) sample+distribution to have shape [10].

Actual behavior: Result is of shape [10, 10]

RC: TransformedDistribution.log_prob is broadcasting a summation over a row / column vector

@feynmanliang feynmanliang changed the title Bijector.log_prob return shape incompatible with batch dimension semantic TransformedDistribution.log_prob return shape incompatible with batch dimension semantic Dec 17, 2020
@stefanwebb
Copy link
Owner

Thanks for catching this! I think unit tests for correct output shapes are first on the table :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants