Source code for lmflow.datasets.multi_modal_dataset

#!/usr/bin/env python
# coding=utf-8
# FIXME update the doc string.
"""This Python code defines a class Multi Modal Dataset.
"""
import copy
from dataclasses import dataclass, field
import json
from PIL import Image
import os.path as osp
import transformers
import torch
from torch.utils.data import Dataset

from lmflow.args import DatasetArguments
from lmflow.utils import llava_conversation_lib as conversation_lib

from lmflow.utils.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN

[docs] class CustomMultiModalDataset(Dataset): """Dataset for Multi Modal data""" def __init__(self, dataset_path: str, data_args: DatasetArguments): super(CustomMultiModalDataset, self).__init__() data_dict = json.load(open(dataset_path, "r"))
[docs] self.data_dict = data_dict
print("Finish loading json file in dataset.")
[docs] self.data_args = data_args
[docs] self.image_folder = data_args.image_folder
[docs] def __len__(self): return len(self.data_dict)
[docs] def register_tokenizer(self, tokenizer, image_processor=None): self.tokenizer = tokenizer self.image_processor = getattr( tokenizer, "image_processor", image_processor)
[docs] def __getitem__(self, i): data = self.data_dict[i] if isinstance(i, int): data = [data] assert len(data) == 1 processor = self.image_processor if 'image' in data[0]: image_file = data[0]['image'] image = Image.open( osp.join(self.image_folder, image_file)).convert("RGB") if self.data_args.image_aspect_ratio == 'pad': def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] else: image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] data = preprocess_multimodal_llava( copy.deepcopy([e["conversations"] for e in data]), self.data_args) else: data = copy.deepcopy([e["conversations"] for e in data]) if self.data_args.sep_style == "plain": data_dict = preprocess_llama_from_llava_plain( data, self.tokenizer, has_image=('image' in self.data_dict[i]) ) else: data_dict = preprocess_llama_from_llava_v1( data, self.tokenizer, has_image=('image' in self.data_dict[i]) ) if isinstance(i, int): data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) # image exist in the data if 'image' in self.data_dict[i]: data_dict['image'] = image else: # image does not exist in the data, but the model is multimodal crop_size = self.image_processor.crop_size data_dict['image'] = torch.zeros( 3, crop_size['height'], crop_size['width']) return data_dict
[docs] def preprocess_multimodal_llava(sources, data_args): is_multimodal = data_args.is_multimodal if not is_multimodal: return sources for source in sources: for sentence in source: if DEFAULT_IMAGE_TOKEN in sentence['value']: sentence['value'] = sentence['value'].replace( DEFAULT_IMAGE_TOKEN, '').strip() sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] sentence['value'] = sentence['value'].strip() if "mmtag" in conversation_lib.default_conversation.version: sentence['value'] = sentence['value'].replace( DEFAULT_IMAGE_TOKEN, '<Image>' + DEFAULT_IMAGE_TOKEN + '</Image>') replace_token = DEFAULT_IMAGE_TOKEN if data_args.use_image_start_end: replace_token = DEFAULT_IM_START_TOKEN + \ replace_token + DEFAULT_IM_END_TOKEN sentence["value"] = sentence["value"].replace( DEFAULT_IMAGE_TOKEN, replace_token) return sources
[docs] def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')] def insert_separator(X, sep): return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] input_ids = [] offset = 0 if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: offset = 1 input_ids.append(prompt_chunks[0][0]) for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): input_ids.extend(x[offset:]) if return_tensors is not None: if return_tensors == 'pt': return torch.tensor(input_ids, dtype=torch.long) raise ValueError(f'Unsupported tensor type: {return_tensors}') return input_ids
[docs] def preprocess_llama_from_llava_plain( sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False): """ This function just add the image in the front of text. And don't add any prompt. Args: sources: The input data with text and image. tokenizer: The tokenizer to process text. has_image: Whether the input data has image. Returns: The input_ids and labels for the model. """ conversations = [] for source in sources: assert len(source) == 2 assert DEFAULT_IMAGE_TOKEN in source[0]['value'] source[0]['value'] = DEFAULT_IMAGE_TOKEN conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep conversations.append(conversation) # tokenize conversations input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] targets = copy.deepcopy(input_ids) for target, source in zip(targets, sources): tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer)) target[:tokenized_len] = IGNORE_INDEX return dict(input_ids=input_ids, labels=targets)
[docs] def preprocess_llama_from_llava_v1( sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False): """ This function add the prompt and then put the image after the prompt. So it needs additional code to generate the target label. Args: sources: The input data with text and image. tokenizer: The tokenizer to process text. has_image: Whether the input data has image. Returns: The input_ids and labels for the model. """ conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations if has_image: input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) else: input_ids = tokenizer( conversations, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.TWO # Mask targets sep = conv.sep + conv.roles[1] + ": " for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep2) cur_len = 1 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep if has_image: round_len = len(tokenizer_image_token(rou, tokenizer)) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 else: round_len = len(tokenizer(rou).input_ids) instruction_len = len(tokenizer(parts[0]).input_ids) - 2 target[cur_len : cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, )
@dataclass
[docs] class DataCollatorForSupervisedDataset(object): """Collate examples for supervised fine-tuning."""
[docs] tokenizer: transformers.PreTrainedTokenizer
[docs] def __call__(self, instances): input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) input_ids = torch.nn.utils.rnn.pad_sequence( input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) input_ids = input_ids[:, :self.tokenizer.model_max_length] labels = labels[:, :self.tokenizer.model_max_length] batch = dict( input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ) if 'image' in instances[0]: images = [instance['image'] for instance in instances] if all(x is not None and x.shape == images[0].shape for x in images): batch['images'] = torch.stack(images) else: batch['images'] = images return batch