Source code for lmflow.pipeline.utils.memory_safe_vllm_inference

#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved.

# Note that this is only a workaround, since vllm
# inference engine cannot release GPU memory properly by now. Please see this github 
# [issue](https://github.com/vllm-project/vllm/issues/1908).

import logging
import sys
import os
from typing import Dict

from transformers import (
    HfArgumentParser
)

from lmflow.datasets import Dataset
from lmflow.models.auto_model import AutoModel
from lmflow.pipeline.vllm_inferencer import VLLMInferencer
from lmflow.args import (
    ModelArguments, 
    DatasetArguments, 
    AutoArguments,
)
from lmflow.utils.constants import MEMORY_SAFE_VLLM_INFERENCE_FINISH_FLAG


[docs] logger = logging.getLogger(__name__)
[docs] def main(): # Parses arguments pipeline_name = "vllm_inferencer" PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name) parser = HfArgumentParser(( ModelArguments, DatasetArguments, PipelineArguments )) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, pipeline_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, pipeline_args = parser.parse_args_into_dataclasses() dataset = Dataset(data_args) model = AutoModel.get_model(model_args, tune_strategy='none') inferencer = VLLMInferencer(model_args, data_args, pipeline_args) res = inferencer.inference( model, dataset, release_gpu=False, enable_decode_inference_result=pipeline_args.enable_decode_inference_result, enable_distributed_inference=pipeline_args.enable_distributed_inference, distributed_inference_num_instances=pipeline_args.distributed_inference_num_instances, inference_batch_size=pipeline_args.vllm_inference_batch_size, ) # use this as a flag, stdout will be captured by the pipeline print(MEMORY_SAFE_VLLM_INFERENCE_FINISH_FLAG)
if __name__ == "__main__": main()