-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
sdp::SDPBackend::flash_attention support PrivateUse1 #126392
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126392
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 481a3f9 with merge base 0ff2f8b (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -14685,6 +14685,11 @@ | |||
CPU: _scaled_dot_product_flash_attention_cpu | |||
tags: nondeterministic_seeded | |||
|
|||
- func: _scaled_dot_product_flash_attention_overrideable(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets rename this to '_sacled_dot_product_fused_attention_overrideable'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
@@ -673,7 +677,7 @@ Tensor scaled_dot_product_attention( | |||
return std::get<0>(out_lse_softmax); | |||
} | |||
case sdp::SDPBackend::flash_attention: { | |||
if(query_.device().type() == DeviceType::CUDA){ | |||
if(query_.device().type() == DeviceType::CUDA) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
The current structure of this op looks like; |-- Determine backend (CUDA, CPU, HIP, PrivateUse1)
| |
| |-- if PrivateUse1:
| | |-- handle_private_use(...)
| |-- else:
| |-- _fused_sdp_choice_stub(...)
|
|-- switch (backend)
|
|-- case cudnn_attention:
| |-- out_lse_softmax = at::_scaled_dot_product_cudnn_attention(...)
|
|-- case flash_attention:
| |-- if CUDA:
| | |-- out_lse_softmax = at::_scaled_dot_product_flash_attention(...)
| |-- else (CPU):
| |-- return at::_scaled_dot_product_flash_attention_for_cpu(...)
|
|-- case efficient_attention:
| |-- out_and_lse = at::_scaled_dot_product_efficient_attention(...)
|
|-- case math:
| |-- return at::_scaled_dot_product_attention_math(...)
|
|-- default:
|-- TORCH_CHECK(false, "No viable backend found.")
|-- return Tensor() I spoke with Alban offline about this, and we came to the conclusion that we want this structure: |-- Determine backend (CUDA, CPU, HIP, PrivateUse1)
| | If stub_registered(){
| | |--_fused_sdp_choice_stub(...)
| | Else
|. | Use math as choice
|
|-- switch (backend)
|
|-- case cudnn_attention:
| |-- out_lse_softmax = at::_scaled_dot_product_cudnn_attention(...)
|
|-- case flash_attention:
| |-- if CUDA:
| | |-- out_lse_softmax = at::_scaled_dot_product_flash_attention(...)
| |-- else (CPU):
| |-- return at::_scaled_dot_product_flash_attention_for_cpu(...)
|
|-- case efficient_attention:
| |-- out_and_lse = at::_scaled_dot_product_efficient_attention(...)
|
|-- case overridable:
|-- return at::_scaled_dot_product_attention_overridable(...)
}
|-- case math:
| |-- return at::_scaled_dot_product_attention_math(...)
|
|
|-- default:
|-- TORCH_CHECK(false, "No viable backend found.")
|-- return Tensor() So what does that mean for this PR, the structure looks pretty good. I made some changes here that should enable this, so once this lands we can make land your updates: #126832 The dispatching logic for the kernels will be
|
# Summary Adds a public method to dispatchstub to check if a fn has been registered for a device. We use this new function to clean up the dispatching logic for SDPA, as well as make the private use dispatching simpler: #126392 Pull Request resolved: #126832 Approved by: https://github.com/ezyang, https://github.com/albanD
# Summary Adds a public method to dispatchstub to check if a fn has been registered for a device. We use this new function to clean up the dispatching logic for SDPA, as well as make the private use dispatching simpler: pytorch#126392 Pull Request resolved: pytorch#126832 Approved by: https://github.com/ezyang, https://github.com/albanD
@drisspg |
The PR I referenced above has landed can you rebase? |
@pytorchmergebot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
c276523
to
761a79d
Compare
@drisspg Rebased and fixed some CI issues |
@@ -680,10 +684,15 @@ Tensor scaled_dot_product_attention( | |||
auto out_lse_softmax = at::_scaled_dot_product_flash_attention( | |||
query_padded, key_padded, value_padded, dropout_p, is_causal, false /*return_debug_mask*/, og_scale.as_float_unchecked()); | |||
return post_process_flash_output(std::get<0>(out_lse_softmax), og_size); | |||
} | |||
} else if (query_.device().type() == DeviceType::PrivateUse1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesnt look right to me
It should now just be 1 more case switch entry
You will need to add the overridable backend
case sdp::SDPBackend::overridable:
return std::get<0>(at::_scaled_dot_product_attention_overridable(
...));```
Private use authors would thsu register a dispatch to the stub and have it return the overrridable backend
by default they would be routed to the math backend
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
left a comment
Fixes #124271
cc @cpuhrsch @drisspg @albanD @soulitzer