lmflow.utils.flash_attention.triton_flash_attention#
Experimental implementation of FlashAttention in Triton. Tested with triton==2.0.0.dev20221202. Triton 2.0 has a new backend (MLIR) but seems like it doesn’t yet work for head dimensions other than 64: openai/triton We’ll update this implementation with the new Triton backend once this is fixed.
We use the FlashAttention implementation from Phil Tillet a starting point. openai/triton
Changes: - Implement both causal and non-causal attention. - Implement both self-attention and cross-attention. - Support arbitrary seqlens (not just multiples of 128), for both forward and backward. - Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward. - Support attention bias. - Speed up the forward pass a bit, and only store the LSE instead of m and l. - Make the backward for d=128 much faster by reducing register spilling. - Optionally parallelize the backward pass across seqlen_k, to deal with the case of small batch size * nheads.
Caution: - This is an experimental implementation. The forward pass should be quite robust but I’m not 100% sure that the backward pass doesn’t have race conditions (due to the Triton compiler). - This implementation has only been tested on A100. - If you plan to use headdim other than 64 and 128, you should test for race conditions (due to the Triton compiler), as done in tests/test_flash_attn.py “test_flash_attn_triton_race_condition”. I’ve tested and fixed many race conditions for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I’m still not 100% confident that there are none left for other head dimensions.
Differences between this Triton version and the CUDA version: - Triton version doesn’t support dropout. - Triton forward is generally faster than CUDA forward, while Triton backward is generally slower than CUDA backward. Overall Triton forward + backward is slightly slower than CUDA forward + backward. - Triton version doesn’t support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor). - Triton version supports attention bias, while CUDA version doesn’t.
Attributes#
Classes#
Functions#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Module Contents#
- lmflow.utils.flash_attention.triton_flash_attention._fwd_kernel(Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE: triton.language.constexpr, IS_CAUSAL: triton.language.constexpr, BLOCK_HEADDIM: triton.language.constexpr, EVEN_M: triton.language.constexpr, EVEN_N: triton.language.constexpr, EVEN_HEADDIM: triton.language.constexpr, BLOCK_M: triton.language.constexpr, BLOCK_N: triton.language.constexpr)[source]#
- lmflow.utils.flash_attention.triton_flash_attention._bwd_preprocess_do_o_dot(Out, DO, Delta, stride_ob, stride_oh, stride_om, stride_dob, stride_doh, stride_dom, nheads, seqlen_q, seqlen_q_rounded, headdim, BLOCK_M: triton.language.constexpr, BLOCK_HEADDIM: triton.language.constexpr)[source]#
- lmflow.utils.flash_attention.triton_flash_attention._bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M: triton.language.constexpr, EVEN_N: triton.language.constexpr, EVEN_HEADDIM: triton.language.constexpr)[source]#
- lmflow.utils.flash_attention.triton_flash_attention._bwd_kernel_one_col_block(start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dqm, stride_dkn, stride_dvn, seqlen_q, seqlen_k, headdim, ATOMIC_ADD: triton.language.constexpr, BIAS_TYPE: triton.language.constexpr, IS_CAUSAL: triton.language.constexpr, BLOCK_HEADDIM: triton.language.constexpr, EVEN_M: triton.language.constexpr, EVEN_N: triton.language.constexpr, EVEN_HEADDIM: triton.language.constexpr, BLOCK_M: triton.language.constexpr, BLOCK_N: triton.language.constexpr)[source]#
- lmflow.utils.flash_attention.triton_flash_attention._bwd_kernel(Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_dob, stride_doh, stride_dom, stride_dqb, stride_dqh, stride_dqm, stride_dkb, stride_dkh, stride_dkn, stride_dvb, stride_dvh, stride_dvn, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE: triton.language.constexpr, IS_CAUSAL: triton.language.constexpr, BLOCK_HEADDIM: triton.language.constexpr, SEQUENCE_PARALLEL: triton.language.constexpr, EVEN_M: triton.language.constexpr, EVEN_N: triton.language.constexpr, EVEN_HEADDIM: triton.language.constexpr, BLOCK_M: triton.language.constexpr, BLOCK_N: triton.language.constexpr)[source]#
- lmflow.utils.flash_attention.triton_flash_attention._flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None)[source]#
- lmflow.utils.flash_attention.triton_flash_attention._flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None)[source]#
- class lmflow.utils.flash_attention.triton_flash_attention.FlashAttnQKVPackedFunc[source]#
Bases:
torch.autograd.Function
- static forward(ctx, qkv, bias=None, causal=False, softmax_scale=None)[source]#
qkv: (batch, seqlen, 3, nheads, headdim) bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen). ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)
- class lmflow.utils.flash_attention.triton_flash_attention.FlashAttnKVPackedFunc[source]#
Bases:
torch.autograd.Function
- static forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None)[source]#
q: (batch, seqlen_q, nheads, headdim) kv: (batch, seqlen_k, 2, nheads, headdim) bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
- class lmflow.utils.flash_attention.triton_flash_attention.FlashAttnFunc[source]#
Bases:
torch.autograd.Function
- static forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None)[source]#
q: (batch_size, seqlen_q, nheads, headdim) k, v: (batch_size, seqlen_k, nheads, headdim) bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)