lmflow.pipeline.utils.lisa_trainer_del#
Attributes#
Classes#
Functions#
|
Module Contents#
- class lmflow.pipeline.utils.lisa_trainer_del.LISATrainer(n_layers: int, interval_steps: int, lisa_layer_attr_name: str = None, *args, **kwargs)[source]#
Bases:
transformers.Trainer
- _get_all_body_layers() List[torch.nn.Module] [source]#
Fetch all the layers of the model excluding the head
- _switch_active_layers()[source]#
Switch the active layers for the next interval. Objects that will be updated after calling: 1. self.active_layers_indices 2. self.active_layers_names 3. requires_grad of the parameters