Source code for lmflow.pipeline.utils.lisa_trainer_cache

import gc
import logging
import time
from collections import defaultdict
from typing import Union, List, DefaultDict, Any

import numpy as np
import torch
import torch.nn as nn
from transformers import Trainer
from transformers.utils import is_sagemaker_mp_enabled

from lmflow.utils.debug.common import timer


from deepspeed import comm as dist
from deepspeed.runtime.utils import empty_cache, see_memory_usage
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
from deepspeed.accelerator import get_accelerator


[docs] logger = logging.getLogger(__name__)
torch.cuda.memory._record_memory_history(max_entries=100000)
[docs] LISA_LAYER_NAME_MAPPING = { 'LlamaForCausalLM': 'model.layers', 'Qwen2ForCausalLM': 'model.layers', 'MistralForCausalLM': 'model.layers', 'MixtralForCausalLM': 'model.layers', 'GemmaForCausalLM': 'model.layers', 'GPT2LMHeadModel': 'transformer.h', }
[docs] LISA_BODY_LAYER_PARAM_GROUPS_IDX = [0, 1]
[docs] NON_LISA_LAYER_PARAM_GROUPS_IDX = [2, 3]
[docs] class LISATrainer(Trainer): def __init__( self, n_layers: int, interval_steps: int, lisa_layer_attr_name: str = None, *args, **kwargs ): super().__init__(*args, **kwargs) setattr(self.args, '_trainer', self) # make trainer callbacks accessible to the attributes in trainer # lisa specific attributes
[docs] self.n_layers = n_layers
[docs] self.interval_steps = interval_steps
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model model_class_name = opt_model.__class__.__name__ if model_class_name in LISA_LAYER_NAME_MAPPING: self.lisa_layer_attr_name = LISA_LAYER_NAME_MAPPING[model_class_name] else: assert lisa_layer_attr_name is not None, "Please provide the attribute name for the model layers." self.lisa_layer_attr_name = lisa_layer_attr_name
[docs] self.num_body_layers = len(self._get_all_body_layers())
[docs] self.active_layers_indices = []
[docs] self.histroy_layers_indices = []
[docs] self.active_layers_names = []
[docs] self._optimizer_param_group_initialized = False
[docs] def _get_all_body_layers(self) -> List[nn.Module]: '''Fetch all the layers of the model excluding the head''' opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model layers = eval('opt_model.' + self.lisa_layer_attr_name) return layers
[docs] def _get_active_layers_names(self) -> List[str]: if not hasattr(self, 'active_layers_indices'): return [] all_names = [] layers = self._get_all_body_layers() for idx in self.active_layers_indices: for name, _ in layers[idx].named_parameters(): all_names.append(f"{self.lisa_layer_attr_name}.{idx}.{name}") return all_names
[docs] def _update_active_layer_info(self): # self.active_layers_indices = [3, 4] if self.active_layers_indices == [1, 2] else [1, 2] # self.active_layers_indices = [1, 2] self.active_layers_indices = np.random.choice(range(self.num_body_layers), self.n_layers, replace=False) self.histroy_layers_indices.append(list(self.active_layers_indices)) # self.active_layers_indices.sort() self.active_layers_names = self._get_active_layers_names() print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), flush=True) print(f"History of layers: {self.histroy_layers_indices[:-1]}", flush=True) print(f"Layers for the next steps: {self.active_layers_indices}: {self.active_layers_names}", flush=True)
[docs] def _switch_active_layers(self): ''' Switch the active layers for the next interval. Objects that will be updated after calling: 1. self.active_layers_indices 2. self.active_layers_names 3. requires_grad of the parameters ''' # Disable gradients for all layers layers = self._get_all_body_layers() for layer in layers: for param in layer.parameters(): param.requires_grad = False # Randomly select n_layers to activate self._update_active_layer_info() # update active name and idx # Enable gradients only for the selected layers layers = self._get_all_body_layers() # Re-fetch layer references for idx in self.active_layers_indices: for param in layers[idx].parameters(): param.requires_grad = True
[docs] def maybe_switch_active_layers(self): # if self.state.global_step == 0: # torch.cuda.memory._dump_snapshot(f'gs_{self.state.global_step}.pickle') if ( self.state.global_step == 0 # skip since already initialized in `create_optimizer` or self.state.global_step % self.interval_steps != 0 ): return # cache param groups that don't need to swtich (lmhead, ln) non_lisa_param_groups = [self.optimizer.param_groups[i] for i in NON_LISA_LAYER_PARAM_GROUPS_IDX] # cache states of non-lisa layers non_lisa_states: DefaultDict[torch.Tensor, Any] = defaultdict(dict) for pg in non_lisa_param_groups: for param in pg['params']: non_lisa_states[param] = self.optimizer.state[param] # clear optimizer to clear the states self.optimizer = None if hasattr(self.accelerator, "deepspeed_engine_wrapped"): if self.accelerator.deepspeed_engine_wrapped is not None: self.accelerator.deepspeed_engine_wrapped.engine.empty_partition_cache() self.accelerator.deepspeed_engine_wrapped.engine.destroy() self.accelerator.deepspeed_engine_wrapped = None gc.collect() torch.cuda.empty_cache() # init new optimizer w/ new lisa layers self.create_optimizer() _, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) # put back non-lisa param groups self.optimizer.param_groups.extend([non_lisa_param_groups[0], non_lisa_param_groups[1]]) if hasattr(self.accelerator, "deepspeed_engine_wrapped"): self._post_init_deepspeed_zero_optimizer_params(self.accelerator.deepspeed_engine_wrapped.engine.optimizer) # put back non-lisa states for gindex in NON_LISA_LAYER_PARAM_GROUPS_IDX: for param in self.optimizer.param_groups[gindex]['params']: self.optimizer.state[param] = non_lisa_states[param] del non_lisa_param_groups del non_lisa_states gc.collect() torch.cuda.empty_cache() if hasattr(self.accelerator, "deepspeed_engine_wrapped"): self.accelerator.deepspeed_engine_wrapped.engine.optimizer._link_all_hp_params() torch.cuda.memory._dump_snapshot(f'gs_{self.state.global_step}.pickle')
[docs] def create_optimizer(self): """ Setup the optimizer. Adopted from transformers.Trainer.create_optimizer. """ opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model if self.optimizer is None: self._switch_active_layers() # init along with the optimizer optimizer_grouped_parameters = self._prepare_optimizer_param_group(opt_model) optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model) # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs` # e.g. for GaLore optimizer. if "params" in optimizer_kwargs: optimizer_grouped_parameters = optimizer_kwargs.pop("params") # Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs` # e.g. for LOMO optimizer. if "model" in optimizer_kwargs: optimizer_grouped_parameters = optimizer_kwargs.pop("model") # For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict` # to avoid arguments conflicts. if "optimizer_dict" in optimizer_kwargs: optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict") self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) if optimizer_cls.__name__ == "Adam8bit": import bitsandbytes manager = bitsandbytes.optim.GlobalOptimManager.get_instance() skipped = 0 for module in opt_model.modules(): if isinstance(module, nn.Embedding): skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) logger.info(f"skipped {module}: {skipped/2**20}M params") manager.register_module_override(module, "weight", {"optim_bits": 32}) logger.debug(f"bitsandbytes: will optimize {module} in fp32") logger.info(f"skipped: {skipped/2**20}M params") if is_sagemaker_mp_enabled(): self.optimizer = smp.DistributedOptimizer(self.optimizer) print(f"optim after create_optimizer {[len(pg['params']) for pg in self.optimizer.param_groups]=}") return self.optimizer
[docs] def _prepare_optimizer_param_group(self, opt_model: nn.Module): decay_parameters = self.get_decay_parameter_names(opt_model) print(f"{decay_parameters=}") if not self._optimizer_param_group_initialized: optimizer_grouped_parameters = [ { # selected body layers with decay "params": [ p for n, p in opt_model.named_parameters() if ( n in self.active_layers_names and n in decay_parameters and p.requires_grad) ], "weight_decay": self.args.weight_decay, }, { # selected body layers without decay "params": [ p for n, p in opt_model.named_parameters() if ( n in self.active_layers_names and n not in decay_parameters and p.requires_grad) ], "weight_decay": 0.0, }, { # this should always be lmhead: # `requires_grad` and `not in active_layers_names` rules out all body layers # `in decay_parameters` rules out ln "params": [ p for n, p in opt_model.named_parameters() if ( n not in self.active_layers_names and n in decay_parameters and p.requires_grad) ], "weight_decay": self.args.weight_decay, }, { # this should always be ln (outside of body layers) "params": [ p for n, p in opt_model.named_parameters() if ( n not in self.active_layers_names and n not in decay_parameters and p.requires_grad) ], "weight_decay": 0.0, }, ] self._optimizer_param_group_initialized = True else: optimizer_grouped_parameters = [ { # selected body layers with decay "params": [ p for n, p in opt_model.named_parameters() if ( n in self.active_layers_names and n in decay_parameters and p.requires_grad) ], "weight_decay": self.args.weight_decay, }, { # selected body layers without decay "params": [ p for n, p in opt_model.named_parameters() if ( n in self.active_layers_names and n not in decay_parameters and p.requires_grad) ], "weight_decay": 0.0, }, ] return optimizer_grouped_parameters
[docs] def _post_init_deepspeed_zero_optimizer_params(self, optimizer: DeepSpeedZeroOptimizer): optimizer.real_dp_process_group = [optimizer.dp_process_group for i in range(len(optimizer.optimizer.param_groups))] optimizer.partition_count = [dist.get_world_size(group=optimizer.dp_process_group) for i in range(len(optimizer.optimizer.param_groups))] for i, param_group in enumerate(optimizer.optimizer.param_groups): if i in LISA_BODY_LAYER_PARAM_GROUPS_IDX: # skip lisa layers continue partition_id = dist.get_rank(group=optimizer.real_dp_process_group[i]) # push this group to list before modify # TODO: Explore simplification that avoids the extra book-keeping by pushing the reordered group trainable_parameters = [] for param in param_group['params']: if param.requires_grad: param.grad_accum = None trainable_parameters.append(param) optimizer.bit16_groups.append(trainable_parameters) # not sure why apex was cloning the weights before flattening # removing cloning here see_memory_usage(f"Before moving param group {i} to CPU") # move all the parameters to cpu to free up GPU space for creating flat buffer # Create temp CPU param copies, free accelerator tensors orig_group_numel = 0 for param in optimizer.bit16_groups[i]: orig_group_numel += param.numel() param.cpu_data = param.data.cpu() param.data = torch.empty(1).to(param.device) empty_cache() see_memory_usage(f"After moving param group {i} to CPU", force=False) # Reorder group parameters for load balancing of gradient partitioning during backward among ranks. # This ensures that gradients are reduced in a fashion such that ownership round robins among the ranks. # For example, rather than 3 gradients (g_n+2, g_n+1, g_n) that are reduced consecutively belonging # to the same rank, instead they will belong to 3 ranks (r_m+2, r_m+1, r_m). if optimizer.round_robin_gradients: round_robin_tensors, round_robin_indices = optimizer._round_robin_reorder( optimizer.bit16_groups[i], dist.get_world_size(group=optimizer.real_dp_process_group[i])) else: round_robin_tensors = optimizer.bit16_groups[i] round_robin_indices = list(range(len(optimizer.bit16_groups[i]))) optimizer.round_robin_bit16_groups.append(round_robin_tensors) optimizer.round_robin_bit16_indices.append(round_robin_indices) # Create meta tensors list, ordered according to round_robin_tensors meta_tensors = [] for param in round_robin_tensors: meta_tensors.append(torch.zeros_like(param.cpu_data, device="meta")) optimizer.round_robin_bit16_meta.append(meta_tensors) # create flat buffer in CPU flattened_buffer = optimizer.flatten_dense_tensors_aligned( optimizer.round_robin_bit16_groups[i], optimizer.nccl_start_alignment_factor * dist.get_world_size(group=optimizer.real_dp_process_group[i]), use_cpu_data=True) # free temp CPU params for param in optimizer.bit16_groups[i]: del param.cpu_data # Move CPU flat tensor to the accelerator memory. optimizer.bit16_groups_flat.append(flattened_buffer.to(get_accelerator().current_device_name())) del flattened_buffer see_memory_usage(f"After flattening and moving param group {i} to GPU", force=False) # Record padding required for alignment if partition_id == dist.get_world_size(group=optimizer.real_dp_process_group[i]) - 1: padding = optimizer.bit16_groups_flat[i].numel() - orig_group_numel else: padding = 0 optimizer.groups_padding.append(padding) if dist.get_rank(group=optimizer.real_dp_process_group[i]) == 0: see_memory_usage(f"After Flattening and after emptying param group {i} cache", force=False) # set model bit16 weight to slices of flattened buffer optimizer._update_model_bit16_weights(i) # divide the flat weights into near equal partition equal to the data parallel degree # each process will compute on a different part of the partition data_parallel_partitions = optimizer.get_data_parallel_partitions(optimizer.bit16_groups_flat[i], i) optimizer.parallel_partitioned_bit16_groups.append(data_parallel_partitions) # verify that data partition start locations are 4-byte aligned for partitioned_data in data_parallel_partitions: assert (partitioned_data.data_ptr() % (2 * optimizer.nccl_start_alignment_factor) == 0) # A partition of the fp32 master weights that will be updated by this process. # Note that the params in single_partition_of_fp32_groups is cloned and detached # from the origin params of the model. if not optimizer.fp16_master_weights_and_gradients: weights_partition = optimizer.parallel_partitioned_bit16_groups[i][partition_id].to( optimizer.device).clone().float().detach() else: weights_partition = optimizer.parallel_partitioned_bit16_groups[i][partition_id].to( optimizer.device).clone().half().detach() if optimizer.cpu_offload: weights_partition = get_accelerator().pin_memory(weights_partition) optimizer.single_partition_of_fp32_groups.append(weights_partition) # Set local optimizer to have flat params of its own partition. # After this, the local optimizer will only contain its own partition of params. # In that case, the local optimizer only saves the states(momentum, variance, etc.) related to its partition's params(zero stage1). optimizer.single_partition_of_fp32_groups[ i].requires_grad = True # keep this in case internal optimizer uses it param_group['params'] = [optimizer.single_partition_of_fp32_groups[i]] partition_size = len(optimizer.bit16_groups_flat[i]) / dist.get_world_size(group=optimizer.real_dp_process_group[i]) params_in_partition, params_not_in_partition, first_offset = optimizer.get_partition_info( optimizer.round_robin_bit16_groups[i], partition_size, partition_id) optimizer.partition_size.append(partition_size) optimizer.params_in_partition.append(params_in_partition) optimizer.params_not_in_partition.append(params_not_in_partition) optimizer.first_offset.append(first_offset)
[docs] def tag(info=''): time.sleep(10) print(info) print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), flush=True)