Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add logic about PrivateUse1 in sdp::SDPBackend::flash_attention #124271

Open
1274085042 opened this issue Apr 17, 2024 · 4 comments 路 May be fixed by #126392
Open

Add logic about PrivateUse1 in sdp::SDPBackend::flash_attention #124271

1274085042 opened this issue Apr 17, 2024 · 4 comments 路 May be fixed by #126392
Assignees
Labels
module: multi-headed-attention module: PrivateUse1 private use oncall: transformer/mha Issues related to Transformers and MultiheadAttention triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@1274085042
Copy link
Contributor

1274085042 commented Apr 17, 2024

馃殌 The feature, motivation and pitch

https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/attention.cpp#L675 The logic only includes CUDA and CPU.
The logic for third-party cuda-like devices should be added to facilitate adaptation for these devices.
cc @jbschlosser @bhosmer @cpuhrsch @erichan1 @drisspg @mikaylagawarecki

My code snippet

    case sdp::SDPBackend::flash_attention: {
      if(query_.device().type() == DeviceType::CUDA){
        c10::SymInt og_size = query_.sym_size(-1);
        Tensor query_padded = pad_last_dim<8, false>(query_);
        Tensor key_padded = pad_last_dim<8, false>(key);
        Tensor value_padded = pad_last_dim<8, false>(value);
        // We need to calculate the scale based off the OG head dim size
        auto og_scale = sdp::calculate_scale(query_, scale);
        auto out_lse_softmax = at::_scaled_dot_product_flash_attention(
            query_padded, key_padded, value_padded, dropout_p, is_causal, false /*return_debug_mask*/, og_scale.as_float_unchecked());
        return post_process_flash_output(std::get<0>(out_lse_softmax), og_size);
      } else if (query_.device().type() == DeviceType::PrivateUse1) {
        auto out_lse_softmax = at::_scaled_dot_product_flash_attention(
            query_, key, value, dropout_p, is_causal, false /*return_debug_mask*/, scale);
        return std::get<0>(out_lse_softmax);
      }
      else {
        // For the CPU case we do not need to pad the last dim
        return std::get<0>(at::_scaled_dot_product_flash_attention_for_cpu(
            query_, key, value, dropout_p, is_causal, attn_mask, scale));
      }
    }

Alternatives

Additional context

No response

@soulitzer soulitzer added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module oncall: transformer/mha Issues related to Transformers and MultiheadAttention module: multi-headed-attention module: PrivateUse1 private use labels Apr 17, 2024
@albanD
Copy link
Collaborator

albanD commented Apr 30, 2024

but flash_attention is really a cuda-specific kernel to implement this.
Should we just have a SDPBackend::overrideable like we do with convolution and you should extend that?

cc @drisspg

@drisspg
Copy link
Contributor

drisspg commented Apr 30, 2024

Yes, I think this is a good idea

@1274085042
Copy link
Contributor Author

case sdp::SDPBackend::flash_attention: {
      if(query_.device().type() == DeviceType::CUDA){
        c10::SymInt og_size = query_.sym_size(-1);
        Tensor query_padded = pad_last_dim<8, false>(query_);
        Tensor key_padded = pad_last_dim<8, false>(key);
        Tensor value_padded = pad_last_dim<8, false>(value);
        // We need to calculate the scale based off the OG head dim size
        auto og_scale = sdp::calculate_scale(query_, scale);
        auto out_lse_softmax = at::_scaled_dot_product_flash_attention(
            query_padded, key_padded, value_padded, dropout_p, is_causal, false /*return_debug_mask*/, og_scale.as_float_unchecked());
        return post_process_flash_output(std::get<0>(out_lse_softmax), og_size);
      } 
        // For the CPU case we do not need to pad the last dim
        return std::get<0>(at::_scaled_dot_product_flash_attention_for_cpu(
            query_, key, value, dropout_p, is_causal, attn_mask, scale));
    }
case sdp::SDPBackend::overrideable: {
 auto out_lse_softmax = at::_scaled_dot_product_flash_attention_overrideable(
            query_, key, value, dropout_p, is_causal, false /*return_debug_mask*/, scale);
        return std::get<0>(out_lse_softmax);
  }
case sdp::SDPBackend::efficient_attention: {
... ...

@albanD @drisspg Is this code what you expected?

@drisspg
Copy link
Contributor

drisspg commented May 20, 2024

Lets actually discuss more here, I dont like that we have to override both the_fused_sdp choice and then dispatch to an overridable only for the flash_attention case.

I think that the PR you linked is pretty good only change would be that we always dispatch to the overrideable case for private use.

@drisspg drisspg self-assigned this May 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: multi-headed-attention module: PrivateUse1 private use oncall: transformer/mha Issues related to Transformers and MultiheadAttention triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
4 participants