Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

odeint is slow #5

Open
modichirag opened this issue May 5, 2022 · 16 comments
Open

odeint is slow #5

modichirag opened this issue May 5, 2022 · 16 comments
Assignees

Comments

@modichirag
Copy link
Collaborator

I was experimenting with some time tests and find that odeint to calculate the growth functions is quite slow.
I have tried to hack and replace it with rk4 integration in the growth function itself which seems to be much faster.

    ode_jit = jit(ode)
    def rk4_ode_jit(carry, t):
        y, t_prev = carry
        h = t - t_prev
	k1 = ode_jit(y, t_prev, cosmo)
	k2 = ode_jit(y + h * k1 / 2, t_prev + h / 2, cosmo)
        k3 = ode_jit(y + h * k2 / 2, t_prev + h / 2, cosmo)
	k4 = ode_jit(y + h * k3, t, cosmo)
        y = y + 1.0 / 6.0 * h * (k1 + 2 * k2 + 2 * k3 + k4)
        return (y, t), y

    (yf, _), G = scan(rk4_ode_jit, (G_ic, lna[0]), lna)

Then I do time tests for 64^3 simulation wherein I pass the cosmology parameters, initial modes as input and calculate time for different outputs (just doing boltzmann solve vs boltzmann + LPT).

@jit
def simulate_boltz(modes, omegam, conf):
    '''Evaluate growth & tranfer function with odeint                                                                                                                                                                                                                                                                         
    '''
    cosmo = SimpleLCDM(conf, Omega_m=omegam)
    cosmo = boltzmann(cosmo)
    mesh = None
    return mesh, cosmo

@jit
def simulate_boltz_rk4(modes, omegam, conf):
    '''Evaluate growth & tranfer function with custom rk4                                                                                                                                                                                                                                                                     
    '''
    cosmo = SimpleLCDM(conf, Omega_m=omegam)
    cosmo = boltzmann_rk4(cosmo)
    mesh = None
    return mesh, cosmo

@jit
def simulate(modes, omegam, conf):
    '''Run LPT simulation with evaluating growth & tranfer function with odeint                                                                                                                                                                                                                                               
    '''
    cosmo = SimpleLCDM(conf, Omega_m=omegam)
    cosmo = boltzmann(cosmo)
    ptcl, obsvbl = lpt(modes, cosmo)
    dens = jnp.zeros(conf.mesh_shape, dtype=conf.float_dtype)
    mesh = scatter(ptcl, dens, 1., conf.cell_size, conf.chunk_size)
    return mesh, cosmo

@jit
def simulate_rk4(modes, omegam, conf):
    '''Run LPT simulation with evaluating growth & tranfer function with custom rk4                                                                                                                                                                                                                                           
    '''
    cosmo = SimpleLCDM(conf, Omega_m=omegam)
    cosmo = boltzmann_rk4(cosmo)
    ptcl, obsvbl = lpt(modes, cosmo)
    dens = jnp.zeros(conf.mesh_shape, dtype=conf.float_dtype)
    mesh = scatter(ptcl, dens, 1., conf.cell_size, conf.chunk_size)
    return mesh, cosmo


@jit
def simulate_nbody(modes, cosmo):
    '''Run LPT simulation without evaluating growth & tranfer function                                                                                                                                                                                                                                                        
    '''
    ptcl, obsvbl = lpt(modes, cosmo)
    conf = cosmo.conf
    dens = jnp.zeros(conf.mesh_shape, dtype=conf.float_dtype)
    mesh = scatter(ptcl, dens, 1., conf.cell_size, conf.chunk_size)
    return mesh, cosmo

The time taken for each of these is

Time taken for boltzmann: 0.5971660375595093
Time taken for boltzmann rk4: 0.007928729057312012
Time taken for LPT: 0.0041596412658691405
Time taken for simulation (Boltzmann + LPT): 0.463437557220459
Time taken for simulation rk4 (Boltzmann + LPT): 0.04284675121307373

rk4 seems to be much faster than using odeint to generate growth rate.

If what I am doing in running the simulations is sensible and the timing numbers portray an accurate picture,
then we should figure a better way (jaxified) to code this?

I have attached the full script as txt file (copy paste in pmwd/pmwd folder, convert to py and it should run)
test_growth.txt

@eelregit
Copy link
Owner

eelregit commented May 5, 2022

@modichirag timeit can be dominated by the compilation time, as suggested (in your other email) by the fact that the std's are of the same order as the means. I usually do time to get the compilation + computation time before timeit.

Also output.block_until_ready should be added to the returned results (https://jax.readthedocs.io/en/latest/async_dispatch.html).

Can you check if these help?

The experimental odeint has more problems. I wanted to try replacing it with the diffrax package.

@EiffL
Copy link

EiffL commented May 5, 2022

diffrax is nice for sure, but I don't think it will make a big difference (I may be wrong),

It does give you more control over the ODE solver, and returns interesting info like how many steps have actually been needed

@eelregit
Copy link
Owner

eelregit commented May 5, 2022

Okay, the reason rk4 is about 20x faster here could be that @modichirag is scanning rk4 on nbody time integration steps (here probably 64 in total).
But to get accurate results the adaptive solvers probably used much finer steps.

I guess the bottleneck is due to the slowness of GPUs at solving small but sequential problems,
maybe also partly the experimental odeint and/or my groweth_integ being hard to jit.

With odeint I sometimes get nan. I wanted to try if diffrax can be more stable.

@modichirag
Copy link
Collaborator Author

modichirag commented May 5, 2022

For odeint, I had tried reducing the tolerance (atol and rtol) to 1e-3 instead of 1e-8. It made a difference of factor of 2 instead of 20.
And in either case, I compared the results of odeint with rk4 and found them to be in agreement upto ~1e-4 which I guess is good enough.
I can try setting up diffrax tomorrow.

@eelregit
Copy link
Owner

eelregit commented May 6, 2022

How long are 1e-8 and 1e-3 odeint timeit now?
For me with 1e-8 tol it was ~170ms, still a lot slower than your rk4 number.

Nice. I am surprised that 64 steps is good enough for growth factors.

@modichirag
Copy link
Collaborator Author

I setup diffrax. It seems to be slower than odeint. For 1e-3 tolerance, it takes 1.27s for Boltzmann.

Time for 1e-8 and 1e-3 with odeint is now 300ms and 250ms respeectively. I think it is different from previous numbers since I had other jobs running on GPU. Some other jobs are still running so this number can go down a bit more on a free GPU but it will certainly remain more time consuming than LPT step by a factor of 5-10x which is bad.

@eelregit
Copy link
Owner

eelregit commented May 7, 2022

What's the correct rk4 timing you got? @modichirag

@eelregit
Copy link
Owner

For me with 1e-8 tol it was ~170ms, still a lot slower than your rk4 number.

@modichirag Actually I misremembered this number, 170ms is for (512^3) 2LPT, growth integration takes 10ms.
But with perfect scaling to 128^3, lpt will still be a lot faster than boltzmann.

@modichirag
Copy link
Collaborator Author

modichirag commented May 15, 2022 via email

@modichirag
Copy link
Collaborator Author

modichirag commented May 15, 2022 via email

@eelregit
Copy link
Owner

@modichirag timeit can be dominated by the compilation time, as suggested (in your other email) by the fact that the std's are of the same order as the means. I usually do time to get the compilation + computation time before timeit.

Also output.block_until_ready should be added to the returned results (https://jax.readthedocs.io/en/latest/async_dispatch.html).

Can you check if these help?

@modichirag Okay.... have you tried to separate jit time from running time as I suggested earlier?

I simply did the following in Jupyter

>>> %time boltzmann(cosmo).growth.block_until_ready()

>>> %timeit boltzmann(cosmo).growth.block_until_ready()

CPU times: user 824 ms, sys: 22.4 ms, total: 846 ms
Wall time: 1.18 s
10.2 ms ± 7.33 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

@modichirag
Copy link
Collaborator Author

modichirag commented May 16, 2022

So I am not using jupyter but instead had a python script. So there is no timeit but I was running a loop of 50 iterations and timing that with time.time()
Yes, I have been compiling the function before timing so that it does not factor into the running cost. This is what you call jit time.
In fact compile time is much longer for rk4 so its good we dont have to compile again and again.

Here is the relevant lines of code from the script I had sent earlier (I was not explicitly printing compile time there, but was compiling nevertheless).

    start = time.time()
    _ = boltzmann(cosmo).growth.block_until_ready()
    print("Time to compile boltzmann: ", (time.time() - start)*1000)

    start = time.time()
    _ = boltzmann_rk4(cosmo2, growth_a2).growth.block_until_ready()
    print("Time to compile boltzmann rk4: ", (time.time() - start)*1000)

    start = time.time()
    [boltzmann(cosmo).growth.block_until_ready() for _ in range(niter)]
    print("Time taken for boltzmann: ", (time.time() - start)/niter*1000)

    start = time.time()
    [boltzmann_rk4(cosmo2, growth_a2).growth.block_until_ready() for _ in range(niter)]
    print("Time taken for boltzmann rk4: ", (time.time() - start)/niter*1000)

And the output is

Time to compile boltzmann:  13.38815689086914
Time to compile boltzmann rk4:  646.0096836090088
Time taken for boltzmann:  12.846832275390625
Time taken for boltzmann rk4:  1.648721694946289

Interestingly block_until_ready() does not seem to make that big a difference, it's 1.4ms-vs-1.6ms and 11.5ms-vs-12.5ms

@eelregit
Copy link
Owner

Okay. So you also gets boltzmann ~ 10ms. I think this will fine for most of our target use cases.
I can try different diffrax algorithm and find the best one, but it's not of high priority right now

@eelregit
Copy link
Owner

@Yucheng-Zhang found that it's 10x faster to odeint on CPUs (including move the results to GPU after wards)

902_1677239555_hd

@Yucheng-Zhang
Copy link
Collaborator

@Yucheng-Zhang found that it's 10x faster to odeint on CPUs (including move the results to GPU after wards)

902_1677239555_hd

Same evaluation on GPU
Screenshot 2023-02-24 at 20 32 54

@eelregit
Copy link
Owner

eelregit commented Nov 10, 2023

To speed up growth function integration, I did the following test

def boltz_factory(backend):
    return jax.jit(boltzmann, device=jax.devices(backend)[0])
boltz_gpu = boltz_factory('gpu')
boltz_cpu = boltz_factory('cpu')
        
conf = Configuration(1., (2,) * 3)
cosmo = SimpleLCDM(conf)
cosmo = boltzmann(cosmo, conf)

%time jax.block_until_ready(boltz_gpu(cosmo, conf))
%timeit jax.block_until_ready(boltz_gpu(cosmo, conf))
%time jax.block_until_ready(jax.device_put(boltz_cpu(cosmo, conf), device=jax.devices('gpu')[0]))
%timeit jax.block_until_ready(jax.device_put(boltz_cpu(cosmo, conf), device=jax.devices('gpu')[0]))

The time and timeit results are:

CPU times: user 10.8 ms, sys: 0 ns, total: 10.8 ms
Wall time: 10.7 ms
9.31 ms ± 6.62 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
CPU times: user 4.19 ms, sys: 11 µs, total: 4.2 ms
Wall time: 3.53 ms
1.85 ms ± 6.77 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

So @adrianbayer you can probably do something like the following

boltz = jax.jit(boltzmann, device=jax.devices('cpu')[0])

cosmo = boltz(cosmo, conf)
cosmo = jax.device_put(cosmo, device=jax.devices('gpu')[0])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants