123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549 |
- from array import array
- from dataclasses import dataclass, fields
- from itertools import tee
- from typing import Iterable, List, Mapping, Optional, Tuple, Union
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from mistral_common.protocol.instruct.messages import ImageChunk
- from PIL import Image
- from transformers import PretrainedConfig
- from xformers.ops.fmha import memory_efficient_attention
- from xformers.ops.fmha.attn_bias import BlockDiagonalMask
- from aphrodite.attention import AttentionMetadata
- from aphrodite.common.config import CacheConfig, MultiModalConfig
- from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
- IntermediateTensors, SequenceData)
- from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
- from aphrodite.modeling.layers.layernorm import RMSNorm
- from aphrodite.modeling.layers.sampler import SamplerOutput
- from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
- from aphrodite.modeling.models.utils import merge_multimodal_embeddings
- from aphrodite.modeling.sampling_metadata import SamplingMetadata
- from aphrodite.multimodal import MULTIMODAL_REGISTRY
- from aphrodite.multimodal.base import MultiModalInputs
- from aphrodite.multimodal.utils import cached_get_tokenizer
- from aphrodite.quantization import QuantizationConfig
- from .interfaces import SupportsMultiModal
- from .utils import init_aphrodite_registered_model
- def get_max_pixtral_image_tokens(ctx: InputContext):
- tokenizer = cached_get_tokenizer(
- ctx.model_config.tokenizer,
- tokenizer_mode=ctx.model_config.tokenizer_mode,
- )
- mm_encoder = tokenizer.instruct.mm_encoder
- max_image_size = mm_encoder.mm_config.max_image_size
- image_patch_size = mm_encoder.mm_config.image_patch_size
- return (max_image_size // image_patch_size) ** 2
- def dummy_data_for_pixtral(
- ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]
- ):
- tokenizer = cached_get_tokenizer(
- ctx.model_config.tokenizer,
- tokenizer_mode=ctx.model_config.tokenizer_mode)
- mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
- patch_size = mm_encoder.mm_config.image_patch_size
- image_token_id = mm_encoder.special_ids.img
- mm_config = ctx.model_config.multimodal_config
- num_images = mm_config.limit_per_prompt.get("image", 1)
- # dummy size
- size = 256
- image = Image.new("RGB", (size, size), color=0)
- image_feature_size = (size**2) // (patch_size**2)
- num_image_tokens = image_feature_size * num_images
- token_ids = array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
- [image_token_id]) * num_image_tokens
- token_ids += array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
- [0]) * (seq_len - num_image_tokens)
- seq_data = SequenceData(token_ids)
- mm_data = {"image": num_images * [image]}
- return seq_data, mm_data
- def input_mapper_for_pixtral(
- ctx: InputContext, data: object
- ) -> MultiModalInputs:
- """Maps the input data to its MultiModalInputs (if any).
- Args:
- ctx: Context of the loaded model.
- data: data potentially containing image/image embeddings to be mapped
- to pixel_values in .forward() for a visual QWenLMHeadModel model.
- Returns:
- MultiModalInputs containing the stacked normalized images tensor or
- image embeddings.
- """
- # Early exit if we have provided an image to a language only Qwen model
- model_config = ctx.model_config
- tokenizer = cached_get_tokenizer(
- model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode
- )
- data_list = data if isinstance(data, list) else [data]
- images = []
- for image_data in data_list:
- image = ImageChunk(image=image_data)
- encoding = tokenizer.instruct.mm_encoder(image)
- image = torch.from_numpy(encoding.image).to(
- device="cuda", dtype=torch.float16
- )
- images.append(image)
- return MultiModalInputs({"images": images})
- def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs):
- multi_modal_data = llm_inputs.get("multi_modal_data")
- if multi_modal_data is not None and "image" in multi_modal_data:
- tokenizer = cached_get_tokenizer(
- ctx.model_config.tokenizer,
- tokenizer_mode=ctx.model_config.tokenizer_mode)
- mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
- image_token_id = mm_encoder.special_ids.img
- if image_token_id not in llm_inputs['prompt_token_ids']:
- raise ValueError(
- (f"You've passed {llm_inputs=} without {image_token_id=}"
- " Make sure to process your input via mistral_common's"
- " tokenizer or pass a chat completion request."))
- return llm_inputs
- @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral)
- @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens)
- @INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral)
- @INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral)
- class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal):
- def __init__(
- self,
- config: PretrainedConfig,
- multimodal_config: MultiModalConfig,
- cache_config: Optional[CacheConfig] = None,
- quant_config: Optional[QuantizationConfig] = None,
- ) -> None:
- super().__init__()
- self.config = config
- self.multimodal_config = multimodal_config
- dataclass_fields = {field.name for field in fields(VisionEncoderArgs)}
- vision_args = {
- key: value
- for key, value in self.config.vision_config.to_dict().items()
- if key in dataclass_fields
- }
- self.vision_args = VisionEncoderArgs(**vision_args)
- # init MistralForCausalLM
- self.language_model = init_aphrodite_registered_model(
- config.text_config, cache_config, quant_config
- )
- self.vision_encoder = VisionTransformer(self.vision_args)
- self.vision_language_adapter = VisionLanguageAdapter(
- self.vision_args, dim=config.text_config.hidden_size
- )
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- kv_caches: List[torch.Tensor],
- attn_metadata: AttentionMetadata,
- intermediate_tensors: Optional[IntermediateTensors] = None,
- **kwargs: object,
- ) -> SamplerOutput:
- """Run forward pass for pixtral.
- TODO
- """
- image_input = self._parse_and_validate_image_input(**kwargs)
- if image_input is not None:
- vision_embeddings = self._process_image_input(image_input)
- inputs_embeds = self.language_model.model.get_input_embeddings(
- input_ids
- )
- inputs_embeds = merge_multimodal_embeddings(
- input_ids,
- inputs_embeds,
- vision_embeddings,
- self.vision_args.image_token_id,
- )
- input_ids = None
- else:
- inputs_embeds = None
- hidden_states = self.language_model.model(
- input_ids,
- positions,
- kv_caches,
- attn_metadata,
- None,
- inputs_embeds=inputs_embeds,
- )
- return hidden_states
- def _parse_and_validate_image_input(
- self,
- images: Optional[
- Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor]
- ] = None,
- ) -> Optional[List[torch.Tensor]]:
- if images is None:
- return None
- if isinstance(images, torch.Tensor):
- # if passed as batch take all images
- N, B, C, W, H = images.shape
- images = images.reshape(N * B, C, W, H)
- images = [images[i] for i in range(images.size(0))]
- elif isinstance(images, list):
- # if passed as list flatten lists of tensors
- flatten_images = []
- for imgs_per_req in images:
- imgs_per_req = [
- imgs_per_req[i] for i in range(imgs_per_req.size(0))
- ] if isinstance(imgs_per_req, torch.Tensor) else imgs_per_req
- flatten_images.extend(imgs_per_req)
- images = flatten_images
- return images
- def _process_image_input(
- self, image_input: List[torch.Tensor]
- ) -> torch.Tensor:
- return self.vision_language_adapter(self.vision_encoder(image_input))
- def compute_logits(
- self,
- hidden_states: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- ) -> Optional[torch.Tensor]:
- return self.language_model.compute_logits(
- hidden_states, sampling_metadata
- )
- def sample(
- self,
- logits: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- ) -> Optional[SamplerOutput]:
- return self.language_model.sample(logits, sampling_metadata)
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
- def is_vision_encoder_weights(weight: Tuple[str, torch.Tensor]):
- return weight[0].startswith("vision_encoder")
- def is_vision_lang_adapter_weights(weight: Tuple[str, torch.Tensor]):
- return weight[0].startswith("vision_language_adapter")
- def is_vision_weights(weight: Tuple[str, torch.Tensor]):
- return is_vision_encoder_weights(
- weight
- ) or is_vision_lang_adapter_weights(weight)
- llm_weights, vision_encoder_weights, vision_lang_adapter_weights = tee(
- weights, 3
- )
- # llm
- llm_weights = filter(lambda x: not is_vision_weights(x), llm_weights)
- self.language_model.load_weights(llm_weights)
- # vision encoder
- vision_encoder_weights = filter(
- is_vision_encoder_weights, vision_encoder_weights
- )
- vision_encoder_dict = dict(self.vision_encoder.named_parameters())
- for name, loaded_weight in vision_encoder_weights:
- # cut 'vision_encoder.'
- name = ".".join(name.split(".")[1:])
- param = vision_encoder_dict[name]
- default_weight_loader(param, loaded_weight)
- # adapter
- vision_lang_adapter_weights = filter(
- is_vision_lang_adapter_weights, vision_lang_adapter_weights
- )
- vision_lang_adpter_dict = dict(
- self.vision_language_adapter.named_parameters()
- )
- for name, loaded_weight in vision_lang_adapter_weights:
- # cut 'vision_language_adapter.'
- name = ".".join(name.split(".")[1:])
- param = vision_lang_adpter_dict[name]
- default_weight_loader(param, loaded_weight)
- # Vision encoder
- @dataclass
- class VisionEncoderArgs:
- hidden_size: int
- num_channels: int
- image_size: int
- patch_size: int
- intermediate_size: int
- num_hidden_layers: int
- num_attention_heads: int
- rope_theta: float # for rope-2D
- image_token_id: int
- def _reshape_for_broadcast(
- freqs_cis: torch.Tensor, x: torch.Tensor
- ) -> torch.Tensor:
- """
- freqs_cis: complex - (seq_len, head_dim / 2)
- x: complex - (bsz, seq_len, head_dim / 2)
- """
- ndim = x.ndim
- assert ndim > 1
- assert freqs_cis.shape == (x.shape[1], x.shape[-1]), (
- freqs_cis.shape,
- (x.shape[1], x.shape[-1]),
- )
- shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
- return freqs_cis.view(*shape)
- def precompute_freqs_cis_2d(
- dim: int,
- height: int,
- width: int,
- theta: float,
- ) -> torch.Tensor:
- """
- freqs_cis: 2D complex tensor of shape (height, width, dim // 2)
- to be indexed by (height, width) position tuples
- """
- # (dim / 2) frequency bases
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
- h = torch.arange(height, device=freqs.device)
- w = torch.arange(width, device=freqs.device)
- freqs_h = torch.outer(h, freqs[::2]).float()
- freqs_w = torch.outer(w, freqs[1::2]).float()
- freqs_2d = torch.cat(
- [
- freqs_h[:, None, :].repeat(1, width, 1),
- freqs_w[None, :, :].repeat(height, 1, 1),
- ],
- dim=-1,
- )
- return torch.polar(torch.ones_like(freqs_2d), freqs_2d)
- def apply_rotary_emb_vit(
- xq: torch.Tensor,
- xk: torch.Tensor,
- freqs_cis: torch.Tensor,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
- xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
- assert freqs_cis.dtype == torch.complex64
- freqs_cis = _reshape_for_broadcast(freqs_cis, xq_)
- xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
- xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
- return xq_out.type_as(xq), xk_out.type_as(xk)
- class FeedForward(nn.Module):
- def __init__(self, args: VisionEncoderArgs):
- super().__init__()
- assert args.intermediate_size is not None
- self.w1 = nn.Linear(
- args.hidden_size, args.intermediate_size, bias=False
- )
- self.w2 = nn.Linear(
- args.intermediate_size, args.hidden_size, bias=False
- )
- self.w3 = nn.Linear(
- args.hidden_size, args.intermediate_size, bias=False
- )
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return self.w2(F.silu(self.w1(x)) * self.w3(x))
- class Attention(nn.Module):
- def __init__(self, args: VisionEncoderArgs):
- super().__init__()
- self.args = args
- assert not args.hidden_size % args.num_attention_heads
- self.n_heads = args.num_attention_heads
- self.head_dim = args.hidden_size // args.num_attention_heads
- self.wq = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
- self.wk = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
- self.wv = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
- self.wo = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
- def forward(
- self,
- x: torch.Tensor,
- mask: BlockDiagonalMask,
- freqs_cis: torch.Tensor,
- ) -> torch.Tensor:
- batch, patches, _ = x.shape
- q, k, v = self.wq(x), self.wk(x), self.wv(x)
- q = q.reshape(batch, patches, self.n_heads, self.head_dim)
- k = k.reshape(batch, patches, self.n_heads, self.head_dim)
- v = v.reshape(batch, patches, self.n_heads, self.head_dim)
- q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
- out = memory_efficient_attention(q, k, v, attn_bias=mask)
- out = out.reshape(batch, patches, self.n_heads * self.head_dim)
- return self.wo(out)
- class TransformerBlock(nn.Module):
- def __init__(self, args: VisionEncoderArgs):
- super().__init__()
- self.attention = Attention(args)
- self.feed_forward = FeedForward(args)
- self.attention_norm = RMSNorm(args.hidden_size, eps=1e-5)
- self.ffn_norm = RMSNorm(args.hidden_size, eps=1e-5)
- def forward(
- self,
- x: torch.Tensor,
- mask: BlockDiagonalMask,
- freqs_cis: torch.Tensor,
- ) -> torch.Tensor:
- r = self.attention.forward(
- self.attention_norm(x), mask=mask, freqs_cis=freqs_cis
- )
- h = x + r
- r = self.feed_forward.forward(self.ffn_norm(h))
- out = h + r
- return out
- class Transformer(nn.Module):
- def __init__(self, args: VisionEncoderArgs):
- super().__init__()
- self.layers = torch.nn.ModuleList()
- for _ in range(args.num_hidden_layers):
- self.layers.append(TransformerBlock(args))
- def forward(
- self,
- x: torch.Tensor,
- mask: BlockDiagonalMask,
- freqs_cis: Optional[torch.Tensor],
- ) -> torch.Tensor:
- for layer in self.layers:
- x = layer(x, mask=mask, freqs_cis=freqs_cis)
- return x
- def position_meshgrid(
- patch_embeds_list: list[torch.Tensor],
- ) -> torch.Tensor:
- positions = torch.cat(
- [
- torch.stack(
- torch.meshgrid(
- torch.arange(p.shape[-2]),
- torch.arange(p.shape[-1]),
- indexing="ij",
- ),
- dim=-1,
- ).reshape(-1, 2)
- for p in patch_embeds_list
- ]
- )
- return positions
- class VisionTransformer(nn.Module):
- def __init__(self, args: VisionEncoderArgs):
- super().__init__()
- self.args = args
- self.patch_conv = nn.Conv2d(
- in_channels=args.num_channels,
- out_channels=args.hidden_size,
- kernel_size=args.patch_size,
- stride=args.patch_size,
- bias=False,
- )
- self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5)
- self.transformer = Transformer(args)
- head_dim = self.args.hidden_size // self.args.num_attention_heads
- assert head_dim % 2 == 0, "ROPE requires even head_dim"
- self._freqs_cis: Optional[torch.Tensor] = None
- @property
- def max_patches_per_side(self) -> int:
- return self.args.image_size // self.args.patch_size
- @property
- def device(self) -> torch.device:
- return next(self.parameters()).device
- @property
- def dtype(self) -> torch.device:
- return next(self.parameters()).dtype
- @property
- def freqs_cis(self) -> torch.Tensor:
- if self._freqs_cis is None:
- self._freqs_cis = precompute_freqs_cis_2d(
- dim=self.args.hidden_size // self.args.num_attention_heads,
- height=self.max_patches_per_side,
- width=self.max_patches_per_side,
- theta=self.args.rope_theta,
- )
- if self._freqs_cis.device != self.device:
- self._freqs_cis = self._freqs_cis.to(device=self.device)
- return self._freqs_cis
- def forward(
- self,
- images: List[torch.Tensor],
- ) -> torch.Tensor:
- """
- Args:
- images: list of N_img images of variable sizes,
- each of shape (C, H, W)
- Returns:
- image_features: tensor of token features for
- all tokens of all images of shape (N_toks, D)
- """
- # pass images through initial convolution independently
- patch_embeds_list = [
- self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images
- ]
- # flatten to a single sequence
- patch_embeds = torch.cat(
- [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1
- )
- patch_embeds = self.ln_pre(patch_embeds)
- # positional embeddings
- positions = position_meshgrid(patch_embeds_list).to(self.device)
- freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]]
- # pass through Transformer with a block diagonal mask delimiting images
- mask = BlockDiagonalMask.from_seqlens(
- [p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
- )
- out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)
- # remove batch dimension of the single sequence
- return out.squeeze(0)
- class VisionLanguageAdapter(nn.Module):
- def __init__(self, args: VisionEncoderArgs, dim: int):
- super().__init__()
- assert isinstance(args, VisionEncoderArgs)
- self.w_in = nn.Linear(
- args.hidden_size,
- dim,
- bias=True,
- )
- self.gelu = nn.GELU()
- self.w_out = nn.Linear(dim, dim, bias=True)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return self.w_out(self.gelu(self.w_in(x)))
|