#!/usr/bin/env python
# coding=utf-8
"""This Python code defines a class Dataset with methods for initializing, loading,
and manipulating datasets from different backends such as Hugging Face and JSON.
The `Dataset` class includes methods for loading datasets from a dictionary and a Hugging
Face dataset, mapping datasets, and retrieving the backend dataset and arguments.
"""
# Importing necessary libraries and modules
import copy
import json
import logging
from pathlib import Path
from cmath import e
from pathlib import Path
from typing import Optional
from datasets import load_dataset
from datasets import Dataset as HFDataset
from lmflow.args import DatasetArguments
from lmflow.utils.constants import (
DATASET_DESCRIPTION_MAP,
TEXT_ONLY_DATASET_DESCRIPTION,
TEXT2TEXT_DATASET_DESCRIPTION,
FLOAT_ONLY_DATASET_DESCRIPTION,
INSTANCE_FIELDS_MAP,
)
from lmflow.utils.versioning import is_multimodal_available
if is_multimodal_available():
from .multi_modal_dataset import CustomMultiModalDataset
[docs]
logger = logging.getLogger(__name__)
[docs]
DATASET_TYPES = [
"text_only",
"text2text",
"float_only",
"image_text",
"conversation",
"paired_conversation",
"paired_text_to_text",
"text_to_textlist",
"text_to_scored_textlist"
]
[docs]
KEY_INSTANCES = "instances"
[docs]
class Dataset:
r"""
Initializes the Dataset object with the given parameters.
Parameters
------------
data_args : DatasetArguments object.
Contains the arguments required to load the dataset.
backend : str, default="huggingface"
A string representing the dataset backend. Defaults to "huggingface".
args : Optional.
Positional arguments.
kwargs : Optional.
Keyword arguments.
"""
def __init__(self, data_args: DatasetArguments=None, backend: str="huggingface", *args, **kwargs):
[docs]
self.data_args = data_args
[docs]
self.backend_dataset = None
[docs]
self.type = None # Original type of the dataset
[docs]
self.dataset_path = data_args.dataset_path
if data_args.dataset_path is None:
return
if backend == "huggingface":
data_files = [
x.absolute().as_posix()
for x in Path(self.dataset_path).glob("*.json")
]
# Iterate through all the files and ensure they have the same data type
for single_file in data_files:
with open(single_file) as fin:
json_data = json.load(fin)
if KEY_TYPE not in json_data.keys():
raise ValueError(
f'"{KEY_TYPE}" field must be specified for data, e.g.'
'{\n'
f' "{KEY_TYPE}: "text_only",\n'
f' "{KEY_INSTANCES}": [\n'
' { "text": "Sentence 1: This is a sentence." }\n'
' { "text": "Sentence 2: This is another sentence." }\n'
f' ]\n'
'}'
)
if self.type is None:
self.type = json_data[KEY_TYPE]
elif self.type != json_data[KEY_TYPE]:
raise ValueError(
'All task files must have same data types. Previous'
f' files have type "{self.type}", but in file'
f' {single_file}, it has type "{self.type}".'
)
# Load the dataset using the HuggingFace dataset library
extensions = "json"
raw_dataset = load_dataset(
extensions,
data_files=data_files,
field=KEY_INSTANCES,
split="train",
)
self.backend_dataset = raw_dataset
self._check_data_format()
elif backend == "json":
# TODO (@Jiachun)
pass
elif backend == "custom_multi_modal":
# FIXME refactor the backend name
if not is_multimodal_available():
raise ValueError(
'Multimodal not available. Please install via `pip install -e ".[multimodal]"`'
)
raw_dataset = CustomMultiModalDataset(self.dataset_path, data_args)
self.backend_dataset = raw_dataset
else:
raise NotImplementedError(f'Unsupported dataset backend "{backend}"')
[docs]
def __len__(self):
return len(self.backend_dataset)
[docs]
def from_dict(self, dict_obj: dict, *args, **kwargs):
r"""
Create a Dataset object from a dictionary.
Return a Dataset given a dict with format:
{
"type": TYPE,
"instances": [
{
"key_1": VALUE_1.1,
"key_2": VALUE_1.2,
...
},
{
"key_1": VALUE_2.1,
"key_2": VALUE_2.2,
...
},
...
]
}
Parameters
-----------
dict_obj : dict.
A dictionary containing the dataset information.
args : Optional.
Positional arguments.
kwargs : Optional.
Keyword arguments.
Returns
---------
self : Dataset object.
"""
if self.backend == "huggingface":
if KEY_TYPE not in dict_obj:
raise ValueError(
f'"{KEY_TYPE}" must be provided to initialize a dataset,'
f' e.g.\n'
f' {TEXT_ONLY_DATASET_DESCRIPTION}'
)
if KEY_INSTANCES not in dict_obj:
raise ValueError(
f'"{KEY_INSTANCES}" must be provided to initialize a'
f' dataset, e.g.\n'
f' {TEXT_ONLY_DATASET_DESCRIPTION}'
)
self.type = dict_obj[KEY_TYPE]
if not self.type in INSTANCE_FIELDS_MAP:
raise ValueError(f'type "{self.type}" is not supported')
correct_fields = INSTANCE_FIELDS_MAP[self.type]
for i, instance in enumerate(dict_obj[KEY_INSTANCES]):
fields = instance.keys()
if not set(correct_fields).issubset(set(fields)):
raise ValueError(
f'data instance fields incorrect'
f' {list(correct_fields)} are required.'
)
try:
hf_dict = {}
if len(dict_obj[KEY_INSTANCES]) > 0:
for key in dict_obj[KEY_INSTANCES][0].keys():
hf_dict[key] = [
instance[key] for instance in dict_obj[KEY_INSTANCES]
]
self.backend_dataset = HFDataset.from_dict(hf_dict, *args, **kwargs)
except AttributeError as ex:
raise ValueError(
f"Error occurs: {ex}. Failed to convert dict to"
f" \"{self.type}\" dataset," f" the standard format is as"
f" follows:\n"
f" {DATASET_DESCRIPTION_MAP[self.type]}"
)
self._check_data_format()
return self
elif self.backend == "dict":
self.backend_dataset = dict_obj
self.type = dict_obj[KEY_TYPE]
return self
else:
raise NotImplementedError(
f'Currently .from_dict is not supported for backend "{self.backend}"'
)
@classmethod
[docs]
def create_from_dict(cls, dict_obj, *args, **kwargs):
r"""
Returns
--------
Returns a Dataset object given a dict.
"""
empty_data_args = DatasetArguments(dataset_path=None)
dataset = Dataset(empty_data_args)
return dataset.from_dict(dict_obj)
[docs]
def to_dict(self):
r"""
Returns
---------
Return a dict represents the dataset:
{
"type": TYPE,
"instances": [
{
"key_1": VALUE_1.1,
"key_2": VALUE_1.2,
...
},
{
"key_1": VALUE_2.1,
"key_2": VALUE_2.2,
...
},
...
]
}
A python dict object represents the content of this dataset.
"""
if self.backend == "huggingface":
dict_obj = {}
dict_obj[KEY_TYPE] = self.get_type()
hf_dict = self.backend_dataset.to_dict()
dict_obj[KEY_INSTANCES] = []
first_key = None
for key in hf_dict.keys():
first_key = key
break
if first_key is not None:
num_instances = len(hf_dict[first_key])
dict_obj[KEY_INSTANCES] = [
{
key: hf_dict[key][i] for key in hf_dict.keys()
}
for i in range(num_instances)
]
return dict_obj
elif self.backend == "dict":
dict_obj = self.backend_dataset
return dict_obj
else:
raise NotImplementedError(
f'Current .to_dict is not supported for backend "{self.backend}"'
)
[docs]
def to_list(self):
"""Returns a list of instances."""
if self.backend == "huggingface":
instance_list = [self.backend_dataset.__getitem__(idx)
for idx in range(len(self.backend_dataset))]
return instance_list
elif self.backend == "dict":
instance_list = copy.deepcopy(self.backend_dataset[KEY_INSTANCES])
# TODO: should be a list of instances, instance should be huggingface datasets row format
return instance_list
else:
raise NotImplementedError(
f'Current .to_list is not supported for backend "{self.backend}"'
)
[docs]
def map(self, *args, **kwargs):
r"""
Parameters
------------
args : Optional.
Positional arguments.
kwargs : Optional.
Keyword arguments.
Returns
---------
self : Dataset object.
"""
# If the dataset uses Hugging Face as the backend,
# call the `map()` function of the Hugging Face backend dataset
if self.backend == "huggingface":
# Set the mapped dataset as the backend dataset of the current dataset
mapped_backend_dataset = self.backend_dataset.map(*args, **kwargs)
self.backend_dataset = mapped_backend_dataset
return self
else:
# If the backend is not Hugging Face, raise a NotImplementedError
raise NotImplementedError(
f'Currently .map is not supported for backend "{self.backend}"'
)
[docs]
def get_backend(self) -> Optional[str]:
r"""
Returns
---------
self.backend
"""
return self.backend
[docs]
def get_backend_dataset(self):
r"""
Returns
---------
self.backend_dataset
"""
return self.backend_dataset
[docs]
def get_fingerprint(self):
r"""
Returns
---------
Fingerprint of the backend_dataset which controls the cache
"""
return self.backend_dataset._fingerprint
[docs]
def get_data_args(self):
r"""
Returns
---------
self.data_args
"""
return self.data_args
[docs]
def get_type(self) -> str:
r"""
Returns
---------
self.type
"""
return self.type
[docs]
def save(
self,
file_path: str,
format: str="json"
):
r"""
Save the dataset to a json file.
Parameters
------------
file_path : str.
The path to the file where the dataset will be saved.
"""
if format == "json":
assert Path(file_path).suffix == ".json", "The file path must have a .json extension."
with open(file_path, "w", encoding='utf-8') as fout:
json.dump(self.to_dict(), fout, indent=4, ensure_ascii=False)
else:
logger.error(f"Unsupported format when saving the dataset: {format}.")
[docs]
def sample(self, n: int, seed: int=42):
r"""
Sample n instances from the dataset.
Parameters
------------
n : int.
The number of instances to sample from the dataset.
Returns
---------
sample_dataset : Dataset object.
A new dataset object containing the sampled instances.
"""
if self.backend == "huggingface":
sampled_dataset = self.backend_dataset.shuffle(seed=seed).select(range(n))
output_dataset = self.create_from_dict(
{
"type": self.get_type(),
"instances": [
{
col_name: sampled_dataset[col_name][i] for col_name in sampled_dataset.column_names
} for i in range(n)
]
}
)
return output_dataset
else:
raise NotImplementedError(
f'Currently .sample is not supported for backend "{self.backend}"'
)
[docs]
def train_test_split(self, test_size: float=0.2, shuffle: bool=True, seed: int=42):
r"""
Split the dataset into training and testing sets.
Parameters
------------
test_size : float, default=0.2.
The proportion of the dataset that will be used for testing.
Returns
---------
train_dataset : Dataset object.
A new dataset object containing the training instances.
test_dataset : Dataset object.
A new dataset object containing the testing instances.
"""
if self.backend == "huggingface":
splited = self.backend_dataset.train_test_split(
test_size=test_size, shuffle=shuffle, seed=seed
)
train_dataset = self.create_from_dict(
{
"type": self.get_type(),
"instances": [
{
col_name: splited["train"][col_name][i] for col_name in splited["train"].column_names
} for i in range(len(splited["train"]))
]
}
)
test_dataset = self.create_from_dict(
{
"type": self.get_type(),
"instances": [
{
col_name: splited["test"][col_name][i] for col_name in splited["test"].column_names
} for i in range(len(splited["test"]))
]
}
)
return train_dataset, test_dataset
else:
raise NotImplementedError(
f'Currently .train_test_split is not supported for backend "{self.backend}"'
)
[docs]
def drop_instances(self, indices: list):
r"""
Drop instances from the dataset.
Parameters
------------
indices : list.
A list of indices of the instances to drop from the dataset.
"""
if self.backend == "huggingface":
self.backend_dataset = self.backend_dataset.remove_indices(indices)
else:
raise NotImplementedError(
f'Currently .drop_instances is not supported for backend "{self.backend}"'
)
[docs]
def sanity_check(
self,
drop_invalid: bool=True,
):
r"""
Perform a sanity check on the dataset.
"""
if self.backend == "huggingface":
self.hf_dataset_sanity_check(drop_invalid)
else:
raise NotImplementedError(
f'Currently .sanity_check is not supported for backend "{self.backend}"'
)
[docs]
def hf_dataset_sanity_check(
self,
drop_invalid: bool=True,
):
r"""
Perform a sanity check on the HuggingFace dataset.
"""
if self.backend_dataset is None or len(self.backend_dataset) == 0:
raise ValueError("Dataset is empty.")
if self.type == 'text_to_textlist':
num_output_per_instance = len(self.backend_dataset['output'][0])
dataset_cache = self.backend_dataset.filter(lambda x: len(x['input'])!=0)
dataset_cache = self.backend_dataset.filter(lambda x: len(x['output']) == num_output_per_instance)
dataset_cache = self.backend_dataset.filter(lambda x: not all([len(output) == 0 for output in x['output']]))
if len(dataset_cache) != len(self.backend_dataset):
warning_info = (
f"Found {len(self.backend_dataset) - len(dataset_cache)} invalid instances "
"during hf_dataset_sanity_check, please check:\n"
" 1. length of input strings should not be empty\n"
" 2. length of output strings should not be all empty\n"
" 3. number of output strings should be consistent\n" # since we will use tensor reshape later
)
if drop_invalid:
self.backend_dataset = dataset_cache
logger.warning(warning_info+"Invalid instances are dropped.")
else:
raise ValueError(warning_info)
else:
logger.warning(f"No sanity check for {self.type} dataset.")