Source code for lmflow.optim.utils

from typing import Optional, Any, Tuple

from transformers import PreTrainedModel
from transformers.utils import is_sagemaker_mp_enabled

import lmflow.optim.optimizers as optim
from lmflow.args import OptimizerNames, TrainingArguments

[docs] def create_customized_optimizer(base_trainer_class, model_args): class CustomizedOptimTrainer(base_trainer_class): @staticmethod def get_optimizer_cls_and_kwargs( args: TrainingArguments, model: Optional[PreTrainedModel] = None, ) -> Tuple[Any, Any]: # parse args.optim_args optim_args = {} if args.customized_optim_args: for mapping in args.customized_optim_args.replace(" ", "").split(","): key, value = mapping.split("=") optim_args[key] = value optimizer_kwargs = {"lr": args.learning_rate} if args.customized_optim == OptimizerNames.DUMMY: optimizer_cls = optim.Dummy dummy_kwargs = { "betas": (args.optim_dummy_beta1, args.optim_dummy_beta2), } optimizer_kwargs.update(dummy_kwargs) elif args.customized_optim == OptimizerNames.ADABELIEF: optimizer_cls = optim.AdaBelief adabelief_kwargs = { "betas": (args.optim_beta1, args.optim_beta2), "weight_decay": (args.optim_weight_decay) } optimizer_kwargs.update(adabelief_kwargs) elif args.customized_optim == OptimizerNames.ADABOUND: optimizer_cls = optim.AdaBound adabound_kwargs = { "betas": (args.optim_beta1, args.optim_beta2), "weight_decay": (args.optim_weight_decay) } optimizer_kwargs.update(adabound_kwargs) elif args.customized_optim == OptimizerNames.LARS: optimizer_cls = optim.LARS lars_kwargs = { "momentum": (args.optim_momentum), "weight_decay": (args.optim_weight_decay), } optimizer_kwargs.update(lars_kwargs) elif args.customized_optim == OptimizerNames.LAMB: optimizer_cls = optim.Lamb lamb_kwargs = { "betas": (args.optim_beta1, args.optim_beta2), "weight_decay": (args.optim_weight_decay), } optimizer_kwargs.update(lamb_kwargs) elif args.customized_optim == OptimizerNames.ADAMAX: optimizer_cls = optim.Adamax adamax_kwargs = { "betas": (args.optim_beta1, args.optim_beta2), "weight_decay": (args.optim_weight_decay), } optimizer_kwargs.update(adamax_kwargs) elif args.customized_optim == OptimizerNames.NADAM: optimizer_cls = optim.NAdam nadam_kwargs = { "betas": (args.optim_beta1, args.optim_beta2), "weight_decay": (args.optim_weight_decay), } optimizer_kwargs.update(nadam_kwargs) elif args.customized_optim == OptimizerNames.RADAM: optimizer_cls = optim.RAdam radam_kwargs = { "betas": (args.optim_beta1, args.optim_beta2), "weight_decay": (args.optim_weight_decay), } optimizer_kwargs.update(radam_kwargs) elif args.customized_optim == OptimizerNames.ADAMP: optimizer_cls = optim.AdamP adamp_kwargs = { "betas": (args.optim_beta1, args.optim_beta2), "weight_decay": (args.optim_weight_decay), } optimizer_kwargs.update(adamp_kwargs) elif args.customized_optim == OptimizerNames.SGDP: optimizer_cls = optim.SGDP sgdp_kwargs = { "momentum": (args.optim_momentum), "weight_decay": (args.optim_weight_decay), } optimizer_kwargs.update(sgdp_kwargs) elif args.customized_optim == OptimizerNames.YOGI: optimizer_cls = optim.Yogi yogi_kwargs = { "betas": (args.optim_beta1, args.optim_beta2), "weight_decay": (args.optim_weight_decay), } optimizer_kwargs.update(yogi_kwargs) elif args.customized_optim == OptimizerNames.SOPHIA: optimizer_cls = optim.SophiaG sophia_kwargs = { "betas": (args.optim_beta1, args.optim_beta2), "weight_decay": (args.optim_weight_decay), } optimizer_kwargs.update(sophia_kwargs) elif args.customized_optim == OptimizerNames.ADAM: optimizer_cls = optim.Adam adam_kwargs = { "betas": (args.optim_beta1, args.optim_beta2), } optimizer_kwargs.update(adam_kwargs) elif args.customized_optim == OptimizerNames.NOVOGRAD: optimizer_cls = optim.NovoGrad novograd_kwargs = { "betas": (args.optim_beta1, args.optim_beta2), "weight_decay": (args.optim_weight_decay), } optimizer_kwargs.update(novograd_kwargs) elif args.customized_optim == OptimizerNames.ADADELTA: optimizer_cls = optim.Adadelta adadelta_kwargs = { } optimizer_kwargs.update(adadelta_kwargs) elif args.customized_optim == OptimizerNames.ADAGRAD: optimizer_cls = optim.AdaGrad adagrad_kwargs = { } optimizer_kwargs.update(adagrad_kwargs) elif args.customized_optim == OptimizerNames.ADAMW_SCHEDULE_FREE: optimizer_cls = optim.AdamWScheduleFree adamw_schedule_free_kwargs = { "betas": (args.optim_beta1, args.optim_beta2), "weight_decay": (args.optim_weight_decay), } optimizer_kwargs.update(adamw_schedule_free_kwargs) elif args.customized_optim == OptimizerNames.SGD_SCHEDULE_FREE: optimizer_cls = optim.SGDScheduleFree sgd_schedule_free_kwargs = { "momentum": (args.optim_momentum), "weight_decay": (args.optim_weight_decay), } optimizer_kwargs.update(sgd_schedule_free_kwargs) elif args.customized_optim == OptimizerNames.ADAN: optimizer_cls = optim.Adan adan_kwargs = { "betas": (args.optim_beta1, args.optim_beta2, args.optim_beta3), "weight_decay": (args.optim_weight_decay), } optimizer_kwargs.update(adan_kwargs) else: raise ValueError( f"Trainer cannot instantiate unsupported optimizer: " f" {args.customized_optim}" ) return optimizer_cls, optimizer_kwargs def create_optimizer(self): opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model if self.optimizer is None: decay_parameters = self.get_decay_parameter_names(opt_model) optimizer_grouped_parameters = [ { "params": [ p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) ], "weight_decay": self.args.weight_decay, }, { "params": [ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) ], "weight_decay": 0.0, }, ] optimizer_cls, optimizer_kwargs = CustomizedOptimTrainer.get_optimizer_cls_and_kwargs(self.args, opt_model) # Overwrite `params` in case it's created by # `get_optimizer_cls_and_kwargs` e.g. for GaLore optimizer. if "params" in optimizer_kwargs: optimizer_grouped_parameters = optimizer_kwargs.pop( "params" ) # For layer-wise dummy optimizers we overwrite # optimizer_grouped_parameters with `optimizer_dict` to # avoid arguments conflicts. if "optimizer_dict" in optimizer_kwargs: optimizer_grouped_parameters = optimizer_kwargs.pop( "optimizer_dict" ) self.optimizer = optimizer_cls( optimizer_grouped_parameters, **optimizer_kwargs ) if is_sagemaker_mp_enabled(): self.optimizer = smp.DistributedOptimizer(self.optimizer) return CustomizedOptimTrainer