Source code for lmflow.utils.flash_attention.llama_flash_attention

from typing import List, Optional, Tuple

import torch
from torch import nn
import math

import transformers
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb,_make_causal_mask,_expand_mask

from einops import rearrange

#try to import flash_attn 2.x.x, if not, import flash_attn 1.x.x
try:
    from flash_attn.flash_attn_interface import flash_attn_func
except:
    from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func as flash_attn_func

from flash_attn.bert_padding import unpad_input, pad_input


[docs] def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() if self.config.pretraining_tp > 1: raise ValueError("pretraining_tp > 1 is not supported for flash attention") else: query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None query_states, key_states, value_states = [ rearrange(x, "b h s d -> b s h d") for x in [query_states, key_states, value_states] ] input_dtype = query_states.dtype if input_dtype == torch.float32: # Handle the case where the model is quantized if hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_proj.weight.dtype query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) # below output will have shape (batch_size, seqlen, nheads, headdim) attn_output = flash_attn_func(query_states, key_states, value_states, causal=True) if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) if output_attentions: raise NotImplementedError("`output_attentions` is not supported when `use_flash_attn` is True") attn_weights = None return attn_output, attn_weights, past_key_value
# Disable the transformation of the attention mask in LlamaModel as the flash attention # requires the attention mask to be the same as the key_padding_mask
[docs] def _prepare_decoder_attention_mask( self, attention_mask, input_shape, inputs_embeds, past_key_values_length ): # [bsz, seq_len] if input_shape[-1] > 1 and past_key_values_length == 0: # encode return attention_mask # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( input_shape, inputs_embeds.dtype, device=inputs_embeds.device, past_key_values_length=past_key_values_length, ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( inputs_embeds.device ) combined_attention_mask = ( expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask ) return combined_attention_mask
[docs] def replace_llama_attn_with_flash_attn(): transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask transformers.models.llama.modeling_llama.LlamaAttention.forward = forward