-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Extend Fake Tensor Caching to Symints #126411
Comments
I'm not sure if this would be worth it. I instrumented the existing cache and ran:
From what I can tell the time spent in the FakeTensorMode dispatch of ops that were not cached because they contained any kind of sym expr was tiny. Of 151s spent in dispatch 0.003s were cache bypassed due to containing a SymInt, SymFloat, or SymBool. Is there a better benchmark I should use to measure the potential of this change? |
You have to run with |
I used Running:
Gives a total time of 338s (so slower) but still only 0.045s in dispatching w/ Sym types. Still fairly insignificant. Or I'm measuring it poorly. |
ok, well, I never claimed that you would expect a speedup here :) |
The benchmark runs with |
I had both (scroll to the right of the given command line in the comment). |
When I run:
And add
to FakeTensorMode I see:
I also see that in the above benchmark without |
So it turns out that the perf measurements work a lot better when you store them with |
馃殌 The feature, motivation and pitch
We have an existing in-memory, fake tensor cache which short-circuits running metas and decompositions when a particular op overload has inputs with metadata that has already been observed. Fake Tensors get cache hits across multiple code paths in torch.compile, and even within the same graph. Enabling it gave a 10% compilation speedup across huggingface dashboard.
We do not currently serve cache hits for tensors with symints, but we should. If you run:
In the second invocation, where we are running the same exact op with the same inputs and symint ids, any guards that might be added must necessarily have been added in the prior run.
According to @ezyang:
"my main fear is you can't literally cache on torch.SymInt id, since these are not interned. so you need to do a structural hash on the sympy expression itself, which has some cost"
There is an existing symint hasher which might be useful: https://github.com/pytorch/pytorch/blob/main/torch/_inductor/fx_passes/dedupe_symint_uses.py#L9
For this repro:
we should only compute
aten.add.Tensor, (FakeTensor(..., size=(s0,)), FakeTensor(..., size=(1,))), {}
once. we do it 6 times currently.cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @aorenste , @masnesral
Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: