Source code for lmflow.pipeline.utils.dpov2_trainer

import logging
from typing import Optional, Union, Dict, List, Any, Tuple, Callable, Literal

from datasets import Dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
    PreTrainedModel, 
    PreTrainedTokenizerBase, 
    DataCollator, 
    TrainingArguments, 
    TrainerCallback
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput

from lmflow.pipeline.utils.dpov2_dataprocessor import PreferenceDataCollatorWithPadding
from lmflow.utils.versioning import is_trl_available

if is_trl_available():
    from trl import DPOTrainer
else:
    raise ImportError("Please install trl package to use dpo_aligner.py")


[docs] logger = logging.getLogger(__name__)
[docs] class DPOv2Trainer(DPOTrainer): def __init__( self, model: Union[PreTrainedModel, nn.Module] = None, ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None, beta: float = 0.1, loss_type: Literal["sigmoid", "hinge", "cross_entropy", "kl", "rev_kl", "raft"] = "rev_kl", args: TrainingArguments = None, data_collator: Optional[DataCollator] = None, label_pad_token_id: int = -100, padding_value: int = 0, truncation_mode: str = "keep_end", train_dataset: Optional[Dataset] = None, eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None, model_init: Optional[Callable[[], PreTrainedModel]] = None, callbacks: Optional[List[TrainerCallback]] = None, optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( None, None, ), preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, max_length: Optional[int] = None, max_prompt_length: Optional[int] = None, max_target_length: Optional[int] = None, peft_config: Optional[Dict] = None, is_encoder_decoder: Optional[bool] = None, disable_dropout: bool = True, generate_during_eval: bool = False, compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None, mask_prompt: Optional[bool] = False, len_penalty: float = 0, preprocessing_num_workers: int = 1, ): if data_collator is None: data_collator = PreferenceDataCollatorWithPadding( tokenizer, max_length=max_length, max_prompt_length=max_prompt_length, label_pad_token_id=label_pad_token_id, padding_value=padding_value, truncation_mode=truncation_mode, is_encoder_decoder=False, max_target_length=max_target_length, mask_prompt=mask_prompt, ) super().__init__( model=model, ref_model=ref_model, beta=beta, loss_type=loss_type, args=args, data_collator=data_collator, label_pad_token_id=label_pad_token_id, padding_value=padding_value, truncation_mode=truncation_mode, train_dataset=train_dataset, eval_dataset=eval_dataset, tokenizer=tokenizer, model_init=model_init, callbacks=callbacks, optimizers=optimizers, preprocess_logits_for_metrics=preprocess_logits_for_metrics, max_length=max_length, max_prompt_length=max_prompt_length, max_target_length=max_target_length, peft_config=peft_config, is_encoder_decoder=is_encoder_decoder, disable_dropout=disable_dropout, generate_during_eval=generate_during_eval, compute_metrics=compute_metrics, dataset_num_proc=preprocessing_num_workers, )
[docs] self.use_dpo_data_collator = True
[docs] self.len_penalty = len_penalty
[docs] def dpo_loss( self, policy_chosen_logps: torch.FloatTensor, policy_rejected_logps: torch.FloatTensor, reference_chosen_logps: torch.FloatTensor, reference_rejected_logps: torch.FloatTensor, reference_free: bool = False, margin: Optional[torch.FloatTensor] = None, len_penalty: float = 0, ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: """Compute the DPO loss for a batch of policy and reference model log probabilities. Args: policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0. reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses. Returns: A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). The losses tensor contains the DPO loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. """ pi_logratios = policy_chosen_logps - policy_rejected_logps ref_logratios = reference_chosen_logps - reference_rejected_logps + len_penalty if reference_free: ref_logratios = 0 if self.loss_type == "sigmoid": logits = pi_logratios - ref_logratios losses = -F.logsigmoid(self.beta * logits) elif self.loss_type == "hinge": logits = pi_logratios - ref_logratios losses = torch.relu(1 - self.beta * logits) elif self.loss_type == "cross_entropy": logits = policy_chosen_logps - reference_chosen_logps losses = -F.logsigmoid(self.beta * logits) elif self.loss_type == "raft": losses = -policy_chosen_logps # F.logsigmoid(self.beta * logits) elif self.loss_type == "ipo": logits = pi_logratios - ref_logratios # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper. losses = (logits - 1 / (2 * self.beta)) ** 2 elif self.loss_type == "kl": logits = pi_logratios - ref_logratios p = F.sigmoid(self.beta * logits) p = torch.minimum(p, torch.ones_like(p) * 0.999) p_gt = torch.exp(margin) / (1 + torch.exp(margin) + 1e-3) losses = p * (torch.log(p) - torch.log(p_gt)) + (1 - p) * (torch.log(1 - p) - torch.log(1 - p_gt)) elif self.loss_type == "tv": logits = pi_logratios - ref_logratios p = F.sigmoid(self.beta * logits) p_gt = torch.exp(margin) / (1 + torch.exp(margin)) losses = torch.abs(p - p_gt) elif self.loss_type == "hellinger": logits = pi_logratios - ref_logratios p = F.sigmoid(self.beta * logits) p = torch.minimum(p, torch.ones_like(p) * 0.999) p_gt = torch.exp(margin) / (1 + torch.exp(margin)) losses = 0.5 * ((p**0.5 - p_gt**0.5) ** 2 + ((1 - p) ** 0.5 - (1 - p_gt) ** 0.5) ** 2) elif self.loss_type == "rev_kl": logits = pi_logratios - ref_logratios logp = F.logsigmoid(self.beta * logits) logp_neg = F.logsigmoid(-self.beta * logits) p_gt = F.sigmoid(margin) losses = -p_gt * (logp) - (1 - p_gt) * logp_neg else: raise ValueError(f"Unknown loss type: {self.loss_type}.") chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach() return losses, chosen_rewards, rejected_rewards
[docs] def get_batch_loss_metrics( self, model, batch: Dict[str, Union[List, torch.LongTensor]], train_eval: Literal["train", "eval"] = "train", ): return self.get_batch_metrics(model, batch, train_eval)
[docs] def get_batch_metrics( self, model, batch: Dict[str, Union[List, torch.LongTensor]], train_eval: Literal["train", "eval"] = "train", ): """Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" metrics = {} ( policy_chosen_logps, policy_rejected_logps, policy_chosen_logits, policy_rejected_logits, ) = self.concatenated_forward(model, batch) with torch.no_grad(): if self.ref_model is None: with self.accelerator.unwrap_model(self.model).disable_adapter(): ( reference_chosen_logps, reference_rejected_logps, _, _, ) = self.concatenated_forward(self.model, batch) else: ( reference_chosen_logps, reference_rejected_logps, _, _, ) = self.concatenated_forward(self.ref_model, batch) if self.len_penalty > 0: chosen_len = batch["chosen_input_ids"].shape[1] * self.len_penalty rejected_len = batch["rejected_input_ids"].shape[1] * self.len_penalty len_penalty = chosen_len - rejected_len else: chosen_len = 1 rejected_len = 1 len_penalty = 0 margin = torch.tensor(batch["margin"], dtype=policy_chosen_logps.dtype).to(self.accelerator.device) losses, chosen_rewards, rejected_rewards = self.dpo_loss( policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, margin=margin, len_penalty=len_penalty, ) reward_accuracies = (chosen_rewards > rejected_rewards).float() prefix = "eval_" if train_eval == "eval" else "" metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().mean() metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().mean() metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean() metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().mean() metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().mean() metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().mean() metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().mean() metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().mean() return losses.mean(), metrics