lmflow.utils.protocol#
ref: volcengine/verl Implement base data transfer protocol between any two functions, modules. We can subclass Protocol to define more detailed batch info with specific keys
Attributes#
Classes#
A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions. |
Functions#
|
Union two dict. Will throw an error if there is an item not the same object with the same key. |
|
Union two tensordicts. |
|
Recursively compares two NumPy arrays for strict equality, with special |
|
Recursively performs a deep comparison between two Python objects. |
|
|
|
|
|
|
|
Create a TensorDict from tensors and non-tensor data. |
Module Contents#
- lmflow.utils.protocol.union_python_dict(dict1: dict, dict2: dict)[source]#
Union two dict. Will throw an error if there is an item not the same object with the same key.
- Args:
dict1: dict2:
Returns:
- lmflow.utils.protocol.union_tensor_dict(tensor_dict1: tensordict.TensorDict, tensor_dict2: tensordict.TensorDict) tensordict.TensorDict[source]#
Union two tensordicts.
- lmflow.utils.protocol._array_equal(array1: numpy.ndarray, array2: numpy.ndarray, visited: set[int]) bool[source]#
Recursively compares two NumPy arrays for strict equality, with special handling for object-dtype arrays, NaN values, and circular references. This function assumes that the two arguments provided are NumPy arrays.
- Args:
array1: The first NumPy array. array2: The second NumPy array.
- Returns:
True if the arrays’ dtypes, shapes, and all elements are equal.
- lmflow.utils.protocol._deep_equal(a: Any, b: Any, visited: set[int]) bool[source]#
Recursively performs a deep comparison between two Python objects. - Handles NaN values correctly (NaN == NaN evaluates to True). - Handling circular references. - Dispatches to _array_equal if both objects are NumPy arrays. - Otherwise, uses standard ‘==’ comparison.
- lmflow.utils.protocol.union_numpy_dict(tensor_dict1: dict[str, numpy.ndarray], tensor_dict2: dict[str, numpy.ndarray]) dict[str, numpy.ndarray][source]#
- lmflow.utils.protocol.collate_fn(x: list[DataProtoItem])[source]#
- lmflow.utils.protocol.get_tensordict(tensor_dict: dict[str, torch.Tensor | list], non_tensor_dict: dict = None) tensordict.TensorDict[source]#
Create a TensorDict from tensors and non-tensor data.
Automatically handles nested structures in lists by converting them to NonTensorStack. This enables support for: - Lists of lists: [[], [0.5, 0.8], [0.9]] - Lists of dicts: [{“acc”: 1.0}, {“acc”: 0.0}] - Lists of lists of dicts: [[{“content”: “…”, “role”: “user”}]]
- Args:
tensor_dict: Dictionary of tensors and lists to include in the TensorDict non_tensor_dict: Dictionary of metadata to store as NonTensorData
- Returns:
TensorDict with proper handling of nested structures
- Example:
>>> td = get_tensordict( ... tensor_dict={ ... "obs": torch.randn(3, 4), ... "turn_scores": [[], [0.5, 0.8], [0.9]] # Nested list ... }, ... non_tensor_dict={"experiment": "test"} ... )
- class lmflow.utils.protocol.DataProto[source]#
A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions. It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/. TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the same batch size should be put inside batch.
- __getitem__(item)[source]#
Enhanced indexing for DataProto objects.
- Args:
- item: Can be one of:
int: A single index
slice: A slice object (start:stop:step)
list: A list of indices
numpy.ndarray: An array of indices
torch.Tensor: A tensor of indices
- Returns:
DataProto: For all indexing types except single integers DataProtoItem: Only for single integer indices
- check_consistency()[source]#
Check the consistency of the DataProto. Mainly for batch and non_tensor_batch We expose this function as a public one so that user can call themselves directly
- classmethod from_single_dict(data: dict[str, torch.Tensor | numpy.ndarray], meta_info=None)[source]#
Create a DataProto from a dict of tensors and non_tensors
- classmethod from_dict(tensors: dict[str, torch.Tensor] | None = None, non_tensors=None, meta_info=None, num_batch_dims=1)[source]#
Create a DataProto from a dict of tensors. This assumes that 1. All the tensor in tensors have the same dim0 2. Only dim0 is the batch dim
- classmethod from_tensordict(tensor_dict: tensordict.TensorDict = None, meta_info=None, num_batch_dims=1)[source]#
Create a DataProto from a TensorDict. This assumes that 1. All the tensor in tensor_dict have the same dim0 2. Only dim0 is the batch dim
- to(device) DataProto[source]#
move the batch to device
- Args:
device (torch.device, str): torch device
- Returns:
DataProto: the current DataProto
- select(batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None, deepcopy=False) DataProto[source]#
Select a subset of the DataProto via batch_keys and meta_info_keys
- Args:
batch_keys (list, optional): a list of strings indicating the keys in batch to select meta_info_keys (list, optional): a list of keys indicating the meta info to select
- Returns:
DataProto: the DataProto with the selected batch_keys and meta_info_keys
- select_idxs(idxs)[source]#
Select specific indices from the DataProto.
- Args:
idxs (torch.Tensor or numpy.ndarray or list): Indices to select
- Returns:
DataProto: A new DataProto containing only the selected indices
- slice(start=None, end=None, step=None)[source]#
Slice the DataProto and return a new DataProto object. This is an improved version of direct slicing which returns a DataProtoItem.
- Args:
start (int, optional): Start index. Defaults to None (start from beginning). end (int, optional): End index (exclusive). Defaults to None (go to end). step (int, optional): Step size. Defaults to None (step=1).
- Returns:
DataProto: A new DataProto containing the sliced data
- Examples:
# Using the slice method directly sliced_data = data_proto.slice(10, 20)
# Using enhanced indexing (returns DataProto) sliced_data = data_proto[10:20] sliced_data = data_proto[::2] # Every other element
# Using list indexing (returns DataProto) indices = [1, 5, 10] selected_data = data_proto[indices]
# Single index still returns DataProtoItem single_item = data_proto[5]
- pop(batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) DataProto[source]#
Pop a subset of the DataProto via batch_keys and meta_info_keys
- Args:
batch_keys (list, optional): a list of strings indicating the keys in batch to pop meta_info_keys (list, optional): a list of keys indicating the meta info to pop
- Returns:
DataProto: the DataProto with the poped batch_keys and meta_info_keys
- rename(old_keys=None, new_keys=None) DataProto[source]#
Note that this function only rename the key in the batch
- union(other: DataProto) DataProto[source]#
Union with another DataProto. Union batch and meta_info separately. Throw an error if
there are conflict keys in batch and they are not equal
the batch size of two data batch is not the same
there are conflict keys in meta_info and they are not the same.
- Args:
other (DataProto): another DataProto to union
- Returns:
DataProto: the DataProto after union
- make_iterator(mini_batch_size, epochs, seed=None, dataloader_kwargs=None)[source]#
Make an iterator from the DataProto. This is built upon that TensorDict can be used as a normal Pytorch dataset. See https://pytorch.org/tensordict/tutorials/data_fashion for more details.
- Args:
- mini_batch_size (int): mini-batch size when iterating the dataset. We require that
batch.batch_size[0] % mini_batch_size == 0.
epochs (int): number of epochs when iterating the dataset. dataloader_kwargs (Any): internally, it returns a DataLoader over the batch. The
dataloader_kwargs is the kwargs passed to the DataLoader.
- Returns:
- Iterator: an iterator that yields a mini-batch data at a time. The total number of iteration
steps is
self.batch.batch_size * epochs // mini_batch_size
- padding(padding_size, padding_candidate='')[source]#
Pad the DataProto by concating with padding_candidate.repeat(padding_size)
- Args:
padding_size (int): the number of repeated padding_candidate padding_candidate: the item to be repeated and appended to the DataProto, only supporting [“first”, “last”]
- chunk(chunks: int) list[DataProto][source]#
Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split.
- Args:
chunks (int): the number of chunks to split on dim=0
- Returns:
List[DataProto]: a list of DataProto after splitting
- split(split_size: int) list[DataProto][source]#
Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split.
- Args:
split_size (int): the size of each split
- Returns:
List[DataProto]: a list of DataProto after splitting
- static concat(data: list[DataProto]) DataProto[source]#
Concat a list of DataProto. The batch is concatenated among dim=0. The meta_info is merged, with special handling for metrics from different workers.
- Args:
data (List[DataProto]): list of DataProto
- Returns:
DataProto: concatenated DataProto
- repeat(repeat_times=2, interleave=True)[source]#
Repeat the batch data a specified number of times.
- Args:
repeat_times (int): Number of times to repeat the data. interleave (bool): Whether to interleave the repeated data.
- Returns:
DataProto: A new DataProto with repeated data.
- unfold_column_chunks(n_split: int, split_keys: list[str] | None = None)[source]#
Split along the second dim into n_split, unfold it to the first dim (batch dim) Useful in passing grouped tensors that doesn’t want to be shuffled in dataset. keys not in split_keys are repeated to match the shape Note that if the split_keys is not provided, it will repeat all the keys in the second dim.
- sample_level_repeat(repeat_times)[source]#
Repeat each row of the batch data a specified number of times.
- Args:
repeat_times (torch.tensor, list, tuple, ndarray): Number of times to repeat the data.
- Returns:
DataProto: A new DataProto with repeated data.
- to_tensordict() tensordict.TensorDict[source]#
Convert this DataProto to TensorDict. Note that this requires tensordict version at least 0.10
Returns: