lmflow.utils.protocol#

ref: volcengine/verl Implement base data transfer protocol between any two functions, modules. We can subclass Protocol to define more detailed batch info with specific keys

Attributes#

Classes#

DataProtoItem

DataProto

A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions.

Functions#

union_python_dict(dict1, dict2)

Union two dict. Will throw an error if there is an item not the same object with the same key.

union_tensor_dict(→ tensordict.TensorDict)

Union two tensordicts.

_array_equal(→ bool)

Recursively compares two NumPy arrays for strict equality, with special

_deep_equal(→ bool)

Recursively performs a deep comparison between two Python objects.

union_numpy_dict(→ dict[str, numpy.ndarray])

list_of_dict_to_dict_of_list(list_of_dict)

collate_fn(x)

get_tensordict(→ tensordict.TensorDict)

Create a TensorDict from tensors and non-tensor data.

Module Contents#

lmflow.utils.protocol.logger[source]#
lmflow.utils.protocol.union_python_dict(dict1: dict, dict2: dict)[source]#

Union two dict. Will throw an error if there is an item not the same object with the same key.

Args:

dict1: dict2:

Returns:

lmflow.utils.protocol.union_tensor_dict(tensor_dict1: tensordict.TensorDict, tensor_dict2: tensordict.TensorDict) tensordict.TensorDict[source]#

Union two tensordicts.

lmflow.utils.protocol._array_equal(array1: numpy.ndarray, array2: numpy.ndarray, visited: set[int]) bool[source]#

Recursively compares two NumPy arrays for strict equality, with special handling for object-dtype arrays, NaN values, and circular references. This function assumes that the two arguments provided are NumPy arrays.

Args:

array1: The first NumPy array. array2: The second NumPy array.

Returns:

True if the arrays’ dtypes, shapes, and all elements are equal.

lmflow.utils.protocol._deep_equal(a: Any, b: Any, visited: set[int]) bool[source]#

Recursively performs a deep comparison between two Python objects. - Handles NaN values correctly (NaN == NaN evaluates to True). - Handling circular references. - Dispatches to _array_equal if both objects are NumPy arrays. - Otherwise, uses standard ‘==’ comparison.

lmflow.utils.protocol.union_numpy_dict(tensor_dict1: dict[str, numpy.ndarray], tensor_dict2: dict[str, numpy.ndarray]) dict[str, numpy.ndarray][source]#
lmflow.utils.protocol.list_of_dict_to_dict_of_list(list_of_dict: list[dict])[source]#
lmflow.utils.protocol.collate_fn(x: list[DataProtoItem])[source]#
lmflow.utils.protocol.get_tensordict(tensor_dict: dict[str, torch.Tensor | list], non_tensor_dict: dict = None) tensordict.TensorDict[source]#

Create a TensorDict from tensors and non-tensor data.

Automatically handles nested structures in lists by converting them to NonTensorStack. This enables support for: - Lists of lists: [[], [0.5, 0.8], [0.9]] - Lists of dicts: [{“acc”: 1.0}, {“acc”: 0.0}] - Lists of lists of dicts: [[{“content”: “…”, “role”: “user”}]]

Args:

tensor_dict: Dictionary of tensors and lists to include in the TensorDict non_tensor_dict: Dictionary of metadata to store as NonTensorData

Returns:

TensorDict with proper handling of nested structures

Example:
>>> td = get_tensordict(
...     tensor_dict={
...         "obs": torch.randn(3, 4),
...         "turn_scores": [[], [0.5, 0.8], [0.9]]  # Nested list
...     },
...     non_tensor_dict={"experiment": "test"}
... )
class lmflow.utils.protocol.DataProtoItem[source]#
batch: tensordict.TensorDict = None[source]#
non_tensor_batch: dict[source]#
meta_info: dict[source]#
class lmflow.utils.protocol.DataProto[source]#

A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions. It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/. TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the same batch size should be put inside batch.

batch: tensordict.TensorDict = None[source]#
non_tensor_batch: dict[source]#
meta_info: dict[source]#
__post_init__()[source]#
__len__()[source]#
__getitem__(item)[source]#

Enhanced indexing for DataProto objects.

Args:
item: Can be one of:
  • int: A single index

  • slice: A slice object (start:stop:step)

  • list: A list of indices

  • numpy.ndarray: An array of indices

  • torch.Tensor: A tensor of indices

Returns:

DataProto: For all indexing types except single integers DataProtoItem: Only for single integer indices

__getstate__()[source]#
__setstate__(data)[source]#
save_to_disk(filepath)[source]#
static load_from_disk(filepath) DataProto[source]#
print_size(prefix='')[source]#
check_consistency()[source]#

Check the consistency of the DataProto. Mainly for batch and non_tensor_batch We expose this function as a public one so that user can call themselves directly

classmethod from_single_dict(data: dict[str, torch.Tensor | numpy.ndarray], meta_info=None)[source]#

Create a DataProto from a dict of tensors and non_tensors

classmethod from_dict(tensors: dict[str, torch.Tensor] | None = None, non_tensors=None, meta_info=None, num_batch_dims=1)[source]#

Create a DataProto from a dict of tensors. This assumes that 1. All the tensor in tensors have the same dim0 2. Only dim0 is the batch dim

classmethod from_tensordict(tensor_dict: tensordict.TensorDict = None, meta_info=None, num_batch_dims=1)[source]#

Create a DataProto from a TensorDict. This assumes that 1. All the tensor in tensor_dict have the same dim0 2. Only dim0 is the batch dim

to(device) DataProto[source]#

move the batch to device

Args:

device (torch.device, str): torch device

Returns:

DataProto: the current DataProto

select(batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None, deepcopy=False) DataProto[source]#

Select a subset of the DataProto via batch_keys and meta_info_keys

Args:

batch_keys (list, optional): a list of strings indicating the keys in batch to select meta_info_keys (list, optional): a list of keys indicating the meta info to select

Returns:

DataProto: the DataProto with the selected batch_keys and meta_info_keys

select_idxs(idxs)[source]#

Select specific indices from the DataProto.

Args:

idxs (torch.Tensor or numpy.ndarray or list): Indices to select

Returns:

DataProto: A new DataProto containing only the selected indices

slice(start=None, end=None, step=None)[source]#

Slice the DataProto and return a new DataProto object. This is an improved version of direct slicing which returns a DataProtoItem.

Args:

start (int, optional): Start index. Defaults to None (start from beginning). end (int, optional): End index (exclusive). Defaults to None (go to end). step (int, optional): Step size. Defaults to None (step=1).

Returns:

DataProto: A new DataProto containing the sliced data

Examples:

# Using the slice method directly sliced_data = data_proto.slice(10, 20)

# Using enhanced indexing (returns DataProto) sliced_data = data_proto[10:20] sliced_data = data_proto[::2] # Every other element

# Using list indexing (returns DataProto) indices = [1, 5, 10] selected_data = data_proto[indices]

# Single index still returns DataProtoItem single_item = data_proto[5]

pop(batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) DataProto[source]#

Pop a subset of the DataProto via batch_keys and meta_info_keys

Args:

batch_keys (list, optional): a list of strings indicating the keys in batch to pop meta_info_keys (list, optional): a list of keys indicating the meta info to pop

Returns:

DataProto: the DataProto with the poped batch_keys and meta_info_keys

rename(old_keys=None, new_keys=None) DataProto[source]#

Note that this function only rename the key in the batch

union(other: DataProto) DataProto[source]#

Union with another DataProto. Union batch and meta_info separately. Throw an error if

  • there are conflict keys in batch and they are not equal

  • the batch size of two data batch is not the same

  • there are conflict keys in meta_info and they are not the same.

Args:

other (DataProto): another DataProto to union

Returns:

DataProto: the DataProto after union

make_iterator(mini_batch_size, epochs, seed=None, dataloader_kwargs=None)[source]#

Make an iterator from the DataProto. This is built upon that TensorDict can be used as a normal Pytorch dataset. See https://pytorch.org/tensordict/tutorials/data_fashion for more details.

Args:
mini_batch_size (int): mini-batch size when iterating the dataset. We require that

batch.batch_size[0] % mini_batch_size == 0.

epochs (int): number of epochs when iterating the dataset. dataloader_kwargs (Any): internally, it returns a DataLoader over the batch. The

dataloader_kwargs is the kwargs passed to the DataLoader.

Returns:
Iterator: an iterator that yields a mini-batch data at a time. The total number of iteration

steps is self.batch.batch_size * epochs // mini_batch_size

padding(padding_size, padding_candidate='')[source]#

Pad the DataProto by concating with padding_candidate.repeat(padding_size)

Args:

padding_size (int): the number of repeated padding_candidate padding_candidate: the item to be repeated and appended to the DataProto, only supporting [“first”, “last”]

chunk(chunks: int) list[DataProto][source]#

Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split.

Args:

chunks (int): the number of chunks to split on dim=0

Returns:

List[DataProto]: a list of DataProto after splitting

split(split_size: int) list[DataProto][source]#

Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split.

Args:

split_size (int): the size of each split

Returns:

List[DataProto]: a list of DataProto after splitting

static concat(data: list[DataProto]) DataProto[source]#

Concat a list of DataProto. The batch is concatenated among dim=0. The meta_info is merged, with special handling for metrics from different workers.

Args:

data (List[DataProto]): list of DataProto

Returns:

DataProto: concatenated DataProto

reorder(indices)[source]#

Note that this operation is in-place

repeat(repeat_times=2, interleave=True)[source]#

Repeat the batch data a specified number of times.

Args:

repeat_times (int): Number of times to repeat the data. interleave (bool): Whether to interleave the repeated data.

Returns:

DataProto: A new DataProto with repeated data.

unfold_column_chunks(n_split: int, split_keys: list[str] | None = None)[source]#

Split along the second dim into n_split, unfold it to the first dim (batch dim) Useful in passing grouped tensors that doesn’t want to be shuffled in dataset. keys not in split_keys are repeated to match the shape Note that if the split_keys is not provided, it will repeat all the keys in the second dim.

sample_level_repeat(repeat_times)[source]#

Repeat each row of the batch data a specified number of times.

Args:

repeat_times (torch.tensor, list, tuple, ndarray): Number of times to repeat the data.

Returns:

DataProto: A new DataProto with repeated data.

to_tensordict() tensordict.TensorDict[source]#

Convert this DataProto to TensorDict. Note that this requires tensordict version at least 0.10

Returns:

get_data_info() str[source]#

Return formatted information about stored data with nested type details.

Returns:

str: Formatted string showing tensor details and recursive metadata types

_get_type_info(value)[source]#

Recursively get type information for nested structures