lmflow.pipeline.utils.lisa_trainer_fsdp#
Attributes#
Classes#
Module Contents#
- class lmflow.pipeline.utils.lisa_trainer_fsdp.LISATrainer(n_layers: int, interval_steps: int, lisa_layer_attr_name: str = None, *args, **kwargs)[source]#
Bases:
transformers.Trainer
- _get_all_body_layers() List[torch.nn.Module] [source]#
Fetch all the layers of the model excluding the head