123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305 |
- """Minimal implementation of CLIPVisionModel intended to be only used
- within a vision language model."""
- from array import array
- from typing import Optional
- import torch
- import torch.nn as nn
- from PIL import Image
- from transformers import CLIPVisionConfig
- from transformers.models.clip.modeling_clip import CLIPAttention
- from aphrodite.common.config import ModelConfig
- from aphrodite.common.sequence import SequenceData
- from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
- from aphrodite.inputs import LLMInputs
- from aphrodite.modeling.layers.activation import get_act_fn
- from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
- RowParallelLinear)
- from aphrodite.multimodal.utils import (cached_get_tokenizer,
- repeat_and_pad_placeholder_tokens)
- from aphrodite.quantization import QuantizationConfig
- def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
- assert image_size % patch_size == 0
- return image_size // patch_size
- def get_clip_num_patches(*, image_size: int, patch_size: int) -> int:
- grid_length = get_clip_patch_grid_length(image_size=image_size,
- patch_size=patch_size)
- return grid_length * grid_length
- def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
- return get_clip_num_patches(image_size=hf_config.image_size,
- patch_size=hf_config.patch_size)
- def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
- return get_clip_image_feature_size(hf_config)
- def dummy_seq_data_for_clip(
- hf_config: CLIPVisionConfig,
- seq_len: int,
- num_images: int,
- *,
- image_token_id: int,
- image_feature_size_override: Optional[int] = None,
- ):
- if image_feature_size_override is None:
- image_feature_size = get_clip_image_feature_size(hf_config)
- else:
- image_feature_size = image_feature_size_override
- token_ids = array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
- [image_token_id]) * image_feature_size * num_images
- token_ids += array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
- [0]) * (seq_len - image_feature_size * num_images)
- return SequenceData(token_ids)
- def dummy_image_for_clip(
- hf_config: CLIPVisionConfig,
- num_images: int,
- *,
- image_width_override: Optional[int] = None,
- image_height_override: Optional[int] = None,
- ):
- width = height = hf_config.image_size
- if image_width_override is not None:
- width = image_width_override
- if image_height_override is not None:
- height = image_height_override
- image = Image.new("RGB", (width, height), color=0)
- return {"image": image if num_images == 1 else [image] * num_images}
- def input_processor_for_clip(
- model_config: ModelConfig,
- hf_config: CLIPVisionConfig,
- llm_inputs: LLMInputs,
- *,
- image_token_id: int,
- image_feature_size_override: Optional[int] = None,
- ):
- multi_modal_data = llm_inputs.get("multi_modal_data")
- if multi_modal_data is None or "image" not in multi_modal_data:
- return llm_inputs
- tokenizer = cached_get_tokenizer(model_config.tokenizer)
- if image_feature_size_override is None:
- image_data = multi_modal_data["image"]
- if isinstance(image_data, Image.Image):
- image_feature_size = get_clip_image_feature_size(hf_config)
- elif isinstance(image_data, torch.Tensor):
- image_feature_size = image_data.shape[0]
- else:
- raise TypeError(f"Invalid image type: {type(image_data)}")
- else:
- image_feature_size = image_feature_size_override
- new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
- tokenizer,
- llm_inputs.get("prompt"),
- llm_inputs["prompt_token_ids"],
- placeholder_token_id=image_token_id,
- repeat_count=image_feature_size,
- )
- # NOTE: Create a defensive copy of the original inputs
- return LLMInputs(prompt_token_ids=new_token_ids,
- prompt=new_prompt,
- multi_modal_data=multi_modal_data)
- # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
- class CLIPVisionEmbeddings(nn.Module):
- def __init__(self, config: CLIPVisionConfig):
- super().__init__()
- self.config = config
- self.embed_dim = config.hidden_size
- self.image_size = config.image_size
- self.patch_size = config.patch_size
- self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
- self.patch_embedding = nn.Conv2d(
- in_channels=config.num_channels,
- out_channels=self.embed_dim,
- kernel_size=self.patch_size,
- stride=self.patch_size,
- bias=False,
- )
- self.num_patches = get_clip_num_patches(image_size=self.image_size,
- patch_size=self.patch_size)
- self.num_positions = self.num_patches + 1
- self.position_embedding = nn.Embedding(self.num_positions,
- self.embed_dim)
- self.register_buffer("position_ids",
- torch.arange(self.num_positions).expand((1, -1)),
- persistent=False)
- def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
- batch_size = pixel_values.shape[0]
- target_dtype = self.patch_embedding.weight.dtype
- patch_embeds = self.patch_embedding(pixel_values.to(
- dtype=target_dtype)) # shape = [*, width, grid, grid]
- patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
- class_embeds = self.class_embedding.expand(batch_size, 1, -1)
- embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
- embeddings = embeddings + self.position_embedding(self.position_ids)
- return embeddings
- class CLIPMLP(nn.Module):
- def __init__(self,
- config: CLIPVisionConfig,
- quant_config: Optional[QuantizationConfig] = None):
- super().__init__()
- self.config = config
- self.activation_fn = get_act_fn(config.hidden_act)
- self.fc1 = ColumnParallelLinear(config.hidden_size,
- config.intermediate_size,
- bias=True,
- quant_config=quant_config)
- self.fc2 = RowParallelLinear(config.intermediate_size,
- config.hidden_size,
- bias=True,
- quant_config=quant_config)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states, _ = self.fc1(hidden_states)
- hidden_states = self.activation_fn(hidden_states)
- hidden_states, _ = self.fc2(hidden_states)
- return hidden_states
- class CLIPEncoderLayer(nn.Module):
- def __init__(self,
- config: CLIPVisionConfig,
- quant_config: Optional[QuantizationConfig] = None):
- super().__init__()
- self.self_attn = CLIPAttention(config)
- self.layer_norm1 = nn.LayerNorm(config.hidden_size,
- eps=config.layer_norm_eps)
- self.mlp = CLIPMLP(config, quant_config=quant_config)
- self.layer_norm2 = nn.LayerNorm(config.hidden_size,
- eps=config.layer_norm_eps)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- residual = hidden_states
- hidden_states = self.layer_norm1(hidden_states)
- hidden_states, _ = self.self_attn(hidden_states=hidden_states)
- hidden_states = residual + hidden_states
- residual = hidden_states
- hidden_states = self.layer_norm2(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
- return hidden_states
- class CLIPEncoder(nn.Module):
- """
- Transformer encoder consisting of `config.num_hidden_layers` self
- attention layers. Each layer is a [`CLIPEncoderLayer`].
- Args:
- config: CLIPConfig
- """
- def __init__(self,
- config: CLIPVisionConfig,
- quant_config: Optional[QuantizationConfig] = None,
- num_hidden_layers_override: Optional[int] = None):
- super().__init__()
- self.config = config
- if num_hidden_layers_override is None:
- num_hidden_layers = config.num_hidden_layers
- else:
- num_hidden_layers = num_hidden_layers_override
- self.layers = nn.ModuleList([
- CLIPEncoderLayer(config=config, quant_config=quant_config)
- for _ in range(num_hidden_layers)
- ])
- def forward(self, inputs_embeds: torch.Tensor):
- hidden_states = inputs_embeds
- for encoder_layer in self.layers:
- hidden_states = encoder_layer(hidden_states)
- return hidden_states
- class CLIPVisionTransformer(nn.Module):
- def __init__(self,
- config: CLIPVisionConfig,
- quant_config: Optional[QuantizationConfig] = None,
- num_hidden_layers_override: Optional[int] = None):
- super().__init__()
- self.config = config
- embed_dim = config.hidden_size
- self.embeddings = CLIPVisionEmbeddings(config)
- # NOTE: This typo of "layrnorm" is not fixed on purpose to match
- # the original transformers code and name of the model weights.
- self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
- self.encoder = CLIPEncoder(
- config=config,
- quant_config=quant_config,
- num_hidden_layers_override=num_hidden_layers_override)
- def forward(
- self,
- pixel_values: torch.Tensor,
- ) -> torch.Tensor:
- hidden_states = self.embeddings(pixel_values)
- hidden_states = self.pre_layrnorm(hidden_states)
- hidden_states = self.encoder(inputs_embeds=hidden_states)
- return hidden_states
- class CLIPVisionModel(nn.Module):
- config_class = CLIPVisionConfig
- main_input_name = "pixel_values"
- def __init__(self,
- config: CLIPVisionConfig,
- quant_config: Optional[QuantizationConfig] = None,
- num_hidden_layers_override: Optional[int] = None):
- super().__init__()
- self.vision_model = CLIPVisionTransformer(
- config=config,
- quant_config=quant_config,
- num_hidden_layers_override=num_hidden_layers_override)
- def forward(self, pixel_values: Optional[torch.Tensor] = None):
- return self.vision_model(pixel_values=pixel_values)
- @property
- def device(self):
- return next(self.parameters()).device
|