lmflow.pipeline.rm_inferencer#

Attributes#

Classes#

RewardModelInferencer

Initializes the Inferencer class with given arguments.

Module Contents#

lmflow.pipeline.rm_inferencer.logger[source]#
class lmflow.pipeline.rm_inferencer.RewardModelInferencer(model_args: lmflow.args.ModelArguments, data_args: lmflow.args.DatasetArguments, inferencer_args: lmflow.args.InferencerArguments, **kwargs)[source]#

Bases: lmflow.pipeline.base_pipeline.BasePipeline

Initializes the Inferencer class with given arguments.

Parameters:
model_argsModelArguments object.

Contains the arguments required to load the model.

data_argsDatasetArguments object.

Contains the arguments required to load the dataset.

inferencer_argsInferencerArguments object.

Contains the arguments required to perform inference.

data_args[source]#
inferencer_args[source]#
model_args[source]#
local_rank[source]#
world_size[source]#
inference(model: lmflow.models.hf_text_regression_model.HFTextRegressionModel, dataset: lmflow.datasets.dataset.Dataset, transform_dataset_in_place: bool = True, use_vllm: bool = False, enable_distributed_inference: bool = False, **kwargs) lmflow.datasets.dataset.Dataset[source]#
_inference(model: lmflow.models.hf_text_regression_model.HFTextRegressionModel, model_input: lmflow.datasets.dataset.Dataset | ray.data.Dataset, enable_distributed_inference: bool = False, **kwargs)[source]#
__inference(model: lmflow.models.hf_text_regression_model.HFTextRegressionModel, model_input: lmflow.datasets.dataset.Dataset) list[float] | list[list[float]][source]#
__distributed_inference(model: lmflow.models.hf_text_regression_model.HFTextRegressionModel, model_input: ray.data.Dataset, num_instances: int, batch_size: int) list[lmflow.utils.data_utils.RewardModelInferenceResultWithInput][source]#
abstractmethod __vllm_inference(model: lmflow.models.hf_text_regression_model.HFTextRegressionModel, model_input: list[str], enable_distributed_inference: bool = False) list[float][source]#
__post_process_model_output(model_output: transformers.modeling_outputs.SequenceClassifierOutputWithPast) list[float][source]#
flatten_list(list_of_list: list[list]) tuple[list, list[int]][source]#
compress_list(list_to_compress: list, sublist_lengths: list[int]) list[list][source]#