Source code for lmflow.tokenization.hf_text_regression_model

#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved.

import logging
from logging import Logger
from typing import Dict, List, Union

import transformers
from transformers.testing_utils import CaptureLogger
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

from lmflow.utils.conversation_template import ConversationTemplate
from lmflow.utils.constants import CONVERSATION_ROLE_NAMES
from lmflow.args import DatasetArguments


[docs] logger = logging.getLogger(__name__)
[docs] tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
[docs] def blocking_paired( token_dict: Dict, column_names: List, block_size: int, model_max_length: int, pad_token_id: int, padding_side: str, truncation_side: str='right', ) -> Dict: block_size_warning_num = 0 num_example = len(token_dict[list(token_dict.keys())[0]]) for i in range(num_example): for column_name in column_names: max_length = min(block_size, model_max_length) pad_length = max_length - len(token_dict[f"input_ids_{column_name}"][i]) if block_size < model_max_length: block_size_warning_num += 1 if pad_length < 0: # Truncates too long samples for key in [f"input_ids_{column_name}", f"attention_mask_{column_name}"]: if truncation_side == 'right': token_dict[key][i] = token_dict[key][i][:pad_length] elif truncation_side == 'left': token_dict[key][i] = token_dict[key][i][-pad_length:] else: raise ValueError( f"truncation_side should be either 'right' or 'left', got {truncation_side}" ) else: if padding_side == 'right': # Pads too short samples token_dict[f"input_ids_{column_name}"][i].extend( [pad_token_id for _ in range(pad_length)] ) token_dict[f"attention_mask_{column_name}"][i].extend( [0 for _ in range(pad_length)] ) elif padding_side == 'left': # Pads too short samples token_dict[f"input_ids_{column_name}"][i] = ( [pad_token_id for _ in range(pad_length)] + token_dict[f"input_ids_{column_name}"][i] ) token_dict[f"attention_mask_{column_name}"][i] = ( [0 for _ in range(pad_length)] + token_dict[f"attention_mask_{column_name}"][i] ) else: raise ValueError( f"padding_side should be either 'right' or 'left', got {padding_side}" ) if block_size_warning_num > 0: logger.warning( f"There are {block_size_warning_num} of {num_example} samples where" f" block_size {block_size} < model_max_length" f" {model_max_length}, use block_size" " for maximum tokenized sequence length" ) return token_dict
[docs] def blocking( token_dict: Dict, block_size: int, model_max_length: int, pad_token_id: int, padding_side: str, truncation_side: str='right', ) -> Dict: block_size_warning_num = 0 num_example = len(token_dict[list(token_dict.keys())[0]]) for i in range(num_example): max_length = min(block_size, model_max_length) pad_length = max_length - len(token_dict["input_ids"][i]) if block_size < model_max_length: block_size_warning_num += 1 if pad_length < 0: # Truncates too long samples for key in ["input_ids", "attention_mask", "labels"]: if truncation_side == 'right': token_dict[key][i] = token_dict[key][i][:pad_length] elif truncation_side == 'left': token_dict[key][i] = token_dict[key][i][-pad_length:] else: raise ValueError( f"truncation_side should be either 'right' or 'left', got {truncation_side}" ) else: if padding_side == 'right': # Pads too short samples token_dict["input_ids"][i].extend( [pad_token_id for _ in range(pad_length)] ) token_dict["attention_mask"][i].extend( [0 for _ in range(pad_length)] ) token_dict["labels"][i].extend( [-100 for _ in range(pad_length)] ) elif padding_side == 'left': # Pads too short samples token_dict["input_ids"][i] = ( [pad_token_id for _ in range(pad_length)] + token_dict["input_ids"][i] ) token_dict["attention_mask"][i] = ( [0 for _ in range(pad_length)] + token_dict["attention_mask"][i] ) token_dict["labels"][i] = ( [-100 for _ in range(pad_length)] + token_dict["labels"][i] ) else: raise ValueError( f"padding_side should be either 'right' or 'left', got {padding_side}" ) if block_size_warning_num > 0: logger.warning( f"There are {block_size_warning_num} of {num_example} samples where" f" block_size {block_size} < model_max_length" f" {model_max_length}, use block_size" " for maximum tokenized sequence length" ) return token_dict
[docs] def blocking_text_to_textlist( token_dict: Dict, block_size: int, model_max_length: int, pad_token_id: int, padding_side: str, truncation_side: str='right', ) -> Dict: block_size_warning_num = 0 num_example = len(token_dict[list(token_dict.keys())[0]]) max_length = min(block_size, model_max_length) for example_idx in range(num_example): for content_idx in range(len(token_dict["input_ids"][example_idx])): pad_length = max_length - len(token_dict["input_ids"][example_idx][content_idx]) if block_size < model_max_length: block_size_warning_num += 1 if pad_length < 0: # Truncates too long samples if truncation_side == 'right': token_dict["input_ids"][example_idx][content_idx] = token_dict["input_ids"][example_idx][content_idx][:pad_length] elif truncation_side == 'left': token_dict["input_ids"][example_idx][content_idx] = token_dict["input_ids"][example_idx][content_idx][-pad_length:] else: raise ValueError( f"truncation_side should be either 'right' or 'left', got {truncation_side}" ) else: if padding_side == 'right': # Pads too short samples token_dict["input_ids"][example_idx][content_idx].extend( [pad_token_id for _ in range(pad_length)] ) elif padding_side == 'left': # Pads too short samples token_dict["input_ids"][example_idx][content_idx] = ( [pad_token_id for _ in range(pad_length)] + token_dict["input_ids"][example_idx][content_idx] ) else: raise ValueError( f"padding_side should be either 'right' or 'left', got {padding_side}" ) if block_size_warning_num > 0: logger.warning( f"There are {block_size_warning_num} of {num_example} samples where" f" block_size {block_size} < model_max_length" f" {model_max_length}, use block_size" " for maximum tokenized sequence length" ) return token_dict
[docs] def paired_conversation_tokenize_function( examples, data_args: DatasetArguments, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], column_names, conversation_template: ConversationTemplate, ) -> Dict: num_example = len(examples[column_names[0]]) token_dict = {} for column_name in column_names: token_dict[f"input_ids_{column_name}"] = [[] for _ in range(num_example)] token_dict[f"attention_mask_{column_name}"] = [[] for _ in range(num_example)] with CaptureLogger(tok_logger) as cl: num_corrupted = 0 for i in range(num_example): try: for column_name in column_names: messages = examples[column_name][i]["messages"] system = examples[column_name][i].get("system", None) tools = examples[column_name][i].get("tools", None) if len(messages) < 2 or messages[0]['role'] != CONVERSATION_ROLE_NAMES['user']: tok_logger.warning( "Invalid instance encountered. Either the conversation has less than " "one round or the first message is not from the user." ) continue if len(messages) % 2 != 0: logger.warning( "The number of messages is not even, the last message will be ignored." ) messages = messages[:-1] encoded_conversation = conversation_template.encode_conversation( tokenizer=tokenizer, messages=messages, system=system, tools=tools, ) input_ids = [] for turn_idx, (user_input, assistant_result) in enumerate(encoded_conversation): input_ids += user_input + assistant_result token_dict[f"input_ids_{column_name}"][i].extend(input_ids) token_dict[f"attention_mask_{column_name}"][i].extend([1] * len(input_ids)) except: num_corrupted += 1 logger.error(f"Error in encoding conversation {i}: {column_name}") logger.error(f"Messages: {messages}") continue if num_corrupted > 0: logger.error(f"Number of corrupted examples: {num_corrupted}") if data_args.disable_group_texts: token_dict = blocking_paired( token_dict=token_dict, column_names=column_names, block_size=data_args.block_size, model_max_length=tokenizer.model_max_length, pad_token_id=tokenizer.pad_token_id, padding_side=tokenizer.padding_side, truncation_side=tokenizer.truncation_side, ) # clm input could be much much longer than block_size if "Token indices sequence length is longer than the" in cl.out: tok_logger.warning( "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits" " before being passed to the model." ) return token_dict
[docs] def conversation_tokenize_function( examples, data_args: DatasetArguments, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], column_names, conversation_template: ConversationTemplate, ) -> Dict: """Handels conversation datasets tokenization """ num_example = len(examples[column_names[0]]) token_dict = { "input_ids": [[] for _ in range(num_example)], "attention_mask": [[] for _ in range(num_example)], "labels": [[] for _ in range(num_example)], } with CaptureLogger(tok_logger) as cl: for i in range(len(examples["messages"])): messages = examples["messages"][i] system = examples.get("system", [None] * num_example)[i] tools = examples.get("tools", [None] * num_example)[i] if len(messages) < 2 or messages[0]['role'] != CONVERSATION_ROLE_NAMES['user']: tok_logger.warning( "Invalid instance encountered. Either the conversation has less than " "one round or the first message is not from the user." ) continue if len(messages) % 2 != 0: logger.warning( "The number of messages is not even, the last message will be ignored." ) messages = messages[:-1] encoded_conversation = conversation_template.encode_conversation( tokenizer=tokenizer, messages=messages, system=system, tools=tools, ) input_ids, labels = [], [] for turn_idx, (user_input, assistant_result) in enumerate(encoded_conversation): input_ids += user_input + assistant_result if data_args.train_on_prompt: labels += user_input + assistant_result else: labels += [-100] * len(user_input) + assistant_result token_dict["input_ids"][i].extend(input_ids) token_dict["attention_mask"][i].extend([1] * len(input_ids)) token_dict["labels"][i].extend(labels) if data_args.disable_group_texts: token_dict = blocking( token_dict=token_dict, block_size=data_args.block_size, model_max_length=tokenizer.model_max_length, pad_token_id=tokenizer.pad_token_id, padding_side=tokenizer.padding_side, truncation_side=tokenizer.truncation_side, ) # clm input could be much much longer than block_size if "Token indices sequence length is longer than the" in cl.out: tok_logger.warning( "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits" " before being passed to the model." ) return token_dict
[docs] def tokenize_function( examples, data_args: DatasetArguments, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], column_names, label_columns, tokenized_column_order, add_special_tokens, use_truncation, ) -> Dict: """Handels text_only and text2text datasets tokenization """ num_example = len(examples[column_names[0]]) token_dict = { "input_ids": [[] for _ in range(num_example)], "attention_mask": [[] for _ in range(num_example)], "labels": [[] for _ in range(num_example)], } with CaptureLogger(tok_logger) as cl: for column_name in tokenized_column_order: encoding = tokenizer( examples[column_name], add_special_tokens=add_special_tokens, truncation=use_truncation, ) if column_name in label_columns: labels = encoding["input_ids"].copy() else: labels = [ [-100] * len(encoding["input_ids"][i]) for i in range(num_example) ] for i in range(num_example): token_dict["input_ids"][i].extend( encoding["input_ids"][i] ) token_dict["attention_mask"][i].extend( encoding["attention_mask"][i] ) token_dict["labels"][i].extend(labels[i]) if data_args.disable_group_texts: token_dict = blocking( token_dict=token_dict, block_size=data_args.block_size, model_max_length=tokenizer.model_max_length, pad_token_id=tokenizer.pad_token_id, padding_side=tokenizer.padding_side, truncation_side=tokenizer.truncation_side, ) # clm input could be much much longer than block_size if "Token indices sequence length is longer than the" in cl.out: tok_logger.warning( "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits" " before being passed to the model." ) return token_dict
[docs] def text_to_textlist_tokenize_function( examples, data_args: DatasetArguments, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], column_names, add_special_tokens, use_truncation, ) -> Dict: """For rm inference, and don't need attn mask and labels. NOTE: input_ids here refers to the tokenized input_ids of the input **and** output """ num_example = len(examples[column_names[0]]) output_dict = {column_name: examples[column_name] for column_name in column_names} output_dict["input_ids"] = [[] for _ in range(num_example)] for example_idx in range(num_example): encoded = tokenizer( [ examples["input"][example_idx] + examples["output"][example_idx][i] for i in range(len(examples["output"][example_idx])) ], add_special_tokens=add_special_tokens, truncation=use_truncation, ) output_dict["input_ids"][example_idx] = encoded["input_ids"] if data_args.disable_group_texts: output_dict = blocking_text_to_textlist( token_dict=output_dict, block_size=data_args.block_size, model_max_length=tokenizer.model_max_length, pad_token_id=tokenizer.pad_token_id, padding_side=tokenizer.padding_side, truncation_side=tokenizer.truncation_side, ) return output_dict