Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax bug in linear_modes #27

Open
maho3 opened this issue Feb 22, 2024 · 2 comments
Open

Jax bug in linear_modes #27

maho3 opened this issue Feb 22, 2024 · 2 comments

Comments

@maho3
Copy link

maho3 commented Feb 22, 2024

On the master branch, this line causes the following error:

  File "/home/mattho/git/ltu-cmass/cmass/nbody/pmwd.py", line 74, in run_density
    ic = linear_modes(wn, pmcosmo, pmconf)
ValueError: the `static_argnums` argument to `jax.checkpoint` / `jax.remat` can only take integer values greater than or equal to `-len(args)` and less than `len(args)`, but got (4,)

It seems @eelregit pointed this out as a jax bug, but it hasn't been resolved yet.

Could there be a work around temporarily pushed to master? As this is currently breaking the master branch.

@eelregit
Copy link
Owner

Hi Matt,

Thanks for pointing out that JAX issue. And sorry for the delay.
Have you tried to use the positional argument workaround pointed out there?

@maho3
Copy link
Author

maho3 commented Feb 26, 2024

I actually ended up implementing the same hack used in your sto branch and it works for me.

If you want, I can make a PR for this small thing to be put on master.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants