import gc
import logging
import time
from typing import Union, List
import numpy as np
import torch
import torch.nn as nn
from transformers import Trainer
from transformers.utils import is_sagemaker_mp_enabled
[docs]
logger = logging.getLogger(__name__)
torch.cuda.memory._record_memory_history(max_entries=100000)
[docs]
LISA_LAYER_NAME_MAPPING = {
'LlamaForCausalLM': 'model.layers',
'Qwen2ForCausalLM': 'model.layers',
'MistralForCausalLM': 'model.layers',
'MixtralForCausalLM': 'model.layers',
'GemmaForCausalLM': 'model.layers',
'GPT2LMHeadModel': 'transformer.h',
}
[docs]
LISA_BODY_LAYER_PARAM_GROUPS_IDX = [2, 3]
[docs]
class LISATrainer(Trainer):
def __init__(
self,
n_layers: int,
interval_steps: int,
lisa_layer_attr_name: str = None,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
setattr(self.args, '_trainer', self) # make trainer callbacks accessible to the attributes in trainer
# lisa specific attributes
[docs]
self.n_layers = n_layers
[docs]
self.interval_steps = interval_steps
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
model_class_name = opt_model.__class__.__name__
if model_class_name in LISA_LAYER_NAME_MAPPING:
self.lisa_layer_attr_name = LISA_LAYER_NAME_MAPPING[model_class_name]
else:
assert lisa_layer_attr_name is not None, "Please provide the attribute name for the model layers."
self.lisa_layer_attr_name = lisa_layer_attr_name
[docs]
self.num_body_layers = len(self._get_all_body_layers())
[docs]
self.active_layers_indices = []
[docs]
self.histroy_layers_indices = []
[docs]
self.active_layers_names = []
[docs]
def _get_all_body_layers(self) -> List[nn.Module]:
'''Fetch all the layers of the model excluding the head'''
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
layers = eval('opt_model.' + self.lisa_layer_attr_name)
return layers
[docs]
def _get_active_layers_names(self) -> List[str]:
if not hasattr(self, 'active_layers_indices'):
return []
all_names = []
layers = self._get_all_body_layers()
for idx in self.active_layers_indices:
for name, _ in layers[idx].named_parameters():
all_names.append(f"{self.lisa_layer_attr_name}.{idx}.{name}")
return all_names
[docs]
def _update_active_layer_info(self):
# self.active_layers_indices = [3, 4] if self.active_layers_indices == [1, 2] else [1, 2]
# self.active_layers_indices = [1, 2]
self.active_layers_indices = np.random.choice(range(self.num_body_layers), self.n_layers, replace=False)
self.histroy_layers_indices.append(list(self.active_layers_indices))
# self.active_layers_indices.sort()
self.active_layers_names = self._get_active_layers_names()
print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), flush=True)
print(f"History of layers: {self.histroy_layers_indices[:-1]}", flush=True)
print(f"Layers for the next steps: {self.active_layers_indices}: {self.active_layers_names}", flush=True)
[docs]
def _switch_active_layers(self):
'''
Switch the active layers for the next interval. Objects that will be updated after calling:
1. self.active_layers_indices
2. self.active_layers_names
3. requires_grad of the parameters
'''
# Disable gradients for all layers
layers = self._get_all_body_layers()
for layer in layers:
for param in layer.parameters():
param.requires_grad = False
# Randomly select n_layers to activate
self._update_active_layer_info() # update active name and idx
# Enable gradients only for the selected layers
layers = self._get_all_body_layers() # Re-fetch layer references
for idx in self.active_layers_indices:
for param in layers[idx].parameters():
param.requires_grad = True
[docs]
def maybe_switch_active_layers(self):
if (
self.state.global_step == 0 # skip since already initialized in `create_optimizer`
or
self.state.global_step % self.interval_steps != 0
):
return
layers = self._get_all_body_layers()
for active_layer_idx in self.active_layers_indices:
for name, param in layers[active_layer_idx].named_parameters():
print(f"{name=}")
del self.optimizer.state[param]
self._switch_active_layers()
# update optimizer pg so that the new layers could be initialized in optimizer.step()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
decay_parameters = self.get_decay_parameter_names(opt_model)
self.optimizer.param_groups[2]['params'] = [
p for n, p in opt_model.named_parameters() if (
n in self.active_layers_names and n in decay_parameters and p.requires_grad)
]
self.optimizer.param_groups[3]['params'] = [
p for n, p in opt_model.named_parameters() if (
n in self.active_layers_names and n not in decay_parameters and p.requires_grad)
]
if self.state.global_step <= 20:
torch.cuda.memory._dump_snapshot(f'gs_{self.state.global_step}.pickle')
[docs]
def create_optimizer(self):
"""
Setup the optimizer. Adopted from transformers.Trainer.create_optimizer.
"""
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None:
self._switch_active_layers() # init along with the optimizer
decay_parameters = self.get_decay_parameter_names(opt_model)
optimizer_grouped_parameters = [
{
# this should always be lmhead:
# `requires_grad` and `not in active_layers_names` rules out all body layers
# `in decay_parameters` rules out ln
"params": [
p for n, p in opt_model.named_parameters() if (
n not in self.active_layers_names and n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
# this should always be ln (outside of body layers)
"params": [
p for n, p in opt_model.named_parameters() if (
n not in self.active_layers_names and n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
{
# selected body layers with decay
"params": [
p for n, p in opt_model.named_parameters() if (
n in self.active_layers_names and n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
# selected body layers without decay
"params": [
p for n, p in opt_model.named_parameters() if (
n in self.active_layers_names and n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model)
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for GaLore optimizer.
if "params" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for LOMO optimizer.
if "model" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
# to avoid arguments conflicts.
if "optimizer_dict" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped/2**20}M params")
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer(self.optimizer)
return self.optimizer