Source code for lmflow.pipeline.base_aligner
#!/usr/bin/env python
"""BaseTuner: a subclass of BasePipeline."""
from abc import abstractmethod
from lmflow.pipeline.base_pipeline import BasePipeline
[docs]
class BaseAligner(BasePipeline):
"""A subclass of BasePipeline which is alignable."""
def __init__(self, *args, **kwargs):
pass
[docs]
def _check_if_alignable(self, model, dataset, reward_model):
# TODO: check if the model is alignable and dataset is compatible
# TODO: add reward_model
pass
@abstractmethod
[docs]
def align(self, model, dataset, reward_model):
raise NotImplementedError(".align is not implemented")