Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remat uses inefficient custom forward kernel #21303

Open
sbodenstein opened this issue May 20, 2024 · 5 comments
Open

Remat uses inefficient custom forward kernel #21303

sbodenstein opened this issue May 20, 2024 · 5 comments
Assignees
Labels
bug Something isn't working

Comments

@sbodenstein
Copy link

Description

When defining custom kernels, there are three distinct kernels for a jax.custom_vjp: f, f_fwd, f_bwd. When inside a jax.vjp and jax.remat, all three kernels should called: first f, then f_fwd, then f_bwd. Instead, f_fwd is called twice and f_bwd is called once:

@jax.custom_vjp
def f(x):
  jax.debug.print('f')
  return jnp.sin(x)

def f_fwd(x):
  jax.debug.print('f_fwd')
  return jnp.sin(x), (x,)

def f_bwd(res, g):
  jax.debug.print('f_bwd')
  x, = res
  cos_x = jnp.cos(x)  # do the cosine here
  return (cos_x * g,)

f.defvjp(f_fwd, f_bwd)

def temp(x):
  out = jax.remat(f)(x)
  out = out ** 2
  return out

jax.jit(jax.grad(temp))(3.2)

prints

# f_bwd
# f_fwd
# f_fwd

This is not a problem with f, f_fwd defined using standard JAX functions due to XLA DCE. With custom kernels, f_fwd will be indeed be called twice, which is generally slower than f as it must move residuals/intermediates to GPU DRAM, which can be very slow.

This impacts all JAX custom kernel libraries.

@mattjj, @chr1sj0nes

System info (python version, jaxlib version, accelerator, etc.)

import jax; jax.print_environment_info()
jax:    0.4.29
jaxlib: 0.4.29
numpy:  1.26.3
python: 3.11.8
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: ...
@sbodenstein sbodenstein added the bug Something isn't working label May 20, 2024
@mattjj mattjj self-assigned this May 22, 2024
@mattjj
Copy link
Member

mattjj commented May 22, 2024

Thanks for raising this!

Something not represented in your minimal repro: because it doesn't actually use custom primitives/kernels, and instead calls regular JAX functions, the cosine will actually get DCE'd (at the JAX level, not the XLA level):

import jax
import jax.numpy as jnp

@jax.custom_vjp
def f(x):
  jax.debug.print('f')
  return jnp.sin(x)

def f_fwd(x):
  jax.debug.print('f_fwd')
  return jnp.sin(x), (x,)

def f_bwd(res, g):
  jax.debug.print('f_bwd')
  x, = res
  cos_x = jnp.cos(x)  # do the cosine here
  return (cos_x * g,)

f.defvjp(f_fwd, f_bwd)

def temp(x):
  out = jax.remat(f)(x)
  out = out ** 2
  return out

print(jax.make_jaxpr(jax.grad(temp))(3.2))  # CHANGED
{ lambda ; a:f32[]. let
    debug_callback[
      callback=<function debug_callback.<locals>._flat_callback at 0x7f5f56355750>
      effect=Debug
    ]
    b:f32[] = sin a
    _:f32[] = integer_pow[y=2] b
    c:f32[] = integer_pow[y=1] b
    d:f32[] = mul 2.0 c
    e:f32[] = mul 1.0 d
    f:f32[] = remat2[
      differentiated=True
      jaxpr={ lambda ; g:f32[] h:f32[]. let
          debug_callback[
            callback=<function debug_callback.<locals>._flat_callback at 0x7f5f56355750>
            effect=Debug
          ]
          _:f32[] = sin g
          debug_callback[
            callback=<function debug_callback.<locals>._flat_callback at 0x7f5f56355bd0>
            effect=Debug
          ]
          i:f32[] = cos g
          j:f32[] = mul i h
        in (j,) }
      policy=None
      prevent_cse=True
    ] a e
  in (f,) }

There aren't two cosines!

In the custom kernel case, that DCE won't happen. At least not by itself.

But a simple fix here might just be to teach JAX's DCE about your custom primitives. In particular, if you have a custom primitive like f_fwd_p, we could set its DCE rule to check if residuals are unused, and if so then just use f_p (i.e. the primal primitive) to compute the needed outputs. DCE rules are very simple to register for primitives.

In your real example, do you have primitives like this? E.g. one for the primal computation f, one for f_fwd, and one for f_bwd? Or is your setup different, e.g. with pallas_calls?

Basically, to see if some fast solutions are feasible, we should make your example program model this aspect of your real program.

The DCE approach is basically a way to clean up the mess after we've made it. We also have an idea to avoid making the mess in the first place, but it requires some bigger JAX autodiff changes. So I'd like to see if a quick DCE solution will get you un-stuck.

WDYT?

@mattjj
Copy link
Member

mattjj commented May 22, 2024

Here's an example of what I mean:

import jax
import jax.numpy as jnp

from jax._src import core

f_p = core.Primitive('f')
f_fwd_p = core.Primitive('f_fwd')
f_bwd_p = core.Primitive('f_bwd')

f_fwd_p.multiple_results = f_bwd_p.multiple_results = True

@f_p.def_abstract_eval
def f(x): return x

@f_fwd_p.def_abstract_eval
def f_fwd(x):
  return x, x

@f_bwd_p.def_abstract_eval
def f_fwd(x, g):
  return g,


@jax.custom_vjp
def f(x):
  return f_p.bind(x)

def f_fwd(x):
  return f_fwd_p.bind(x)

def f_bwd(res, g):
  return (*f_bwd_p.bind(res, g),)

f.defvjp(f_fwd, f_bwd)

def temp(x):
  out = jax.remat(f)(x)
  out = out ** 2
  return out

jaxpr = jax.make_jaxpr(jax.grad(temp))(3.2)
print(jaxpr)
{ lambda ; a:f32[]. let
    b:f32[] _:f32[] = f_fwd a  # <----- NOTE HERE
    _:f32[] = integer_pow[y=2] b
    c:f32[] = integer_pow[y=1] b
    d:f32[] = mul 2.0 c
    e:f32[] = mul 1.0 d
    f:f32[] = remat2[
      differentiated=True
      jaxpr={ lambda ; g:f32[] h:f32[]. let
          _:f32[] i:f32[] = f_fwd g
          j:f32[] = f_bwd i h
        in (j,) }
      policy=None
      prevent_cse=True
    ] a e
  in (f,) }

Notice we're calling f_fwd twice, and in the primal pass call we're dropping its second output.

But here's how we can set up a DCE rule:

from jax._src.interpreters import partial_eval as pe
def f_fwd_dce(used_outs: list[bool], eqn):
  _, res_used = used_outs
  if res_used:
    return [True], eqn
  else:
    new_eqn = eqn.replace(outvars=eqn.outvars[:1], primitive=f_p)
    return [True], new_eqn
pe.dce_rules[f_fwd_p] = f_fwd_dce

print(pe.dce_jaxpr(jaxpr.jaxpr, [True] * len(jaxpr.out_avals))[0])
{ lambda ; a:f32[]. let
    b:f32[] = f a  # <----- NOTE HERE
    c:f32[] = integer_pow[y=1] b
    d:f32[] = mul 2.0 c
    e:f32[] = mul 1.0 d
    f:f32[] = remat2[
      differentiated=True
      jaxpr={ lambda ; g:f32[] h:f32[]. let
          _:f32[] i:f32[] = f_fwd g
          j:f32[] = f_bwd i h
        in (j,) }
      policy=None
      prevent_cse=True
    ] a e
  in (f,) }

@chr1sj0nes
Copy link
Member

That's interesting. I didn't know it was possible. We are currently using triton_call or pallas_call, but it might be possible to wrap them in primitives and use DCE.

A concrete example would be a Pallas kernel equivalent to the following:

def f(x, w):
  y = x @ w
  z = jax.nn.gelu(y)
  return z

The inference version of the kernel will output z only, but when training, we need to output both y and z.

@sbodenstein
Copy link
Author

triton_call is already wrapped as a primitive. Let's try fix it at the Pallas/jax-triton level.

@mattjj
Copy link
Member

mattjj commented May 23, 2024

It may also be possible to skip additional primitive-wrapping; we could also register a DCE rule for triton_call and pallas_call that dispatches to a table of handlers indexed by the callables passed to those calls, or something.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants