#!/usr/bin/env python
# coding=utf-8
# FIXME update the doc string.
"""This Python code defines a class Multi Modal Dataset.
"""
import copy
from dataclasses import dataclass, field
import json
from PIL import Image
import os.path as osp
import transformers
import torch
from torch.utils.data import Dataset
from lmflow.args import DatasetArguments
from lmflow.utils import llava_conversation_lib as conversation_lib
from lmflow.utils.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
[docs]
class CustomMultiModalDataset(Dataset):
"""Dataset for Multi Modal data"""
def __init__(self, dataset_path: str,
data_args: DatasetArguments):
super(CustomMultiModalDataset, self).__init__()
data_dict = json.load(open(dataset_path, "r"))
[docs]
self.data_dict = data_dict
print("Finish loading json file in dataset.")
[docs]
self.data_args = data_args
[docs]
self.image_folder = data_args.image_folder
[docs]
def __len__(self):
return len(self.data_dict)
[docs]
def register_tokenizer(self, tokenizer, image_processor=None):
self.tokenizer = tokenizer
self.image_processor = getattr(
tokenizer, "image_processor", image_processor)
[docs]
def __getitem__(self, i):
data = self.data_dict[i]
if isinstance(i, int):
data = [data]
assert len(data) == 1
processor = self.image_processor
if 'image' in data[0]:
image_file = data[0]['image']
image = Image.open(
osp.join(self.image_folder, image_file)).convert("RGB")
if self.data_args.image_aspect_ratio == 'pad':
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
else:
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
data = preprocess_multimodal_llava(
copy.deepcopy([e["conversations"] for e in data]),
self.data_args)
else:
data = copy.deepcopy([e["conversations"] for e in data])
if self.data_args.sep_style == "plain":
data_dict = preprocess_llama_from_llava_plain(
data,
self.tokenizer,
has_image=('image' in self.data_dict[i])
)
else:
data_dict = preprocess_llama_from_llava_v1(
data,
self.tokenizer,
has_image=('image' in self.data_dict[i])
)
if isinstance(i, int):
data_dict = dict(input_ids=data_dict["input_ids"][0],
labels=data_dict["labels"][0])
# image exist in the data
if 'image' in self.data_dict[i]:
data_dict['image'] = image
else:
# image does not exist in the data, but the model is multimodal
crop_size = self.image_processor.crop_size
data_dict['image'] = torch.zeros(
3, crop_size['height'], crop_size['width'])
return data_dict
[docs]
def preprocess_multimodal_llava(sources, data_args):
is_multimodal = data_args.is_multimodal
if not is_multimodal:
return sources
for source in sources:
for sentence in source:
if DEFAULT_IMAGE_TOKEN in sentence['value']:
sentence['value'] = sentence['value'].replace(
DEFAULT_IMAGE_TOKEN, '').strip()
sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value']
sentence['value'] = sentence['value'].strip()
if "mmtag" in conversation_lib.default_conversation.version:
sentence['value'] = sentence['value'].replace(
DEFAULT_IMAGE_TOKEN, '<Image>' + DEFAULT_IMAGE_TOKEN + '</Image>')
replace_token = DEFAULT_IMAGE_TOKEN
if data_args.use_image_start_end:
replace_token = DEFAULT_IM_START_TOKEN + \
replace_token + DEFAULT_IM_END_TOKEN
sentence["value"] = sentence["value"].replace(
DEFAULT_IMAGE_TOKEN, replace_token)
return sources
[docs]
def tokenizer_image_token(prompt,
tokenizer,
image_token_index=IMAGE_TOKEN_INDEX,
return_tensors=None):
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
def insert_separator(X, sep):
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
input_ids = []
offset = 0
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
offset = 1
input_ids.append(prompt_chunks[0][0])
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
input_ids.extend(x[offset:])
if return_tensors is not None:
if return_tensors == 'pt':
return torch.tensor(input_ids, dtype=torch.long)
raise ValueError(f'Unsupported tensor type: {return_tensors}')
return input_ids
[docs]
def preprocess_llama_from_llava_plain(
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False):
"""
This function just add the image in the front of text.
And don't add any prompt.
Args:
sources: The input data with text and image.
tokenizer: The tokenizer to process text.
has_image: Whether the input data has image.
Returns:
The input_ids and labels for the model.
"""
conversations = []
for source in sources:
assert len(source) == 2
assert DEFAULT_IMAGE_TOKEN in source[0]['value']
source[0]['value'] = DEFAULT_IMAGE_TOKEN
conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
conversations.append(conversation)
# tokenize conversations
input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
targets = copy.deepcopy(input_ids)
for target, source in zip(targets, sources):
tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer))
target[:tokenized_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=targets)
[docs]
def preprocess_llama_from_llava_v1(
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False):
"""
This function add the prompt and then put the image after the prompt.
So it needs additional code to generate the target label.
Args:
sources: The input data with text and image.
tokenizer: The tokenizer to process text.
has_image: Whether the input data has image.
Returns:
The input_ids and labels for the model.
"""
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
if has_image:
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
# Mask targets
sep = conv.sep + conv.roles[1] + ": "
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_image:
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
)
return dict(
input_ids=input_ids,
labels=targets,
)
@dataclass
[docs]
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
[docs]
tokenizer: transformers.PreTrainedTokenizer
[docs]
def __call__(self, instances):
input_ids, labels = tuple([instance[key] for instance in instances]
for key in ("input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id)
labels = torch.nn.utils.rnn.pad_sequence(labels,
batch_first=True,
padding_value=IGNORE_INDEX)
input_ids = input_ids[:, :self.tokenizer.model_max_length]
labels = labels[:, :self.tokenizer.model_max_length]
batch = dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
if 'image' in instances[0]:
images = [instance['image'] for instance in instances]
if all(x is not None and x.shape == images[0].shape for x in images):
batch['images'] = torch.stack(images)
else:
batch['images'] = images
return batch