Source code for lmflow.pipeline.sglang_inferencer
#!/usr/bin/env python
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved.
import json
import logging
from typing import Optional, Union
from transformers import AutoTokenizer
from lmflow.args import (
DatasetArguments,
InferencerArguments,
ModelArguments,
)
from lmflow.datasets import Dataset
from lmflow.models.hf_decoder_model import HFDecoderModel
from lmflow.pipeline.base_pipeline import BasePipeline
from lmflow.utils.versioning import is_sglang_available
from lmflow.utils.protocol import DataProto
[docs]
logger = logging.getLogger(__name__)
if is_sglang_available():
pass
else:
raise ImportError("SGLang is not available, please install sglang using `pip install -e .[sglang]`.")
[docs]
class SGLangInferencer(BasePipeline):
def __init__(
self,
model_args: ModelArguments,
data_args: DatasetArguments,
inferencer_args: InferencerArguments,
):
assert inferencer_args.inference_engine == "sglang"
[docs]
self.model_args = model_args
[docs]
self.data_args = data_args
[docs]
self.inferencer_args = inferencer_args
[docs]
self.eos_token_id = AutoTokenizer.from_pretrained(model_args.model_name_or_path).eos_token_id
[docs]
self.sampling_params = self._parse_args_to_sampling_params(inferencer_args)
[docs]
def _parse_args_to_sampling_params(
self,
inference_args: InferencerArguments,
) -> dict:
if inference_args.use_beam_search:
logger.warning("`use_beam_search` is ignored, as SGLang does not support currently.")
sampling_params = {
"n": inference_args.num_output_sequences,
"temperature": inference_args.temperature + 1e-6,
"max_new_tokens": inference_args.max_new_tokens,
"sampling_seed": inference_args.random_seed,
"top_p": inference_args.top_p,
"top_k": inference_args.top_k,
"stop_token_ids": [self.eos_token_id] + inference_args.additional_stop_token_ids,
}
return sampling_params
[docs]
def inference(
self,
model: HFDecoderModel,
dataset: Dataset,
release_gpu: bool = False,
inference_args: Optional[InferencerArguments] = None,
) -> DataProto:
if inference_args:
logger.warning("Overriding the default inference arguments with the provided arguments in .inference()")
sampling_params = self._parse_args_to_sampling_params(inference_args)
else:
sampling_params = self.sampling_params
# TODO: we need lmflow data sample protocol for better programming experience, data tracking, etc.
model_input = model.prepare_inputs_for_inference(
dataset=dataset,
apply_chat_template=self.inferencer_args.apply_chat_template,
inference_engine="sglang",
sampling_params=sampling_params,
)
outputs = model.inference(
inputs=model_input,
return_logprob=self.inferencer_args.return_logprob,
release_gpu=release_gpu,
inference_engine="sglang",
gpu_memory_utilization=self.inferencer_args.inference_gpu_memory_utilization,
tensor_parallel_size=self.inferencer_args.inference_tensor_parallel_size,
enable_deterministic_inference=self.inferencer_args.enable_deterministic_inference,
attention_backend=self.inferencer_args.attention_backend,
)
if self.inferencer_args.save_inference_results:
self.save_inference_results(outputs, self.inferencer_args.inference_results_path)
return outputs
[docs]
def save_inference_results(
self,
outputs: DataProto,
inference_results_path: str,
):
if not inference_results_path.endswith(".pkl"):
logger.warning(f"The inference results path must be a pickle file. Change the path to {inference_results_path}.pkl")
inference_results_path = inference_results_path + ".pkl"
outputs.save_to_disk(inference_results_path)
logger.info(f"Inference results are saved to {inference_results_path}.")
[docs]
def load_inference_results(
self,
inference_results_path: str,
) -> DataProto:
return DataProto.load_from_disk(inference_results_path)