from dataclasses import dataclass, fields from itertools import tee from typing import Iterable, List, Mapping, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from mistral_common.protocol.instruct.messages import ImageChunk from PIL import Image from transformers import PretrainedConfig from xformers.ops.fmha import memory_efficient_attention from xformers.ops.fmha.attn_bias import BlockDiagonalMask from aphrodite.attention import AttentionMetadata from aphrodite.common.config import CacheConfig, MultiModalConfig from aphrodite.common.sequence import IntermediateTensors, SequenceData from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs from aphrodite.modeling.layers.layernorm import RMSNorm from aphrodite.modeling.layers.sampler import SamplerOutput from aphrodite.modeling.model_loader.weight_utils import default_weight_loader from aphrodite.modeling.models.utils import merge_multimodal_embeddings from aphrodite.modeling.sampling_metadata import SamplingMetadata from aphrodite.multimodal import MULTIMODAL_REGISTRY from aphrodite.multimodal.base import MultiModalInputs from aphrodite.multimodal.utils import cached_get_tokenizer from aphrodite.quantization import QuantizationConfig from .interfaces import SupportsMultiModal from .utils import init_aphrodite_registered_model def get_max_pixtral_image_tokens(ctx: InputContext): tokenizer = cached_get_tokenizer( ctx.model_config.tokenizer, tokenizer_mode=ctx.model_config.tokenizer_mode, ) mm_encoder = tokenizer.instruct.mm_encoder max_image_size = mm_encoder.mm_config.max_image_size image_patch_size = mm_encoder.mm_config.image_patch_size return (max_image_size // image_patch_size) ** 2 def dummy_data_for_pixtral( ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int] ): tokenizer = cached_get_tokenizer( ctx.model_config.tokenizer, tokenizer_mode=ctx.model_config.tokenizer_mode) mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder patch_size = mm_encoder.mm_config.image_patch_size image_token_id = mm_encoder.special_ids.img mm_config = ctx.model_config.multimodal_config num_images = mm_config.limit_per_prompt.get("image", 1) # dummy size size = 256 image = Image.new("RGB", (size, size), color=0) image_feature_size = (size**2) // (patch_size**2) num_image_tokens = image_feature_size * num_images seq_data = SequenceData.from_token_counts( (image_token_id, num_image_tokens), (0, seq_len - num_image_tokens), ) mm_data = {"image": num_images * [image]} return seq_data, mm_data def input_mapper_for_pixtral( ctx: InputContext, data: object ) -> MultiModalInputs: """Maps the input data to its MultiModalInputs (if any). Args: ctx: Context of the loaded model. data: data potentially containing image/image embeddings to be mapped to pixel_values in .forward() for a visual QWenLMHeadModel model. Returns: MultiModalInputs containing the stacked normalized images tensor or image embeddings. """ # Early exit if we have provided an image to a language only Qwen model model_config = ctx.model_config tokenizer = cached_get_tokenizer( model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode ) data_list = data if isinstance(data, list) else [data] images = [] for image_data in data_list: image = ImageChunk(image=image_data) encoding = tokenizer.instruct.mm_encoder(image) image = torch.from_numpy(encoding.image).to( device="cuda", dtype=torch.float16 ) images.append(image) return MultiModalInputs({"images": images}) def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is not None and "image" in multi_modal_data: tokenizer = cached_get_tokenizer( ctx.model_config.tokenizer, tokenizer_mode=ctx.model_config.tokenizer_mode) mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder image_token_id = mm_encoder.special_ids.img if image_token_id not in llm_inputs['prompt_token_ids']: raise ValueError( (f"You've passed {llm_inputs=} without {image_token_id=}" " Make sure to process your input via mistral_common's" " tokenizer or pass a chat completion request.")) return llm_inputs @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral) @INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral) class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal): def __init__( self, config: PretrainedConfig, multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.multimodal_config = multimodal_config dataclass_fields = {field.name for field in fields(VisionEncoderArgs)} vision_args = { key: value for key, value in self.config.vision_config.to_dict().items() if key in dataclass_fields } self.vision_args = VisionEncoderArgs(**vision_args) # init MistralForCausalLM self.language_model = init_aphrodite_registered_model( config.text_config, cache_config, quant_config ) self.vision_encoder = VisionTransformer(self.vision_args) self.vision_language_adapter = VisionLanguageAdapter( self.vision_args, dim=config.text_config.hidden_size ) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, ) -> SamplerOutput: """Run forward pass for pixtral. TODO """ image_input = self._parse_and_validate_image_input(**kwargs) if image_input is not None: vision_embeddings = self._process_image_input(image_input) inputs_embeds = self.language_model.model.get_input_embeddings( input_ids ) inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, vision_embeddings, self.vision_args.image_token_id, ) input_ids = None else: inputs_embeds = None hidden_states = self.language_model.model( input_ids, positions, kv_caches, attn_metadata, None, inputs_embeds=inputs_embeds, ) return hidden_states def _parse_and_validate_image_input( self, images: Optional[ Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor] ] = None, ) -> Optional[List[torch.Tensor]]: if images is None: return None if isinstance(images, torch.Tensor): # if passed as batch take all images N, B, C, W, H = images.shape images = images.reshape(N * B, C, W, H) images = [images[i] for i in range(images.size(0))] elif isinstance(images, list): # if passed as list flatten lists of tensors flatten_images = [] for imgs_per_req in images: imgs_per_req = [ imgs_per_req[i] for i in range(imgs_per_req.size(0)) ] if isinstance(imgs_per_req, torch.Tensor) else imgs_per_req flatten_images.extend(imgs_per_req) images = flatten_images return images def _process_image_input( self, image_input: List[torch.Tensor] ) -> torch.Tensor: return self.vision_language_adapter(self.vision_encoder(image_input)) def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: return self.language_model.compute_logits( hidden_states, sampling_metadata ) def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def is_vision_encoder_weights(weight: Tuple[str, torch.Tensor]): return weight[0].startswith("vision_encoder") def is_vision_lang_adapter_weights(weight: Tuple[str, torch.Tensor]): return weight[0].startswith("vision_language_adapter") def is_vision_weights(weight: Tuple[str, torch.Tensor]): return is_vision_encoder_weights( weight ) or is_vision_lang_adapter_weights(weight) llm_weights, vision_encoder_weights, vision_lang_adapter_weights = tee( weights, 3 ) # llm llm_weights = filter(lambda x: not is_vision_weights(x), llm_weights) self.language_model.load_weights(llm_weights) # vision encoder vision_encoder_weights = filter( is_vision_encoder_weights, vision_encoder_weights ) vision_encoder_dict = dict(self.vision_encoder.named_parameters()) for name, loaded_weight in vision_encoder_weights: # cut 'vision_encoder.' name = ".".join(name.split(".")[1:]) param = vision_encoder_dict[name] default_weight_loader(param, loaded_weight) # adapter vision_lang_adapter_weights = filter( is_vision_lang_adapter_weights, vision_lang_adapter_weights ) vision_lang_adpter_dict = dict( self.vision_language_adapter.named_parameters() ) for name, loaded_weight in vision_lang_adapter_weights: # cut 'vision_language_adapter.' name = ".".join(name.split(".")[1:]) param = vision_lang_adpter_dict[name] default_weight_loader(param, loaded_weight) # Vision encoder @dataclass class VisionEncoderArgs: hidden_size: int num_channels: int image_size: int patch_size: int intermediate_size: int num_hidden_layers: int num_attention_heads: int rope_theta: float # for rope-2D image_token_id: int def _reshape_for_broadcast( freqs_cis: torch.Tensor, x: torch.Tensor ) -> torch.Tensor: """ freqs_cis: complex - (seq_len, head_dim / 2) x: complex - (bsz, seq_len, head_dim / 2) """ ndim = x.ndim assert ndim > 1 assert freqs_cis.shape == (x.shape[1], x.shape[-1]), ( freqs_cis.shape, (x.shape[1], x.shape[-1]), ) shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(*shape) def precompute_freqs_cis_2d( dim: int, height: int, width: int, theta: float, ) -> torch.Tensor: """ freqs_cis: 2D complex tensor of shape (height, width, dim // 2) to be indexed by (height, width) position tuples """ # (dim / 2) frequency bases freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) h = torch.arange(height, device=freqs.device) w = torch.arange(width, device=freqs.device) freqs_h = torch.outer(h, freqs[::2]).float() freqs_w = torch.outer(w, freqs[1::2]).float() freqs_2d = torch.cat( [ freqs_h[:, None, :].repeat(1, width, 1), freqs_w[None, :, :].repeat(height, 1, 1), ], dim=-1, ) return torch.polar(torch.ones_like(freqs_2d), freqs_2d) def apply_rotary_emb_vit( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) assert freqs_cis.dtype == torch.complex64 freqs_cis = _reshape_for_broadcast(freqs_cis, xq_) xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) class FeedForward(nn.Module): def __init__(self, args: VisionEncoderArgs): super().__init__() assert args.intermediate_size is not None self.w1 = nn.Linear( args.hidden_size, args.intermediate_size, bias=False ) self.w2 = nn.Linear( args.intermediate_size, args.hidden_size, bias=False ) self.w3 = nn.Linear( args.hidden_size, args.intermediate_size, bias=False ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(F.silu(self.w1(x)) * self.w3(x)) class Attention(nn.Module): def __init__(self, args: VisionEncoderArgs): super().__init__() self.args = args assert not args.hidden_size % args.num_attention_heads self.n_heads = args.num_attention_heads self.head_dim = args.hidden_size // args.num_attention_heads self.wq = nn.Linear(args.hidden_size, args.hidden_size, bias=False) self.wk = nn.Linear(args.hidden_size, args.hidden_size, bias=False) self.wv = nn.Linear(args.hidden_size, args.hidden_size, bias=False) self.wo = nn.Linear(args.hidden_size, args.hidden_size, bias=False) def forward( self, x: torch.Tensor, mask: BlockDiagonalMask, freqs_cis: torch.Tensor, ) -> torch.Tensor: batch, patches, _ = x.shape q, k, v = self.wq(x), self.wk(x), self.wv(x) q = q.reshape(batch, patches, self.n_heads, self.head_dim) k = k.reshape(batch, patches, self.n_heads, self.head_dim) v = v.reshape(batch, patches, self.n_heads, self.head_dim) q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis) out = memory_efficient_attention(q, k, v, attn_bias=mask) out = out.reshape(batch, patches, self.n_heads * self.head_dim) return self.wo(out) class TransformerBlock(nn.Module): def __init__(self, args: VisionEncoderArgs): super().__init__() self.attention = Attention(args) self.feed_forward = FeedForward(args) self.attention_norm = RMSNorm(args.hidden_size, eps=1e-5) self.ffn_norm = RMSNorm(args.hidden_size, eps=1e-5) def forward( self, x: torch.Tensor, mask: BlockDiagonalMask, freqs_cis: torch.Tensor, ) -> torch.Tensor: r = self.attention.forward( self.attention_norm(x), mask=mask, freqs_cis=freqs_cis ) h = x + r r = self.feed_forward.forward(self.ffn_norm(h)) out = h + r return out class Transformer(nn.Module): def __init__(self, args: VisionEncoderArgs): super().__init__() self.layers = torch.nn.ModuleList() for _ in range(args.num_hidden_layers): self.layers.append(TransformerBlock(args)) def forward( self, x: torch.Tensor, mask: BlockDiagonalMask, freqs_cis: Optional[torch.Tensor], ) -> torch.Tensor: for layer in self.layers: x = layer(x, mask=mask, freqs_cis=freqs_cis) return x def position_meshgrid( patch_embeds_list: list[torch.Tensor], ) -> torch.Tensor: positions = torch.cat( [ torch.stack( torch.meshgrid( torch.arange(p.shape[-2]), torch.arange(p.shape[-1]), indexing="ij", ), dim=-1, ).reshape(-1, 2) for p in patch_embeds_list ] ) return positions class VisionTransformer(nn.Module): def __init__(self, args: VisionEncoderArgs): super().__init__() self.args = args self.patch_conv = nn.Conv2d( in_channels=args.num_channels, out_channels=args.hidden_size, kernel_size=args.patch_size, stride=args.patch_size, bias=False, ) self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5) self.transformer = Transformer(args) head_dim = self.args.hidden_size // self.args.num_attention_heads assert head_dim % 2 == 0, "ROPE requires even head_dim" self._freqs_cis: Optional[torch.Tensor] = None @property def max_patches_per_side(self) -> int: return self.args.image_size // self.args.patch_size @property def device(self) -> torch.device: return next(self.parameters()).device @property def dtype(self) -> torch.device: return next(self.parameters()).dtype @property def freqs_cis(self) -> torch.Tensor: if self._freqs_cis is None: self._freqs_cis = precompute_freqs_cis_2d( dim=self.args.hidden_size // self.args.num_attention_heads, height=self.max_patches_per_side, width=self.max_patches_per_side, theta=self.args.rope_theta, ) if self._freqs_cis.device != self.device: self._freqs_cis = self._freqs_cis.to(device=self.device) return self._freqs_cis def forward( self, images: List[torch.Tensor], ) -> torch.Tensor: """ Args: images: list of N_img images of variable sizes, each of shape (C, H, W) Returns: image_features: tensor of token features for all tokens of all images of shape (N_toks, D) """ # pass images through initial convolution independently patch_embeds_list = [ self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images ] # flatten to a single sequence patch_embeds = torch.cat( [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1 ) patch_embeds = self.ln_pre(patch_embeds) # positional embeddings positions = position_meshgrid(patch_embeds_list).to(self.device) freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]] # pass through Transformer with a block diagonal mask delimiting images mask = BlockDiagonalMask.from_seqlens( [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], ) out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis) # remove batch dimension of the single sequence return out.squeeze(0) class VisionLanguageAdapter(nn.Module): def __init__(self, args: VisionEncoderArgs, dim: int): super().__init__() assert isinstance(args, VisionEncoderArgs) self.w_in = nn.Linear( args.hidden_size, dim, bias=True, ) self.gelu = nn.GELU() self.w_out = nn.Linear(dim, dim, bias=True) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w_out(self.gelu(self.w_in(x)))