Source code for lmflow.utils.flash_attention.gpt_neo_flash_attention

from typing import List, Optional, Tuple

import torch
import transformers
from einops import rearrange

#try to import flash_attn 2.x.x, if not, import flash_attn 1.x.x
try:
    from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
except:
    from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func

from flash_attn.bert_padding import unpad_input, pad_input

[docs] def _attn(self, query, key, value, attention_mask=None, head_mask=None): # (batch, head, seq_length, head_features) query = query.to(torch.bfloat16) key = key.to(torch.bfloat16) query = query * torch.sqrt(torch.tensor(self.head_dim)) qkv = torch.stack( [query, key, value], dim=2 )# [bsz, nh, 3, t, hd] qkv = qkv.transpose(1,3)## [bsz, q_len, 3, nh, hd] bsz = qkv.shape[0] q_len = qkv.shape[1] attention_mask = torch.where(attention_mask == -0.0, True, False) key_padding_mask = rearrange(attention_mask, "b () () s -> b s") if attention_mask is not None else None if key_padding_mask is None: qkv = rearrange(qkv, "b s ... -> (b s) ...") max_s = q_len cu_q_lens = torch.arange( 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device ) output = flash_attn_unpadded_qkvpacked_func( qkv, cu_q_lens, max_s, self.attn_dropout.p if self.training else 0.0 , softmax_scale=None, causal=True )# attention compute output = rearrange(output, "(b s) ... -> b s ...", b=bsz) else: nheads = qkv.shape[-2] x = rearrange(qkv, "b s three h d -> b s (three h d)") x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) x_unpad = rearrange( x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads ) output_unpad = flash_attn_unpadded_qkvpacked_func( x_unpad, cu_q_lens, max_s, self.attn_dropout.p if self.training else 0.0, softmax_scale=None, causal=True ) output = rearrange( pad_input( rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len ), "b s (h d) -> b s h d", h=nheads, ) return output, None
[docs] def forward( self, hidden_states, attention_mask=None, layer_past=None, head_mask=None, use_cache=False, output_attentions=False, ): assert head_mask is None, "head_mask is not supported" assert not output_attentions, "output_attentions is not supported" assert not use_cache, "use_cache is not supported" query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) value = self.v_proj(hidden_states) query = self._split_heads(query, self.num_heads, self.head_dim) key = self._split_heads(key, self.num_heads, self.head_dim) value = self._split_heads(value, self.num_heads, self.head_dim) if layer_past is not None: past_key = layer_past[0] past_value = layer_past[1] key = torch.cat((past_key, key), dim=-2) value = torch.cat((past_value, value), dim=-2) present = None attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) new_shape = attn_output.size()[:-2] + (self.num_heads * self.head_dim,) attn_output = attn_output.view(new_shape) attn_output = self.out_proj(attn_output) attn_output = self.resid_dropout(attn_output) outputs = (attn_output, present) return outputs # a, present, (attentions)
[docs] def replace_gpt_neo_attn_with_flash_attn(): transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoSelfAttention._attn = _attn transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoSelfAttention.forward = forward