Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test problems with JAX #768

Open
matteobachetti opened this issue Oct 5, 2023 · 3 comments
Open

Test problems with JAX #768

matteobachetti opened this issue Oct 5, 2023 · 3 comments
Assignees

Comments

@matteobachetti
Copy link
Member

matteobachetti commented Oct 5, 2023

@dhuppenkothen @Gaurav17Joshi

  1. we are getting this new deprecation warning. Better to fix it asap, so that we have a stable enough API

    DeprecationWarning: jax.linear_util.transformation is deprecated. Use jax.extend.linear_util.transformation instead.

  2. new test problem:

ERROR ../../.tox/py311-test-alldeps-cov/lib/python3.11/site-packages/stingray/modeling/tests/test_gpmodeling.py::TestGPResult::test_sample - jax.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float64[5] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was sample_U at /home/runner/work/stingray/stingray/.tox/py311-test-alldeps-cov/lib/python3.11/site-packages/jaxns/model.py:47 traced for jit.
@matteobachetti
Copy link
Member Author

Ok, it's from a dependency. I'm fixing it in #763 by adding an ignore line to setup.cfg

@matteobachetti matteobachetti changed the title DeprecationWarning in JAX Test problems with JAX Oct 5, 2023
@matteobachetti
Copy link
Member Author

I'm shutting down the second by setting xfail on the gpmodeling tests. This needs to be addressed, maybe in #767

@Gaurav17Joshi
Copy link
Contributor

I think the second error is happening in the Nested sampling step (from line 286 to 295 ) of the test_gpmodeling, as the issue says the leak happens in the function sample_U, which is a function used by JAXNs while sampling.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants