"""
ref: https://github.com/volcengine/verl/blob/main/verl/protocol.py
Implement base data transfer protocol between any two functions, modules.
We can subclass Protocol to define more detailed batch info with specific keys
"""
import contextlib
import copy
import logging
import math
import pickle
from dataclasses import dataclass, field
from typing import Any, Optional
import numpy as np
import tensordict
import torch
from packaging import version
from packaging.version import parse as parse_version
from tensordict import TensorDict
from tensordict.tensorclass import NonTensorData, NonTensorStack
from torch.utils.data import DataLoader
from lmflow.utils.envs import get_torch_device
[docs]
logger = logging.getLogger(__name__)
with contextlib.suppress(Exception):
tensordict.set_lazy_legacy(False).set()
if parse_version(tensordict.__version__) < parse_version("0.10.0"):
tensordict.set_list_to_stack(True).set()
[docs]
def union_python_dict(dict1: dict, dict2: dict):
"""Union two dict. Will throw an error if there is an item not the same object with the same key.
Args:
dict1:
dict2:
Returns:
"""
for key, val in dict2.items():
if key in dict1:
assert dict2[key] == dict1[key], f"{key} in meta_dict1 and meta_dict2 are not the same object"
dict1[key] = val
return dict1
[docs]
def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict:
"""Union two tensordicts."""
assert tensor_dict1.batch_size == tensor_dict2.batch_size, (
f"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}"
)
for key in tensor_dict2.keys():
if key not in tensor_dict1.keys():
tensor_dict1[key] = tensor_dict2[key]
else:
assert tensor_dict1[key].equal(tensor_dict2[key]), (
f"{key} in tensor_dict1 and tensor_dict2 are not the same object"
)
return tensor_dict1
[docs]
def _array_equal(array1: np.ndarray, array2: np.ndarray, visited: set[int]) -> bool:
"""
Recursively compares two NumPy arrays for strict equality, with special
handling for object-dtype arrays, NaN values, and circular references.
This function assumes that the two arguments provided are NumPy arrays.
Args:
array1: The first NumPy array.
array2: The second NumPy array.
Returns:
True if the arrays' dtypes, shapes, and all elements are equal.
"""
# Check dtype and shape first, as this is the fastest failure path.
if array1.dtype != array2.dtype or array1.shape != array2.shape:
return False
# For non-object dtypes, use NumPy's implementation with equal_nan=True.
if array1.dtype != "object":
return np.array_equal(array1, array2, equal_nan=True)
# For object-dtype arrays, we must recursively compare each element.
# We delegate to _deep_equal to handle elements, as they could be any
# type, including other nested arrays or NaNs.
return all(_deep_equal(x, y, visited) for x, y in zip(array1.flat, array2.flat, strict=False))
[docs]
def _deep_equal(a: Any, b: Any, visited: set[int]) -> bool:
"""
Recursively performs a deep comparison between two Python objects.
- Handles NaN values correctly (NaN == NaN evaluates to True).
- Handling circular references.
- Dispatches to _array_equal if both objects are NumPy arrays.
- Otherwise, uses standard '==' comparison.
"""
if type(a) is not type(b):
return False
# If we have seen this object ID before on this path, it's a cycle.
# Since we already know the types match, we can safely assume this part
# of the structure is equal.
obj_id = id(a)
if obj_id in visited:
return True
visited.add(obj_id)
# Perform the specific comparison based on type
result = False
if isinstance(a, float) and math.isnan(a) and math.isnan(b):
result = True
elif isinstance(a, np.ndarray):
# We know b is also an ndarray due to the initial type check
result = _array_equal(a, b, visited)
else:
# Standard equality for all other types
result = a == b
# Clean up the visited set on the way out of the recursion
visited.remove(obj_id)
return result
[docs]
def union_numpy_dict(tensor_dict1: dict[str, np.ndarray], tensor_dict2: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
for key, val in tensor_dict2.items():
if key in tensor_dict1:
assert isinstance(tensor_dict2[key], np.ndarray)
assert isinstance(tensor_dict1[key], np.ndarray)
# to properly deal with nan and object type
assert _deep_equal(tensor_dict1[key], tensor_dict2[key], visited=set()), (
f"`{key}` in tensor_dict1 and tensor_dict2 are not the same object."
)
tensor_dict1[key] = val
return tensor_dict1
[docs]
def list_of_dict_to_dict_of_list(list_of_dict: list[dict]):
if len(list_of_dict) == 0:
return {}
keys = list_of_dict[0].keys()
output = {key: [] for key in keys}
for data in list_of_dict:
for key, item in data.items():
assert key in output
output[key].append(item)
return output
[docs]
def collate_fn(x: list["DataProtoItem"]):
batch = []
non_tensor_batch = []
for data in x:
batch.append(data.batch)
non_tensor_batch.append(data.non_tensor_batch)
batch = torch.stack(batch).contiguous()
non_tensor_batch = list_of_dict_to_dict_of_list(non_tensor_batch)
for key, val in non_tensor_batch.items():
non_tensor_batch[key] = np.array(val, dtype=object)
return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)
[docs]
def get_tensordict(tensor_dict: dict[str, torch.Tensor | list], non_tensor_dict: dict = None) -> TensorDict:
"""Create a TensorDict from tensors and non-tensor data.
Automatically handles nested structures in lists by converting them to NonTensorStack.
This enables support for:
- Lists of lists: [[], [0.5, 0.8], [0.9]]
- Lists of dicts: [{"acc": 1.0}, {"acc": 0.0}]
- Lists of lists of dicts: [[{"content": "...", "role": "user"}]]
Args:
tensor_dict: Dictionary of tensors and lists to include in the TensorDict
non_tensor_dict: Dictionary of metadata to store as NonTensorData
Returns:
TensorDict with proper handling of nested structures
Example:
>>> td = get_tensordict(
... tensor_dict={
... "obs": torch.randn(3, 4),
... "turn_scores": [[], [0.5, 0.8], [0.9]] # Nested list
... },
... non_tensor_dict={"experiment": "test"}
... )
"""
tensor_dict = tensor_dict.copy()
if non_tensor_dict is None:
non_tensor_dict = {}
batch_size = None
for key, val in tensor_dict.items():
if isinstance(val, torch.Tensor) and val.is_nested:
assert val.is_contiguous(), "Nested tensors must be contiguous. Try setting layout=torch.jagged"
assert val.layout == torch.jagged, "Nested tensors must be jagged."
# Skip validation for NonTensorStack as it's already properly formatted
if isinstance(val, NonTensorStack):
if batch_size is None:
batch_size = len(val)
else:
assert len(val) == batch_size, (
f"Batch size of NonTensorStack {key} is not consistent with other tensors. "
f"Expected {batch_size}, got {len(val)}"
)
continue
if isinstance(val, list):
for v in val:
assert not isinstance(v, torch.Tensor), (
"Passing a list makes the data NonTensorStack, "
"which doesn't support torch.Tensor. Please convert to numpy first"
)
# Convert to NonTensorStack to handle nested structures
tensor_dict[key] = NonTensorStack.from_list([NonTensorData(item) for item in val])
assert isinstance(val, torch.Tensor | list)
if batch_size is None:
batch_size = val.size(0) if isinstance(val, torch.Tensor) else len(val)
else:
val_batch_size = val.size(0) if isinstance(val, torch.Tensor) else len(val)
assert val_batch_size == batch_size, (
f"Batch size of tensor {key} is not consistent with other tensors. "
f"Expected {batch_size}, got {val_batch_size}"
)
if batch_size is None:
batch_size = []
else:
batch_size = [batch_size]
for key, val in non_tensor_dict.items():
assert key not in tensor_dict
tensor_dict[key] = NonTensorData(val)
return TensorDict(source=tensor_dict, batch_size=batch_size)
@dataclass
[docs]
class DataProtoItem:
[docs]
batch: TensorDict = None
[docs]
non_tensor_batch: dict = field(default_factory=dict)
@dataclass
[docs]
class DataProto:
"""
A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions.
It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/.
TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the
same batch size should be put inside batch.
"""
[docs]
batch: TensorDict = None
[docs]
non_tensor_batch: dict = field(default_factory=dict)
[docs]
def __post_init__(self):
# perform necessary checking
self.check_consistency()
[docs]
def __len__(self):
if self.batch is not None:
return self.batch.batch_size[0]
elif self.non_tensor_batch is not None and len(self.non_tensor_batch) > 0:
random_key = list(self.non_tensor_batch.keys())[0]
return self.non_tensor_batch[random_key].shape[0]
else:
return 0
[docs]
def __getitem__(self, item):
"""
Enhanced indexing for DataProto objects.
Args:
item: Can be one of:
- int: A single index
- slice: A slice object (start:stop:step)
- list: A list of indices
- numpy.ndarray: An array of indices
- torch.Tensor: A tensor of indices
Returns:
DataProto: For all indexing types except single integers
DataProtoItem: Only for single integer indices
"""
# Case 1: Slice object - use the slice method
if isinstance(item, slice):
return self.slice(item.start, item.stop, item.step)
# Case 2: List, numpy array, or torch tensor - use sel_idxs
elif isinstance(item, list | np.ndarray | torch.Tensor):
return self.select_idxs(item)
# Case 3: Single integer - return DataProtoItem for backward compatibility
elif isinstance(item, int | np.integer):
tensor_data = self.batch[item] if self.batch is not None else None
non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}
return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)
# # Case 4: Unsupported type
else:
raise TypeError(f"Indexing with {type(item)} is not supported")
[docs]
def __getstate__(self):
import io
buffer = io.BytesIO()
if tensordict.__version__ >= "0.5.0" and self.batch is not None:
self.batch = self.batch.contiguous()
self.batch = self.batch.consolidate()
torch.save(self.batch, buffer)
return buffer, self.non_tensor_batch, self.meta_info
[docs]
def __setstate__(self, data):
batch_deserialized, non_tensor_batch, meta_info = data
batch_deserialized.seek(0)
batch = torch.load(
batch_deserialized, weights_only=False, map_location="cpu" if not get_torch_device().is_available() else None,
)
self.batch = batch
self.non_tensor_batch = non_tensor_batch
self.meta_info = meta_info
[docs]
def save_to_disk(self, filepath):
with open(filepath, "wb") as f:
pickle.dump(self, f)
@staticmethod
[docs]
def load_from_disk(filepath) -> "DataProto":
with open(filepath, "rb") as f:
data = pickle.load(f)
return data
[docs]
def print_size(self, prefix=""):
size_of_tensordict = 0
if self.batch is not None:
for _, tensor in self.batch.items():
size_of_tensordict += tensor.element_size() * tensor.numel()
size_of_numpy_array = 0
for _, numpy_array in self.non_tensor_batch.items():
size_of_numpy_array += numpy_array.nbytes
size_of_numpy_array /= 1024**3
size_of_tensordict /= 1024**3
message = f"Size of tensordict: {size_of_tensordict} GB, size of non_tensor_batch: {size_of_numpy_array} GB"
if prefix:
message = f"{prefix}, " + message
print(message)
[docs]
def check_consistency(self):
"""Check the consistency of the DataProto. Mainly for batch and non_tensor_batch
We expose this function as a public one so that user can call themselves directly
"""
if self.batch is not None:
assert len(self.batch.batch_size) == 1, "only support num_batch_dims=1"
if self.non_tensor_batch is not None:
for key, val in self.non_tensor_batch.items():
assert isinstance(val, np.ndarray)
if self.batch is not None and self.non_tensor_batch is not None and len(self.non_tensor_batch) != 0:
# TODO: we can actually lift this restriction if needed
assert len(self.batch.batch_size) == 1, "only support num_batch_dims=1 when non_tensor_batch is not empty."
batch_size = self.batch.batch_size[0]
for key, val in self.non_tensor_batch.items():
assert isinstance(val, np.ndarray), (
f"data in the non_tensor_batch must be a numpy.array with dtype=object, but for "
f"{key=}, got {type(val)=}"
)
assert val.shape[0] == batch_size, (
f"key {key} length {len(val)} is not equal to batch size {batch_size}"
)
@classmethod
[docs]
def from_single_dict(cls, data: dict[str, torch.Tensor | np.ndarray], meta_info=None):
"""Create a DataProto from a dict of tensors and non_tensors"""
tensors = {}
non_tensors = {}
for key, val in data.items():
if isinstance(val, torch.Tensor):
tensors[key] = val
elif isinstance(val, np.ndarray):
non_tensors[key] = val
else:
raise ValueError(f"Unsupported type in data {type(val)}")
return cls.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info)
@classmethod
[docs]
def from_dict(
cls,
tensors: Optional[dict[str, torch.Tensor]] = None,
non_tensors=None,
meta_info=None,
num_batch_dims=1,
):
"""Create a DataProto from a dict of tensors. This assumes that
1. All the tensor in tensors have the same dim0
2. Only dim0 is the batch dim
"""
assert num_batch_dims > 0, "num_batch_dims must be greater than zero"
if non_tensors is not None:
assert num_batch_dims == 1, "only support num_batch_dims=1 when non_tensors is not None."
if tensors is None:
tensors = {}
if meta_info is None:
meta_info = {}
if non_tensors is None:
non_tensors = {}
assert isinstance(non_tensors, dict)
# get and check batch size
batch_size = None
pivot_key = None
for key, tensor in tensors.items():
if batch_size is None:
batch_size = tensor.shape[:num_batch_dims]
pivot_key = key
else:
current_batch = tensor.shape[:num_batch_dims]
assert batch_size == current_batch, (
f"Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. "
f"Got {pivot_key} has {batch_size}, {key} has {current_batch}"
)
for key, val in non_tensors.items():
if not isinstance(val, np.ndarray):
non_tensors[key] = np.array(val, dtype=object)
tensor_dict = TensorDict(source=tensors, batch_size=batch_size) if tensors else None
return cls(batch=tensor_dict, non_tensor_batch=non_tensors, meta_info=meta_info)
@classmethod
[docs]
def from_tensordict(
cls,
tensor_dict: TensorDict = None,
meta_info=None,
num_batch_dims=1,
):
"""Create a DataProto from a TensorDict. This assumes that
1. All the tensor in tensor_dict have the same dim0
2. Only dim0 is the batch dim
"""
assert version.parse(tensordict.__version__) >= version.parse("0.10.0"), (
"Build DataProto from TensorDict at least requires tensordict version 0.10.0"
)
from tensordict import NonTensorData, NonTensorStack
assert num_batch_dims > 0, "num_batch_dims must be greater than zero"
if not all(isinstance(val, torch.Tensor) for val in tensor_dict.values()):
assert num_batch_dims == 1, "only support num_batch_dims=1 when tensor_dict contains non tensor data."
if meta_info is None:
meta_info = {}
batch = {}
non_tensor_batch = {}
batch_size = None
for key, val in tensor_dict.items():
if isinstance(val, torch.Tensor):
batch[key] = val
if batch_size is None:
batch_size = val.shape[:num_batch_dims]
elif isinstance(val, NonTensorStack):
non_tensor_batch[key] = np.array([elem.data for elem in val], dtype=object)
elif isinstance(val, NonTensorData):
meta_info[key] = val.data
return cls(
batch=TensorDict(batch, batch_size=batch_size),
non_tensor_batch=non_tensor_batch,
meta_info=meta_info,
)
[docs]
def to(self, device) -> "DataProto":
"""move the batch to device
Args:
device (torch.device, str): torch device
Returns:
DataProto: the current DataProto
"""
if self.batch is not None:
self.batch = self.batch.to(device)
return self
[docs]
def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None, deepcopy=False) -> "DataProto":
"""Select a subset of the DataProto via batch_keys and meta_info_keys
Args:
batch_keys (list, optional): a list of strings indicating the keys in batch to select
meta_info_keys (list, optional): a list of keys indicating the meta info to select
Returns:
DataProto: the DataProto with the selected batch_keys and meta_info_keys
"""
# TODO (zhangchi.usc1992) whether to copy
if batch_keys is not None:
batch_keys = tuple(batch_keys)
sub_batch = self.batch.select(*batch_keys)
else:
sub_batch = self.batch
if non_tensor_batch_keys is not None:
non_tensor_batch = {key: val for key, val in self.non_tensor_batch.items() if key in non_tensor_batch_keys}
else:
non_tensor_batch = self.non_tensor_batch
if deepcopy:
non_tensor_batch = copy.deepcopy(non_tensor_batch)
if meta_info_keys is not None:
sub_meta_info = {key: val for key, val in self.meta_info.items() if key in meta_info_keys}
else:
sub_meta_info = self.meta_info
if deepcopy:
sub_meta_info = copy.deepcopy(sub_meta_info)
return type(self)(batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info)
[docs]
def select_idxs(self, idxs):
"""
Select specific indices from the DataProto.
Args:
idxs (torch.Tensor or numpy.ndarray or list): Indices to select
Returns:
DataProto: A new DataProto containing only the selected indices
"""
if isinstance(idxs, list):
idxs = torch.tensor(idxs)
if idxs.dtype != torch.bool:
idxs = idxs.type(torch.int32)
if isinstance(idxs, np.ndarray):
idxs_np = idxs
idxs_torch = torch.from_numpy(idxs)
else: # torch.Tensor
idxs_torch = idxs
idxs_np = idxs.detach().cpu().numpy()
batch_size = int(idxs_np.sum()) if idxs_np.dtype == bool else idxs_np.shape[0]
if self.batch is not None:
# Use TensorDict's built-in indexing capabilities
selected_batch = TensorDict(
source={key: tensor[idxs_torch] for key, tensor in self.batch.items()},
batch_size=(batch_size,),
device=self.batch.device,
)
else:
selected_batch = None
selected_non_tensor = {}
for key, val in self.non_tensor_batch.items():
selected_non_tensor[key] = val[idxs_np]
return type(self)(batch=selected_batch, non_tensor_batch=selected_non_tensor, meta_info=self.meta_info)
[docs]
def slice(self, start=None, end=None, step=None):
"""
Slice the DataProto and return a new DataProto object.
This is an improved version of direct slicing which returns a DataProtoItem.
Args:
start (int, optional): Start index. Defaults to None (start from beginning).
end (int, optional): End index (exclusive). Defaults to None (go to end).
step (int, optional): Step size. Defaults to None (step=1).
Returns:
DataProto: A new DataProto containing the sliced data
Examples:
# Using the slice method directly
sliced_data = data_proto.slice(10, 20)
# Using enhanced indexing (returns DataProto)
sliced_data = data_proto[10:20]
sliced_data = data_proto[::2] # Every other element
# Using list indexing (returns DataProto)
indices = [1, 5, 10]
selected_data = data_proto[indices]
# Single index still returns DataProtoItem
single_item = data_proto[5]
"""
# Create a slice object
slice_obj = slice(start, end, step)
# Handle the batch data
if self.batch is not None:
# Use TensorDict's built-in slicing capabilities
sliced_batch = self.batch[slice_obj]
else:
sliced_batch = None
# Handle the non-tensor batch data
sliced_non_tensor = {}
for key, val in self.non_tensor_batch.items():
sliced_non_tensor[key] = val[slice_obj]
# Return a new DataProto object
return type(self)(batch=sliced_batch, non_tensor_batch=sliced_non_tensor, meta_info=self.meta_info)
[docs]
def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> "DataProto":
"""Pop a subset of the DataProto via `batch_keys` and `meta_info_keys`
Args:
batch_keys (list, optional): a list of strings indicating the keys in batch to pop
meta_info_keys (list, optional): a list of keys indicating the meta info to pop
Returns:
DataProto: the DataProto with the poped batch_keys and meta_info_keys
"""
if batch_keys is None:
batch_keys = []
if meta_info_keys is None:
meta_info_keys = []
if non_tensor_batch_keys is None:
non_tensor_batch_keys = []
tensors = {}
# tensor batch
for key in batch_keys:
assert key in self.batch.keys()
tensors[key] = self.batch.pop(key)
non_tensors = {}
# non tensor batch
for key in non_tensor_batch_keys:
assert key in self.non_tensor_batch.keys()
non_tensors[key] = self.non_tensor_batch.pop(key)
meta_info = {}
for key in meta_info_keys:
assert key in self.meta_info.keys()
meta_info[key] = self.meta_info.pop(key)
return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info)
[docs]
def rename(self, old_keys=None, new_keys=None) -> "DataProto":
"""
Note that this function only rename the key in the batch
"""
def validate_input(keys):
if keys is not None:
if isinstance(keys, str):
keys = [keys]
elif isinstance(keys, list):
pass
else:
raise TypeError(f"keys must be a list or a string, but got {type(keys)}")
return keys
old_keys = validate_input(old_keys)
new_keys = validate_input(new_keys)
if len(new_keys) != len(old_keys):
raise ValueError(
f"new_keys and old_keys must have the same length, but got {len(new_keys)} and {len(old_keys)}"
)
self.batch.rename_key_(tuple(old_keys), tuple(new_keys))
return self
[docs]
def union(self, other: "DataProto") -> "DataProto":
"""Union with another DataProto. Union batch and meta_info separately.
Throw an error if
- there are conflict keys in batch and they are not equal
- the batch size of two data batch is not the same
- there are conflict keys in meta_info and they are not the same.
Args:
other (DataProto): another DataProto to union
Returns:
DataProto: the DataProto after union
"""
self.batch = union_tensor_dict(self.batch, other.batch)
self.non_tensor_batch = union_numpy_dict(self.non_tensor_batch, other.non_tensor_batch)
self.meta_info = union_python_dict(self.meta_info, other.meta_info)
return self
[docs]
def make_iterator(self, mini_batch_size, epochs, seed=None, dataloader_kwargs=None):
r"""Make an iterator from the DataProto. This is built upon that TensorDict can be used as a normal Pytorch
dataset. See https://pytorch.org/tensordict/tutorials/data_fashion for more details.
Args:
mini_batch_size (int): mini-batch size when iterating the dataset. We require that
``batch.batch_size[0] % mini_batch_size == 0``.
epochs (int): number of epochs when iterating the dataset.
dataloader_kwargs (Any): internally, it returns a DataLoader over the batch. The
dataloader_kwargs is the kwargs passed to the DataLoader.
Returns:
Iterator: an iterator that yields a mini-batch data at a time. The total number of iteration
steps is ``self.batch.batch_size * epochs // mini_batch_size``
"""
assert self.batch.batch_size[0] % mini_batch_size == 0, f"{self.batch.batch_size[0]} % {mini_batch_size} != 0"
# we can directly create a dataloader from TensorDict
if dataloader_kwargs is None:
dataloader_kwargs = {}
if seed is not None:
generator = torch.Generator()
generator.manual_seed(seed)
else:
generator = None
assert isinstance(dataloader_kwargs, dict)
train_dataloader = DataLoader(
dataset=self, batch_size=mini_batch_size, collate_fn=collate_fn, generator=generator, **dataloader_kwargs
)
def get_data():
for _ in range(epochs):
for d in train_dataloader:
d.meta_info = self.meta_info
yield d
return iter(get_data())
[docs]
def padding(self, padding_size, padding_candidate=""):
"""Pad the DataProto by concating with padding_candidate.repeat(padding_size)
Args:
padding_size (int): the number of repeated padding_candidate
padding_candidate: the item to be repeated and appended to the DataProto, only supporting ["first", "last"]
"""
if padding_size == 0:
return
padding_candidate = self.select_idxs([0 if padding_candidate == "first" else len(self) - 1])
padding_part = padding_candidate.repeat(padding_size)
padded_dp = DataProto.concat([self, padding_part])
self.batch = padded_dp.batch
self.non_tensor_batch = padded_dp.non_tensor_batch
[docs]
def chunk(self, chunks: int) -> list["DataProto"]:
"""Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split.
Args:
chunks (int): the number of chunks to split on dim=0
Returns:
List[DataProto]: a list of DataProto after splitting
"""
if not self.is_padding_enabled():
assert len(self) % chunks == 0, (
f"only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}."
)
bsz_in_batch = None
if self.batch is not None:
batch_lst = self.batch.chunk(chunks=chunks, dim=0)
bsz_in_batch = np.array([batch.batch_size[0] for batch in batch_lst])
chunk_indices = np.cumsum(bsz_in_batch)[:-1]
else:
batch_lst = [None for _ in range(chunks)]
non_tensor_batch_lst = [{} for _ in range(chunks)]
for key, val in self.non_tensor_batch.items():
assert isinstance(val, np.ndarray)
if bsz_in_batch is not None:
non_tensor_lst = np.array_split(val, chunk_indices.tolist())
else:
non_tensor_lst = np.array_split(val, chunks)
assert len(non_tensor_lst) == chunks
for i in range(chunks):
non_tensor_batch_lst[i][key] = non_tensor_lst[i]
output = []
for i in range(chunks):
output.append(
type(self)(batch=batch_lst[i], non_tensor_batch=non_tensor_batch_lst[i], meta_info=self.meta_info)
)
return output
[docs]
def split(self, split_size: int) -> list["DataProto"]:
"""Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split.
Args:
split_size (int): the size of each split
Returns:
List[DataProto]: a list of DataProto after splitting
"""
return [self[i : i + split_size] for i in range(0, len(self), split_size)]
@staticmethod
[docs]
def concat(data: list["DataProto"]) -> "DataProto":
"""Concat a list of DataProto. The batch is concatenated among dim=0.
The meta_info is merged, with special handling for metrics from different workers.
Args:
data (List[DataProto]): list of DataProto
Returns:
DataProto: concatenated DataProto
"""
batch_lst = []
for batch in data:
batch_lst.append(batch.batch)
new_batch = torch.cat(batch_lst, dim=0) if batch_lst[0] is not None else None
non_tensor_batch = list_of_dict_to_dict_of_list(list_of_dict=[d.non_tensor_batch for d in data])
for key, val in non_tensor_batch.items():
non_tensor_batch[key] = np.concatenate(val, axis=0)
# Merge meta_info with special handling for metrics
merged_meta_info = {}
if data:
# Merge non-metric meta_info and aggregate metrics from all workers.
all_metrics = []
for d in data:
for k, v in d.meta_info.items():
if k == "metrics":
if v is not None:
if isinstance(v, list):
all_metrics.extend(v)
else:
all_metrics.append(v)
else:
if k in merged_meta_info:
# Ensure consistency for overlapping non-metric keys
assert merged_meta_info[k] == v, f"Conflicting values for meta_info key '{k}'"
else:
merged_meta_info[k] = v
# Flatten list of dicts to dict of lists for consistent metrics structure
if all_metrics:
merged_meta_info["metrics"] = list_of_dict_to_dict_of_list(all_metrics)
cls = type(data[0]) if len(data) > 0 else DataProto
return cls(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=merged_meta_info)
[docs]
def reorder(self, indices):
"""
Note that this operation is in-place
"""
indices_np = indices.detach().numpy()
self.batch = self.batch[indices]
self.non_tensor_batch = {key: val[indices_np] for key, val in self.non_tensor_batch.items()}
[docs]
def repeat(self, repeat_times=2, interleave=True):
"""
Repeat the batch data a specified number of times.
Args:
repeat_times (int): Number of times to repeat the data.
interleave (bool): Whether to interleave the repeated data.
Returns:
DataProto: A new DataProto with repeated data.
"""
if self.batch is not None:
if interleave:
# Interleave the data
repeated_tensors = {
key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items()
}
else:
# Stack the data
repeated_tensors = {
key: tensor.unsqueeze(0).expand(repeat_times, *tensor.shape).reshape(-1, *tensor.shape[1:])
for key, tensor in self.batch.items()
}
repeated_batch = TensorDict(
source=repeated_tensors,
batch_size=(self.batch.batch_size[0] * repeat_times,),
)
else:
repeated_batch = None
repeated_non_tensor_batch = {}
for key, val in self.non_tensor_batch.items():
if interleave:
repeated_non_tensor_batch[key] = np.repeat(val, repeat_times, axis=0)
else:
repeated_non_tensor_batch[key] = np.tile(val, (repeat_times,) + (1,) * (val.ndim - 1))
return type(self)(
batch=repeated_batch,
non_tensor_batch=repeated_non_tensor_batch,
meta_info=self.meta_info,
)
[docs]
def unfold_column_chunks(self, n_split: int, split_keys: Optional[list[str]] = None):
"""Split along the second dim into `n_split`, unfold it to the first dim (batch dim)
Useful in passing grouped tensors that doesn't want to be shuffled in dataset.
keys not in split_keys are repeated to match the shape
Note that if the `split_keys` is not provided, it will repeat all the keys in the second dim.
"""
if self.batch is not None:
unfolded_batch = {}
for key in self.batch.keys():
if key in split_keys if split_keys is not None else False:
shape = list(self.batch[key].shape)
shape[0] = self.batch[key].shape[0] * n_split
shape[1] = self.batch[key].shape[1] // n_split
unfolded_batch[key] = self.batch[key].reshape(*shape)
else:
unfolded_batch[key] = torch.repeat_interleave(self.batch[key], n_split, dim=0)
# locate the `unfolded_batch` as a TensorDict on the same device as the original batch
unfolded_batch = TensorDict(
source=unfolded_batch, batch_size=(self.batch.batch_size[0] * n_split,), device=self.batch.device
)
else:
unfolded_batch = None
repeated_non_tensor_batch = {}
for key, val in self.non_tensor_batch.items():
if key in split_keys:
shape = list(val.shape)
shape[0] = val.shape[0] * n_split
shape[1] = val.shape[1] // n_split
repeated_non_tensor_batch[key] = val.reshape(*shape)
else:
repeated_non_tensor_batch[key] = np.repeat(val, n_split, axis=0)
return type(self)(
batch=unfolded_batch,
non_tensor_batch=repeated_non_tensor_batch,
meta_info=self.meta_info,
)
[docs]
def sample_level_repeat(self, repeat_times):
"""
Repeat each row of the batch data a specified number of times.
Args:
repeat_times (torch.tensor, list, tuple, ndarray): Number of times to repeat the data.
Returns:
DataProto: A new DataProto with repeated data.
"""
if isinstance(repeat_times, tuple):
repeat_times = list(repeat_times)
elif isinstance(repeat_times, torch.Tensor):
assert len(repeat_times.shape) == 1
repeat_times = repeat_times.tolist()
elif isinstance(repeat_times, np.ndarray):
assert len(repeat_times.shape) == 1
repeat_times = repeat_times.tolist()
else:
assert isinstance(repeat_times, list), (
f"repeat_times type must be in [list, torch.Tensor, np.ndarray, tuple], got {type(repeat_times)}"
)
repeat_times = torch.tensor(repeat_times)
if self.batch is not None:
# Interleave the data
repeated_tensors = {
key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items()
}
repeated_batch = TensorDict(
source=repeated_tensors,
batch_size=(repeat_times.sum().item(),),
device=self.batch.device,
)
else:
repeated_batch = None
repeated_non_tensor_batch = {}
for key, val in self.non_tensor_batch.items():
repeated_non_tensor_batch[key] = np.repeat(val, repeat_times, axis=0)
return type(self)(
batch=repeated_batch,
non_tensor_batch=repeated_non_tensor_batch,
meta_info=self.meta_info,
)
[docs]
def to_tensordict(self) -> TensorDict:
"""Convert this DataProto to TensorDict. Note that this requires tensordict version at least 0.10
Returns:
"""
assert parse_version(tensordict.__version__) >= parse_version("0.10"), (
"Convert DataProto to TensorDict at least requires tensordict version 0.10"
)
tensor_batch = self.batch.to_dict()
non_tensor_batch = self.non_tensor_batch
from tensordict.tensorclass import NonTensorData, NonTensorStack
common_keys = set(tensor_batch.keys()) & set(non_tensor_batch.keys())
assert len(common_keys) == 0, f"tensor_batch and non_tensor_batch have common keys {common_keys}"
for key, val in non_tensor_batch.items():
assert isinstance(val, np.ndarray)
# Convert to NonTensorStack instead of plain list to handle nested structures
tensor_batch[key] = NonTensorStack.from_list([NonTensorData(item) for item in val])
output = get_tensordict(tensor_dict=tensor_batch, non_tensor_dict=self.meta_info)
return output
[docs]
def get_data_info(self) -> str:
"""Return formatted information about stored data with nested type details.
Returns:
str: Formatted string showing tensor details and recursive metadata types
"""
info = ["batch"]
for key, tensor in self.batch.items():
if hasattr(tensor, "shape") and hasattr(tensor, "dtype") and hasattr(tensor, "device"):
info.append(f" {key}: {tuple(tensor.shape)} ({tensor.dtype}) {tensor.device}")
elif hasattr(tensor, "shape") and hasattr(tensor, "dtype"):
info.append(f" {key}: {tuple(tensor.shape)} ({tensor.dtype})")
else:
info.append(f" {key}: {type(tensor).__name__}")
info.append("non_tensor_batch")
for key, array in self.non_tensor_batch.items():
info.append(f" {key}: ndarray{array.shape} ({array.dtype})")
info.append("meta_info")
for k, v in self.meta_info.items():
type_info = self._get_type_info(v)
info.append(f" {k}: {type_info}")
return "\n".join(info)
[docs]
def _get_type_info(self, value):
"""Recursively get type information for nested structures"""
if isinstance(value, list):
elem_types = {self._get_type_info(v) for v in value[:3]}
return f"list[{'|'.join(elem_types) if elem_types else '...'}]"
if isinstance(value, tuple):
elem_types = [self._get_type_info(v) for v in value]
return f"tuple({', '.join(elem_types)})"
if isinstance(value, dict):
if not value:
return "dict"
k, v = next(iter(value.items()))
return f"dict[{self._get_type_info(k)}: {self._get_type_info(v)}]"
if isinstance(value, np.ndarray):
return f"ndarray{value.shape} ({value.dtype})"
return type(value).__name__