pixtral.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545
  1. from dataclasses import dataclass, fields
  2. from itertools import tee
  3. from typing import Iterable, List, Mapping, Optional, Tuple, Union
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from mistral_common.protocol.instruct.messages import ImageChunk
  8. from PIL import Image
  9. from transformers import PretrainedConfig
  10. from xformers.ops.fmha import memory_efficient_attention
  11. from xformers.ops.fmha.attn_bias import BlockDiagonalMask
  12. from aphrodite.attention import AttentionMetadata
  13. from aphrodite.common.config import CacheConfig, MultiModalConfig
  14. from aphrodite.common.sequence import IntermediateTensors, SequenceData
  15. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  16. from aphrodite.modeling.layers.layernorm import RMSNorm
  17. from aphrodite.modeling.layers.sampler import SamplerOutput
  18. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  19. from aphrodite.modeling.models.utils import merge_multimodal_embeddings
  20. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  21. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  22. from aphrodite.multimodal.base import MultiModalInputs
  23. from aphrodite.multimodal.utils import cached_get_tokenizer
  24. from aphrodite.quantization import QuantizationConfig
  25. from .interfaces import SupportsMultiModal
  26. from .utils import init_aphrodite_registered_model
  27. def get_max_pixtral_image_tokens(ctx: InputContext):
  28. tokenizer = cached_get_tokenizer(
  29. ctx.model_config.tokenizer,
  30. tokenizer_mode=ctx.model_config.tokenizer_mode,
  31. )
  32. mm_encoder = tokenizer.instruct.mm_encoder
  33. max_image_size = mm_encoder.mm_config.max_image_size
  34. image_patch_size = mm_encoder.mm_config.image_patch_size
  35. return (max_image_size // image_patch_size) ** 2
  36. def dummy_data_for_pixtral(
  37. ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]
  38. ):
  39. tokenizer = cached_get_tokenizer(
  40. ctx.model_config.tokenizer,
  41. tokenizer_mode=ctx.model_config.tokenizer_mode)
  42. mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
  43. patch_size = mm_encoder.mm_config.image_patch_size
  44. image_token_id = mm_encoder.special_ids.img
  45. mm_config = ctx.model_config.multimodal_config
  46. num_images = mm_config.limit_per_prompt.get("image", 1)
  47. # dummy size
  48. size = 256
  49. image = Image.new("RGB", (size, size), color=0)
  50. image_feature_size = (size**2) // (patch_size**2)
  51. num_image_tokens = image_feature_size * num_images
  52. seq_data = SequenceData.from_token_counts(
  53. (image_token_id, num_image_tokens),
  54. (0, seq_len - num_image_tokens),
  55. )
  56. mm_data = {"image": num_images * [image]}
  57. return seq_data, mm_data
  58. def input_mapper_for_pixtral(
  59. ctx: InputContext, data: object
  60. ) -> MultiModalInputs:
  61. """Maps the input data to its MultiModalInputs (if any).
  62. Args:
  63. ctx: Context of the loaded model.
  64. data: data potentially containing image/image embeddings to be mapped
  65. to pixel_values in .forward() for a visual QWenLMHeadModel model.
  66. Returns:
  67. MultiModalInputs containing the stacked normalized images tensor or
  68. image embeddings.
  69. """
  70. # Early exit if we have provided an image to a language only Qwen model
  71. model_config = ctx.model_config
  72. tokenizer = cached_get_tokenizer(
  73. model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode
  74. )
  75. data_list = data if isinstance(data, list) else [data]
  76. images = []
  77. for image_data in data_list:
  78. image = ImageChunk(image=image_data)
  79. encoding = tokenizer.instruct.mm_encoder(image)
  80. image = torch.from_numpy(encoding.image).to(
  81. device="cuda", dtype=torch.float16
  82. )
  83. images.append(image)
  84. return MultiModalInputs({"images": images})
  85. def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs):
  86. multi_modal_data = llm_inputs.get("multi_modal_data")
  87. if multi_modal_data is not None and "image" in multi_modal_data:
  88. tokenizer = cached_get_tokenizer(
  89. ctx.model_config.tokenizer,
  90. tokenizer_mode=ctx.model_config.tokenizer_mode)
  91. mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
  92. image_token_id = mm_encoder.special_ids.img
  93. if image_token_id not in llm_inputs['prompt_token_ids']:
  94. raise ValueError(
  95. (f"You've passed {llm_inputs=} without {image_token_id=}"
  96. " Make sure to process your input via mistral_common's"
  97. " tokenizer or pass a chat completion request."))
  98. return llm_inputs
  99. @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral)
  100. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens)
  101. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral)
  102. @INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral)
  103. class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal):
  104. def __init__(
  105. self,
  106. config: PretrainedConfig,
  107. multimodal_config: MultiModalConfig,
  108. cache_config: Optional[CacheConfig] = None,
  109. quant_config: Optional[QuantizationConfig] = None,
  110. ) -> None:
  111. super().__init__()
  112. self.config = config
  113. self.multimodal_config = multimodal_config
  114. dataclass_fields = {field.name for field in fields(VisionEncoderArgs)}
  115. vision_args = {
  116. key: value
  117. for key, value in self.config.vision_config.to_dict().items()
  118. if key in dataclass_fields
  119. }
  120. self.vision_args = VisionEncoderArgs(**vision_args)
  121. # init MistralForCausalLM
  122. self.language_model = init_aphrodite_registered_model(
  123. config.text_config, cache_config, quant_config
  124. )
  125. self.vision_encoder = VisionTransformer(self.vision_args)
  126. self.vision_language_adapter = VisionLanguageAdapter(
  127. self.vision_args, dim=config.text_config.hidden_size
  128. )
  129. def forward(
  130. self,
  131. input_ids: torch.Tensor,
  132. positions: torch.Tensor,
  133. kv_caches: List[torch.Tensor],
  134. attn_metadata: AttentionMetadata,
  135. intermediate_tensors: Optional[IntermediateTensors] = None,
  136. **kwargs: object,
  137. ) -> SamplerOutput:
  138. """Run forward pass for pixtral.
  139. TODO
  140. """
  141. image_input = self._parse_and_validate_image_input(**kwargs)
  142. if image_input is not None:
  143. vision_embeddings = self._process_image_input(image_input)
  144. inputs_embeds = self.language_model.model.get_input_embeddings(
  145. input_ids
  146. )
  147. inputs_embeds = merge_multimodal_embeddings(
  148. input_ids,
  149. inputs_embeds,
  150. vision_embeddings,
  151. self.vision_args.image_token_id,
  152. )
  153. input_ids = None
  154. else:
  155. inputs_embeds = None
  156. hidden_states = self.language_model.model(
  157. input_ids,
  158. positions,
  159. kv_caches,
  160. attn_metadata,
  161. None,
  162. inputs_embeds=inputs_embeds,
  163. )
  164. return hidden_states
  165. def _parse_and_validate_image_input(
  166. self,
  167. images: Optional[
  168. Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor]
  169. ] = None,
  170. ) -> Optional[List[torch.Tensor]]:
  171. if images is None:
  172. return None
  173. if isinstance(images, torch.Tensor):
  174. # if passed as batch take all images
  175. N, B, C, W, H = images.shape
  176. images = images.reshape(N * B, C, W, H)
  177. images = [images[i] for i in range(images.size(0))]
  178. elif isinstance(images, list):
  179. # if passed as list flatten lists of tensors
  180. flatten_images = []
  181. for imgs_per_req in images:
  182. imgs_per_req = [
  183. imgs_per_req[i] for i in range(imgs_per_req.size(0))
  184. ] if isinstance(imgs_per_req, torch.Tensor) else imgs_per_req
  185. flatten_images.extend(imgs_per_req)
  186. images = flatten_images
  187. return images
  188. def _process_image_input(
  189. self, image_input: List[torch.Tensor]
  190. ) -> torch.Tensor:
  191. return self.vision_language_adapter(self.vision_encoder(image_input))
  192. def compute_logits(
  193. self,
  194. hidden_states: torch.Tensor,
  195. sampling_metadata: SamplingMetadata,
  196. ) -> Optional[torch.Tensor]:
  197. return self.language_model.compute_logits(
  198. hidden_states, sampling_metadata
  199. )
  200. def sample(
  201. self,
  202. logits: torch.Tensor,
  203. sampling_metadata: SamplingMetadata,
  204. ) -> Optional[SamplerOutput]:
  205. return self.language_model.sample(logits, sampling_metadata)
  206. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  207. def is_vision_encoder_weights(weight: Tuple[str, torch.Tensor]):
  208. return weight[0].startswith("vision_encoder")
  209. def is_vision_lang_adapter_weights(weight: Tuple[str, torch.Tensor]):
  210. return weight[0].startswith("vision_language_adapter")
  211. def is_vision_weights(weight: Tuple[str, torch.Tensor]):
  212. return is_vision_encoder_weights(
  213. weight
  214. ) or is_vision_lang_adapter_weights(weight)
  215. llm_weights, vision_encoder_weights, vision_lang_adapter_weights = tee(
  216. weights, 3
  217. )
  218. # llm
  219. llm_weights = filter(lambda x: not is_vision_weights(x), llm_weights)
  220. self.language_model.load_weights(llm_weights)
  221. # vision encoder
  222. vision_encoder_weights = filter(
  223. is_vision_encoder_weights, vision_encoder_weights
  224. )
  225. vision_encoder_dict = dict(self.vision_encoder.named_parameters())
  226. for name, loaded_weight in vision_encoder_weights:
  227. # cut 'vision_encoder.'
  228. name = ".".join(name.split(".")[1:])
  229. param = vision_encoder_dict[name]
  230. default_weight_loader(param, loaded_weight)
  231. # adapter
  232. vision_lang_adapter_weights = filter(
  233. is_vision_lang_adapter_weights, vision_lang_adapter_weights
  234. )
  235. vision_lang_adpter_dict = dict(
  236. self.vision_language_adapter.named_parameters()
  237. )
  238. for name, loaded_weight in vision_lang_adapter_weights:
  239. # cut 'vision_language_adapter.'
  240. name = ".".join(name.split(".")[1:])
  241. param = vision_lang_adpter_dict[name]
  242. default_weight_loader(param, loaded_weight)
  243. # Vision encoder
  244. @dataclass
  245. class VisionEncoderArgs:
  246. hidden_size: int
  247. num_channels: int
  248. image_size: int
  249. patch_size: int
  250. intermediate_size: int
  251. num_hidden_layers: int
  252. num_attention_heads: int
  253. rope_theta: float # for rope-2D
  254. image_token_id: int
  255. def _reshape_for_broadcast(
  256. freqs_cis: torch.Tensor, x: torch.Tensor
  257. ) -> torch.Tensor:
  258. """
  259. freqs_cis: complex - (seq_len, head_dim / 2)
  260. x: complex - (bsz, seq_len, head_dim / 2)
  261. """
  262. ndim = x.ndim
  263. assert ndim > 1
  264. assert freqs_cis.shape == (x.shape[1], x.shape[-1]), (
  265. freqs_cis.shape,
  266. (x.shape[1], x.shape[-1]),
  267. )
  268. shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
  269. return freqs_cis.view(*shape)
  270. def precompute_freqs_cis_2d(
  271. dim: int,
  272. height: int,
  273. width: int,
  274. theta: float,
  275. ) -> torch.Tensor:
  276. """
  277. freqs_cis: 2D complex tensor of shape (height, width, dim // 2)
  278. to be indexed by (height, width) position tuples
  279. """
  280. # (dim / 2) frequency bases
  281. freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
  282. h = torch.arange(height, device=freqs.device)
  283. w = torch.arange(width, device=freqs.device)
  284. freqs_h = torch.outer(h, freqs[::2]).float()
  285. freqs_w = torch.outer(w, freqs[1::2]).float()
  286. freqs_2d = torch.cat(
  287. [
  288. freqs_h[:, None, :].repeat(1, width, 1),
  289. freqs_w[None, :, :].repeat(height, 1, 1),
  290. ],
  291. dim=-1,
  292. )
  293. return torch.polar(torch.ones_like(freqs_2d), freqs_2d)
  294. def apply_rotary_emb_vit(
  295. xq: torch.Tensor,
  296. xk: torch.Tensor,
  297. freqs_cis: torch.Tensor,
  298. ) -> Tuple[torch.Tensor, torch.Tensor]:
  299. xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
  300. xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
  301. assert freqs_cis.dtype == torch.complex64
  302. freqs_cis = _reshape_for_broadcast(freqs_cis, xq_)
  303. xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
  304. xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
  305. return xq_out.type_as(xq), xk_out.type_as(xk)
  306. class FeedForward(nn.Module):
  307. def __init__(self, args: VisionEncoderArgs):
  308. super().__init__()
  309. assert args.intermediate_size is not None
  310. self.w1 = nn.Linear(
  311. args.hidden_size, args.intermediate_size, bias=False
  312. )
  313. self.w2 = nn.Linear(
  314. args.intermediate_size, args.hidden_size, bias=False
  315. )
  316. self.w3 = nn.Linear(
  317. args.hidden_size, args.intermediate_size, bias=False
  318. )
  319. def forward(self, x: torch.Tensor) -> torch.Tensor:
  320. return self.w2(F.silu(self.w1(x)) * self.w3(x))
  321. class Attention(nn.Module):
  322. def __init__(self, args: VisionEncoderArgs):
  323. super().__init__()
  324. self.args = args
  325. assert not args.hidden_size % args.num_attention_heads
  326. self.n_heads = args.num_attention_heads
  327. self.head_dim = args.hidden_size // args.num_attention_heads
  328. self.wq = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
  329. self.wk = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
  330. self.wv = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
  331. self.wo = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
  332. def forward(
  333. self,
  334. x: torch.Tensor,
  335. mask: BlockDiagonalMask,
  336. freqs_cis: torch.Tensor,
  337. ) -> torch.Tensor:
  338. batch, patches, _ = x.shape
  339. q, k, v = self.wq(x), self.wk(x), self.wv(x)
  340. q = q.reshape(batch, patches, self.n_heads, self.head_dim)
  341. k = k.reshape(batch, patches, self.n_heads, self.head_dim)
  342. v = v.reshape(batch, patches, self.n_heads, self.head_dim)
  343. q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
  344. out = memory_efficient_attention(q, k, v, attn_bias=mask)
  345. out = out.reshape(batch, patches, self.n_heads * self.head_dim)
  346. return self.wo(out)
  347. class TransformerBlock(nn.Module):
  348. def __init__(self, args: VisionEncoderArgs):
  349. super().__init__()
  350. self.attention = Attention(args)
  351. self.feed_forward = FeedForward(args)
  352. self.attention_norm = RMSNorm(args.hidden_size, eps=1e-5)
  353. self.ffn_norm = RMSNorm(args.hidden_size, eps=1e-5)
  354. def forward(
  355. self,
  356. x: torch.Tensor,
  357. mask: BlockDiagonalMask,
  358. freqs_cis: torch.Tensor,
  359. ) -> torch.Tensor:
  360. r = self.attention.forward(
  361. self.attention_norm(x), mask=mask, freqs_cis=freqs_cis
  362. )
  363. h = x + r
  364. r = self.feed_forward.forward(self.ffn_norm(h))
  365. out = h + r
  366. return out
  367. class Transformer(nn.Module):
  368. def __init__(self, args: VisionEncoderArgs):
  369. super().__init__()
  370. self.layers = torch.nn.ModuleList()
  371. for _ in range(args.num_hidden_layers):
  372. self.layers.append(TransformerBlock(args))
  373. def forward(
  374. self,
  375. x: torch.Tensor,
  376. mask: BlockDiagonalMask,
  377. freqs_cis: Optional[torch.Tensor],
  378. ) -> torch.Tensor:
  379. for layer in self.layers:
  380. x = layer(x, mask=mask, freqs_cis=freqs_cis)
  381. return x
  382. def position_meshgrid(
  383. patch_embeds_list: list[torch.Tensor],
  384. ) -> torch.Tensor:
  385. positions = torch.cat(
  386. [
  387. torch.stack(
  388. torch.meshgrid(
  389. torch.arange(p.shape[-2]),
  390. torch.arange(p.shape[-1]),
  391. indexing="ij",
  392. ),
  393. dim=-1,
  394. ).reshape(-1, 2)
  395. for p in patch_embeds_list
  396. ]
  397. )
  398. return positions
  399. class VisionTransformer(nn.Module):
  400. def __init__(self, args: VisionEncoderArgs):
  401. super().__init__()
  402. self.args = args
  403. self.patch_conv = nn.Conv2d(
  404. in_channels=args.num_channels,
  405. out_channels=args.hidden_size,
  406. kernel_size=args.patch_size,
  407. stride=args.patch_size,
  408. bias=False,
  409. )
  410. self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5)
  411. self.transformer = Transformer(args)
  412. head_dim = self.args.hidden_size // self.args.num_attention_heads
  413. assert head_dim % 2 == 0, "ROPE requires even head_dim"
  414. self._freqs_cis: Optional[torch.Tensor] = None
  415. @property
  416. def max_patches_per_side(self) -> int:
  417. return self.args.image_size // self.args.patch_size
  418. @property
  419. def device(self) -> torch.device:
  420. return next(self.parameters()).device
  421. @property
  422. def dtype(self) -> torch.device:
  423. return next(self.parameters()).dtype
  424. @property
  425. def freqs_cis(self) -> torch.Tensor:
  426. if self._freqs_cis is None:
  427. self._freqs_cis = precompute_freqs_cis_2d(
  428. dim=self.args.hidden_size // self.args.num_attention_heads,
  429. height=self.max_patches_per_side,
  430. width=self.max_patches_per_side,
  431. theta=self.args.rope_theta,
  432. )
  433. if self._freqs_cis.device != self.device:
  434. self._freqs_cis = self._freqs_cis.to(device=self.device)
  435. return self._freqs_cis
  436. def forward(
  437. self,
  438. images: List[torch.Tensor],
  439. ) -> torch.Tensor:
  440. """
  441. Args:
  442. images: list of N_img images of variable sizes,
  443. each of shape (C, H, W)
  444. Returns:
  445. image_features: tensor of token features for
  446. all tokens of all images of shape (N_toks, D)
  447. """
  448. # pass images through initial convolution independently
  449. patch_embeds_list = [
  450. self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images
  451. ]
  452. # flatten to a single sequence
  453. patch_embeds = torch.cat(
  454. [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1
  455. )
  456. patch_embeds = self.ln_pre(patch_embeds)
  457. # positional embeddings
  458. positions = position_meshgrid(patch_embeds_list).to(self.device)
  459. freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]]
  460. # pass through Transformer with a block diagonal mask delimiting images
  461. mask = BlockDiagonalMask.from_seqlens(
  462. [p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
  463. )
  464. out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)
  465. # remove batch dimension of the single sequence
  466. return out.squeeze(0)
  467. class VisionLanguageAdapter(nn.Module):
  468. def __init__(self, args: VisionEncoderArgs, dim: int):
  469. super().__init__()
  470. assert isinstance(args, VisionEncoderArgs)
  471. self.w_in = nn.Linear(
  472. args.hidden_size,
  473. dim,
  474. bias=True,
  475. )
  476. self.gelu = nn.GELU()
  477. self.w_out = nn.Linear(dim, dim, bias=True)
  478. def forward(self, x: torch.Tensor) -> torch.Tensor:
  479. return self.w_out(self.gelu(self.w_in(x)))