pixtral.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549
  1. from array import array
  2. from dataclasses import dataclass, fields
  3. from itertools import tee
  4. from typing import Iterable, List, Mapping, Optional, Tuple, Union
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from mistral_common.protocol.instruct.messages import ImageChunk
  9. from PIL import Image
  10. from transformers import PretrainedConfig
  11. from xformers.ops.fmha import memory_efficient_attention
  12. from xformers.ops.fmha.attn_bias import BlockDiagonalMask
  13. from aphrodite.attention import AttentionMetadata
  14. from aphrodite.common.config import CacheConfig, MultiModalConfig
  15. from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
  16. IntermediateTensors, SequenceData)
  17. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  18. from aphrodite.modeling.layers.layernorm import RMSNorm
  19. from aphrodite.modeling.layers.sampler import SamplerOutput
  20. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  21. from aphrodite.modeling.models.utils import merge_multimodal_embeddings
  22. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  23. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  24. from aphrodite.multimodal.base import MultiModalInputs
  25. from aphrodite.multimodal.utils import cached_get_tokenizer
  26. from aphrodite.quantization import QuantizationConfig
  27. from .interfaces import SupportsMultiModal
  28. from .utils import init_aphrodite_registered_model
  29. def get_max_pixtral_image_tokens(ctx: InputContext):
  30. tokenizer = cached_get_tokenizer(
  31. ctx.model_config.tokenizer,
  32. tokenizer_mode=ctx.model_config.tokenizer_mode,
  33. )
  34. mm_encoder = tokenizer.instruct.mm_encoder
  35. max_image_size = mm_encoder.mm_config.max_image_size
  36. image_patch_size = mm_encoder.mm_config.image_patch_size
  37. return (max_image_size // image_patch_size) ** 2
  38. def dummy_data_for_pixtral(
  39. ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]
  40. ):
  41. tokenizer = cached_get_tokenizer(
  42. ctx.model_config.tokenizer,
  43. tokenizer_mode=ctx.model_config.tokenizer_mode)
  44. mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
  45. patch_size = mm_encoder.mm_config.image_patch_size
  46. image_token_id = mm_encoder.special_ids.img
  47. mm_config = ctx.model_config.multimodal_config
  48. num_images = mm_config.limit_per_prompt.get("image", 1)
  49. # dummy size
  50. size = 256
  51. image = Image.new("RGB", (size, size), color=0)
  52. image_feature_size = (size**2) // (patch_size**2)
  53. num_image_tokens = image_feature_size * num_images
  54. token_ids = array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  55. [image_token_id]) * num_image_tokens
  56. token_ids += array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  57. [0]) * (seq_len - num_image_tokens)
  58. seq_data = SequenceData(token_ids)
  59. mm_data = {"image": num_images * [image]}
  60. return seq_data, mm_data
  61. def input_mapper_for_pixtral(
  62. ctx: InputContext, data: object
  63. ) -> MultiModalInputs:
  64. """Maps the input data to its MultiModalInputs (if any).
  65. Args:
  66. ctx: Context of the loaded model.
  67. data: data potentially containing image/image embeddings to be mapped
  68. to pixel_values in .forward() for a visual QWenLMHeadModel model.
  69. Returns:
  70. MultiModalInputs containing the stacked normalized images tensor or
  71. image embeddings.
  72. """
  73. # Early exit if we have provided an image to a language only Qwen model
  74. model_config = ctx.model_config
  75. tokenizer = cached_get_tokenizer(
  76. model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode
  77. )
  78. data_list = data if isinstance(data, list) else [data]
  79. images = []
  80. for image_data in data_list:
  81. image = ImageChunk(image=image_data)
  82. encoding = tokenizer.instruct.mm_encoder(image)
  83. image = torch.from_numpy(encoding.image).to(
  84. device="cuda", dtype=torch.float16
  85. )
  86. images.append(image)
  87. return MultiModalInputs({"images": images})
  88. def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs):
  89. multi_modal_data = llm_inputs.get("multi_modal_data")
  90. if multi_modal_data is not None and "image" in multi_modal_data:
  91. tokenizer = cached_get_tokenizer(
  92. ctx.model_config.tokenizer,
  93. tokenizer_mode=ctx.model_config.tokenizer_mode)
  94. mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
  95. image_token_id = mm_encoder.special_ids.img
  96. if image_token_id not in llm_inputs['prompt_token_ids']:
  97. raise ValueError(
  98. (f"You've passed {llm_inputs=} without {image_token_id=}"
  99. " Make sure to process your input via mistral_common's"
  100. " tokenizer or pass a chat completion request."))
  101. return llm_inputs
  102. @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral)
  103. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens)
  104. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral)
  105. @INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral)
  106. class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal):
  107. def __init__(
  108. self,
  109. config: PretrainedConfig,
  110. multimodal_config: MultiModalConfig,
  111. cache_config: Optional[CacheConfig] = None,
  112. quant_config: Optional[QuantizationConfig] = None,
  113. ) -> None:
  114. super().__init__()
  115. self.config = config
  116. self.multimodal_config = multimodal_config
  117. dataclass_fields = {field.name for field in fields(VisionEncoderArgs)}
  118. vision_args = {
  119. key: value
  120. for key, value in self.config.vision_config.to_dict().items()
  121. if key in dataclass_fields
  122. }
  123. self.vision_args = VisionEncoderArgs(**vision_args)
  124. # init MistralForCausalLM
  125. self.language_model = init_aphrodite_registered_model(
  126. config.text_config, cache_config, quant_config
  127. )
  128. self.vision_encoder = VisionTransformer(self.vision_args)
  129. self.vision_language_adapter = VisionLanguageAdapter(
  130. self.vision_args, dim=config.text_config.hidden_size
  131. )
  132. def forward(
  133. self,
  134. input_ids: torch.Tensor,
  135. positions: torch.Tensor,
  136. kv_caches: List[torch.Tensor],
  137. attn_metadata: AttentionMetadata,
  138. intermediate_tensors: Optional[IntermediateTensors] = None,
  139. **kwargs: object,
  140. ) -> SamplerOutput:
  141. """Run forward pass for pixtral.
  142. TODO
  143. """
  144. image_input = self._parse_and_validate_image_input(**kwargs)
  145. if image_input is not None:
  146. vision_embeddings = self._process_image_input(image_input)
  147. inputs_embeds = self.language_model.model.get_input_embeddings(
  148. input_ids
  149. )
  150. inputs_embeds = merge_multimodal_embeddings(
  151. input_ids,
  152. inputs_embeds,
  153. vision_embeddings,
  154. self.vision_args.image_token_id,
  155. )
  156. input_ids = None
  157. else:
  158. inputs_embeds = None
  159. hidden_states = self.language_model.model(
  160. input_ids,
  161. positions,
  162. kv_caches,
  163. attn_metadata,
  164. None,
  165. inputs_embeds=inputs_embeds,
  166. )
  167. return hidden_states
  168. def _parse_and_validate_image_input(
  169. self,
  170. images: Optional[
  171. Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor]
  172. ] = None,
  173. ) -> Optional[List[torch.Tensor]]:
  174. if images is None:
  175. return None
  176. if isinstance(images, torch.Tensor):
  177. # if passed as batch take all images
  178. N, B, C, W, H = images.shape
  179. images = images.reshape(N * B, C, W, H)
  180. images = [images[i] for i in range(images.size(0))]
  181. elif isinstance(images, list):
  182. # if passed as list flatten lists of tensors
  183. flatten_images = []
  184. for imgs_per_req in images:
  185. imgs_per_req = [
  186. imgs_per_req[i] for i in range(imgs_per_req.size(0))
  187. ] if isinstance(imgs_per_req, torch.Tensor) else imgs_per_req
  188. flatten_images.extend(imgs_per_req)
  189. images = flatten_images
  190. return images
  191. def _process_image_input(
  192. self, image_input: List[torch.Tensor]
  193. ) -> torch.Tensor:
  194. return self.vision_language_adapter(self.vision_encoder(image_input))
  195. def compute_logits(
  196. self,
  197. hidden_states: torch.Tensor,
  198. sampling_metadata: SamplingMetadata,
  199. ) -> Optional[torch.Tensor]:
  200. return self.language_model.compute_logits(
  201. hidden_states, sampling_metadata
  202. )
  203. def sample(
  204. self,
  205. logits: torch.Tensor,
  206. sampling_metadata: SamplingMetadata,
  207. ) -> Optional[SamplerOutput]:
  208. return self.language_model.sample(logits, sampling_metadata)
  209. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  210. def is_vision_encoder_weights(weight: Tuple[str, torch.Tensor]):
  211. return weight[0].startswith("vision_encoder")
  212. def is_vision_lang_adapter_weights(weight: Tuple[str, torch.Tensor]):
  213. return weight[0].startswith("vision_language_adapter")
  214. def is_vision_weights(weight: Tuple[str, torch.Tensor]):
  215. return is_vision_encoder_weights(
  216. weight
  217. ) or is_vision_lang_adapter_weights(weight)
  218. llm_weights, vision_encoder_weights, vision_lang_adapter_weights = tee(
  219. weights, 3
  220. )
  221. # llm
  222. llm_weights = filter(lambda x: not is_vision_weights(x), llm_weights)
  223. self.language_model.load_weights(llm_weights)
  224. # vision encoder
  225. vision_encoder_weights = filter(
  226. is_vision_encoder_weights, vision_encoder_weights
  227. )
  228. vision_encoder_dict = dict(self.vision_encoder.named_parameters())
  229. for name, loaded_weight in vision_encoder_weights:
  230. # cut 'vision_encoder.'
  231. name = ".".join(name.split(".")[1:])
  232. param = vision_encoder_dict[name]
  233. default_weight_loader(param, loaded_weight)
  234. # adapter
  235. vision_lang_adapter_weights = filter(
  236. is_vision_lang_adapter_weights, vision_lang_adapter_weights
  237. )
  238. vision_lang_adpter_dict = dict(
  239. self.vision_language_adapter.named_parameters()
  240. )
  241. for name, loaded_weight in vision_lang_adapter_weights:
  242. # cut 'vision_language_adapter.'
  243. name = ".".join(name.split(".")[1:])
  244. param = vision_lang_adpter_dict[name]
  245. default_weight_loader(param, loaded_weight)
  246. # Vision encoder
  247. @dataclass
  248. class VisionEncoderArgs:
  249. hidden_size: int
  250. num_channels: int
  251. image_size: int
  252. patch_size: int
  253. intermediate_size: int
  254. num_hidden_layers: int
  255. num_attention_heads: int
  256. rope_theta: float # for rope-2D
  257. image_token_id: int
  258. def _reshape_for_broadcast(
  259. freqs_cis: torch.Tensor, x: torch.Tensor
  260. ) -> torch.Tensor:
  261. """
  262. freqs_cis: complex - (seq_len, head_dim / 2)
  263. x: complex - (bsz, seq_len, head_dim / 2)
  264. """
  265. ndim = x.ndim
  266. assert ndim > 1
  267. assert freqs_cis.shape == (x.shape[1], x.shape[-1]), (
  268. freqs_cis.shape,
  269. (x.shape[1], x.shape[-1]),
  270. )
  271. shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
  272. return freqs_cis.view(*shape)
  273. def precompute_freqs_cis_2d(
  274. dim: int,
  275. height: int,
  276. width: int,
  277. theta: float,
  278. ) -> torch.Tensor:
  279. """
  280. freqs_cis: 2D complex tensor of shape (height, width, dim // 2)
  281. to be indexed by (height, width) position tuples
  282. """
  283. # (dim / 2) frequency bases
  284. freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
  285. h = torch.arange(height, device=freqs.device)
  286. w = torch.arange(width, device=freqs.device)
  287. freqs_h = torch.outer(h, freqs[::2]).float()
  288. freqs_w = torch.outer(w, freqs[1::2]).float()
  289. freqs_2d = torch.cat(
  290. [
  291. freqs_h[:, None, :].repeat(1, width, 1),
  292. freqs_w[None, :, :].repeat(height, 1, 1),
  293. ],
  294. dim=-1,
  295. )
  296. return torch.polar(torch.ones_like(freqs_2d), freqs_2d)
  297. def apply_rotary_emb_vit(
  298. xq: torch.Tensor,
  299. xk: torch.Tensor,
  300. freqs_cis: torch.Tensor,
  301. ) -> Tuple[torch.Tensor, torch.Tensor]:
  302. xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
  303. xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
  304. assert freqs_cis.dtype == torch.complex64
  305. freqs_cis = _reshape_for_broadcast(freqs_cis, xq_)
  306. xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
  307. xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
  308. return xq_out.type_as(xq), xk_out.type_as(xk)
  309. class FeedForward(nn.Module):
  310. def __init__(self, args: VisionEncoderArgs):
  311. super().__init__()
  312. assert args.intermediate_size is not None
  313. self.w1 = nn.Linear(
  314. args.hidden_size, args.intermediate_size, bias=False
  315. )
  316. self.w2 = nn.Linear(
  317. args.intermediate_size, args.hidden_size, bias=False
  318. )
  319. self.w3 = nn.Linear(
  320. args.hidden_size, args.intermediate_size, bias=False
  321. )
  322. def forward(self, x: torch.Tensor) -> torch.Tensor:
  323. return self.w2(F.silu(self.w1(x)) * self.w3(x))
  324. class Attention(nn.Module):
  325. def __init__(self, args: VisionEncoderArgs):
  326. super().__init__()
  327. self.args = args
  328. assert not args.hidden_size % args.num_attention_heads
  329. self.n_heads = args.num_attention_heads
  330. self.head_dim = args.hidden_size // args.num_attention_heads
  331. self.wq = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
  332. self.wk = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
  333. self.wv = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
  334. self.wo = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
  335. def forward(
  336. self,
  337. x: torch.Tensor,
  338. mask: BlockDiagonalMask,
  339. freqs_cis: torch.Tensor,
  340. ) -> torch.Tensor:
  341. batch, patches, _ = x.shape
  342. q, k, v = self.wq(x), self.wk(x), self.wv(x)
  343. q = q.reshape(batch, patches, self.n_heads, self.head_dim)
  344. k = k.reshape(batch, patches, self.n_heads, self.head_dim)
  345. v = v.reshape(batch, patches, self.n_heads, self.head_dim)
  346. q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
  347. out = memory_efficient_attention(q, k, v, attn_bias=mask)
  348. out = out.reshape(batch, patches, self.n_heads * self.head_dim)
  349. return self.wo(out)
  350. class TransformerBlock(nn.Module):
  351. def __init__(self, args: VisionEncoderArgs):
  352. super().__init__()
  353. self.attention = Attention(args)
  354. self.feed_forward = FeedForward(args)
  355. self.attention_norm = RMSNorm(args.hidden_size, eps=1e-5)
  356. self.ffn_norm = RMSNorm(args.hidden_size, eps=1e-5)
  357. def forward(
  358. self,
  359. x: torch.Tensor,
  360. mask: BlockDiagonalMask,
  361. freqs_cis: torch.Tensor,
  362. ) -> torch.Tensor:
  363. r = self.attention.forward(
  364. self.attention_norm(x), mask=mask, freqs_cis=freqs_cis
  365. )
  366. h = x + r
  367. r = self.feed_forward.forward(self.ffn_norm(h))
  368. out = h + r
  369. return out
  370. class Transformer(nn.Module):
  371. def __init__(self, args: VisionEncoderArgs):
  372. super().__init__()
  373. self.layers = torch.nn.ModuleList()
  374. for _ in range(args.num_hidden_layers):
  375. self.layers.append(TransformerBlock(args))
  376. def forward(
  377. self,
  378. x: torch.Tensor,
  379. mask: BlockDiagonalMask,
  380. freqs_cis: Optional[torch.Tensor],
  381. ) -> torch.Tensor:
  382. for layer in self.layers:
  383. x = layer(x, mask=mask, freqs_cis=freqs_cis)
  384. return x
  385. def position_meshgrid(
  386. patch_embeds_list: list[torch.Tensor],
  387. ) -> torch.Tensor:
  388. positions = torch.cat(
  389. [
  390. torch.stack(
  391. torch.meshgrid(
  392. torch.arange(p.shape[-2]),
  393. torch.arange(p.shape[-1]),
  394. indexing="ij",
  395. ),
  396. dim=-1,
  397. ).reshape(-1, 2)
  398. for p in patch_embeds_list
  399. ]
  400. )
  401. return positions
  402. class VisionTransformer(nn.Module):
  403. def __init__(self, args: VisionEncoderArgs):
  404. super().__init__()
  405. self.args = args
  406. self.patch_conv = nn.Conv2d(
  407. in_channels=args.num_channels,
  408. out_channels=args.hidden_size,
  409. kernel_size=args.patch_size,
  410. stride=args.patch_size,
  411. bias=False,
  412. )
  413. self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5)
  414. self.transformer = Transformer(args)
  415. head_dim = self.args.hidden_size // self.args.num_attention_heads
  416. assert head_dim % 2 == 0, "ROPE requires even head_dim"
  417. self._freqs_cis: Optional[torch.Tensor] = None
  418. @property
  419. def max_patches_per_side(self) -> int:
  420. return self.args.image_size // self.args.patch_size
  421. @property
  422. def device(self) -> torch.device:
  423. return next(self.parameters()).device
  424. @property
  425. def dtype(self) -> torch.device:
  426. return next(self.parameters()).dtype
  427. @property
  428. def freqs_cis(self) -> torch.Tensor:
  429. if self._freqs_cis is None:
  430. self._freqs_cis = precompute_freqs_cis_2d(
  431. dim=self.args.hidden_size // self.args.num_attention_heads,
  432. height=self.max_patches_per_side,
  433. width=self.max_patches_per_side,
  434. theta=self.args.rope_theta,
  435. )
  436. if self._freqs_cis.device != self.device:
  437. self._freqs_cis = self._freqs_cis.to(device=self.device)
  438. return self._freqs_cis
  439. def forward(
  440. self,
  441. images: List[torch.Tensor],
  442. ) -> torch.Tensor:
  443. """
  444. Args:
  445. images: list of N_img images of variable sizes,
  446. each of shape (C, H, W)
  447. Returns:
  448. image_features: tensor of token features for
  449. all tokens of all images of shape (N_toks, D)
  450. """
  451. # pass images through initial convolution independently
  452. patch_embeds_list = [
  453. self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images
  454. ]
  455. # flatten to a single sequence
  456. patch_embeds = torch.cat(
  457. [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1
  458. )
  459. patch_embeds = self.ln_pre(patch_embeds)
  460. # positional embeddings
  461. positions = position_meshgrid(patch_embeds_list).to(self.device)
  462. freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]]
  463. # pass through Transformer with a block diagonal mask delimiting images
  464. mask = BlockDiagonalMask.from_seqlens(
  465. [p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
  466. )
  467. out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)
  468. # remove batch dimension of the single sequence
  469. return out.squeeze(0)
  470. class VisionLanguageAdapter(nn.Module):
  471. def __init__(self, args: VisionEncoderArgs, dim: int):
  472. super().__init__()
  473. assert isinstance(args, VisionEncoderArgs)
  474. self.w_in = nn.Linear(
  475. args.hidden_size,
  476. dim,
  477. bias=True,
  478. )
  479. self.gelu = nn.GELU()
  480. self.w_out = nn.Linear(dim, dim, bias=True)
  481. def forward(self, x: torch.Tensor) -> torch.Tensor:
  482. return self.w_out(self.gelu(self.w_in(x)))