Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

arviz.from_numpyro seems to ignore log_likelihood argument #2196

Open
lumip opened this issue Jan 21, 2023 · 2 comments
Open

arviz.from_numpyro seems to ignore log_likelihood argument #2196

lumip opened this issue Jan 21, 2023 · 2 comments

Comments

@lumip
Copy link

lumip commented Jan 21, 2023

Describe the bug
The arviz.from_numpyro method has an argument log_likelihood which seems to enable the user to provide a dictionary(?) of likelihoods for posterior samples. However, it seems to be entirely ignored in favour of likelihoods being computed from the provided model instead. (However, I may be using it wrongly as there appears to be no documentation of it; if so, please point that out to me).

To Reproduce
Consider the following code snippet:

import numpyro.distributions as dists
import numpyro
import jax
import jax.numpy as jnp
import numpy as np
import arviz as az

def model(xs, ys):
    mu = numpyro.sample("mu", dists.Normal(0., 1.))
    with numpyro.plate("xs_obs", len(xs) if xs is not None else 1):
        xs_dist = dists.Normal(mu, 1.)
        xs = numpyro.sample("xs", xs_dist, obs=xs)
    with numpyro.plate("ys_obs", len(ys) if ys is not None else 1):
        ys_dist = dists.Normal(mu, 4.)
        ys = numpyro.sample("ys", ys_dist, obs=ys)

    ll = jnp.concatenate([xs_dist.log_prob(xs), ys_dist.log_prob(ys)])
    numpyro.deterministic("ll", ll)

xs = np.random.randn(1000) + 5.
ys = np.random.randn(1000) * 4. + 5.

nuts = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(nuts, num_chains=4, num_warmup=100, num_samples=2000)
mcmc.run(jax.random.PRNGKey(782385), xs, ys)

samples = mcmc.get_samples(group_by_chain=True)
ll = { 'xs_ys': samples['ll'] }

idata = az.from_numpyro(mcmc, log_likelihood=ll)
print(idata['log_likelihood'])
assert list(idata['log_likelihood'].data_vars) == ["xs_ys"]

This prints

<xarray.Dataset>
Dimensions:   (chain: 4, draw: 2000, xs_dim_0: 1000, ys_dim_0: 1000)
Coordinates:
  * chain     (chain) int64 0 1 2 3
  * draw      (draw) int64 0 1 2 3 4 5 6 ... 1993 1994 1995 1996 1997 1998 1999
  * xs_dim_0  (xs_dim_0) int64 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
  * ys_dim_0  (ys_dim_0) int64 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
Data variables:
    xs        (chain, draw, xs_dim_0) float32 -2.074 -1.25 ... -1.283 -1.435
    ys        (chain, draw, ys_dim_0) float32 -4.216 -2.825 ... -5.277 -2.749
Attributes:
    created_at:                 2023-01-21T14:29:16.129290
    arviz_version:              0.14.0
    inference_library:          numpyro
    inference_library_version:  0.10.1

Expected behavior
idata.log_likelihood should have a single data variable xs_ys with the contents of samples['ll'], i.e., arviz.from_numpyro should use the values provided via the log_likelihood instead of re-computing the log-likelihoods from the model.

Additional context

arviz==0.14.0
numpyro==0.10.1
jax==0.4.1
jaxlib==0.4.1
@OriolAbril
Copy link
Member

The log_likelihood parameter can only be a boolean (True or False, or None if you want the value in rcParams to be used). If you want custom values in the log likelihood group that differ from what is defined by the model you should either modify those after calling from_numpyro or use loglikelihood=False and then InferenceData.add_groups to create the group later on (as done for example in https://python.arviz.org/en/stable/user_guide/numpyro_refitting_xr_lik.html). Hope it helps.

The docstring needs updating though to include the log_likelihood parameter, its type info and description. Do you want to create a PR for this?

@lumip
Copy link
Author

lumip commented Jan 23, 2023

Alright, thanks for pointing that out. I will work around it like you suggested with the add_groups after creating the InferenceData object for now. I'm not sure I have the time currently to work on PRs, but I can see if I get around to it some time soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants