Source code for lmflow.pipeline.sglang_inferencer

#!/usr/bin/env python
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved.
import json
import logging
from typing import Optional, Union

from transformers import AutoTokenizer

from lmflow.args import (
    DatasetArguments,
    InferencerArguments,
    ModelArguments,
)
from lmflow.datasets import Dataset
from lmflow.models.hf_decoder_model import HFDecoderModel
from lmflow.pipeline.base_pipeline import BasePipeline
from lmflow.utils.versioning import is_sglang_available

[docs] logger = logging.getLogger(__name__)
if is_sglang_available(): pass else: raise ImportError("SGLang is not available, please install sglang using `pip install -e .[sglang]`.")
[docs] class SGLangInferencer(BasePipeline): def __init__( self, model_args: ModelArguments, data_args: DatasetArguments, inferencer_args: InferencerArguments, ): assert inferencer_args.inference_engine == "sglang"
[docs] self.model_args = model_args
[docs] self.data_args = data_args
[docs] self.inferencer_args = inferencer_args
[docs] self.eos_token_id = AutoTokenizer.from_pretrained(model_args.model_name_or_path).eos_token_id
[docs] self.sampling_params = self._parse_args_to_sampling_params(inferencer_args)
[docs] def _parse_args_to_sampling_params( self, inference_args: InferencerArguments, ) -> dict: if inference_args.use_beam_search: logger.warning("`use_beam_search` is ignored, as SGLang does not support currently.") sampling_params = { "n": inference_args.num_output_sequences, "temperature": inference_args.temperature + 1e-6, "max_new_tokens": inference_args.max_new_tokens, "sampling_seed": inference_args.random_seed, "top_p": inference_args.top_p, "top_k": inference_args.top_k, "stop_token_ids": [self.eos_token_id] + inference_args.additional_stop_token_ids, } return sampling_params
[docs] def inference( self, model: HFDecoderModel, dataset: Dataset, release_gpu: bool = False, inference_args: Optional[InferencerArguments] = None, ): if inference_args: logger.warning("Overriding the default inference arguments with the provided arguments in .inference()") sampling_params = self._parse_args_to_sampling_params(inference_args) else: sampling_params = self.sampling_params # TODO: we need lmflow data sample protocol for better programming experience, data tracking, etc. model_input = model.prepare_inputs_for_inference( dataset=dataset, apply_chat_template=self.inferencer_args.apply_chat_template, inference_engine="sglang", ) # handling n>1 since we don't want one-to-many mapping model_input = [sample for sample in model_input for _ in range(sampling_params["n"])] outputs = model.inference( inputs=model_input, sampling_params=sampling_params.copy().update({"n": 1}), return_logprob=self.inferencer_args.return_logprob, release_gpu=release_gpu, inference_engine="sglang", gpu_memory_utilization=self.inferencer_args.inference_gpu_memory_utilization, tensor_parallel_size=self.inferencer_args.inference_tensor_parallel_size, enable_deterministic_inference=self.inferencer_args.enable_deterministic_inference, attention_backend=self.inferencer_args.attention_backend, ) if self.inferencer_args.save_results: self.save_inference_results(outputs, self.inferencer_args.results_path) return outputs
[docs] def save_inference_results( self, outputs: Union[list[list[str]], list[list[list[int]]]], save_file_path: str, ): with open(save_file_path, "w", encoding="utf-8") as f: json.dump(outputs, f, ensure_ascii=False, indent=4) logger.info(f"Inference results are saved to {save_file_path}.")
[docs] def load_inference_results( self, results_path: str, ) -> Union[list[list[str]], list[list[list[int]]]]: with open(results_path) as f: results = json.load(f) return results