Source code for lmflow.models.text_regression_model

#!/usr/bin/env python
# coding=utf-8
"""
A model maps "text_only" data to float.
"""

from lmflow.models.regression_model import RegressionModel
from lmflow.datasets.dataset import Dataset


[docs] class TextRegressionModel(RegressionModel): r""" Initializes a TextRegressionModel instance. Parameters ------------ model_args : Model arguments such as model name, path, revision, etc. args : Optional. Positional arguments. kwargs : Optional. Keyword arguments. """ def __init__( self, model_args, *args, **kwargs ): """ Initializes a TextRegressionModel instance. :param model_args: dictionary with model arguments such as model name, path, revision, etc. """
[docs] self.inference_func = None
[docs] def register_inference_function(self, inference_func): """ Registers a regression function. """ self.inference_func = inference_func
[docs] def inference(self, inputs: Dataset): """ Gets regression results of a given dataset. :inputs: Dataset object, only accept type "text_only". """ if self.inference_func is not None: return self.inference_func(inputs) else: pass