-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds batch/event shape test and fix failure #29
base: master
Are you sure you want to change the base?
Conversation
@@ -11,7 +11,7 @@ | |||
|
|||
|
|||
class AffineAutoregressive(flowtorch.Bijector): | |||
event_dim = 1 | |||
event_dim = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The batch_dim erroneously gets reinterpreted into an event_dim when this is 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, shouldn't this be by design? The AffineAutoregressive
makes the final dimension of a random variable dependent, so will introduce correlations if the base distribution's final dimension is a batch dim
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should AffineAutoregressive
(or any Bijector
) continue across batch dimension? Across batch dimension, samples are independent (https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/Distribution#shapes_2) so I'm inclined to say a completely separate bijector (which cannot use autoregressive predictors from the previous batch instance) should be instantiated for a total of prod(batch_shape)
bijectors.
This is similar to sample_dim, except there we re-use the same Bijector
due to IID whereas here samples across batch_dim are not identically distributed.
NOTE: it may be incorrect to remove
event_dim=1
from the bijector; @stefanwebb to confirmThis change is