Source code for lmflow.pipeline.utils.dpov2_dataprocessor

from dataclasses import dataclass
import logging
from typing import Optional, Union, Dict, List, Any

import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import (
    PreTrainedModel,
    PreTrainedTokenizerBase,
)


[docs] logger = logging.getLogger(__name__)
@dataclass
[docs] class PreferenceDataCollatorWithPadding:
[docs] tokenizer: PreTrainedTokenizerBase
[docs] model: Optional[PreTrainedModel] = None
[docs] padding: Union[bool, str] = True
[docs] max_length: Optional[int] = None
[docs] max_prompt_length: Optional[int] = None
[docs] label_pad_token_id: int = -100
[docs] padding_value: int = 0
[docs] truncation_mode: str = "keep_end"
[docs] is_encoder_decoder: Optional[bool] = False
[docs] max_target_length: Optional[int] = None
[docs] mask_prompt: Optional[bool] = False
[docs] def tokenize_batch_element( self, prompt: str, chosen: str, rejected: str, ) -> Dict: """Tokenize a single batch element. At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt + chosen or prompt + rejected responses is/are too long. First we truncate the prompt; if we're still too long, we truncate the chosen/rejected. We also create the labels for the chosen/rejected responses, which are of length equal to the sum of the length of the prompt and the chosen/rejected response, with label_pad_token_id for the prompt tokens. """ batch = {} if self.is_encoder_decoder: raise NotImplementedError chosen_tokens = self.tokenizer(chosen, add_special_tokens=False) rejected_tokens = self.tokenizer(rejected, add_special_tokens=False) prompt_tokens = self.tokenizer(prompt, add_special_tokens=False) eos_token_id = self.tokenizer.eos_token_id # Get indices in list prompt_tokens["input_ids"] that equals the EOS token (often 0) eos_indices_prompt = [i for i, x in enumerate(prompt_tokens["input_ids"]) if x == eos_token_id] # attention mask these indices to eos_token_id if self.mask_prompt: new_attention_mask = [0 for i, p in enumerate(prompt_tokens["attention_mask"])] else: new_attention_mask = [ 0 if i in eos_indices_prompt else p for i, p in enumerate(prompt_tokens["attention_mask"]) ] prompt_tokens["attention_mask"] = new_attention_mask # do the same for chosen and rejected eos_indices_chosen = [i for i, x in enumerate(chosen_tokens["input_ids"]) if x == eos_token_id] new_attention_mask_c = [ 0 if i in eos_indices_chosen else p for i, p in enumerate(chosen_tokens["attention_mask"]) ] chosen_tokens["attention_mask"] = new_attention_mask_c eos_indices_rejected = [i for i, x in enumerate(rejected_tokens["input_ids"]) if x == eos_token_id] new_attention_mask_r = [ 0 if i in eos_indices_rejected else p for i, p in enumerate(rejected_tokens["attention_mask"]) ] rejected_tokens["attention_mask"] = new_attention_mask_r # add EOS token to end of prompt chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id) chosen_tokens["attention_mask"].append(1) rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id) rejected_tokens["attention_mask"].append(1) longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) # if combined sequence is too long, truncate the prompt if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length: if self.truncation_mode == "keep_start": prompt_tokens = {k: v[: self.max_prompt_length] for k, v in prompt_tokens.items()} elif self.truncation_mode == "keep_end": prompt_tokens = {k: v[-self.max_prompt_length :] for k, v in prompt_tokens.items()} else: raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") # if that's still too long, truncate the response if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length: chosen_tokens = {k: v[: self.max_length - self.max_prompt_length] for k, v in chosen_tokens.items()} rejected_tokens = { k: v[: self.max_length - self.max_prompt_length] for k, v in rejected_tokens.items() } # Create labels chosen_sequence_tokens = {k: prompt_tokens[k] + chosen_tokens[k] for k in chosen_tokens} rejected_sequence_tokens = {k: prompt_tokens[k] + rejected_tokens[k] for k in rejected_tokens} chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] chosen_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len( prompt_tokens["input_ids"] ) rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:] rejected_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len( prompt_tokens["input_ids"] ) for k, toks in { "chosen": chosen_sequence_tokens, "rejected": rejected_sequence_tokens, "prompt": prompt_tokens, }.items(): for type_key, tokens in toks.items(): if type_key == "token_type_ids": continue batch[f"{k}_{type_key}"] = tokens batch["prompt"] = prompt batch["chosen"] = prompt + chosen batch["rejected"] = prompt + rejected batch["chosen_response_only"] = chosen batch["rejected_response_only"] = rejected return batch
[docs] def collate(self, batch): # first, pad everything to the same length padded_batch = {} for k in batch[0].keys(): if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"): if self.is_encoder_decoder: to_pad = [torch.LongTensor(ex[k]) for ex in batch] if (k.startswith("prompt")) and (k.endswith("input_ids")): padding_value = self.tokenizer.pad_token_id elif k.endswith("_attention_mask"): padding_value = 0 elif (k.startswith("chosen")) or (k.startswith("rejected")) or ("decoder" in k): padding_value = self.label_pad_token_id else: raise ValueError(f"Unexpected key in batch '{k}'") padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value) else: # adapted from https://stackoverflow.com/questions/73256206 if "prompt" in k: to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch] else: to_pad = [torch.LongTensor(ex[k]) for ex in batch] if k.endswith("_input_ids"): padding_value = self.tokenizer.pad_token_id elif k.endswith("_labels"): padding_value = self.label_pad_token_id elif k.endswith("_attention_mask"): padding_value = self.padding_value else: raise ValueError(f"Unexpected key in batch '{k}'") padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value) # for the prompt, flip back so padding is on left side if "prompt" in k: padded_batch[k] = padded_batch[k].flip(dims=[1]) else: padded_batch[k] = [ex[k] for ex in batch] return padded_batch
[docs] def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: tokenized_batch = [] for feature in features: prompt = feature["prompt"] chosen = feature["chosen"] rejected = feature["rejected"] batch_element = self.tokenize_batch_element(prompt, chosen, rejected) batch_element["margin"] = feature["margin"] tokenized_batch.append(batch_element) # return collated batch return self.collate(tokenized_batch)