#!/usr/bin/env python
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved.
import copy
import importlib.resources as pkg_resources
import json
import logging
import os
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
import subprocess
import sys
from typing import Any, Optional, Union
import numpy as np
from transformers import AutoTokenizer
from lmflow.args import (
    DatasetArguments,
    InferencerArguments,
    ModelArguments,
)
from lmflow.datasets import Dataset
from lmflow.models.hf_decoder_model import HFDecoderModel
from lmflow.pipeline.base_pipeline import BasePipeline
from lmflow.utils.common import make_shell_args_from_dataclass
from lmflow.utils.constants import MEMORY_SAFE_VLLM_INFERENCE_ENV_VAR_TO_REMOVE, RETURN_CODE_ERROR_BUFFER
from lmflow.utils.data_utils import VLLMInferenceResultWithInput
from lmflow.utils.versioning import is_ray_available, is_vllm_available
[docs]
logger = logging.getLogger(__name__) 
if is_vllm_available():
    from vllm import SamplingParams
else:
    raise ImportError("VLLM is not available, please install vllm.")
if is_ray_available():
    import ray
    import ray.data
    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
else:
    logger.warning("Ray is not available, distributed vllm inference will not be supported.")
[docs]
class InferencerWithOffloading(BasePipeline):
    def __init__(
        self,
        model_args: ModelArguments,
        data_args: DatasetArguments,
        inferencer_args: InferencerArguments,
    ):
[docs]
        self.model_args = model_args 
[docs]
        self.data_args = data_args 
[docs]
        self.inferencer_args = inferencer_args 
[docs]
        self.eos_token_id = AutoTokenizer.from_pretrained(model_args.model_name_or_path).eos_token_id 
[docs]
    def inference(self):
        raise NotImplementedError(".inference is not implemented") 
[docs]
    def save_inference_results(self):
        raise NotImplementedError(".save_inference_results is not implemented") 
[docs]
    def load_inference_results(self):
        raise NotImplementedError(".load_inference_results is not implemented") 
 
[docs]
class VLLMInferencer(InferencerWithOffloading):
    def __init__(
        self,
        model_args: ModelArguments,
        data_args: DatasetArguments,
        inferencer_args: InferencerArguments,
    ):
        assert inferencer_args.use_vllm, "The inferencer_args.use_vllm must be True."
        super().__init__(model_args, data_args, inferencer_args)
[docs]
        self.sampling_params = self.parse_to_sampling_params(inferencer_args) 
[docs]
    def parse_to_sampling_params(
        self,
        inference_args: InferencerArguments,
    ) -> SamplingParams:
        return SamplingParams(
            use_beam_search=inference_args.use_beam_search,
            n=inference_args.num_output_sequences,
            temperature=inference_args.temperature + 1e-6,
            max_tokens=inference_args.max_new_tokens,
            seed=inference_args.random_seed,
            top_p=inference_args.top_p,
            top_k=inference_args.top_k,
            stop_token_ids=[self.eos_token_id] + inference_args.additional_stop_token_ids,
        ) 
[docs]
    def inference(
        self,
        model: HFDecoderModel,
        dataset: Dataset,
        enable_decode_inference_result: bool = True,
        release_gpu: bool = False,
        inference_args: Optional[InferencerArguments] = None,
        enable_distributed_inference: bool = False,
        **kwargs,
    ) -> list[VLLMInferenceResultWithInput]:
        """Perform inference using the provided model and dataset. Will save inference results if
        `save_results` is set to True in `inferencer_args`.
        Parameters
        ----------
        model : HFDecoderModel
            LMFlow HFDecoderModel object
        dataset : Dataset
            LMFlow Dataset object
        apply_chat_template : bool, optional
            Whether to apply chat template to the input, by default True.
        enable_decode_inference_result : bool, optional
            Whether to decode after generation, by default False.
        release_gpu : bool, optional
            Whether to release gpu resources, by default False.
        inference_args : InferencerArguments, optional
            by default None
        Returns
        -------
        list[VLLMInferenceResultWithInput]
            Return a list of VLLMInferenceResultWithInput, where each
            element contains the input prompt and the corresponding output.
            When `enable_decode_inference_result = True`, the output would be a list of strings,
            contains sampling_params.n samples for the corresponding prompt.
            When `enable_decode_inference_result = False`, return a list of list of ints
            (token ids, no decoding after generation).
        """
        if inference_args:
            logger.warning("Overriding the default inference arguments with the provided arguments in .inference()")
            sampling_params = self.parse_to_sampling_params(inference_args)
        else:
            sampling_params = self.sampling_params
        sampling_params.detokenize = enable_decode_inference_result
        model_input = model.prepare_inputs_for_inference(
            dataset=dataset,
            apply_chat_template=self.inferencer_args.apply_chat_template,
            use_vllm=self.inferencer_args.use_vllm,
            enable_distributed_inference=enable_distributed_inference,
        )
        if enable_distributed_inference:
            outputs = self._distributed_inference(
                model=model,
                model_input=model_input,
                sampling_params=sampling_params,
                num_instances=kwargs.get("distributed_inference_num_instances"),
                batch_size=kwargs.get("inference_batch_size", 4),
                release_gpu=release_gpu,
            )
        else:
            outputs = self._inference(
                model=model,
                model_input=model_input,
                sampling_params=sampling_params,
                release_gpu=release_gpu,
            )
        if self.inferencer_args.save_results:
            self.save_inference_results(outputs, self.inferencer_args.results_path)
        return outputs 
[docs]
    def _inference(
        self,
        model: HFDecoderModel,
        model_input: list[str],
        sampling_params: SamplingParams,
        release_gpu: bool = False,
    ) -> list[VLLMInferenceResultWithInput]:
        outputs = model.inference(
            inputs=model_input,
            sampling_params=sampling_params,
            release_gpu=release_gpu,
            use_vllm=True,
            vllm_gpu_memory_utilization=self.inferencer_args.vllm_gpu_memory_utilization,
            vllm_tensor_parallel_size=self.inferencer_args.vllm_tensor_parallel_size,
        )
        return outputs 
[docs]
    def _distributed_inference(
        self,
        model: HFDecoderModel,
        model_input: ray.data.Dataset,
        sampling_params: SamplingParams,
        num_instances: int,
        batch_size: int = 4,
        release_gpu: bool = False,
    ) -> list[VLLMInferenceResultWithInput]:
        # prepare distributed inference resources
        # from https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_distributed.py
        ## strategy
        def scheduling_strategy_fn():
            # One bundle per tensor parallel worker
            pg = ray.util.placement_group(
                [{"GPU": 1, "CPU": 1}] * self.inferencer_args.vllm_tensor_parallel_size,
                strategy="STRICT_PACK",
            )
            return dict(
                scheduling_strategy=PlacementGroupSchedulingStrategy(pg, placement_group_capture_child_tasks=True)
            )
        resources_kwarg: dict[str, Any] = {}
        if self.inferencer_args.vllm_tensor_parallel_size == 1:
            # For tensor_parallel_size == 1, we simply set num_gpus=1.
            resources_kwarg["num_gpus"] = 1
        else:
            # Otherwise, we have to set num_gpus=0 and provide
            # a function that will create a placement group for
            # each instance.
            resources_kwarg["num_gpus"] = 0
            resources_kwarg["ray_remote_args_fn"] = scheduling_strategy_fn
        ## predictor
        class DistributedPredictor:
            def __init__(
                self,
                model: HFDecoderModel,
                sampling_params: SamplingParams,
                vllm_gpu_memory_utilization: float,
                vllm_tensor_parallel_size: int,
                release_gpu: bool = False,
            ):
                self.model = copy.deepcopy(model)
                self.model.activate_model_for_inference(
                    use_vllm=True,
                    vllm_gpu_memory_utilization=vllm_gpu_memory_utilization,
                    vllm_tensor_parallel_size=vllm_tensor_parallel_size,
                )
                self.sampling_params = sampling_params
                self.release_gpu = release_gpu
            def __call__(self, batch: dict[str, np.ndarray]):
                """batch: dict[str, np.ndarray], {"item": array(['...', '...', '...', ...])}"""
                batched_inference_res = self.model.inference(
                    inputs=batch["item"],
                    sampling_params=self.sampling_params,
                    release_gpu=self.release_gpu,
                    use_vllm=True,
                )  # this is the postprocessed output, see model.__vllm_inference
                batched_final_res = {
                    "input": [sample["input"] for sample in batched_inference_res],
                    "output": [sample["output"] for sample in batched_inference_res],
                }  # do this since we're writing to a pandas dataframe
                return batched_final_res
        # inference
        model_input_mapping = model_input.map_batches(
            DistributedPredictor,
            concurrency=num_instances,  # Set the concurrency to the number of LLM instances.
            batch_size=batch_size,
            fn_constructor_kwargs={
                "model": model,
                "sampling_params": sampling_params,
                "vllm_gpu_memory_utilization": self.inferencer_args.vllm_gpu_memory_utilization,
                "vllm_tensor_parallel_size": self.inferencer_args.vllm_tensor_parallel_size,
                "release_gpu": release_gpu,
            },
            **resources_kwarg,
        )
        df_model_output = model_input_mapping.to_pandas()  # the actual forwards are executed here
        logger.info(f"Distributed vllm inference result preview:\n{df_model_output.head(10)}")
        model_output = [{"input": row["input"], "output": row["output"]} for _, row in df_model_output[:].iterrows()]
        return model_output 
[docs]
    def save_inference_results(
        self,
        outputs: Union[list[list[str]], list[list[list[int]]]],
        save_file_path: str,
    ):
        with open(save_file_path, "w", encoding="utf-8") as f:
            json.dump(outputs, f, ensure_ascii=False, indent=4)
        logger.info(f"Inference results are saved to {save_file_path}.") 
[docs]
    def load_inference_results(
        self,
        results_path: str,
    ) -> Union[list[list[str]], list[list[list[int]]]]:
        with open(results_path) as f:
            results = json.load(f)
        return results 
 
[docs]
class MemorySafeVLLMInferencer(VLLMInferencer):
    def __init__(
        self,
        model_args: ModelArguments,
        data_args: DatasetArguments,
        inferencer_args: InferencerArguments,
    ):
        assert inferencer_args.save_results, "For MemorySafeVLLMInferencer, `save_results` must be True."
        super().__init__(model_args, data_args, inferencer_args)
[docs]
        self.inferencer_file_path = pkg_resources.files("lmflow.pipeline.utils") / "memory_safe_vllm_inference.py" 
[docs]
    def inference(self) -> list[VLLMInferenceResultWithInput]:
        inferencer_args = make_shell_args_from_dataclass(
            dataclass_objects=[
                self.model_args,
                self.data_args,
                self.inferencer_args,
            ],
            format="shell",
        )
        cmd = "python " + str(self.inferencer_file_path) + " " + inferencer_args
        current_env = os.environ.copy()
        for var in MEMORY_SAFE_VLLM_INFERENCE_ENV_VAR_TO_REMOVE:
            current_env.pop(var, None)
        cli_res = subprocess.run(
            args=cmd,
            stdout=sys.stdout,
            stderr=sys.stdout,
            shell=True,
            preexec_fn=os.setsid,
            env=current_env,
        )
        logger.info(f"MemorySafeVLLMInference subprocess run finished, info at finish: {cli_res}")
        if cli_res.returncode in RETURN_CODE_ERROR_BUFFER:
            # > Fatal Python error: _enter_buffered_busy: could not acquire lock for
            # > <_io.BufferedWriter name='<stdout>'> at interpreter shutdown, possibly
            # > due to daemon threads
            logger.warning(
                "^^^^^^^^^^ Please ignore the above error, as it comes from the subprocess. "
                "This may due to a kill signal with unfinished stdout/stderr writing in the subprocess. "
            )
        else:
            if cli_res.returncode != 0:
                raise RuntimeError(f"Error during MemorySafeVLLMInference: {cli_res}")
        outputs = self.load_inference_results(self.inferencer_args.results_path)
        logger.info("MemorySafeVLLMInference result captured.")
        return outputs