Source code for lmflow.pipeline.auto_pipeline
#!/usr/bin/env python
"""Return a pipeline automatically based on its name."""
from lmflow.pipeline.evaluator import Evaluator
from lmflow.pipeline.finetuner import Finetuner
from lmflow.pipeline.inferencer import Inferencer
from lmflow.pipeline.rm_inferencer import RewardModelInferencer
from lmflow.pipeline.rm_tuner import RewardModelTuner
from lmflow.utils.versioning import is_package_version_at_least, is_ray_available, is_trl_available, is_vllm_available
[docs]
PIPELINE_MAPPING = {
"evaluator": Evaluator,
"finetuner": Finetuner,
"inferencer": Inferencer,
"rm_inferencer": RewardModelInferencer,
"rm_tuner": RewardModelTuner,
}
if not is_package_version_at_least("transformers", "4.35.0"):
from lmflow.pipeline.raft_aligner import RaftAligner
PIPELINE_MAPPING["raft_aligner"] = RaftAligner
else:
PIPELINE_NEEDS_EXTRAS.append("raft_aligner")
if is_vllm_available():
from lmflow.pipeline.vllm_inferencer import VLLMInferencer
PIPELINE_MAPPING["vllm_inferencer"] = VLLMInferencer
else:
PIPELINE_NEEDS_EXTRAS.append("vllm_inferencer")
if is_trl_available():
from lmflow.pipeline.dpo_aligner import DPOAligner
from lmflow.pipeline.dpov2_aligner import DPOv2Aligner
PIPELINE_MAPPING["dpo_aligner"] = DPOAligner
PIPELINE_MAPPING["dpov2_aligner"] = DPOv2Aligner
else:
PIPELINE_NEEDS_EXTRAS.extend(["dpo_aligner", "dpov2_aligner"])
if is_vllm_available() and is_trl_available() and is_ray_available():
from lmflow.pipeline.iterative_dpo_aligner import IterativeDPOAligner
PIPELINE_MAPPING["iterative_dpo_aligner"] = IterativeDPOAligner
else:
PIPELINE_NEEDS_EXTRAS.append("iterative_dpo_aligner")
[docs]
class AutoPipeline:
"""
The class designed to return a pipeline automatically based on its name.
"""
@classmethod
[docs]
def get_pipeline(self, pipeline_name, model_args, data_args, pipeline_args, *args, **kwargs):
if pipeline_name not in PIPELINE_MAPPING:
if pipeline_name in PIPELINE_NEEDS_EXTRAS:
raise NotImplementedError(
f'Please install the necessary dependencies to use pipeline "{pipeline_name}"'
)
raise NotImplementedError(f'Pipeline "{pipeline_name}" is not supported')
pipeline = PIPELINE_MAPPING[pipeline_name](model_args, data_args, pipeline_args, *args, **kwargs)
return pipeline