-
Notifications
You must be signed in to change notification settings - Fork 25.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: end-to-end compilation #30788
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
This is great 🔥 🚀 . Please ping me once ready - I would like to run some benchmarks! Also, it would be best to have an integration test for at least 2 models, probably |
I added [WIP] to the title but didn't change it to draft mode. Hope I am doing it right. |
@gante Do you have the generation running time numbers of |
@ydshieh I'm not sure it will be faster, the biggest win is in terms of compatibility (compiled graphs are more portable and easily consumed by specialized hardware). I will run and share a few benchmarks in any case 🤗 |
Nice! If you are OK with that, maybe adopt the following small & simple script for a (even just a first) benchmark.
|
7db06d8
to
9c08cec
Compare
# with torch.compile, we only support the basic case where we assume `input_ids` has all tokens (prompt + | ||
# generated tokens) | ||
# TODO (joao): remove this when torch's support for control flow is not experimental (https://pytorch.org/docs/stable/generated/torch.cond.html) | ||
if is_torchdynamo_compiling(): | ||
input_ids = input_ids[:, cache_position] | ||
|
||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(the else
is the original code)
@ydshieh TL;DR not much faster, at least for now :) (benchmarks in the PR header) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool 🔥
# generated tokens) | ||
# TODO (joao): remove this when torch's support for control flow is not experimental (https://pytorch.org/docs/stable/generated/torch.cond.html) | ||
if is_torchdynamo_compiling(): | ||
input_ids = input_ids[:, cache_position] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if input_ids
is not None or compile doesn't work with embeds yet?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or compile doesn't work with embeds yet?
this one. end-to-end compilation is very very feature limited, only covering the most common use cases for now
input_ids = input_ids[:, past_length:] | ||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. | ||
|
||
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Jut wondering, as I've never understood this part. How can this behavior be triggered? If a user uses static cache, and wants to generate several times with different seq length without re-init cache?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zucchini-nlp It's possible with SinkCache
:) and now with the SlidingWindowCache
too
(seemingly unrelated CI failures) |
What does this PR do?
Closes #29374
This PR introduces an MVP for end-to-end
generate
compilation -- the wholegenerate
function can be compiled into a single graph. From this state, we can start testing and improving many aspects of ourgenerate
code wrt compilation, as well as unlock features that rely ongenerate
being a single graph.generate
compilation in this PR has MANY restrictions, including:torch.multinomial(probs, num_samples=1)
can't be compiled into a cuda graph, sodo_sample=True
is not compatible without a cuda graph break;Performance
TL;DR:
forward
, ~5% faster. The main benefit is the ability to export a single graphmax_new_tokens=256
Setup:
forward
but callinggenerate
(i.e. there is some overhead from calling the uncompiledgenerate
)google/gemma-2b
code
A100
RTX3090