lmflow.utils.flash_attention.llama_flash_attention#

Functions#

forward(→ Tuple[torch.Tensor, Optional[torch.Tensor], ...)

_prepare_decoder_attention_mask(self, attention_mask, ...)

replace_llama_attn_with_flash_attn()

Module Contents#

lmflow.utils.flash_attention.llama_flash_attention.forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_value: Tuple[torch.Tensor] | None = None, output_attentions: bool = False, use_cache: bool = False) Tuple[torch.Tensor, torch.Tensor | None, Tuple[torch.Tensor] | None][source]#
lmflow.utils.flash_attention.llama_flash_attention._prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length)[source]#
lmflow.utils.flash_attention.llama_flash_attention.replace_llama_attn_with_flash_attn()[source]#