Source code for lmflow.pipeline.vllm_inferencer

#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved.
import copy
import importlib.resources as pkg_resources
import json
import logging
import os
os.environ['VLLM_WORKER_MULTIPROC_METHOD']='spawn'
import subprocess
import sys
from functools import partial
from typing import List, Union, Optional, Dict, Any

import numpy as np
from transformers import AutoTokenizer

from lmflow.datasets import Dataset
from lmflow.pipeline.base_pipeline import BasePipeline
from lmflow.models.hf_decoder_model import HFDecoderModel
from lmflow.args import (
    InferencerArguments, 
    ModelArguments, 
    DatasetArguments,
)
from lmflow.utils.common import make_shell_args_from_dataclass
from lmflow.utils.constants import RETURN_CODE_ERROR_BUFFER, MEMORY_SAFE_VLLM_INFERENCE_ENV_VAR_TO_REMOVE
from lmflow.utils.data_utils import VLLMInferenceResultWithInput
from lmflow.utils.versioning import is_vllm_available, is_ray_available


[docs] logger = logging.getLogger(__name__)
if is_vllm_available(): from vllm import SamplingParams, LLM else: raise ImportError("VLLM is not available, please install vllm.") if is_ray_available(): import ray import ray.data from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy else: logger.warning("Ray is not available, distributed vllm inference will not be supported.")
[docs] class InferencerWithOffloading(BasePipeline): def __init__( self, model_args: ModelArguments, data_args: DatasetArguments, inferencer_args: InferencerArguments, ):
[docs] self.model_args = model_args
[docs] self.data_args = data_args
[docs] self.inferencer_args = inferencer_args
[docs] self.eos_token_id = AutoTokenizer.from_pretrained(model_args.model_name_or_path).eos_token_id
[docs] def inference(self): raise NotImplementedError(".inference is not implemented")
[docs] def save_inference_results(self): raise NotImplementedError(".save_inference_results is not implemented")
[docs] def load_inference_results(self): raise NotImplementedError(".load_inference_results is not implemented")
[docs] class VLLMInferencer(InferencerWithOffloading): def __init__( self, model_args: ModelArguments, data_args: DatasetArguments, inferencer_args: InferencerArguments, ): assert inferencer_args.use_vllm, "The inferencer_args.use_vllm must be True." super().__init__(model_args, data_args, inferencer_args)
[docs] self.sampling_params = self.parse_to_sampling_params(inferencer_args)
[docs] def parse_to_sampling_params( self, inference_args: InferencerArguments, ) -> SamplingParams: return SamplingParams( use_beam_search=inference_args.use_beam_search, n=inference_args.num_output_sequences, temperature=inference_args.temperature + 1e-6, max_tokens=inference_args.max_new_tokens, seed=inference_args.random_seed, top_p=inference_args.top_p, top_k=inference_args.top_k, stop_token_ids=[self.eos_token_id] + inference_args.additional_stop_token_ids )
[docs] def inference( self, model: HFDecoderModel, dataset: Dataset, enable_decode_inference_result: bool = True, release_gpu: bool = False, inference_args: Optional[InferencerArguments] = None, enable_distributed_inference: bool = False, **kwargs, ) -> List[VLLMInferenceResultWithInput]: """Perform inference using the provided model and dataset. Will save inference results if `save_results` is set to True in `inferencer_args`. Parameters ---------- model : HFDecoderModel LMFlow HFDecoderModel object dataset : Dataset LMFlow Dataset object apply_chat_template : bool, optional Whether to apply chat template to the input, by default True. enable_decode_inference_result : bool, optional Whether to decode after generation, by default False. release_gpu : bool, optional Whether to release gpu resources, by default False. inference_args : InferencerArguments, optional by default None Returns ------- List[VLLMInferenceResultWithInput] Return a list of VLLMInferenceResultWithInput, where each element contains the input prompt and the corresponding output. When `enable_decode_inference_result = True`, the output would be a list of strings, contains sampling_params.n samples for the corresponding prompt. When `enable_decode_inference_result = False`, return a list of list of ints (token ids, no decoding after generation). """ if inference_args: logger.warning( "Overriding the default inference arguments with the provided arguments in .inference()" ) sampling_params = self.parse_to_sampling_params(inference_args) else: sampling_params = self.sampling_params sampling_params.detokenize = enable_decode_inference_result model_input = model.prepare_inputs_for_inference( dataset=dataset, apply_chat_template=self.inferencer_args.apply_chat_template, use_vllm=self.inferencer_args.use_vllm, enable_distributed_inference=enable_distributed_inference, ) if enable_distributed_inference: outputs = self._distributed_inference( model=model, model_input=model_input, sampling_params=sampling_params, num_instances=kwargs.get("distributed_inference_num_instances"), batch_size=kwargs.get("inference_batch_size", 4), release_gpu=release_gpu, ) else: outputs = self._inference( model=model, model_input=model_input, sampling_params=sampling_params, release_gpu=release_gpu, ) if self.inferencer_args.save_results: self.save_inference_results(outputs, self.inferencer_args.results_path) return outputs
[docs] def _inference( self, model: HFDecoderModel, model_input: List[str], sampling_params: SamplingParams, release_gpu: bool = False, ) -> List[VLLMInferenceResultWithInput]: outputs = model.inference( inputs=model_input, sampling_params=sampling_params, release_gpu=release_gpu, use_vllm=True, vllm_gpu_memory_utilization=self.inferencer_args.vllm_gpu_memory_utilization, vllm_tensor_parallel_size=self.inferencer_args.vllm_tensor_parallel_size, ) return outputs
[docs] def _distributed_inference( self, model: HFDecoderModel, model_input: ray.data.Dataset, sampling_params: SamplingParams, num_instances: int, batch_size: int = 4, release_gpu: bool = False, ) -> List[VLLMInferenceResultWithInput]: # prepare distributed inference resources # from https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_distributed.py ## strategy def scheduling_strategy_fn(): # One bundle per tensor parallel worker pg = ray.util.placement_group( [{ "GPU": 1, "CPU": 1 }] * self.inferencer_args.vllm_tensor_parallel_size, strategy="STRICT_PACK", ) return dict( scheduling_strategy=PlacementGroupSchedulingStrategy( pg, placement_group_capture_child_tasks=True ) ) resources_kwarg: Dict[str, Any] = {} if self.inferencer_args.vllm_tensor_parallel_size == 1: # For tensor_parallel_size == 1, we simply set num_gpus=1. resources_kwarg["num_gpus"] = 1 else: # Otherwise, we have to set num_gpus=0 and provide # a function that will create a placement group for # each instance. resources_kwarg["num_gpus"] = 0 resources_kwarg["ray_remote_args_fn"] = scheduling_strategy_fn ## predictor class DistributedPredictor: def __init__( self, model: HFDecoderModel, sampling_params: SamplingParams, vllm_gpu_memory_utilization: float, vllm_tensor_parallel_size: int, release_gpu: bool=False, ): self.model = copy.deepcopy(model) self.model.activate_model_for_inference( use_vllm=True, vllm_gpu_memory_utilization=vllm_gpu_memory_utilization, vllm_tensor_parallel_size=vllm_tensor_parallel_size, ) self.sampling_params = sampling_params self.release_gpu = release_gpu def __call__(self, batch: Dict[str, np.ndarray]): """batch: Dict[str, np.ndarray], {"item": array(['...', '...', '...', ...])} """ batched_inference_res = self.model.inference( inputs=batch['item'], sampling_params=self.sampling_params, release_gpu=self.release_gpu, use_vllm=True, ) # this is the postprocessed output, see model.__vllm_inference batched_final_res = { "input": [sample['input'] for sample in batched_inference_res], "output": [sample['output'] for sample in batched_inference_res] } # do this since we're writing to a pandas dataframe return batched_final_res # inference model_input_mapping = model_input.map_batches( DistributedPredictor, concurrency=num_instances, # Set the concurrency to the number of LLM instances. batch_size=batch_size, fn_constructor_kwargs={ "model": model, "sampling_params": sampling_params, "vllm_gpu_memory_utilization": self.inferencer_args.vllm_gpu_memory_utilization, "vllm_tensor_parallel_size": self.inferencer_args.vllm_tensor_parallel_size, "release_gpu": release_gpu, }, **resources_kwarg, ) df_model_output = model_input_mapping.to_pandas() # the actual forwards are executed here logger.info(f"Distributed vllm inference result preview:\n{df_model_output.head(10)}") model_output = [ {"input": row["input"], "output": row["output"]} for _, row in df_model_output[:].iterrows() ] return model_output
[docs] def save_inference_results( self, outputs: Union[List[List[str]], List[List[List[int]]]], save_file_path: str, ): with open(save_file_path, "w", encoding='utf-8') as f: json.dump(outputs, f, ensure_ascii=False, indent=4) logger.info(f"Inference results are saved to {save_file_path}.")
[docs] def load_inference_results( self, results_path: str, ) -> Union[List[List[str]], List[List[List[int]]]]: with open(results_path, "r") as f: results = json.load(f) return results
[docs] class MemorySafeVLLMInferencer(VLLMInferencer): def __init__( self, model_args: ModelArguments, data_args: DatasetArguments, inferencer_args: InferencerArguments, ): assert inferencer_args.save_results, "For MemorySafeVLLMInferencer, `save_results` must be True." super().__init__(model_args, data_args, inferencer_args)
[docs] self.inferencer_file_path = pkg_resources.files("lmflow.pipeline.utils") / "memory_safe_vllm_inference.py"
[docs] def inference(self) -> List[VLLMInferenceResultWithInput]: inferencer_args = make_shell_args_from_dataclass( dataclass_objects=[ self.model_args, self.data_args, self.inferencer_args, ], format="shell", ) cmd = "python " + str(self.inferencer_file_path) + " " + inferencer_args current_env = os.environ.copy() for var in MEMORY_SAFE_VLLM_INFERENCE_ENV_VAR_TO_REMOVE: current_env.pop(var, None) cli_res = subprocess.run( args=cmd, stdout=sys.stdout, stderr=sys.stdout, shell=True, preexec_fn=os.setsid, env=current_env, ) logger.info(f"MemorySafeVLLMInference subprocess run finished, info at finish: {cli_res}") if cli_res.returncode in RETURN_CODE_ERROR_BUFFER: # > Fatal Python error: _enter_buffered_busy: could not acquire lock for <_io.BufferedWriter name='<stdout>'> # > at interpreter shutdown, possibly due to daemon threads logger.warning( "^^^^^^^^^^ Please ignore the above error, as it comes from the subprocess. " "This may due to a kill signal with unfinished stdout/stderr writing in the subprocess. " ) else: if cli_res.returncode != 0: raise RuntimeError(f"Error during MemorySafeVLLMInference: {cli_res}") outputs = self.load_inference_results(self.inferencer_args.results_path) logger.info("MemorySafeVLLMInference result captured.") return outputs