Source code for lmflow.models.vision_encoder.clip_encoder
import torch
import torch.nn as nn
from transformers import CLIPImageProcessor, CLIPVisionConfig, CLIPVisionModel
from lmflow.utils.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX
[docs]
def build_vision_tower(vision_tower_cfg, **kwargs):
vision_tower = getattr(vision_tower_cfg, "image_encoder_name_or_path", "openai/clip-vit-large-patch14")
if vision_tower.startswith("openai") or vision_tower.startswith("laion"):
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
raise ValueError(f"Unknown vision tower: {vision_tower}")
# FIXME check if can directly use the BlipVisionEncoder
[docs]
class CLIPVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()
if not delay_load:
self.load_model()
else:
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
[docs]
def load_model(self):
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
self.vision_tower.requires_grad_(False)
self.is_loaded = True
[docs]
def encode_images(self, images, language_projection):
image_features = self(images)
# FIXME the language projection is register in the CustomAutoVision2SeqModel
# check how to move this code to improve the readability
if language_projection is not None:
image_features = language_projection(image_features)
return image_features
[docs]
def feature_select(self, image_forward_outs):
image_features = image_forward_outs.hidden_states[self.select_layer]
if self.select_feature == "patch":
image_features = image_features[:, 1:]
elif self.select_feature == "cls_patch":
image_features = image_features
else:
raise ValueError(f"Unexpected select feature: {self.select_feature}")
return image_features
@torch.no_grad()
[docs]
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(
image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True
)
image_feature = self.feature_select(image_forward_out).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(
images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
return image_features
@property
[docs]
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
@property
@property
[docs]
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only
@property
@property
[docs]
def prepare_inputs_labels_for_multimodal(
self,
input_ids,
attention_mask,
past_key_values,
labels,
images,
language_projection=None,
language_model=None,
**kwargs,
):
"""
Copy from the LLAVA code base.
Should be polished.
"""
vision_tower = self.vision_tower
# commonly used in model.generate (past_key_values is not None)
# to avoid forward the image multiple time
if vision_tower is None or images is None or input_ids.shape[1] == 1:
if (
past_key_values is not None
and vision_tower is not None
and images is not None
and input_ids.shape[1] == 1
):
attention_mask = torch.ones(
(attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
return input_ids, attention_mask, past_key_values, None, labels
if type(images) is list or images.ndim == 5:
concat_images = torch.cat([image for image in images], dim=0)
image_features = self.encode_images(concat_images, language_projection)
split_sizes = [image.shape[0] for image in images]
image_features = torch.split(image_features, split_sizes, dim=0)
image_features = [x.flatten(0, 1) for x in image_features]
else:
image_features = self.encode_images(images, language_projection)
new_input_embeds = []
new_labels = [] if labels is not None else None
cur_image_idx = 0
for batch_idx, cur_input_ids in enumerate(input_ids):
if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
# multimodal LLM, but the current sample is not multimodal
cur_input_embeds = language_model.embed_tokens(cur_input_ids)
cur_input_embeds = cur_input_embeds + (0.0 * language_projection(vision_tower.dummy_feature)).sum()
new_input_embeds.append(cur_input_embeds)
if labels is not None:
new_labels.append(labels[batch_idx])
cur_image_idx += 1
continue
image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
cur_new_input_embeds = []
if labels is not None:
cur_labels = labels[batch_idx]
cur_new_labels = []
assert cur_labels.shape == cur_input_ids.shape
while image_token_indices.numel() > 0:
cur_image_features = image_features[cur_image_idx]
image_token_start = image_token_indices[0]
# print("image token_start", image_token_start,
# "curr_input_ids", cur_input_ids.shape)
if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(
self.config, "mm_use_im_start_end", False
):
cur_new_input_embeds.append(
language_model.embed_tokens(cur_input_ids[: image_token_start - 1]).detach()
)
cur_new_input_embeds.append(
language_model.embed_tokens(cur_input_ids[image_token_start - 1 : image_token_start])
)
cur_new_input_embeds.append(cur_image_features)
cur_new_input_embeds.append(
language_model.embed_tokens(cur_input_ids[image_token_start + 1 : image_token_start + 2])
)
if labels is not None:
cur_new_labels.append(cur_labels[:image_token_start])
cur_new_labels.append(
torch.full(
(cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype
)
)
cur_new_labels.append(cur_labels[image_token_start : image_token_start + 1])
cur_labels = cur_labels[image_token_start + 2 :]
else:
cur_input_ids = cur_input_ids.to(device=language_model.device)
cur_new_input_embeds.append(language_model.embed_tokens(cur_input_ids[:image_token_start]))
cur_new_input_embeds.append(cur_image_features)
if labels is not None:
cur_new_labels.append(cur_labels[:image_token_start])
cur_new_labels.append(
torch.full(
(cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype
)
)
cur_labels = cur_labels[image_token_start + 1 :]
cur_image_idx += 1
if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(
self.config, "mm_use_im_start_end", False
):
cur_input_ids = cur_input_ids[image_token_start + 2 :]
else:
cur_input_ids = cur_input_ids[image_token_start + 1 :]
image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
if cur_input_ids.numel() > 0:
if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(
self.config, "mm_use_im_start_end", False
):
cur_new_input_embeds.append(language_model.embed_tokens(cur_input_ids).detach())
else:
cur_new_input_embeds.append(language_model.embed_tokens(cur_input_ids))
if labels is not None:
cur_new_labels.append(cur_labels)
cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
new_input_embeds.append(cur_new_input_embeds)
if labels is not None:
cur_new_labels = torch.cat(cur_new_labels, dim=0)
new_labels.append(cur_new_labels)
if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
max_len = max(x.shape[0] for x in new_input_embeds)
new_input_embeds_align = []
for cur_new_embed in new_input_embeds:
cur_new_embed = torch.cat(
(
cur_new_embed,
torch.zeros(
(max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype,
device=cur_new_embed.device,
),
),
dim=0,
)
new_input_embeds_align.append(cur_new_embed)
new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
if labels is not None:
new_labels_align = []
_new_labels = new_labels
for cur_new_label in new_labels:
cur_new_label = torch.cat(
(
cur_new_label,
torch.full(
(max_len - cur_new_label.shape[0],),
IGNORE_INDEX,
dtype=cur_new_label.dtype,
device=cur_new_label.device,
),
),
dim=0,
)
new_labels_align.append(cur_new_label)
new_labels = torch.stack(new_labels_align, dim=0)
if attention_mask is not None:
new_attention_mask = []
for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(
attention_mask, _new_labels, new_labels
):
new_attn_mask_pad_left = torch.full(
(cur_new_labels.shape[0] - labels.shape[1],),
True,
dtype=attention_mask.dtype,
device=attention_mask.device,
)
new_attn_mask_pad_right = torch.full(
(cur_new_labels_align.shape[0] - cur_new_labels.shape[0],),
False,
dtype=attention_mask.dtype,
device=attention_mask.device,
)
cur_new_attention_mask = torch.cat(
(new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0
)
new_attention_mask.append(cur_new_attention_mask)
attention_mask = torch.stack(new_attention_mask, dim=0)
assert attention_mask.shape == new_labels.shape
else:
new_input_embeds = torch.stack(new_input_embeds, dim=0)
if labels is not None:
new_labels = torch.stack(new_labels, dim=0)
if attention_mask is not None:
new_attn_mask_pad_left = torch.full(
(attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]),
True,
dtype=attention_mask.dtype,
device=attention_mask.device,
)
attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
assert attention_mask.shape == new_input_embeds.shape[:2]
return None, attention_mask, past_key_values, new_input_embeds, new_labels