Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multihead arch with cutlass fused multihead attention #1976

Draft
wants to merge 70 commits into
base: master
Choose a base branch
from

Conversation

almaudoh
Copy link
Contributor

cutlass implementation of fused multihead attention layer giving about 10% speedup on A100.

@todo

  • Fusing biases
  • Fusing layernorms
  • Fusing FFNs

ankan-ban and others added 30 commits March 22, 2022 22:28
- skip connection add before layer norm now has a scaling factor (alpha)
 - replace conv layer of value and mlh heads with an embedding layer when attention body is used.
- will be removed once it's fixed.
- also fix scratch space calculation.
 - factor of sizeof(DataType) was missing.
- to handle bigger/wider networks
1.3% improvement in BT2 on RTX 4090
15.6% improvement in test BT3 network with 64 heads.
almaudoh-1 and others added 27 commits September 29, 2023 23:46
@almaudoh almaudoh changed the title Multihead arch cuda cutlass fmha Multihead arch cuda cutlass fused multihead attention Mar 11, 2024
@almaudoh almaudoh changed the title Multihead arch cuda cutlass fused multihead attention Multihead arch with cutlass fused multihead attention Mar 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants