Source code for lmflow.utils.model
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved.
import logging
from typing import Dict, Any, List, Tuple, Union
from transformers import AutoTokenizer
from lmflow.args import ModelArguments
[docs]
logger = logging.getLogger(__name__)
[docs]
def check_homogeneity(model_args_list: List[ModelArguments]) -> bool:
assert all(isinstance(model_args, ModelArguments) for model_args in model_args_list), \
"model_args_list should be a list of ModelArguments objects."
assert len(model_args_list) > 1, "model_args_list should have at least two elements."
tokenizer_names = []
for model_args in model_args_list:
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, use_fast=False)
tokenizer_names.append(tokenizer.__class__.__name__)
return len(set(tokenizer_names)) == 1