Source code for lmflow.pipeline.vllm_inferencer

#!/usr/bin/env python
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved.
import importlib.resources as pkg_resources
import logging
import os

os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
import subprocess
import sys
from typing import Optional

from transformers import AutoTokenizer

from lmflow.args import (
    DatasetArguments,
    InferencerArguments,
    ModelArguments,
)
from lmflow.datasets import Dataset
from lmflow.models.hf_decoder_model import HFDecoderModel
from lmflow.pipeline.base_pipeline import BasePipeline
from lmflow.utils.common import make_shell_args_from_dataclass
from lmflow.utils.constants import MEMORY_SAFE_VLLM_INFERENCE_ENV_VAR_TO_REMOVE, RETURN_CODE_ERROR_BUFFER
from lmflow.utils.protocol import DataProto
from lmflow.utils.versioning import is_vllm_available

[docs] logger = logging.getLogger(__name__)
if is_vllm_available(): pass else: raise ImportError("VLLM is not available, please install vllm.")
[docs] class VLLMInferencer(BasePipeline): def __init__( self, model_args: ModelArguments, data_args: DatasetArguments, inferencer_args: InferencerArguments, ): assert inferencer_args.inference_engine == "vllm"
[docs] self.model_args = model_args
[docs] self.data_args = data_args
[docs] self.inferencer_args = inferencer_args
[docs] self.eos_token_id = AutoTokenizer.from_pretrained(model_args.model_name_or_path).eos_token_id
[docs] self.sampling_params = self._parse_args_to_sampling_params(inferencer_args)
[docs] def _parse_args_to_sampling_params( self, inference_args: InferencerArguments, ) -> dict: if inference_args.use_beam_search: logger.warning("`use_beam_search` is ignored, as vLLM V1 engine no longer supports beam search.") sampling_params = { "n": inference_args.num_output_sequences, "temperature": inference_args.temperature + 1e-6, "max_new_tokens": inference_args.max_new_tokens, "seed": inference_args.random_seed, "top_p": inference_args.top_p, "top_k": inference_args.top_k, "stop_token_ids": [self.eos_token_id] + inference_args.additional_stop_token_ids, } return sampling_params
[docs] def inference( self, model: HFDecoderModel, dataset: Dataset, release_gpu: bool = False, inference_args: Optional[InferencerArguments] = None, ) -> DataProto: if inference_args: logger.warning("Overriding the default inference arguments with the provided arguments in .inference()") sampling_params = self._parse_args_to_sampling_params(inference_args) else: sampling_params = self.sampling_params model_input = model.prepare_inputs_for_inference( dataset=dataset, apply_chat_template=self.inferencer_args.apply_chat_template, inference_engine="vllm", sampling_params=sampling_params, ) outputs = model.inference( inputs=model_input, release_gpu=release_gpu, inference_engine="vllm", gpu_memory_utilization=self.inferencer_args.inference_gpu_memory_utilization, tensor_parallel_size=self.inferencer_args.inference_tensor_parallel_size, data_parallel_size=self.inferencer_args.inference_data_parallel_size, max_model_len=self.inferencer_args.inference_max_model_len, ) if self.inferencer_args.save_inference_results: self.save_inference_results(outputs, self.inferencer_args.inference_results_path) return outputs
[docs] def save_inference_results( self, outputs: DataProto, inference_results_path: str, ): save_path = os.path.join(inference_results_path, "inference_results.pkl") outputs.save_to_disk(save_path) logger.info(f"Inference results are saved to {save_path}.")
[docs] def load_inference_results( self, inference_results_path: str, ) -> DataProto: load_path = os.path.join(inference_results_path, "inference_results.pkl") return DataProto.load_from_disk(load_path)
[docs] class MemorySafeVLLMInferencer(VLLMInferencer): """Run VLLM inference in a subprocess for memory safety. This is a workaround since vllm cannot release GPU memory properly in-process. See: https://github.com/vllm-project/vllm/issues/1908 """ def __init__( self, model_args: ModelArguments, data_args: DatasetArguments, inferencer_args: InferencerArguments, ): assert inferencer_args.save_inference_results or inferencer_args.save_results, ( "For MemorySafeVLLMInferencer, `save_inference_results` must be True." ) super().__init__(model_args, data_args, inferencer_args)
[docs] self.inferencer_file_path = pkg_resources.files("lmflow.pipeline.utils") / "memory_safe_vllm_inference.py"
[docs] def inference(self) -> DataProto: inferencer_args = make_shell_args_from_dataclass( dataclass_objects=[ self.model_args, self.data_args, self.inferencer_args, ], format="shell", ) cmd = "python " + str(self.inferencer_file_path) + " " + inferencer_args current_env = os.environ.copy() for var in MEMORY_SAFE_VLLM_INFERENCE_ENV_VAR_TO_REMOVE: current_env.pop(var, None) cli_res = subprocess.run( args=cmd, stdout=sys.stdout, stderr=sys.stdout, shell=True, preexec_fn=os.setsid, env=current_env, ) logger.info(f"MemorySafeVLLMInference subprocess run finished, info at finish: {cli_res}") if cli_res.returncode in RETURN_CODE_ERROR_BUFFER: logger.warning( "^^^^^^^^^^ Please ignore the above error, as it comes from the subprocess. " "This may due to a kill signal with unfinished stdout/stderr writing in the subprocess. " ) else: if cli_res.returncode != 0: raise RuntimeError(f"Error during MemorySafeVLLMInference: {cli_res}") inference_results_path = self.inferencer_args.inference_results_path or self.inferencer_args.results_path outputs = self.load_inference_results(inference_results_path) logger.info("MemorySafeVLLMInference result captured.") return outputs