Source code for lmflow.pipeline.base_aligner

#!/usr/bin/env python
# coding=utf-8
""" BaseTuner: a subclass of BasePipeline.
"""

from lmflow.pipeline.base_pipeline import BasePipeline


[docs] class BaseAligner(BasePipeline): """ A subclass of BasePipeline which is alignable. """ def __init__(self, *args, **kwargs): pass
[docs] def _check_if_alignable(self, model, dataset, reward_model): # TODO: check if the model is alignable and dataset is compatible # TODO: add reward_model pass
[docs] def align(self, model, dataset, reward_model): raise NotImplementedError(".align is not implemented")