Source code for lmflow.utils.multimodal

import glob
import torch
from transformers import LlamaConfig
from tqdm import tqdm


[docs] def update_custom_config(config, model_args): if model_args.llm_model_name_or_path is not None: text_config = LlamaConfig.from_pretrained( model_args.llm_model_name_or_path) config.text_config = text_config config.with_qformer = model_args.with_qformer config.custom_vision_model = model_args.custom_vision_model if model_args.custom_vision_model: # config.vision_model_args = model_args config.image_encoder_name_or_path = \ model_args.image_encoder_name_or_path config.vision_select_layer = model_args.vision_select_layer if getattr(model_args, "vision_select_feature", None) is not None: config.vision_select_feature = model_args.vision_select_feature return config
[docs] def load_llava_pretrain_model(model, checkpoint_path): checkpoint_path = glob.glob(checkpoint_path) for path in tqdm(checkpoint_path): state_dict = torch.load(path, map_location="cpu") new_state_dict = adapt_llava_model_to_lmflow_type(state_dict) # modify the name of the key # import pdb; pdb.set_trace() lmflow_keys = model.state_dict().keys() for key in new_state_dict.keys(): if key not in lmflow_keys: print("key not in lmflow_keys: ", key) model.load_state_dict(new_state_dict, strict=False) return model
[docs] def adapt_llava_model_to_lmflow_type(state_dict): new_state_dict = {} for key, item in state_dict.items(): key = key.replace("model.layers", "language_model.model.layers") key = key.replace("model.embed_tokens", "language_model.model.embed_tokens") key = key.replace("model.mm_projector", "language_projection") key = key.replace("lm_head", "language_model.lm_head") key = key.replace("model.norm", "language_model.model.norm") if "vision_tower" in key: continue new_state_dict[key] = item return new_state_dict