Source code for lmflow.pipeline.utils.peft_trainer

#!/usr/bin/env python
# coding=utf-8
"""Trainer for Peft models
"""

from __future__ import absolute_import
from transformers import Trainer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.trainer_callback import (
    TrainerCallback,
    TrainerControl,
    TrainerState,
)
from transformers.training_args import TrainingArguments
import os
import numpy as np

[docs] class PeftTrainer(Trainer):
[docs] def _save_checkpoint(self, _, trial, metrics=None): """ Don't save base model, optimizer etc. but create checkpoint folder (needed for saving adapter) """ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" run_dir = self._get_output_dir(trial=trial) output_dir = os.path.join(run_dir, checkpoint_folder) if metrics is not None and self.args.metric_for_best_model is not None: metric_to_check = self.args.metric_for_best_model if not metric_to_check.startswith("eval_"): metric_to_check = f"eval_{metric_to_check}" metric_value = metrics[metric_to_check] operator = np.greater if self.args.greater_is_better else np.less if (self.state.best_metric is None or self.state.best_model_checkpoint is None or operator(metric_value, self.state.best_metric)): self.state.best_metric = metric_value self.state.best_model_checkpoint = output_dir os.makedirs(output_dir, exist_ok=True) if self.args.should_save: self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
[docs] class PeftSavingCallback(TrainerCallback): """ Correctly save PEFT model and not full model """
[docs] def _save(self, model, folder): if folder is None: folder = "" peft_model_path = os.path.join(folder, "adapter_model") model.save_pretrained(peft_model_path)
[docs] def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): """ Save final best model adapter """ self._save(kwargs['model'], state.best_model_checkpoint)
[docs] def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): """ Save intermediate model adapters in case of interrupted training """ folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") self._save(kwargs['model'], folder)
[docs] def on_save( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs, ): checkpoint_folder = os.path.join( args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}" ) self._save(kwargs['model'], checkpoint_folder) peft_model_path = os.path.join(checkpoint_folder, "adapter_model") kwargs["model"].save_pretrained(peft_model_path) return control