Source code for lmflow.pipeline.utils.memory_safe_dpov2_align

#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved.
import logging
import os
import sys
import copy

from transformers import (
    HfArgumentParser
)

from lmflow.datasets import Dataset
from lmflow.models.hf_decoder_model import HFDecoderModel
from lmflow.pipeline.dpov2_aligner import DPOv2Aligner
from lmflow.args import (
    ModelArguments, 
    DatasetArguments, 
    DPOv2AlignerArguments,
)
from lmflow.utils.common import remove_dataclass_attr_prefix, create_copied_dataclass


[docs] logger = logging.getLogger(__name__)
[docs] ReferenceModelArguments: ModelArguments = create_copied_dataclass( original_dataclass=ModelArguments, field_prefix="reference_", class_prefix="Reference" )
[docs] def main(): # Parses arguments parser = HfArgumentParser(( ModelArguments, ReferenceModelArguments, DatasetArguments, DPOv2AlignerArguments, )) target_model_args, ref_model_args, data_args, aligner_args = parser.parse_args_into_dataclasses() ref_model_args_dict = remove_dataclass_attr_prefix(ref_model_args, "reference_") ref_model_args = ModelArguments(**ref_model_args_dict) target_model = HFDecoderModel(target_model_args) ref_model = HFDecoderModel(ref_model_args) train_dataset = Dataset(data_args) eval_dataset = copy.deepcopy(train_dataset.sample( n=100, seed=aligner_args.random_seed )) aligner = DPOv2Aligner( model_args=target_model_args, data_args=train_dataset.data_args, aligner_args=aligner_args, ref_model_args=ref_model.model_args, ) aligner.align( model=target_model, ref_model=ref_model, train_dataset=train_dataset, eval_dataset=eval_dataset, )
if __name__ == "__main__": main()