Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sdp::SDPBackend::flash_attention support PrivateUse1 #126392

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

1274085042
Copy link
Contributor

@1274085042 1274085042 commented May 16, 2024

Copy link

pytorch-bot bot commented May 16, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126392

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 481a3f9 with merge base 0ff2f8b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@albanD albanD requested review from drisspg and removed request for albanD May 16, 2024 13:25
@@ -14685,6 +14685,11 @@
CPU: _scaled_dot_product_flash_attention_cpu
tags: nondeterministic_seeded

- func: _scaled_dot_product_flash_attention_overrideable(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets rename this to '_sacled_dot_product_fused_attention_overrideable'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

@@ -673,7 +677,7 @@ Tensor scaled_dot_product_attention(
return std::get<0>(out_lse_softmax);
}
case sdp::SDPBackend::flash_attention: {
if(query_.device().type() == DeviceType::CUDA){
if(query_.device().type() == DeviceType::CUDA) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

@1274085042 1274085042 requested a review from drisspg May 21, 2024 02:55
@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 21, 2024
@drisspg drisspg requested a review from jainapurva May 21, 2024 23:59
@drisspg
Copy link
Contributor

drisspg commented May 22, 2024

The current structure of this op looks like;

|-- Determine backend (CUDA, CPU, HIP, PrivateUse1)
|    |
|    |-- if PrivateUse1:
|    |      |-- handle_private_use(...)
|    |-- else:
|          |-- _fused_sdp_choice_stub(...)
|
|-- switch (backend)
     |
     |-- case cudnn_attention:
     |      |-- out_lse_softmax = at::_scaled_dot_product_cudnn_attention(...)
     |
     |-- case flash_attention:
     |      |-- if CUDA:
     |      |      |-- out_lse_softmax = at::_scaled_dot_product_flash_attention(...)
     |      |-- else (CPU):
     |            |-- return at::_scaled_dot_product_flash_attention_for_cpu(...)
     |
     |-- case efficient_attention:
     |      |-- out_and_lse = at::_scaled_dot_product_efficient_attention(...)
     |
     |-- case math:
     |      |-- return at::_scaled_dot_product_attention_math(...)
     |
     |-- default:
            |-- TORCH_CHECK(false, "No viable backend found.")
            |-- return Tensor()

I spoke with Alban offline about this, and we came to the conclusion that we want this structure:

|-- Determine backend (CUDA, CPU, HIP, PrivateUse1)
|    | If stub_registered(){
|    | 		|--_fused_sdp_choice_stub(...)
|	 | Else
|.   | Use math as choice
|
|-- switch (backend)
     |
     |-- case cudnn_attention:
     |      |-- out_lse_softmax = at::_scaled_dot_product_cudnn_attention(...)
     |
     |-- case flash_attention:
     |      |-- if CUDA:
     |      |      |-- out_lse_softmax = at::_scaled_dot_product_flash_attention(...)
     |      |-- else (CPU):
     |            |-- return at::_scaled_dot_product_flash_attention_for_cpu(...)
     |
     |-- case efficient_attention:
     |      |-- out_and_lse = at::_scaled_dot_product_efficient_attention(...)
     |
	 |-- case overridable:
	 		|-- return at::_scaled_dot_product_attention_overridable(...)
	 }
     |-- case math:
     |      |-- return at::_scaled_dot_product_attention_math(...)
     |
	 |
     |-- default:
            |-- TORCH_CHECK(false, "No viable backend found.")
            |-- return Tensor()

So what does that mean for this PR, the structure looks pretty good. I made some changes here that should enable this, so once this lands we can make land your updates: #126832

The dispatching logic for the kernels will be
default_choice is math, (if a device doesnt register a stub then they will get routed to math)

  • if a choice is registered devices have the option to go to an overridable op that this pr provides. That op should have no preprocessing but will be run through 'validate_sdpa' and convert attn_mask from bool to float

pytorchmergebot pushed a commit that referenced this pull request May 25, 2024
# Summary

Adds a public method to dispatchstub to check if a fn has been registered for a device. We use this new function to clean up the dispatching logic for SDPA, as well as make the private use dispatching simpler:
#126392
Pull Request resolved: #126832
Approved by: https://github.com/ezyang, https://github.com/albanD
titaiwangms pushed a commit to titaiwangms/pytorch that referenced this pull request May 28, 2024
# Summary

Adds a public method to dispatchstub to check if a fn has been registered for a device. We use this new function to clean up the dispatching logic for SDPA, as well as make the private use dispatching simpler:
pytorch#126392
Pull Request resolved: pytorch#126832
Approved by: https://github.com/ezyang, https://github.com/albanD
@1274085042
Copy link
Contributor Author

@drisspg
could this update be landed?

@drisspg
Copy link
Contributor

drisspg commented May 29, 2024

The PR I referenced above has landed can you rebase?

@1274085042
Copy link
Contributor Author

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased flash_attention_overrideable onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout flash_attention_overrideable && git pull --rebase)

@1274085042
Copy link
Contributor Author

@drisspg Rebased and fixed some CI issues

@@ -680,10 +684,15 @@ Tensor scaled_dot_product_attention(
auto out_lse_softmax = at::_scaled_dot_product_flash_attention(
query_padded, key_padded, value_padded, dropout_p, is_causal, false /*return_debug_mask*/, og_scale.as_float_unchecked());
return post_process_flash_output(std::get<0>(out_lse_softmax), og_size);
}
} else if (query_.device().type() == DeviceType::PrivateUse1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesnt look right to me

It should now just be 1 more case switch entry

You will need to add the overridable backend

   case sdp::SDPBackend::overridable:
      return std::get<0>(at::_scaled_dot_product_attention_overridable(
         ...));```
         
         
         
     Private use authors would thsu register a dispatch to the stub and have it return the overrridable backend
     
     by default they would be routed to the math backend

Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left a comment

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add logic about PrivateUse1 in sdp::SDPBackend::flash_attention
6 participants