Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend Fake Tensor Caching to Symints #126411

Open
eellison opened this issue May 16, 2024 · 8 comments
Open

Extend Fake Tensor Caching to Symints #126411

eellison opened this issue May 16, 2024 · 8 comments
Assignees
Labels
oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@eellison
Copy link
Contributor

eellison commented May 16, 2024

馃殌 The feature, motivation and pitch

We have an existing in-memory, fake tensor cache which short-circuits running metas and decompositions when a particular op overload has inputs with metadata that has already been observed. Fake Tensors get cache hits across multiple code paths in torch.compile, and even within the same graph. Enabling it gave a 10% compilation speedup across huggingface dashboard.

We do not currently serve cache hits for tensors with symints, but we should. If you run:

torch.ops.aten.add(FakeTensor[2, 3, s0]), 2)
torch.ops.aten.add(FakeTensor[2, 3, s0]), 2)

In the second invocation, where we are running the same exact op with the same inputs and symint ids, any guards that might be added must necessarily have been added in the prior run.

According to @ezyang:

"my main fear is you can't literally cache on torch.SymInt id, since these are not interned. so you need to do a structural hash on the sympy expression itself, which has some cost"

There is an existing symint hasher which might be useful: https://github.com/pytorch/pytorch/blob/main/torch/_inductor/fx_passes/dedupe_symint_uses.py#L9

For this repro:

import torch

@torch.compile(backend="aot_eager", dynamic=True)
def foo(x):
    t = torch.rand([1])
    t2 = torch.rand([1])
    return x + t, x + t2


inp = torch.rand([20])
foo(inp)

we should only compute aten.add.Tensor, (FakeTensor(..., size=(s0,)), FakeTensor(..., size=(1,))), {} once. we do it 6 times currently.

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @aorenste , @masnesral

Alternatives

No response

Additional context

No response

@aorenste aorenste self-assigned this May 16, 2024
@eellison eellison added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module oncall: pt2 labels May 16, 2024
@aorenste
Copy link
Contributor

I'm not sure if this would be worth it. I instrumented the existing cache and ran:

python benchmarks/dynamo/torchbench.py --performance --inference --amp --backend inductor --disable-cudagraphs --device cuda

From what I can tell the time spent in the FakeTensorMode dispatch of ops that were not cached because they contained any kind of sym expr was tiny. Of 151s spent in dispatch 0.003s were cache bypassed due to containing a SymInt, SymFloat, or SymBool.

Is there a better benchmark I should use to measure the potential of this change?

@ezyang
Copy link
Contributor

ezyang commented May 28, 2024

You have to run with --dynamic-shapes --dynamic-batch-size-only to actually trigger dynamic shapes in the benchmark suite

@aorenste
Copy link
Contributor

You have to run with --dynamic-shapes --dynamic-batch-size-only to actually trigger dynamic shapes in the benchmark suite

I used --dynamic-batch-only because --dynamic-batch-size-only wasn't accepted.

Running:

python benchmarks/dynamo/torchbench.py --performance --inference --amp --backend inductor --disable-cudagraphs --device cuda --dynamic-shapes --dynamic-batch-only

Gives a total time of 338s (so slower) but still only 0.045s in dispatching w/ Sym types. Still fairly insignificant. Or I'm measuring it poorly.

@ezyang
Copy link
Contributor

ezyang commented May 28, 2024

ok, well, I never claimed that you would expect a speedup here :)

@eellison
Copy link
Contributor Author

eellison commented May 28, 2024

The benchmark runs with --dynamic-shapes --dynamic-batch-only, not just --dynamic-batch-only. Maybe you need both ?

@aorenste
Copy link
Contributor

--dynamic-shapes --dynamic-batch-size-only

I had both (scroll to the right of the given command line in the comment).

@eellison
Copy link
Contributor Author

eellison commented May 28, 2024

When I run:

python benchmarks/dynamo/huggingface.py --performance --training --amp --backend aot_eager --device cuda --only BertForQuestionAnswering 
 --print-compilation-time --dynamic-batch-only

And add

    def __del__(self):
        print(self.cache_info())

to FakeTensorMode I see:

DispatchCacheInfo(hits=19258, misses=189, bypasses={'symbolic shape': 22386, 'dynamic output shape': 1, 'CompositeImplicitAutograd': 697, 'non-fake tensor': 54, 'non-FakeTensor output': 51}, size=189)

I also see that in the above benchmark without dynamic-batch-only disabling fake tensor cache causes 5 seconds slowdown. It's possible you're only looking at sym_types inputs but not fake tensor inputs with symints. About half of the ops are bypassed due to symints so I would expect a couple seconds of improvement.

@aorenste
Copy link
Contributor

So it turns out that the perf measurements work a lot better when you store them with += instead of =. Once I do that the symint stuff pops out as quite a bit more expensive (significant percentage of the dispatch time for some of the benchmarks)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants