lmflow.utils.protocol ===================== .. py:module:: lmflow.utils.protocol .. autoapi-nested-parse:: ref: https://github.com/volcengine/verl/blob/main/verl/protocol.py Implement base data transfer protocol between any two functions, modules. We can subclass Protocol to define more detailed batch info with specific keys .. !! processed by numpydoc !! Attributes ---------- .. autoapisummary:: lmflow.utils.protocol.logger Classes ------- .. autoapisummary:: lmflow.utils.protocol.DataProtoItem lmflow.utils.protocol.DataProto Functions --------- .. autoapisummary:: lmflow.utils.protocol.union_python_dict lmflow.utils.protocol.union_tensor_dict lmflow.utils.protocol._array_equal lmflow.utils.protocol._deep_equal lmflow.utils.protocol.union_numpy_dict lmflow.utils.protocol.list_of_dict_to_dict_of_list lmflow.utils.protocol.collate_fn lmflow.utils.protocol.get_tensordict Module Contents --------------- .. py:data:: logger .. py:function:: union_python_dict(dict1: dict, dict2: dict) Union two dict. Will throw an error if there is an item not the same object with the same key. Args: dict1: dict2: Returns: .. !! processed by numpydoc !! .. py:function:: union_tensor_dict(tensor_dict1: tensordict.TensorDict, tensor_dict2: tensordict.TensorDict) -> tensordict.TensorDict Union two tensordicts. .. !! processed by numpydoc !! .. py:function:: _array_equal(array1: numpy.ndarray, array2: numpy.ndarray, visited: set[int]) -> bool Recursively compares two NumPy arrays for strict equality, with special handling for object-dtype arrays, NaN values, and circular references. This function assumes that the two arguments provided are NumPy arrays. Args: array1: The first NumPy array. array2: The second NumPy array. Returns: True if the arrays' dtypes, shapes, and all elements are equal. .. !! processed by numpydoc !! .. py:function:: _deep_equal(a: Any, b: Any, visited: set[int]) -> bool Recursively performs a deep comparison between two Python objects. - Handles NaN values correctly (NaN == NaN evaluates to True). - Handling circular references. - Dispatches to _array_equal if both objects are NumPy arrays. - Otherwise, uses standard '==' comparison. .. !! processed by numpydoc !! .. py:function:: union_numpy_dict(tensor_dict1: dict[str, numpy.ndarray], tensor_dict2: dict[str, numpy.ndarray]) -> dict[str, numpy.ndarray] .. py:function:: list_of_dict_to_dict_of_list(list_of_dict: list[dict]) .. py:function:: collate_fn(x: list[DataProtoItem]) .. py:function:: get_tensordict(tensor_dict: dict[str, torch.Tensor | list], non_tensor_dict: dict = None) -> tensordict.TensorDict Create a TensorDict from tensors and non-tensor data. Automatically handles nested structures in lists by converting them to NonTensorStack. This enables support for: - Lists of lists: [[], [0.5, 0.8], [0.9]] - Lists of dicts: [{"acc": 1.0}, {"acc": 0.0}] - Lists of lists of dicts: [[{"content": "...", "role": "user"}]] Args: tensor_dict: Dictionary of tensors and lists to include in the TensorDict non_tensor_dict: Dictionary of metadata to store as NonTensorData Returns: TensorDict with proper handling of nested structures Example: >>> td = get_tensordict( ... tensor_dict={ ... "obs": torch.randn(3, 4), ... "turn_scores": [[], [0.5, 0.8], [0.9]] # Nested list ... }, ... non_tensor_dict={"experiment": "test"} ... ) .. !! processed by numpydoc !! .. py:class:: DataProtoItem .. py:attribute:: batch :type: tensordict.TensorDict :value: None .. py:attribute:: non_tensor_batch :type: dict .. py:attribute:: meta_info :type: dict .. py:class:: DataProto A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions. It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/. TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the same batch size should be put inside batch. .. !! processed by numpydoc !! .. py:attribute:: batch :type: tensordict.TensorDict :value: None .. py:attribute:: non_tensor_batch :type: dict .. py:attribute:: meta_info :type: dict .. py:method:: __post_init__() .. py:method:: __len__() .. py:method:: __getitem__(item) Enhanced indexing for DataProto objects. Args: item: Can be one of: - int: A single index - slice: A slice object (start:stop:step) - list: A list of indices - numpy.ndarray: An array of indices - torch.Tensor: A tensor of indices Returns: DataProto: For all indexing types except single integers DataProtoItem: Only for single integer indices .. !! processed by numpydoc !! .. py:method:: __getstate__() .. py:method:: __setstate__(data) .. py:method:: save_to_disk(filepath) .. py:method:: load_from_disk(filepath) -> DataProto :staticmethod: .. py:method:: print_size(prefix='') .. py:method:: check_consistency() Check the consistency of the DataProto. Mainly for batch and non_tensor_batch We expose this function as a public one so that user can call themselves directly .. !! processed by numpydoc !! .. py:method:: from_single_dict(data: dict[str, torch.Tensor | numpy.ndarray], meta_info=None) :classmethod: Create a DataProto from a dict of tensors and non_tensors .. !! processed by numpydoc !! .. py:method:: from_dict(tensors: Optional[dict[str, torch.Tensor]] = None, non_tensors=None, meta_info=None, num_batch_dims=1) :classmethod: Create a DataProto from a dict of tensors. This assumes that 1. All the tensor in tensors have the same dim0 2. Only dim0 is the batch dim .. !! processed by numpydoc !! .. py:method:: from_tensordict(tensor_dict: tensordict.TensorDict = None, meta_info=None, num_batch_dims=1) :classmethod: Create a DataProto from a TensorDict. This assumes that 1. All the tensor in tensor_dict have the same dim0 2. Only dim0 is the batch dim .. !! processed by numpydoc !! .. py:method:: to(device) -> DataProto move the batch to device Args: device (torch.device, str): torch device Returns: DataProto: the current DataProto .. !! processed by numpydoc !! .. py:method:: select(batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None, deepcopy=False) -> DataProto Select a subset of the DataProto via batch_keys and meta_info_keys Args: batch_keys (list, optional): a list of strings indicating the keys in batch to select meta_info_keys (list, optional): a list of keys indicating the meta info to select Returns: DataProto: the DataProto with the selected batch_keys and meta_info_keys .. !! processed by numpydoc !! .. py:method:: select_idxs(idxs) Select specific indices from the DataProto. Args: idxs (torch.Tensor or numpy.ndarray or list): Indices to select Returns: DataProto: A new DataProto containing only the selected indices .. !! processed by numpydoc !! .. py:method:: slice(start=None, end=None, step=None) Slice the DataProto and return a new DataProto object. This is an improved version of direct slicing which returns a DataProtoItem. Args: start (int, optional): Start index. Defaults to None (start from beginning). end (int, optional): End index (exclusive). Defaults to None (go to end). step (int, optional): Step size. Defaults to None (step=1). Returns: DataProto: A new DataProto containing the sliced data Examples: # Using the slice method directly sliced_data = data_proto.slice(10, 20) # Using enhanced indexing (returns DataProto) sliced_data = data_proto[10:20] sliced_data = data_proto[::2] # Every other element # Using list indexing (returns DataProto) indices = [1, 5, 10] selected_data = data_proto[indices] # Single index still returns DataProtoItem single_item = data_proto[5] .. !! processed by numpydoc !! .. py:method:: pop(batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> DataProto Pop a subset of the DataProto via `batch_keys` and `meta_info_keys` Args: batch_keys (list, optional): a list of strings indicating the keys in batch to pop meta_info_keys (list, optional): a list of keys indicating the meta info to pop Returns: DataProto: the DataProto with the poped batch_keys and meta_info_keys .. !! processed by numpydoc !! .. py:method:: rename(old_keys=None, new_keys=None) -> DataProto Note that this function only rename the key in the batch .. !! processed by numpydoc !! .. py:method:: union(other: DataProto) -> DataProto Union with another DataProto. Union batch and meta_info separately. Throw an error if - there are conflict keys in batch and they are not equal - the batch size of two data batch is not the same - there are conflict keys in meta_info and they are not the same. Args: other (DataProto): another DataProto to union Returns: DataProto: the DataProto after union .. !! processed by numpydoc !! .. py:method:: make_iterator(mini_batch_size, epochs, seed=None, dataloader_kwargs=None) Make an iterator from the DataProto. This is built upon that TensorDict can be used as a normal Pytorch dataset. See https://pytorch.org/tensordict/tutorials/data_fashion for more details. Args: mini_batch_size (int): mini-batch size when iterating the dataset. We require that ``batch.batch_size[0] % mini_batch_size == 0``. epochs (int): number of epochs when iterating the dataset. dataloader_kwargs (Any): internally, it returns a DataLoader over the batch. The dataloader_kwargs is the kwargs passed to the DataLoader. Returns: Iterator: an iterator that yields a mini-batch data at a time. The total number of iteration steps is ``self.batch.batch_size * epochs // mini_batch_size`` .. !! processed by numpydoc !! .. py:method:: padding(padding_size, padding_candidate='') Pad the DataProto by concating with padding_candidate.repeat(padding_size) Args: padding_size (int): the number of repeated padding_candidate padding_candidate: the item to be repeated and appended to the DataProto, only supporting ["first", "last"] .. !! processed by numpydoc !! .. py:method:: chunk(chunks: int) -> list[DataProto] Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split. Args: chunks (int): the number of chunks to split on dim=0 Returns: List[DataProto]: a list of DataProto after splitting .. !! processed by numpydoc !! .. py:method:: split(split_size: int) -> list[DataProto] Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split. Args: split_size (int): the size of each split Returns: List[DataProto]: a list of DataProto after splitting .. !! processed by numpydoc !! .. py:method:: concat(data: list[DataProto]) -> DataProto :staticmethod: Concat a list of DataProto. The batch is concatenated among dim=0. The meta_info is merged, with special handling for metrics from different workers. Args: data (List[DataProto]): list of DataProto Returns: DataProto: concatenated DataProto .. !! processed by numpydoc !! .. py:method:: reorder(indices) Note that this operation is in-place .. !! processed by numpydoc !! .. py:method:: repeat(repeat_times=2, interleave=True) Repeat the batch data a specified number of times. Args: repeat_times (int): Number of times to repeat the data. interleave (bool): Whether to interleave the repeated data. Returns: DataProto: A new DataProto with repeated data. .. !! processed by numpydoc !! .. py:method:: unfold_column_chunks(n_split: int, split_keys: Optional[list[str]] = None) Split along the second dim into `n_split`, unfold it to the first dim (batch dim) Useful in passing grouped tensors that doesn't want to be shuffled in dataset. keys not in split_keys are repeated to match the shape Note that if the `split_keys` is not provided, it will repeat all the keys in the second dim. .. !! processed by numpydoc !! .. py:method:: sample_level_repeat(repeat_times) Repeat each row of the batch data a specified number of times. Args: repeat_times (torch.tensor, list, tuple, ndarray): Number of times to repeat the data. Returns: DataProto: A new DataProto with repeated data. .. !! processed by numpydoc !! .. py:method:: to_tensordict() -> tensordict.TensorDict Convert this DataProto to TensorDict. Note that this requires tensordict version at least 0.10 Returns: .. !! processed by numpydoc !! .. py:method:: get_data_info() -> str Return formatted information about stored data with nested type details. Returns: str: Formatted string showing tensor details and recursive metadata types .. !! processed by numpydoc !! .. py:method:: _get_type_info(value) Recursively get type information for nested structures .. !! processed by numpydoc !!