lmflow.pipeline.utils.dpov2_trainer#

Attributes#

Classes#

Module Contents#

lmflow.pipeline.utils.dpov2_trainer.logger[source]#
class lmflow.pipeline.utils.dpov2_trainer.DPOv2Trainer(model: transformers.PreTrainedModel | torch.nn.Module = None, ref_model: transformers.PreTrainedModel | torch.nn.Module | None = None, beta: float = 0.1, loss_type: Literal['sigmoid', 'hinge', 'cross_entropy', 'kl', 'rev_kl', 'raft'] = 'rev_kl', args: transformers.TrainingArguments = None, data_collator: transformers.DataCollator | None = None, label_pad_token_id: int = -100, padding_value: int = 0, truncation_mode: str = 'keep_end', train_dataset: datasets.Dataset | None = None, eval_dataset: datasets.Dataset | Dict[str, datasets.Dataset] | None = None, tokenizer: transformers.PreTrainedTokenizerBase | None = None, model_init: Callable[[], transformers.PreTrainedModel] | None = None, callbacks: List[transformers.trainer_callback.TrainerCallback] | None = None, optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, max_length: int | None = None, max_prompt_length: int | None = None, max_target_length: int | None = None, peft_config: Dict | None = None, is_encoder_decoder: bool | None = None, disable_dropout: bool = True, generate_during_eval: bool = False, compute_metrics: Callable[[transformers.trainer_utils.EvalLoopOutput], Dict] | None = None, mask_prompt: bool | None = False, len_penalty: float = 0, preprocessing_num_workers: int = 1)[source]#

Bases: trl.DPOTrainer

use_dpo_data_collator = True[source]#
len_penalty = 0[source]#
dpo_loss(policy_chosen_logps: torch.FloatTensor, policy_rejected_logps: torch.FloatTensor, reference_chosen_logps: torch.FloatTensor, reference_rejected_logps: torch.FloatTensor, reference_free: bool = False, margin: torch.FloatTensor | None = None, len_penalty: float = 0) Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor][source]#

Compute the DPO loss for a batch of policy and reference model log probabilities.

Args:

policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0. reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses.

Returns:

A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). The losses tensor contains the DPO loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.

get_batch_loss_metrics(model, batch: Dict[str, List | torch.LongTensor], train_eval: Literal['train', 'eval'] = 'train')[source]#
get_batch_metrics(model, batch: Dict[str, List | torch.LongTensor], train_eval: Literal['train', 'eval'] = 'train')[source]#

Compute the DPO loss and other metrics for the given batch of inputs for train or test.