Skip to content

Confused about JIT-compilation of nested functions #21297

Answered by dfm
dfdx asked this question in Q&A
Discussion options

You must be logged in to vote

That's what I expected! Yeah, like you say, the usual advice here would be to move the jit as high up the stack as a you can. For example, in the flax examples that you link to, the jit is applied to the training step, e.g.:

@jax.jit
def outer(...):
  def inner(...):
    ...
  inner(...)

in which case inner is only compiled once!

But, there are cases where this won't necessarily work (e.g. long compile times, etc.). In that case, maybe you could try converting the closure into a compiled function (at the global level) which takes the relevant parameters as static arguments, which should also lead to a cache hit.

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@dfdx
Comment options

@dfm
Comment options

Answer selected by dfdx
@dfdx
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants