Source code for lmflow.pipeline.auto_pipeline

#!/usr/bin/env python
# coding=utf-8
"""Return a pipeline automatically based on its name.
"""
from lmflow.utils.versioning import (
    is_package_version_at_least, 
    is_vllm_available, 
    is_trl_available, 
    is_ray_available
)

from lmflow.pipeline.evaluator import Evaluator
from lmflow.pipeline.finetuner import Finetuner
from lmflow.pipeline.inferencer import Inferencer
from lmflow.pipeline.rm_tuner import RewardModelTuner
from lmflow.pipeline.rm_inferencer import RewardModelInferencer
[docs] PIPELINE_MAPPING = { "evaluator": Evaluator, "finetuner": Finetuner, "inferencer": Inferencer, "rm_inferencer": RewardModelInferencer, "rm_tuner": RewardModelTuner, }
[docs] PIPELINE_NEEDS_EXTRAS = []
if not is_package_version_at_least('transformers', '4.35.0'): from lmflow.pipeline.raft_aligner import RaftAligner PIPELINE_MAPPING['raft_aligner'] = RaftAligner else: PIPELINE_NEEDS_EXTRAS.append('raft_aligner') if is_vllm_available(): from lmflow.pipeline.vllm_inferencer import VLLMInferencer PIPELINE_MAPPING['vllm_inferencer'] = VLLMInferencer else: PIPELINE_NEEDS_EXTRAS.append('vllm_inferencer') if is_trl_available(): from lmflow.pipeline.dpo_aligner import DPOAligner from lmflow.pipeline.dpov2_aligner import DPOv2Aligner PIPELINE_MAPPING['dpo_aligner'] = DPOAligner PIPELINE_MAPPING['dpov2_aligner'] = DPOv2Aligner else: PIPELINE_NEEDS_EXTRAS.extend(['dpo_aligner', 'dpov2_aligner']) if is_vllm_available() and is_trl_available() and is_ray_available(): from lmflow.pipeline.iterative_dpo_aligner import IterativeDPOAligner PIPELINE_MAPPING['iterative_dpo_aligner'] = IterativeDPOAligner else: PIPELINE_NEEDS_EXTRAS.append('iterative_dpo_aligner')
[docs] class AutoPipeline: """ The class designed to return a pipeline automatically based on its name. """ @classmethod
[docs] def get_pipeline(self, pipeline_name, model_args, data_args, pipeline_args, *args, **kwargs ): if pipeline_name not in PIPELINE_MAPPING: if pipeline_name in PIPELINE_NEEDS_EXTRAS: raise NotImplementedError( f'Please install the necessary dependencies ' f'to use pipeline "{pipeline_name}"' ) raise NotImplementedError( f'Pipeline "{pipeline_name}" is not supported' ) pipeline = PIPELINE_MAPPING[pipeline_name]( model_args, data_args, pipeline_args, *args, **kwargs ) return pipeline