-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remat uses inefficient custom forward kernel #21303
Comments
Thanks for raising this! Something not represented in your minimal repro: because it doesn't actually use custom primitives/kernels, and instead calls regular JAX functions, the cosine will actually get DCE'd (at the JAX level, not the XLA level): import jax
import jax.numpy as jnp
@jax.custom_vjp
def f(x):
jax.debug.print('f')
return jnp.sin(x)
def f_fwd(x):
jax.debug.print('f_fwd')
return jnp.sin(x), (x,)
def f_bwd(res, g):
jax.debug.print('f_bwd')
x, = res
cos_x = jnp.cos(x) # do the cosine here
return (cos_x * g,)
f.defvjp(f_fwd, f_bwd)
def temp(x):
out = jax.remat(f)(x)
out = out ** 2
return out
print(jax.make_jaxpr(jax.grad(temp))(3.2)) # CHANGED
There aren't two cosines! In the custom kernel case, that DCE won't happen. At least not by itself. But a simple fix here might just be to teach JAX's DCE about your custom primitives. In particular, if you have a custom primitive like In your real example, do you have primitives like this? E.g. one for the primal computation Basically, to see if some fast solutions are feasible, we should make your example program model this aspect of your real program. The DCE approach is basically a way to clean up the mess after we've made it. We also have an idea to avoid making the mess in the first place, but it requires some bigger JAX autodiff changes. So I'd like to see if a quick DCE solution will get you un-stuck. WDYT? |
Here's an example of what I mean: import jax
import jax.numpy as jnp
from jax._src import core
f_p = core.Primitive('f')
f_fwd_p = core.Primitive('f_fwd')
f_bwd_p = core.Primitive('f_bwd')
f_fwd_p.multiple_results = f_bwd_p.multiple_results = True
@f_p.def_abstract_eval
def f(x): return x
@f_fwd_p.def_abstract_eval
def f_fwd(x):
return x, x
@f_bwd_p.def_abstract_eval
def f_fwd(x, g):
return g,
@jax.custom_vjp
def f(x):
return f_p.bind(x)
def f_fwd(x):
return f_fwd_p.bind(x)
def f_bwd(res, g):
return (*f_bwd_p.bind(res, g),)
f.defvjp(f_fwd, f_bwd)
def temp(x):
out = jax.remat(f)(x)
out = out ** 2
return out
jaxpr = jax.make_jaxpr(jax.grad(temp))(3.2)
print(jaxpr)
Notice we're calling But here's how we can set up a DCE rule: from jax._src.interpreters import partial_eval as pe
def f_fwd_dce(used_outs: list[bool], eqn):
_, res_used = used_outs
if res_used:
return [True], eqn
else:
new_eqn = eqn.replace(outvars=eqn.outvars[:1], primitive=f_p)
return [True], new_eqn
pe.dce_rules[f_fwd_p] = f_fwd_dce
print(pe.dce_jaxpr(jaxpr.jaxpr, [True] * len(jaxpr.out_avals))[0])
|
That's interesting. I didn't know it was possible. We are currently using A concrete example would be a Pallas kernel equivalent to the following:
The inference version of the kernel will output |
|
It may also be possible to skip additional primitive-wrapping; we could also register a DCE rule for |
Description
When defining custom kernels, there are three distinct kernels for a
jax.custom_vjp
:f
,f_fwd
,f_bwd
. When inside ajax.vjp
andjax.remat
, all three kernels should called: firstf
, thenf_fwd
, thenf_bwd
. Instead,f_fwd
is called twice andf_bwd
is called once:prints
This is not a problem with
f
,f_fwd
defined using standard JAX functions due to XLA DCE. With custom kernels,f_fwd
will be indeed be called twice, which is generally slower thanf
as it must move residuals/intermediates to GPU DRAM, which can be very slow.This impacts all JAX custom kernel libraries.
@mattjj, @chr1sj0nes
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: