import copy
import importlib.resources as pkg_resources
import logging
import os
import subprocess
import sys
from typing import Optional, List, Tuple, Dict, Union
import numpy as np
from tqdm import tqdm
import torch
from transformers import TrainingArguments
from lmflow.pipeline.utils.dpov2_trainer import DPOv2Trainer
from lmflow.pipeline.base_aligner import BaseAligner
from lmflow.args import (
ModelArguments,
DatasetArguments,
DPOv2AlignerArguments
)
from lmflow.utils.common import (
make_shell_args_from_dataclass,
add_dataclass_attr_prefix,
create_copied_dataclass
)
from lmflow.models.hf_decoder_model import HFDecoderModel
from lmflow.datasets.dataset import Dataset, KEY_SCORE, KEY_TYPE, KEY_INSTANCES
from lmflow.utils.constants import MEMORY_SAFE_DPOV2_ALIGN_ENV_VAR_TO_REMOVE
[docs]
logger = logging.getLogger(__name__)
[docs]
ReferenceModelArguments = create_copied_dataclass(
original_dataclass=ModelArguments,
field_prefix="reference_",
class_prefix="Reference"
)
[docs]
class DPOv2Aligner(BaseAligner):
def __init__(
self,
model_args: ModelArguments,
data_args: DatasetArguments,
aligner_args: DPOv2AlignerArguments,
ref_model_args: ModelArguments,
):
[docs]
self.model_args = model_args
[docs]
self.ref_model_args = ref_model_args
[docs]
self.data_args = data_args
[docs]
self.aligner_args = aligner_args
[docs]
def align(
self,
model: HFDecoderModel,
ref_model: HFDecoderModel,
train_dataset: Dataset,
eval_dataset: Dataset,
transform_dataset_in_place: bool=True,
):
if (train_dataset.get_type() not in ["text_to_scored_textlist", "paired_text_to_text"]) or \
(eval_dataset.get_type() not in ["text_to_scored_textlist", "paired_text_to_text"]):
raise ValueError(
f"Unsupported dataset type {train_dataset.get_type()} for DPOv2 aligner."
)
# step 0. setting up
if self.aligner_args.gradient_checkpointing:
logger.warning(
"Setting backend_model.config.use_cache to False since using gradient checkpointing"
)
model.get_backend_model().config.use_cache = False
ref_model.get_backend_model().config.use_cache = False
# step 1. prepare datasets
if train_dataset.get_type() == "text_to_scored_textlist":
train_dataset = self.convert_to_paired_dataset(
source_dataset=train_dataset,
sampling_paired_method=self.aligner_args.sampling_paired_method,
length_penalty=self.aligner_args.length_penalty,
margin_scale=self.aligner_args.margin_scale,
use_fast=False,
)
if self.data_args.max_train_samples:
train_dataset.backend_dataset = train_dataset.backend_dataset.select(range(self.data_args.max_train_samples))
if eval_dataset.get_type() == "text_to_scored_textlist":
eval_dataset = self.convert_to_paired_dataset(
source_dataset=eval_dataset,
sampling_paired_method=self.aligner_args.sampling_paired_method,
margin_scale=self.aligner_args.margin_scale,
use_fast=False,
)
# step 2. prepare trainer
dpo_trainer = DPOv2Trainer(
model.get_backend_model(),
ref_model.get_backend_model(),
train_dataset=train_dataset.get_backend_dataset(), # tokenization is done in the trainer
eval_dataset=eval_dataset.get_backend_dataset(),
tokenizer=model.tokenizer,
args=self.__prepare_training_args(self.aligner_args),
beta=self.aligner_args.beta,
loss_type=self.aligner_args.loss_type,
max_prompt_length=self.aligner_args.max_prompt_length,
max_length=self.aligner_args.max_length,
mask_prompt=self.aligner_args.mask_prompt,
len_penalty=self.aligner_args.length_penalty,
# preprocessing_num_workers=self.data_args.preprocessing_num_workers, # will trigger TypeError: cannot pickle 'torch._C._distributed_c10d.ProcessGroup' object
)
# step 3. train
dpo_trainer.train()
dpo_trainer.save_model(self.aligner_args.output_dir)
# step 4. save
output_dir = os.path.join(self.aligner_args.output_dir, "final_checkpoint")
dpo_trainer.model.save_pretrained(output_dir)
# step 5. release resources
with torch.no_grad():
torch.cuda.empty_cache()
[docs]
def __prepare_training_args(
self,
args: DPOv2AlignerArguments,
) -> TrainingArguments:
training_args = TrainingArguments(
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
num_train_epochs=args.num_train_epochs,
save_strategy=args.save_strategy,
logging_steps=args.logging_steps,
save_steps=args.save_steps,
gradient_accumulation_steps=args.gradient_accumulation_steps,
gradient_checkpointing=args.gradient_checkpointing,
learning_rate=args.learning_rate,
evaluation_strategy=args.evaluation_strategy,
eval_steps=args.eval_steps,
output_dir=args.output_dir,
lr_scheduler_type=args.lr_scheduler_type,
warmup_steps=args.warmup_steps,
optim=args.optim,
bf16=args.bf16,
report_to=args.report_to,
run_name=args.run_name,
remove_unused_columns=False, # DO NOT CHANGE THIS, may cause error https://discuss.huggingface.co/t/indexerror-invalid-key-16-is-out-of-bounds-for-size-0/14298/3
)
logger.debug(f"Actual training arguments for dpo trainer: {training_args}")
return training_args
[docs]
def convert_to_paired_dataset(
self,
source_dataset: Dataset,
sampling_paired_method: str="random",
length_penalty: float=0.0,
margin_scale: float=1.0,
use_fast: bool=False,
) -> Dataset:
"""Convert a scored one to multiple (text_to_scored_textlist) to a paired dataset by rejection sampling.
"""
output_dict = {
KEY_INSTANCES: []
}
if source_dataset.get_type() in ["text_to_scored_textlist"]:
output_dict[KEY_TYPE] = "paired_text_to_text"
for sample in tqdm(source_dataset.get_backend_dataset(), desc="Converting to paired dataset"):
sample_output_dict = {}
lengths = self._calc_response_lengths(sample["output"], source_dataset.get_type())
penalized_rewards = self._calc_reward_with_length_penalty(
rewards=[content[KEY_SCORE] for content in sample["output"]],
lengths=lengths,
length_penalty=length_penalty
)
chosen_idx, rejected_idx = self.sampling_paired_idx_from_rewards(
rewards=penalized_rewards,
sampling_paired_method=sampling_paired_method,
use_fast=use_fast
)
sample_output_dict["prompt"] = sample["input"]
sample_output_dict["chosen"] = sample["output"][chosen_idx]["text"]
sample_output_dict["rejected"] = sample["output"][rejected_idx]["text"]
sample_output_dict["margin"] = (sample["output"][chosen_idx][KEY_SCORE] - sample["output"][rejected_idx][KEY_SCORE]) * margin_scale
output_dict[KEY_INSTANCES].append(sample_output_dict)
output_dataset_args = copy.deepcopy(source_dataset.data_args)
output_dataset_args.dataset_path = None
output_dataset_args.dataset_name = f"paired_{output_dataset_args.dataset_name}"
output_dataset = Dataset(output_dataset_args)
output_dataset = output_dataset.from_dict(output_dict)
return output_dataset
[docs]
def _calc_response_lengths(
self,
outputs: List[Union[str, Dict[str, str]]],
dataset_type: str,
) -> List[int]:
all_lengths = []
if dataset_type == "text_to_scored_textlist":
all_lengths = [len(output["text"]) for output in outputs]
else:
raise NotImplementedError(
f"Unknown dataset type {dataset_type} when calculating the response length."
)
return all_lengths
[docs]
def _calc_reward_with_length_penalty(
self,
rewards: List[float],
lengths: List[int],
length_penalty: float,
) -> List[float]:
"""When length_penalty > 0, penalize the longer sequence by subtracting
length_penalty * length from the reward. Vice versa when length_penalty < 0.
"""
assert len(rewards) == len(lengths), "The number of rewards and lengths should be the same."
return [reward - length_penalty * length for reward, length in zip(rewards, lengths)]
[docs]
def sampling_paired_idx_from_rewards(
self,
rewards: List[float],
sampling_paired_method: str="random",
use_fast: bool=False,
) -> Tuple[int, int]:
"""Prepare the dataset for DPO training by rejection sampling.
We implement different strategies to select pairs, including
random: randomly select two instances
max_min: best v.s. worst
max_max: best v.s. second best
max_random: best v.s. random from the remaining
"""
if use_fast:
return self._sampling_paired_idx_from_rewards_fast(rewards, sampling_paired_method)
else:
return self._sampling_paired_idx_from_rewards(rewards, sampling_paired_method)
[docs]
def _sampling_paired_idx_from_rewards(
self,
rewards: List[float],
sampling_paired_method: str="random"
) -> Tuple[int, int]:
idx_0, idx_1 = -1, -1
if sampling_paired_method == "random":
idx_0, idx_1 = np.random.choice(len(rewards), size=2, replace=False)
elif sampling_paired_method == "max_min":
idx_0, idx_1 = np.argmax(rewards), np.argmin(rewards)
elif sampling_paired_method == "max_max":
sorted_indices = np.argsort(rewards)
idx_0, idx_1 = sorted_indices[-1], sorted_indices[-2]
elif sampling_paired_method == "max_random":
idx_0 = np.argmax(rewards)
idx_1 = np.random.choice([i for i in range(len(rewards)) if i != idx_0])
else:
raise ValueError(f"Unknown sampling method: {sampling_paired_method}")
chosen_idx, rejected_idx = (idx_0, idx_1) if rewards[idx_0] > rewards[idx_1] else (idx_1, idx_0)
return chosen_idx, rejected_idx
[docs]
def _sampling_paired_idx_from_rewards_fast(
self,
rewards: List[float],
sampling_paired_method: str="random"
) -> Tuple[int, int]:
idx_0, idx_1 = -1, -1
if sampling_paired_method == "random":
idx_0, idx_1 = 0, 1
elif sampling_paired_method == "max_min":
idx_0, idx_1 = np.argmax(rewards), np.argmin(rewards)
elif sampling_paired_method == "max_max":
sorted_indices = np.argsort(rewards)
idx_0, idx_1 = sorted_indices[-1], sorted_indices[-2]
elif sampling_paired_method == "max_random":
idx_0 = np.argmax(rewards)
idx_1 = 0 if idx_0 != 0 else 1
else:
raise ValueError(f"Unknown sampling method: {sampling_paired_method}")
chosen_idx, rejected_idx = (idx_0, idx_1) if rewards[idx_0] > rewards[idx_1] else (idx_1, idx_0)
return chosen_idx, rejected_idx
[docs]
class MemorySafeDPOv2Aligner:
def __init__(
self,
model_args: ModelArguments,
data_args: DatasetArguments,
aligner_args: DPOv2AlignerArguments,
ref_model_args: ModelArguments,
):
[docs]
self.model_args = model_args
[docs]
self.ref_model_args = ReferenceModelArguments(**add_dataclass_attr_prefix(ref_model_args, 'reference_'))
[docs]
self.data_args = data_args
[docs]
self.aligner_args = aligner_args
[docs]
self.aligner_file_path = pkg_resources.files("lmflow.pipeline.utils") / "memory_safe_dpov2_align.py"
[docs]
def align(self):
aligner_args = make_shell_args_from_dataclass(
dataclass_objects=[
self.model_args,
self.data_args,
self.aligner_args,
self.ref_model_args
],
format="shell",
ignored_args_list=['accelerator_config', 'fsdp_config', '_n_gpu'],
)
cmd = (
f"accelerate launch --config_file {self.aligner_args.accelerate_config_file}"
+ " "
+ str(self.aligner_file_path)
+ " "
+ aligner_args
)
current_env = os.environ.copy()
for var in MEMORY_SAFE_DPOV2_ALIGN_ENV_VAR_TO_REMOVE:
current_env.pop(var, None)
cli_res = subprocess.run(
args=cmd,
stdout=sys.stdout,
stderr=sys.stdout,
shell=True,
preexec_fn=os.setsid,
env=current_env,
)
logger.info(f"MemorySafeDPOv2Aligner subprocess run finished, info at finish: {cli_res}")
if cli_res.returncode != 0:
print(cli_res.stderr)
raise RuntimeError(f"Error during MemorySafeDPOv2Aligner: {cli_res}")