-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Missing JAX implementation of Eigvalsh #651
Comments
@ckrapu a PR to implement the JAX dispatch would be very helpful. The only thing about your implementation is that it should fail with a The second argument to the dispatch function is |
Also it may be useful to add a helper function with the same signature as the scipy one that creates the Op with the lower flag and converts a Python |
Is this issue still open? I would like to work on this |
I think so, but the best way to answer that question is to create a minimal PyTensor graph that uses Eigvalsh and confirming it fails on JAX mode with NotImplementedError |
@ricardoV94 pytensor/pytensor/compile/function/pfunc.py Lines 660 to 662 in 4d4bd16
|
Those are two things. There are many constants in PyTensor graphs. It's just that when compiling a function you can't say you'll use a constant as an input, those are always implicit inputs |
Describe the issue:
TL;DR: missing op for CAR on jax backend on PyMC requires small tweak to CAR logp, new jax op in PyTensor.
If you try to use pm.CAR with
nuts_sampler='numpyro'
(Colab example here) you'll run into an error:NotImplementedError: No JAX conversion for the given
Op: Eigvalsh{lower=True}
This error arises because:
jax.numpy
version ofeigvalsh
doesn't have an entry for the linker to use. The only place we use eigvalsh is for the CAR logp.eigvalsh
op taking only a single argument (matching the jax.numpy signature), the logp call mentioned earlier will complain because it supplies 2 arguments instead of 1.Note that the
jax
version ofeigvalsh
only takes a single arraya
while the PyTensor and SciPy allow for two array arguments,a,b
. The logic already exists to toggle between the two signatures ifb
is a None tensor constant, so to get the whole thing to work we just need to tweak the logp call to use the None tensor instead of an identity matrix, i.e. (lam = pt.linalg.eigvalsh(DWD, NoneConst)
instead oflam = pt.linalg.eigvalsh(DWD, np.eye(DWD.shape[0]))
and also create the target in PyTensor. As far as I can tell, the PyTensor addition topytensor.link.jax.dispatch.slinalg
should look like this:Any feedback on this would be greatly appreciated as I'd like to do a PR for it.
Reproducable code example:
Error message:
PyTensor version information:
Context for the issue:
Blocker for using CAR on Jax backend
The text was updated successfully, but these errors were encountered: