import copy
from dataclasses import fields
import gc
import json
import logging
from pathlib import Path
from typing import List, Dict, Any, Optional
from tqdm import tqdm
from lmflow.models.hf_text_regression_model import HFTextRegressionModel
from lmflow.models.hf_decoder_model import HFDecoderModel
from lmflow.datasets.dataset import Dataset
from lmflow.pipeline.dpov2_aligner import MemorySafeDPOv2Aligner
from lmflow.pipeline.rm_inferencer import RewardModelInferencer
from lmflow.pipeline.vllm_inferencer import MemorySafeVLLMInferencer
from lmflow.args import (
ModelArguments,
DatasetArguments,
InferencerArguments,
IterativeDPOAlignerArguments,
DPOv2AlignerArguments,
)
from lmflow.utils.common import print_banner
[docs]
logger = logging.getLogger(__name__)
[docs]
class IterativeDPOAligner:
def __init__(
self,
model_args: ModelArguments,
data_args: DatasetArguments,
aligner_args:IterativeDPOAlignerArguments,
ref_model_args: ModelArguments,
reward_model_args: ModelArguments,
**kwargs,
):
[docs]
self.model_args = model_args
[docs]
self.data_args = data_args
[docs]
self.aligner_args = aligner_args
[docs]
self.ref_model_args = ref_model_args
[docs]
self.reward_model_args = reward_model_args
[docs]
self.workspace_path = Path(self.aligner_args.output_dir)
[docs]
def align(
self,
dataset_list: List[Dataset]
):
num_iterations = len(dataset_list)
for iter_idx in tqdm(
range(self.aligner_args.initial_iter_idx, num_iterations),
desc="Iterative DPO Align",
unit="iteration"
):
if iter_idx == 0:
target_model_args = self.model_args
else:
target_model_args = copy.deepcopy(self.model_args)
target_model_args.model_name_or_path = str(self.workspace_path/f"iteration_{iter_idx}"/"model")
self._align_single_iteration(
iteration_name=f"iteration_{iter_idx+1}",
target_model_args=target_model_args,
reward_model_args=self.reward_model_args,
ref_model_args=self.ref_model_args,
dataset=dataset_list[iter_idx],
)
[docs]
def _align_single_iteration(
self,
iteration_name: str,
target_model_args: ModelArguments,
reward_model_args: ModelArguments,
ref_model_args: ModelArguments,
dataset: Dataset,
):
if self.aligner_args.do_response_generation:
# generate responses
print_banner(f'Iterative DPO {iteration_name}: Generate responses')
model = HFDecoderModel(
model_args=target_model_args,
tune_strategy='none'
)
self._do_target_model_inference(
model=model,
dataset=dataset,
output_dir=str(self.workspace_path/iteration_name),
)
del model
if self.aligner_args.do_scoring:
# reward model scoring
print_banner(f'Iterative DPO {iteration_name}: Reward model scoring')
reward_model = HFTextRegressionModel(
model_args=reward_model_args,
tune_strategy='none',
use_accelerator=self.aligner_args.use_accelerator,
)
target_model_inference_result_data_args = copy.deepcopy(dataset.data_args)
target_model_inference_result_data_args.dataset_path = str(self.workspace_path/iteration_name/"target_model_inference_result")
target_model_inference_result_data_args.block_size = self.aligner_args.reward_model_inference_block_size
target_model_inference_result_dataset = Dataset(target_model_inference_result_data_args)
self._do_reward_model_inference(
model=reward_model,
dataset=target_model_inference_result_dataset,
output_dir=str(self.workspace_path/iteration_name),
)
del reward_model
if self.aligner_args.do_dpo_align:
# DPO training
print_banner(f'Iterative DPO {iteration_name}: DPO training')
dpo_train_data_args = copy.deepcopy(dataset.data_args)
dpo_train_data_args.dataset_path = str(self.workspace_path/iteration_name/"reward_model_inference_result")
self._do_single_dpo_align(
model_args=target_model_args,
ref_model_args=ref_model_args,
data_args=dpo_train_data_args,
output_dir=str(self.workspace_path/iteration_name/"model"),
iteration_name=iteration_name,
)
[docs]
def _do_target_model_inference(
self,
model: HFDecoderModel,
dataset: Dataset,
output_dir: str,
):
result_cache_path = str(Path(output_dir)/"cache"/"target_model_inference_result.json")
inferencer = MemorySafeVLLMInferencer(
model_args=model.model_args,
data_args=dataset.data_args,
inferencer_args=self._parse_target_model_inference_args(
args=self.aligner_args,
result_cache_path=result_cache_path,
),
)
res = inferencer.inference()
dataset_out = {"type": "text_to_textlist", "instances": res}
target_model_inference_result_dir = Path(output_dir)/"target_model_inference_result"
target_model_inference_result_dir.mkdir(parents=True, exist_ok=True)
json.dump(
dataset_out,
open(str(target_model_inference_result_dir/"result.json"), "w", encoding='utf-8'),
ensure_ascii=False,
indent=4,
)
[docs]
def _do_reward_model_inference(
self,
model: HFTextRegressionModel,
dataset: Dataset,
output_dir: str,
):
inferencer = RewardModelInferencer(
model_args=model.model_args,
data_args=dataset.data_args,
inferencer_args=self._parse_reward_model_inference_args(self.aligner_args),
)
res = inferencer.inference(
model=model,
dataset=dataset,
transform_dataset_in_place=True,
use_vllm=False,
enable_distributed_inference=self.aligner_args.enable_distributed_inference,
distributed_inference_num_instances=self.aligner_args.distributed_inference_num_instances,
inference_batch_size=self.aligner_args.reward_model_inference_batch_size,
)
reward_model_inference_result_dir = Path(output_dir)/"reward_model_inference_result"
reward_model_inference_result_dir.mkdir(parents=True, exist_ok=True)
res.save(str(reward_model_inference_result_dir/"result.json"))
[docs]
def _do_single_dpo_align(
self,
model_args: ModelArguments,
ref_model_args: ModelArguments,
data_args: DatasetArguments,
output_dir: str,
iteration_name: str,
):
aligner = MemorySafeDPOv2Aligner(
model_args=model_args,
data_args=data_args,
aligner_args=self._parse_dpo_aligner_args(
args=self.aligner_args,
output_dir=output_dir,
iteration_name=iteration_name,
),
ref_model_args=ref_model_args,
)
aligner.align()
[docs]
def _parse_target_model_inference_args(
self,
args: IterativeDPOAlignerArguments,
result_cache_path: str,
) -> InferencerArguments:
inferencer_args = self.__filter_args(
mixed_args=args,
target_cls=InferencerArguments,
)
inferencer_args.save_results=True
inferencer_args.results_path=result_cache_path
return inferencer_args
[docs]
def _parse_reward_model_inference_args(
self,
args: IterativeDPOAlignerArguments,
) -> InferencerArguments:
inferencer_args = self.__filter_args(
mixed_args=args,
target_cls=InferencerArguments,
)
return inferencer_args
[docs]
def _parse_dpo_aligner_args(
self,
args: IterativeDPOAlignerArguments,
output_dir: str,
iteration_name: str,
) -> DPOv2AlignerArguments:
aligner_args = self.__filter_args(
mixed_args=args,
target_cls=DPOv2AlignerArguments,
)
aligner_args.output_dir = output_dir
aligner_args.run_name = f"{args.run_name}_{iteration_name}"
return aligner_args
[docs]
def __filter_args(
self,
mixed_args,
target_cls,
):
target_cls_fields = {f.name for f in fields(target_cls) if f.init}
common_fields = {f: getattr(mixed_args, f) for f in target_cls_fields if hasattr(mixed_args, f)}
return target_cls(**common_fields)