Source code for lmflow.models.auto_model

#!/usr/bin/env python
# coding=utf-8
"""Automatically get correct model type.
"""

from lmflow.models.hf_decoder_model import HFDecoderModel
from lmflow.models.hf_text_regression_model import HFTextRegressionModel
from lmflow.models.hf_encoder_decoder_model import HFEncoderDecoderModel

[docs] class AutoModel: @classmethod
[docs] def get_model(self, model_args, *args, **kwargs): arch_type = model_args.arch_type if arch_type == "decoder_only": return HFDecoderModel(model_args, *args, **kwargs) elif arch_type == "text_regression": return HFTextRegressionModel(model_args, *args, **kwargs) elif arch_type == "encoder_decoder" or \ arch_type == "vision_encoder_decoder": return HFEncoderDecoderModel(model_args, *args, **kwargs) else: raise NotImplementedError( f"model architecture type \"{arch_type}\" is not supported" )