Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Outer JIT compilation time could be optimized #9

Open
EiffL opened this issue Nov 22, 2022 · 9 comments
Open

Outer JIT compilation time could be optimized #9

EiffL opened this issue Nov 22, 2022 · 9 comments

Comments

@EiffL
Copy link

EiffL commented Nov 22, 2022

In this example:
https://gist.github.com/EiffL/8e46d261e5d52cd28ca81e233fef9b04

It takes 3 mins for the first evaluation of the model to run, but just a few seconds in the second run.

@modichirag has also been able to check that the compilation time is a function of the number of steps. This would indicate that the code is building an overly complex computational graph including explicitly each step of the nbody.

I suspect this is due to using a python for loop in the nbody function. Probably things would improve a lot if it were replaced with a lax.scan

@eelregit
Copy link
Owner

eelregit commented Nov 22, 2022

You can remove the jit in the gist.

We have discussed this. This is part of the reason that pmwd already does jit for you.

It was a scan before but that was slower than a for loop. Maybe you can try again and see if it gets faster.
With the current implementation, we can easily save a snapshot between two jitted steps.
IO is side effect. Not sure how it can be done with jit and scan.

The down side is that the backprops of growth, modes, and lpt are not jitted in this approach.
There's a workaround to split those and the following nbody into part1 and 2 of the model,
the but it's a bit cumbersome.

@eelregit
Copy link
Owner

The quickstart notebook has some timing.

@EiffL
Copy link
Author

EiffL commented Nov 22, 2022

Oh right right, well the problem is that ultimately there will need to be a jit outside of pmwd. For instance in an hmc sampler or as part a of a larger simulation model that also includes additional computational layers. And concretely for me it's good to have my distributed code inside a jit (not required though).

Ok I'll see if I can think of how to do the export with a scan. But I can't imagine it being that much slower? Can you remember by how much it changed things?

@EiffL EiffL changed the title Very slow JIT compilation time Outer JIT compilation time could be optimized Nov 22, 2022
@eelregit
Copy link
Owner

eelregit commented Nov 22, 2022

I don't see why HMC requires a top-level jit. Maybe it's more convenient that way. But why mandatory?

And what do you have in mind for the the larger model? If it's a really big model like almost any NNs,
top-level jitting would see this problem again even if pmwd is leaner?

I remember it was 20-30%. And I think one cannot export inside jit.

@eelregit
Copy link
Owner

eelregit commented Nov 22, 2022

I wonder if it's still slow if you jit nbody_step first, and then the whole model function.

@EiffL
Copy link
Author

EiffL commented Nov 22, 2022

When you use an external sampler like numpyro or TFP, it will usually compile all the logic of the hmc kernel, including the evaluation of the log likelihood. It may be possible to disable that (and I agree it's not in principle necessary), but by default in JAX the user expects to be able to jit their code without knowledge of the underlying implementation.

20/30% sounds like a lot yeahhh.... I guess it was for small size problems though, but still I see the reason for this tradeoff if that is that bad.

And otherwise, yeah I agree that saving snapshots to disk from within a jitted function wouldn't be super trivial. But if you are doing things on the fly that's probably not super important. I can see though that maybe you want to avoid the memory cost of storing intermediate snapshots....

@eelregit
Copy link
Owner

eelregit commented Nov 22, 2022

Some of those users are likely already used to the compilation speed. I have heard complaints about JAX taking minutes to compile NNs.

I don't know why JAX cannot get cache hits on nbody_step after unrolling the for loop and compiling nbody_step the first time. Here's a related issue: google/jax#284

@eelregit
Copy link
Owner

A discussion that mentions adding lower level jit helps with compilation time
google/jax#10104

@eelregit
Copy link
Owner

I think saving snapshots are also important for normal use cases, like generating mocks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants