lmflow.utils.flash_attention.bloom_flash_attention#

Functions#

forward(self, hidden_states, residual, alibi, ...[, ...])

_prepare_attn_mask(→ torch.BoolTensor)

replace_bloom_attn_with_flash_attn()

Module Contents#

lmflow.utils.flash_attention.bloom_flash_attention.forward(self, hidden_states: torch.Tensor, residual: torch.Tensor, alibi: torch.Tensor, attention_mask: torch.Tensor, layer_past: Tuple[torch.Tensor, torch.Tensor] | None = None, head_mask: torch.Tensor | None = None, use_cache: bool = False, output_attentions: bool = False)[source]#
lmflow.utils.flash_attention.bloom_flash_attention._prepare_attn_mask(self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int) torch.BoolTensor[source]#
lmflow.utils.flash_attention.bloom_flash_attention.replace_bloom_attn_with_flash_attn()[source]#