Source code for lmflow.pipeline.utils.lisa_trainer_del

import gc
import logging
import time
from typing import Union, List

import numpy as np
import torch
import torch.nn as nn
from transformers import Trainer
from transformers.utils import is_sagemaker_mp_enabled

from lmflow.utils.debug.common import timer


from deepspeed import comm as dist
from deepspeed.runtime.utils import empty_cache, see_memory_usage
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
from deepspeed.accelerator import get_accelerator


[docs] logger = logging.getLogger(__name__)
torch.cuda.memory._record_memory_history(max_entries=100000)
[docs] LISA_LAYER_NAME_MAPPING = { 'LlamaForCausalLM': 'model.layers', 'Qwen2ForCausalLM': 'model.layers', 'MistralForCausalLM': 'model.layers', 'MixtralForCausalLM': 'model.layers', 'GemmaForCausalLM': 'model.layers', 'GPT2LMHeadModel': 'transformer.h', }
[docs] LISA_BODY_LAYER_PARAM_GROUPS_IDX = [2, 3]
[docs] class LISATrainer(Trainer): def __init__( self, n_layers: int, interval_steps: int, lisa_layer_attr_name: str = None, *args, **kwargs ): super().__init__(*args, **kwargs) setattr(self.args, '_trainer', self) # make trainer callbacks accessible to the attributes in trainer # lisa specific attributes
[docs] self.n_layers = n_layers
[docs] self.interval_steps = interval_steps
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model model_class_name = opt_model.__class__.__name__ if model_class_name in LISA_LAYER_NAME_MAPPING: self.lisa_layer_attr_name = LISA_LAYER_NAME_MAPPING[model_class_name] else: assert lisa_layer_attr_name is not None, "Please provide the attribute name for the model layers." self.lisa_layer_attr_name = lisa_layer_attr_name
[docs] self.num_body_layers = len(self._get_all_body_layers())
[docs] self.active_layers_indices = []
[docs] self.histroy_layers_indices = []
[docs] self.active_layers_names = []
[docs] def _get_all_body_layers(self) -> List[nn.Module]: '''Fetch all the layers of the model excluding the head''' opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model layers = eval('opt_model.' + self.lisa_layer_attr_name) return layers
[docs] def _get_active_layers_names(self) -> List[str]: if not hasattr(self, 'active_layers_indices'): return [] all_names = [] layers = self._get_all_body_layers() for idx in self.active_layers_indices: for name, _ in layers[idx].named_parameters(): all_names.append(f"{self.lisa_layer_attr_name}.{idx}.{name}") return all_names
[docs] def _update_active_layer_info(self): # self.active_layers_indices = [3, 4] if self.active_layers_indices == [1, 2] else [1, 2] # self.active_layers_indices = [1, 2] self.active_layers_indices = np.random.choice(range(self.num_body_layers), self.n_layers, replace=False) self.histroy_layers_indices.append(list(self.active_layers_indices)) # self.active_layers_indices.sort() self.active_layers_names = self._get_active_layers_names() print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), flush=True) print(f"History of layers: {self.histroy_layers_indices[:-1]}", flush=True) print(f"Layers for the next steps: {self.active_layers_indices}: {self.active_layers_names}", flush=True)
[docs] def _switch_active_layers(self): ''' Switch the active layers for the next interval. Objects that will be updated after calling: 1. self.active_layers_indices 2. self.active_layers_names 3. requires_grad of the parameters ''' # Disable gradients for all layers layers = self._get_all_body_layers() for layer in layers: for param in layer.parameters(): param.requires_grad = False # Randomly select n_layers to activate self._update_active_layer_info() # update active name and idx # Enable gradients only for the selected layers layers = self._get_all_body_layers() # Re-fetch layer references for idx in self.active_layers_indices: for param in layers[idx].parameters(): param.requires_grad = True
[docs] def maybe_switch_active_layers(self): if self.state.global_step == 0: torch.cuda.memory._dump_snapshot(f'gs_{self.state.global_step}.pickle') if ( self.state.global_step == 0 # skip since already initialized in `create_optimizer` or self.state.global_step % self.interval_steps != 0 ): return if hasattr(self.accelerator, "deepspeed_engine_wrapped"): keys_to_remove = [ statekey for statekey_idx, statekey in enumerate(self.optimizer.state.keys()) if statekey_idx in LISA_BODY_LAYER_PARAM_GROUPS_IDX ] for key in keys_to_remove: del self.optimizer.state[key] else: layers = self._get_all_body_layers() for active_layer_idx in self.active_layers_indices: for name, param in layers[active_layer_idx].named_parameters(): print(f"{name=}") del self.optimizer.state[param] self._switch_active_layers() # update optimizer pg so that the new layers could be initialized in optimizer.step() opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model decay_parameters = self.get_decay_parameter_names(opt_model) self.optimizer.param_groups[2]['params'] = [ p for n, p in opt_model.named_parameters() if ( n in self.active_layers_names and n in decay_parameters and p.requires_grad) ] self.optimizer.param_groups[3]['params'] = [ p for n, p in opt_model.named_parameters() if ( n in self.active_layers_names and n not in decay_parameters and p.requires_grad) ] if hasattr(self.accelerator, "deepspeed_engine_wrapped"): self._reinit_deepspeed_zero_optimizer_params(self.accelerator.deepspeed_engine_wrapped.engine.optimizer) if self.state.global_step <= 20: torch.cuda.memory._dump_snapshot(f'gs_{self.state.global_step}.pickle')
[docs] def create_optimizer(self): """ Setup the optimizer. Adopted from transformers.Trainer.create_optimizer. """ opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model if self.optimizer is None: self._switch_active_layers() # init along with the optimizer decay_parameters = self.get_decay_parameter_names(opt_model) optimizer_grouped_parameters = [ { # this should always be lmhead: # `requires_grad` and `not in active_layers_names` rules out all body layers # `in decay_parameters` rules out ln "params": [ p for n, p in opt_model.named_parameters() if ( n not in self.active_layers_names and n in decay_parameters and p.requires_grad) ], "weight_decay": self.args.weight_decay, }, { # this should always be ln (outside of body layers) "params": [ p for n, p in opt_model.named_parameters() if ( n not in self.active_layers_names and n not in decay_parameters and p.requires_grad) ], "weight_decay": 0.0, }, { # selected body layers with decay "params": [ p for n, p in opt_model.named_parameters() if ( n in self.active_layers_names and n in decay_parameters and p.requires_grad) ], "weight_decay": self.args.weight_decay, }, { # selected body layers without decay "params": [ p for n, p in opt_model.named_parameters() if ( n in self.active_layers_names and n not in decay_parameters and p.requires_grad) ], "weight_decay": 0.0, }, ] optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model) # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs` # e.g. for GaLore optimizer. if "params" in optimizer_kwargs: optimizer_grouped_parameters = optimizer_kwargs.pop("params") # Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs` # e.g. for LOMO optimizer. if "model" in optimizer_kwargs: optimizer_grouped_parameters = optimizer_kwargs.pop("model") # For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict` # to avoid arguments conflicts. if "optimizer_dict" in optimizer_kwargs: optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict") self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) if optimizer_cls.__name__ == "Adam8bit": import bitsandbytes manager = bitsandbytes.optim.GlobalOptimManager.get_instance() skipped = 0 for module in opt_model.modules(): if isinstance(module, nn.Embedding): skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) logger.info(f"skipped {module}: {skipped/2**20}M params") manager.register_module_override(module, "weight", {"optim_bits": 32}) logger.debug(f"bitsandbytes: will optimize {module} in fp32") logger.info(f"skipped: {skipped/2**20}M params") if is_sagemaker_mp_enabled(): self.optimizer = smp.DistributedOptimizer(self.optimizer) see_memory_usage('[mem usage] after creating optimizer', True) return self.optimizer
[docs] def _reinit_deepspeed_zero_optimizer_params(self, optimizer: DeepSpeedZeroOptimizer): num_non_lisa_body_layer_pgs = len(self.optimizer.param_groups) - len(LISA_BODY_LAYER_PARAM_GROUPS_IDX) objs = [ optimizer.bit16_groups, optimizer.round_robin_bit16_groups, optimizer.round_robin_bit16_indices, optimizer.round_robin_bit16_meta, optimizer.bit16_groups_flat, optimizer.groups_padding, optimizer.parallel_partitioned_bit16_groups, optimizer.single_partition_of_fp32_groups, optimizer.partition_size, optimizer.params_in_partition, optimizer.params_not_in_partition, optimizer.first_offset ] for obj in objs: del obj[num_non_lisa_body_layer_pgs:] empty_cache() torch.cuda.empty_cache() gc.collect() for i, param_group in enumerate(optimizer.optimizer.param_groups): if i in range(num_non_lisa_body_layer_pgs): # skip lmhead, ln, etc. continue partition_id = dist.get_rank(group=optimizer.real_dp_process_group[i]) # push this group to list before modify # TODO: Explore simplification that avoids the extra book-keeping by pushing the reordered group trainable_parameters = [] for param in param_group['params']: if param.requires_grad: param.grad_accum = None trainable_parameters.append(param) optimizer.bit16_groups.append(trainable_parameters) # not sure why apex was cloning the weights before flattening # removing cloning here see_memory_usage(f"Before moving param group {i} to CPU") # move all the parameters to cpu to free up GPU space for creating flat buffer # Create temp CPU param copies, free accelerator tensors orig_group_numel = 0 for param in optimizer.bit16_groups[i]: orig_group_numel += param.numel() param.cpu_data = param.data.cpu() param.data = torch.empty(1).to(param.device) empty_cache() see_memory_usage(f"After moving param group {i} to CPU", force=False) # Reorder group parameters for load balancing of gradient partitioning during backward among ranks. # This ensures that gradients are reduced in a fashion such that ownership round robins among the ranks. # For example, rather than 3 gradients (g_n+2, g_n+1, g_n) that are reduced consecutively belonging # to the same rank, instead they will belong to 3 ranks (r_m+2, r_m+1, r_m). if optimizer.round_robin_gradients: round_robin_tensors, round_robin_indices = optimizer._round_robin_reorder( optimizer.bit16_groups[i], dist.get_world_size(group=optimizer.real_dp_process_group[i])) else: round_robin_tensors = optimizer.bit16_groups[i] round_robin_indices = list(range(len(optimizer.bit16_groups[i]))) optimizer.round_robin_bit16_groups.append(round_robin_tensors) optimizer.round_robin_bit16_indices.append(round_robin_indices) # Create meta tensors list, ordered according to round_robin_tensors meta_tensors = [] for param in round_robin_tensors: meta_tensors.append(torch.zeros_like(param.cpu_data, device="meta")) optimizer.round_robin_bit16_meta.append(meta_tensors) # create flat buffer in CPU flattened_buffer = optimizer.flatten_dense_tensors_aligned( optimizer.round_robin_bit16_groups[i], optimizer.nccl_start_alignment_factor * dist.get_world_size(group=optimizer.real_dp_process_group[i]), use_cpu_data=True) # free temp CPU params for param in optimizer.bit16_groups[i]: del param.cpu_data # Move CPU flat tensor to the accelerator memory. optimizer.bit16_groups_flat.append(flattened_buffer.to(get_accelerator().current_device_name())) del flattened_buffer see_memory_usage(f"After flattening and moving param group {i} to GPU", force=False) # Record padding required for alignment if partition_id == dist.get_world_size(group=optimizer.real_dp_process_group[i]) - 1: padding = optimizer.bit16_groups_flat[i].numel() - orig_group_numel else: padding = 0 optimizer.groups_padding.append(padding) if dist.get_rank(group=optimizer.real_dp_process_group[i]) == 0: see_memory_usage(f"After Flattening and after emptying param group {i} cache", force=False) # set model bit16 weight to slices of flattened buffer optimizer._update_model_bit16_weights(i) # divide the flat weights into near equal partition equal to the data parallel degree # each process will compute on a different part of the partition data_parallel_partitions = optimizer.get_data_parallel_partitions(optimizer.bit16_groups_flat[i], i) optimizer.parallel_partitioned_bit16_groups.append(data_parallel_partitions) # verify that data partition start locations are 4-byte aligned for partitioned_data in data_parallel_partitions: assert (partitioned_data.data_ptr() % (2 * optimizer.nccl_start_alignment_factor) == 0) # A partition of the fp32 master weights that will be updated by this process. # Note that the params in single_partition_of_fp32_groups is cloned and detached # from the origin params of the model. if not optimizer.fp16_master_weights_and_gradients: weights_partition = optimizer.parallel_partitioned_bit16_groups[i][partition_id].to( optimizer.device).clone().float().detach() else: weights_partition = optimizer.parallel_partitioned_bit16_groups[i][partition_id].to( optimizer.device).clone().half().detach() if optimizer.cpu_offload: weights_partition = get_accelerator().pin_memory(weights_partition) optimizer.single_partition_of_fp32_groups.append(weights_partition) # Set local optimizer to have flat params of its own partition. # After this, the local optimizer will only contain its own partition of params. # In that case, the local optimizer only saves the states(momentum, variance, etc.) related to its partition's params(zero stage1). optimizer.single_partition_of_fp32_groups[ i].requires_grad = True # keep this in case internal optimizer uses it param_group['params'] = [optimizer.single_partition_of_fp32_groups[i]] partition_size = len(optimizer.bit16_groups_flat[i]) / dist.get_world_size(group=optimizer.real_dp_process_group[i]) params_in_partition, params_not_in_partition, first_offset = optimizer.get_partition_info( optimizer.round_robin_bit16_groups[i], partition_size, partition_id) optimizer.partition_size.append(partition_size) optimizer.params_in_partition.append(params_in_partition) optimizer.params_not_in_partition.append(params_not_in_partition) optimizer.first_offset.append(first_offset)
[docs] def tag(info=''): time.sleep(10) print(info) print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), flush=True)