Source code for lmflow.models.text_regression_model
#!/usr/bin/env python
# coding=utf-8
"""
A model maps "text_only" data to float.
"""
from lmflow.models.regression_model import RegressionModel
from lmflow.datasets.dataset import Dataset
[docs]
class TextRegressionModel(RegressionModel):
r"""
Initializes a TextRegressionModel instance.
Parameters
------------
model_args :
Model arguments such as model name, path, revision, etc.
args : Optional.
Positional arguments.
kwargs : Optional.
Keyword arguments.
"""
def __init__(
self,
model_args,
*args,
**kwargs
):
"""
Initializes a TextRegressionModel instance.
:param model_args: dictionary with model arguments such as model name, path, revision, etc.
"""
[docs]
self.inference_func = None
[docs]
def register_inference_function(self, inference_func):
"""
Registers a regression function.
"""
self.inference_func = inference_func
[docs]
def inference(self, inputs: Dataset):
"""
Gets regression results of a given dataset.
:inputs: Dataset object, only accept type "text_only".
"""
if self.inference_func is not None:
return self.inference_func(inputs)
else:
pass