lmflow.utils.flash_attention.gpt2_flash_attention#

Functions#

forward(→ Tuple[Union[torch.Tensor, ...)

_prepare_decoder_attention_mask(self, attention_mask, ...)

replace_gpt2_attn_with_flash_attn()

Module Contents#

lmflow.utils.flash_attention.gpt2_flash_attention.forward(self, hidden_states: Tuple[torch.FloatTensor] | None, layer_past: Tuple[torch.Tensor] | None = None, attention_mask: torch.FloatTensor | None = None, head_mask: torch.FloatTensor | None = None, encoder_hidden_states: torch.Tensor | None = None, encoder_attention_mask: torch.FloatTensor | None = None, use_cache: bool | None = False, output_attentions: bool | None = False) Tuple[torch.Tensor | Tuple[torch.Tensor], Ellipsis][source]#
lmflow.utils.flash_attention.gpt2_flash_attention._prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length)[source]#
lmflow.utils.flash_attention.gpt2_flash_attention.replace_gpt2_attn_with_flash_attn()[source]#