import copy
import json
import logging
from dataclasses import fields
from pathlib import Path
from tqdm import tqdm
from lmflow.args import (
DatasetArguments,
DPOv2AlignerArguments,
InferencerArguments,
IterativeDPOAlignerArguments,
ModelArguments,
)
from lmflow.datasets.dataset import Dataset
from lmflow.models.hf_decoder_model import HFDecoderModel
from lmflow.models.hf_text_regression_model import HFTextRegressionModel
from lmflow.pipeline.base_pipeline import BasePipeline
from lmflow.pipeline.dpov2_aligner import MemorySafeDPOv2Aligner
from lmflow.pipeline.rm_inferencer import RewardModelInferencer
from lmflow.utils.common import print_banner
from lmflow.utils.protocol import DataProto
from lmflow.utils.versioning import is_sglang_available, is_vllm_available
[docs]
logger = logging.getLogger(__name__)
[docs]
class IterativeDPOAligner:
def __init__(
self,
model_args: ModelArguments,
data_args: DatasetArguments,
aligner_args: IterativeDPOAlignerArguments,
ref_model_args: ModelArguments,
reward_model_args: ModelArguments,
**kwargs,
):
[docs]
self.model_args = model_args
[docs]
self.data_args = data_args
[docs]
self.aligner_args = aligner_args
[docs]
self.ref_model_args = ref_model_args
[docs]
self.reward_model_args = reward_model_args
[docs]
self.workspace_path = Path(self.aligner_args.output_dir)
[docs]
def align(self, dataset_list: list[Dataset]):
num_iterations = len(dataset_list)
for iter_idx in tqdm(
range(self.aligner_args.initial_iter_idx, num_iterations), desc="Iterative DPO Align", unit="iteration"
):
if iter_idx == 0:
target_model_args = self.model_args
else:
target_model_args = copy.deepcopy(self.model_args)
target_model_args.model_name_or_path = str(self.workspace_path / f"iteration_{iter_idx}" / "model")
self._align_single_iteration(
iteration_name=f"iteration_{iter_idx + 1}",
target_model_args=target_model_args,
reward_model_args=self.reward_model_args,
ref_model_args=self.ref_model_args,
dataset=dataset_list[iter_idx],
)
[docs]
def _align_single_iteration(
self,
iteration_name: str,
target_model_args: ModelArguments,
reward_model_args: ModelArguments,
ref_model_args: ModelArguments,
dataset: Dataset,
):
if self.aligner_args.do_response_generation:
# generate responses
print_banner(f"Iterative DPO {iteration_name}: Generate responses")
model = HFDecoderModel(model_args=target_model_args, do_train=False)
self._do_target_model_inference(
model=model,
dataset=dataset,
output_dir=str(self.workspace_path / iteration_name),
)
del model
if self.aligner_args.do_scoring:
# reward model scoring
print_banner(f"Iterative DPO {iteration_name}: Reward model scoring")
reward_model = HFTextRegressionModel(
model_args=reward_model_args,
do_train=False,
)
target_model_inference_result_data_args = copy.deepcopy(dataset.data_args)
target_model_inference_result_data_args.dataset_path = str(
self.workspace_path / iteration_name / "target_model_inference_result"
)
target_model_inference_result_data_args.block_size = self.aligner_args.reward_model_inference_block_size
target_model_inference_result_dataset = Dataset(target_model_inference_result_data_args)
self._do_reward_model_inference(
model=reward_model,
dataset=target_model_inference_result_dataset,
output_dir=str(self.workspace_path / iteration_name),
)
del reward_model
if self.aligner_args.do_dpo_align:
# DPO training
print_banner(f"Iterative DPO {iteration_name}: DPO training")
dpo_train_data_args = copy.deepcopy(dataset.data_args)
dpo_train_data_args.dataset_path = str(
self.workspace_path / iteration_name / "reward_model_inference_result"
)
self._do_single_dpo_align(
model_args=target_model_args,
ref_model_args=ref_model_args,
data_args=dpo_train_data_args,
output_dir=str(self.workspace_path / iteration_name / "model"),
iteration_name=iteration_name,
)
[docs]
def _do_target_model_inference(
self,
model: HFDecoderModel,
dataset: Dataset,
output_dir: str,
):
inferencer_args = self._parse_target_model_inference_args(args=self.aligner_args)
inferencer = self._build_response_generator(
model_args=model.model_args,
data_args=dataset.data_args,
inferencer_args=inferencer_args,
)
res = inferencer.inference(model=model, dataset=dataset, release_gpu=True)
instances = self._dataproto_to_text_to_textlist_instances(res)
target_model_inference_result_dir = Path(output_dir) / "target_model_inference_result"
target_model_inference_result_dir.mkdir(parents=True, exist_ok=True)
json.dump(
{"type": "text_to_textlist", "instances": instances},
open(str(target_model_inference_result_dir / "result.json"), "w", encoding="utf-8"),
ensure_ascii=False,
indent=4,
)
@staticmethod
[docs]
def _build_response_generator(
model_args: ModelArguments,
data_args: DatasetArguments,
inferencer_args: InferencerArguments,
) -> BasePipeline:
engine = inferencer_args.inference_engine
if engine == "vllm":
if not is_vllm_available():
raise ImportError('vllm is not installed. Install via `pip install -e ".[vllm]"`.')
from lmflow.pipeline.vllm_inferencer import VLLMInferencer
return VLLMInferencer(model_args, data_args, inferencer_args)
if engine == "sglang":
if not is_sglang_available():
raise ImportError('sglang is not installed. Install via `pip install -e ".[sglang]"`.')
from lmflow.pipeline.sglang_inferencer import SGLangInferencer
return SGLangInferencer(model_args, data_args, inferencer_args)
raise ValueError(
f"iterative_dpo_aligner: unsupported inference_engine={engine!r}. Use 'vllm' or 'sglang'."
)
@staticmethod
[docs]
def _dataproto_to_text_to_textlist_instances(res: DataProto) -> list[dict]:
# VLLMInferencer flattens n samples by repeat-interleaving inputs (see
# HFDecoderModel.prepare_inputs_for_inference); each block of
# `actual_n_rollouts` consecutive rows shares the same prompt. Group
# them back into one instance per prompt.
n_rollouts = res.meta_info["actual_n_rollouts"]
inputs_flat = res.non_tensor_batch["inputs"].tolist()
outputs_flat = res.non_tensor_batch["outputs"].tolist()
if len(inputs_flat) % n_rollouts != 0:
raise ValueError(
f"Inference result length {len(inputs_flat)} is not a multiple of n_rollouts={n_rollouts}"
)
return [
{"input": inputs_flat[i], "output": outputs_flat[i : i + n_rollouts]}
for i in range(0, len(inputs_flat), n_rollouts)
]
[docs]
def _do_reward_model_inference(
self,
model: HFTextRegressionModel,
dataset: Dataset,
output_dir: str,
):
inferencer = RewardModelInferencer(
model_args=model.model_args,
data_args=dataset.data_args,
inferencer_args=self._parse_reward_model_inference_args(self.aligner_args),
)
res = inferencer.inference(
model=model,
dataset=dataset,
transform_dataset_in_place=True,
use_vllm=False,
enable_distributed_inference=self.aligner_args.enable_distributed_inference,
distributed_inference_num_instances=self.aligner_args.distributed_inference_num_instances,
inference_batch_size=self.aligner_args.reward_model_inference_batch_size,
)
reward_model_inference_result_dir = Path(output_dir) / "reward_model_inference_result"
reward_model_inference_result_dir.mkdir(parents=True, exist_ok=True)
res.save(str(reward_model_inference_result_dir / "result.json"))
[docs]
def _do_single_dpo_align(
self,
model_args: ModelArguments,
ref_model_args: ModelArguments,
data_args: DatasetArguments,
output_dir: str,
iteration_name: str,
):
aligner = MemorySafeDPOv2Aligner(
model_args=model_args,
data_args=data_args,
aligner_args=self._parse_dpo_aligner_args(
args=self.aligner_args,
output_dir=output_dir,
iteration_name=iteration_name,
),
ref_model_args=ref_model_args,
)
aligner.align()
[docs]
def _parse_target_model_inference_args(
self,
args: IterativeDPOAlignerArguments,
) -> InferencerArguments:
return self.__filter_args(
mixed_args=args,
target_cls=InferencerArguments,
)
[docs]
def _parse_reward_model_inference_args(
self,
args: IterativeDPOAlignerArguments,
) -> InferencerArguments:
inferencer_args = self.__filter_args(
mixed_args=args,
target_cls=InferencerArguments,
)
return inferencer_args
[docs]
def _parse_dpo_aligner_args(
self,
args: IterativeDPOAlignerArguments,
output_dir: str,
iteration_name: str,
) -> DPOv2AlignerArguments:
aligner_args = self.__filter_args(
mixed_args=args,
target_cls=DPOv2AlignerArguments,
)
aligner_args.output_dir = output_dir
aligner_args.run_name = f"{args.run_name}_{iteration_name}"
return aligner_args
[docs]
def __filter_args(
self,
mixed_args,
target_cls,
):
target_cls_fields = {f.name for f in fields(target_cls) if f.init}
common_fields = {f: getattr(mixed_args, f) for f in target_cls_fields if hasattr(mixed_args, f)}
return target_cls(**common_fields)