#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 7/4/2024 21:12
# @Author : Yu Li
# @Site :
# @File : dpo_pipeline.py
import os
from pathlib import Path
from typing import Dict, Optional
from datasets import Dataset, load_dataset
from peft import LoraConfig
from transformers import TrainingArguments
from lmflow.pipeline.base_aligner import BaseAligner
from lmflow.utils.versioning import is_trl_available
if is_trl_available():
from trl import DPOTrainer
else:
raise ImportError("Please install trl package to use dpo_aligner.py")
[docs]
def get_paired_dataset(
data_root: str,
data_dir: str,
sanity_check: bool = False,
cache_dir: Optional[str] = None,
num_proc=24,
) -> Dataset:
"""Load dataset and convert it to the necessary format.
The dataset is converted to a dictionary with the following structure:
{
'prompt': List[str],
'chosen': List[str],
'rejected': List[str],
}
Prompts are structured as follows:
"Question: " + <prompt> + "\n\nAnswer: "
"""
data_path = Path(data_root) / data_dir
data_files = [
x.absolute().as_posix()
for x in data_path.glob("*.json")
]
dataset = load_dataset(
path=data_root,
split="train",
data_files=data_files,
cache_dir=cache_dir,
)
original_columns = dataset.column_names
if sanity_check:
dataset = dataset.select(range(min(len(dataset), 1000)))
def return_prompt_and_responses(samples) -> Dict[str, str]:
return {
"prompt": ["Question: " + question + "\n\nAnswer: " for question in samples["question"]],
"chosen": samples["response_j"],
"rejected": samples["response_k"],
}
return dataset.map(
return_prompt_and_responses,
batched=True,
num_proc=num_proc,
remove_columns=original_columns,
)
[docs]
class DPOAligner(BaseAligner):
def __init__(self, model_args, data_args, aligner_args):
[docs]
self.model_args = model_args
[docs]
self.data_args = data_args
[docs]
self.aligner_args = aligner_args
[docs]
self.train_dataset = None
[docs]
self.eval_dataset = None
[docs]
def _initialize_trainer(self, model, tokenizer):
peft_config = LoraConfig(
r=self.model_args.lora_r,
lora_alpha=self.model_args.lora_alpha,
lora_dropout=self.model_args.lora_dropout,
target_modules=[
"q_proj",
"v_proj",
"k_proj",
"out_proj",
"fc_in",
"fc_out",
"wte",
],
bias="none",
task_type="CAUSAL_LM",
)
training_args = TrainingArguments(
per_device_train_batch_size=self.aligner_args.per_device_train_batch_size,
per_device_eval_batch_size=self.aligner_args.per_device_eval_batch_size,
max_steps=self.aligner_args.max_steps,
logging_steps=self.aligner_args.logging_steps,
save_steps=self.aligner_args.save_steps,
gradient_accumulation_steps=self.aligner_args.gradient_accumulation_steps,
gradient_checkpointing=self.aligner_args.gradient_checkpointing,
learning_rate=self.aligner_args.learning_rate,
evaluation_strategy="steps",
eval_steps=self.aligner_args.eval_steps,
output_dir=self.aligner_args.output_dir,
report_to=self.aligner_args.report_to,
lr_scheduler_type=self.aligner_args.lr_scheduler_type,
warmup_steps=self.aligner_args.warmup_steps,
optim=self.aligner_args.optimizer_type,
bf16=True,
remove_unused_columns=False,
run_name=self.aligner_args.run_name,
ddp_find_unused_parameters=False,
# gradient_checkpointing_kwargs=dict(use_reentrant=self.aligner_args.gradient_checkpointing_use_reentrant),
seed=self.aligner_args.seed,
)
dpo_trainer = DPOTrainer(
model,
ref_model=None,
args=training_args,
beta=self.aligner_args.beta,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset if self.eval_dataset else None,
tokenizer=tokenizer,
peft_config=peft_config,
max_prompt_length=self.aligner_args.beta,
max_length=self.aligner_args.max_length,
)
return dpo_trainer
[docs]
def _load_dataset(self):
# load training set
self.train_dataset = get_paired_dataset(data_root=self.data_args.dataset_path,
data_dir="train",
sanity_check=self.aligner_args.sanity_check)
self.train_dataset = self.train_dataset.filter(
lambda x: len(x["prompt"]) + len(x["chosen"]) <= self.aligner_args.max_length
and len(x["prompt"]) + len(x["rejected"]) <= self.aligner_args.max_length
)
# load evaluation set
if self.aligner_args.eval_dataset_path:
self.eval_dataset = get_paired_dataset(data_root=self.aligner_args.eval_dataset_path,
data_dir="test",
sanity_check=True)
self.eval_dataset = self.eval_dataset.filter(
lambda x: len(x["prompt"]) + len(x["chosen"]) <= self.aligner_args.max_length
and len(x["prompt"]) + len(x["rejected"]) <= self.aligner_args.max_length
)
[docs]
def align(self, model, dataset, reward_model):
tokenizer = model.get_tokenizer()
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
self._load_dataset()
wrapped_model = model
model = model.get_backend_model()
dpo_trainer = self._initialize_trainer(model, tokenizer)
dpo_trainer.train()
dpo_trainer.save_model(self.aligner_args.output_dir)
# 7. save
output_dir = os.path.join(self.aligner_args.output_dir, "final_checkpoint")
dpo_trainer.model.save_pretrained(output_dir)