Skip to content

Can we call a non-jit function from a jit function? #21287

Answered by jakevdp
Qazalbash asked this question in Q&A
Discussion options

You must be logged in to vote

To answer the question in your title, yes this is possible via jax.pure_callback; see Exploring pure_callback for an example of how you might use it. The host-side computation there can be anything, including a call to a non-jit-compiled JAX function.

Answering your broader question: depending on the details of your code, I suspect a better approach here would be to avoid JIT-compiling your entire workflow, and instead compile small sections of it, and then execute those JIT-compiled sections in sequence.

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@Qazalbash
Comment options

Answer selected by Qazalbash
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants