Source code for lmflow.utils.flash_attention.bloom_flash_attention

from typing import List, Optional, Tuple, Union

import torch
from torch import nn
import torch.nn.functional as F

import transformers
from transformers.models.bloom.modeling_bloom import dropout_add

from einops import rearrange

from .triton_flash_attention import flash_attn_qkvpacked_func

[docs] def forward( self, hidden_states: torch.Tensor, residual: torch.Tensor, alibi: torch.Tensor, attention_mask: torch.Tensor, layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, ): dtype = hidden_states.dtype fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] # 3 x [batch_size, seq_length, num_heads, head_dim] (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) batch_size, q_length, _, _ = query_layer.shape bsz, q_len = batch_size, q_length if layer_past is not None: past_key, past_value = layer_past # concatenate along seq_length dimension: # - key: [batch_size * self.num_heads, head_dim, kv_length] # - value: [batch_size * self.num_heads, kv_length, head_dim] key_layer = torch.cat((past_key, key_layer), dim=2) value_layer = torch.cat((past_value, value_layer), dim=1) if use_cache is True: present = (key_layer, value_layer) else: present = None reshaped_alibi = rearrange(alibi, '(b h) one s-> b h one s', h = self.num_heads) reshaped_alibi = reshaped_alibi * self.beta attention_mask = (1.0 - attention_mask) attention_mask = attention_mask[:, None, None, :].bool() reshaped_alibi_masked = reshaped_alibi.masked_fill(attention_mask, -1e9) reshaped_query_layer = query_layer reshaped_key_layer = key_layer reshaped_value_layer = value_layer qkv = torch.concat([reshaped_query_layer.unsqueeze(2), reshaped_key_layer.unsqueeze(2), reshaped_value_layer.unsqueeze(2)], dim = 2) output = flash_attn_qkvpacked_func( qkv, reshaped_alibi_masked, True, self.inv_norm_factor ) output = rearrange(output, 'b s h d -> (b h) s d') # change view [batch_size, num_heads, q_length, head_dim] context_layer = self._merge_heads(output) # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 if self.pretraining_tp > 1 and self.slow_but_exact: slices = self.hidden_size / self.pretraining_tp output_tensor = torch.zeros_like(context_layer) for i in range(self.pretraining_tp): output_tensor = output_tensor + F.linear( context_layer[:, :, int(i * slices) : int((i + 1) * slices)], self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], ) else: output_tensor = self.dense(context_layer) output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) outputs = (output_tensor, present) if output_attentions: outputs += (context_layer,) return outputs
# Disable the transformation of the attention mask in LlamaModel as the flash attention # requires the attention mask to be the same as the key_padding_mask
[docs] def _prepare_attn_mask( self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int ) -> torch.BoolTensor: return attention_mask
[docs] def replace_bloom_attn_with_flash_attn(): transformers.models.bloom.modeling_bloom.BloomModel._prepare_attn_mask = ( _prepare_attn_mask ) transformers.models.bloom.modeling_bloom.BloomAttention.forward = forward