Source code for lmflow.pipeline.utils.lisa_trainer
from typing import Optional
import numpy as np
from transformers import PreTrainedModel
from transformers.trainer_callback import TrainerCallback
[docs]
class DynamicLayerActivationCallback(TrainerCallback):
def __init__(
self,
n_layers: int,
interval_steps: int,
model: PreTrainedModel,
lisa_layers_attribute: Optional[str] = None,
):
super().__init__()
[docs]
self.n_layers = n_layers
[docs]
self.interval_steps = interval_steps
# Determine the way to access layers based on the model type
class_to_layers_map = {
'LlamaForCausalLM': 'model.model.layers',
'Qwen2ForCausalLM': 'model.model.layers',
'MistralForCausalLM': 'model.model.layers',
'MixtralForCausalLM': 'model.model.layers',
'GemmaForCausalLM': 'model.model.layers',
'GPT2LMHeadModel': 'model.transformer.h',
'HymbaForCausalLM': 'model.model.layers',
}
model_class_name = self.model.__class__.__name__
if model_class_name in class_to_layers_map:
self.layers_attribute = class_to_layers_map[model_class_name]
else:
assert lisa_layers_attribute is not None, (
"Please provide the attribute to access the layers of the model."
)
self.layers_attribute = lisa_layers_attribute
[docs]
self.total_layers = len(eval('self.' + self.layers_attribute)) # Dynamically execute to get the number of layers
[docs]
self.active_layers_indices = []
[docs]
def freeze_all_layers(self):
layers = eval('self.' + self.layers_attribute) # Dynamically execute to get layers
for layer in layers:
for param in layer.parameters():
param.requires_grad = False
[docs]
def on_step_begin(self, args, state, control, **kwargs):
# Check if it's time to switch active layers, including at step 0
if state.global_step % self.interval_steps == 0:
self.switch_active_layers()
[docs]
def switch_active_layers(self):
# First, disable gradients for all layers
self.freeze_all_layers()
# Randomly select n_layers to activate
layers = eval('self.' + self.layers_attribute) # Re-fetch layer references
self.active_layers_indices = np.random.choice(range(self.total_layers), self.n_layers, replace=False)
print(f"Activating layers at indices: {self.active_layers_indices} for the next steps.", flush=True)
# Enable gradients only for the selected layers
for idx in self.active_layers_indices:
for param in layers[idx].parameters():
param.requires_grad = True