-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JIT constant folding #21300
Comments
I'm aware that |
Some explanation why it depends on the shape: We have a heuristic to not apply constant folding if the operand shape is too large. The cutoff is 45 * 1000 * 1000 elements. In the "fast" cases we don't apply constant folding. |
Thanks for the reply! import jax
import jax.numpy as jnp
import jax.core
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platforms", "cpu")
v = jnp.zeros((200000, 10, 10))
def f():
return jax.vmap(jax.vmap(jnp.sum))(v)
def g():
return jax.vmap(jnp.sum)(v)
print("f")
jax.jit(f)()
print("g")
jax.jit(g)() Maybe it would help to clarify what constant folding is used for / where it makes sense to apply it. I'm wondering why this is so slow - intuitively, I would think that constant folding happens approximately at the speed of |
For constant folding, the HloEvaluator is used. It is not optimized for speed, but for correctness, as it is used as reference backend in tests. https://github.com/openxla/xla/blob/main/xla/service/hlo_constant_folding.cc I don't know what the nested jax.vmap would translate to, but I think you can safely assume that fast runtime means that constant folding is not applied. Constant folding only makes sense if what is being constant folded would run several times. If it is run only a single time, then you would be better off without constant folding. |
Thanks for clarifying! Would it be a useful addition to For example, you could force this with the current API using Maybe this would be a nice addition for those users (like me) who use many and large static arrays (i.e. constants in the context of jit) but don't want constant folding to slow the compilation down. |
I am not familiar with the JAX side of things. On XLA side we have a flag that could be used to turn off constant folding: --xla_disable_hlo_passes=constant_folding This can be set via the XLA_FLAGS environment variable. So something like os.environ['XLA_FLAGS'] = "--xla_disable_hlo_passes=constant_folding" from python |
Thanks for helping! It would be nice to also have an option like this in |
You can do this via |
That was a fast comment! When trying this, i get the following error: |
Ohh sorry you need |
The code I used: v = jnp.zeros((200000, 10, 10))
def f():
return jax.vmap(jax.vmap(jnp.sum))(v)
jax.jit(f).lower().compile(compiler_options={'xla_disable_hlo_passes': 'constant_folding'}) |
Hmm, this might require some fixes in the jax code. I'll take a look. |
Description
Hi,
I was hoping that someone could help me with this.
Sometimes, when using constants in jitted functions, I get warnings like this one:
These warnings appear seemingly random, for example with the following code:
This code produces "constant folding" warnings on windows and on linux. Maybe / probably this is dependend on OS version, CPU type, ...
When playing around with array shapes and number of nested vmaps, these messages appear or not appear without any clear (atleast not clear to me) pattern. For exampe, this is fast:
While this is slow and produces the warning:
Constant folding only happens when compiling with
jax.jit
- making jaxprs is not affected.Since jaxprs are perfectly able to catch constants, it is possible to compile them while treating constants as variables.
The following function demonstrates this:
Now, using
other_jit(f)()
instead ofjax.jit(f)()
prevents the issue.I was wondering if this is intended behavior.
Wouldn't it be a better solution in most cases to always treat constants as variables while compiling, to prevent constant folding from slowing down compilations?
In real-world scenarios, using (a generalized version of) the
other_jit
function I presented here can significantly reduce compilation times from a few minutes to just seconds.What's your opinion on this? I would appreciate any help or suggestions.
System info (python version, jaxlib version, accelerator, etc.)
cpu
jax 0.4.28
jaxlib 0.4.28
The text was updated successfully, but these errors were encountered: