qwen2_vl.py 41 KB


  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
  4. # Copyright 2024 The Qwen team.
  5. # Copyright 2023 The PygmalionAI team.
  6. # Copyright 2023 The vLLM team.
  7. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  8. #
  9. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  10. # and OPT implementations in this library. It has been modified from its
  11. # original forms to accommodate minor architectural differences compared
  12. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  13. #
  14. # Licensed under the Apache License, Version 2.0 (the "License");
  15. # you may not use this file except in compliance with the License.
  16. # You may obtain a copy of the License at
  17. #
  18. # http://www.apache.org/licenses/LICENSE-2.0
  19. #
  20. # Unless required by applicable law or agreed to in writing, software
  21. # distributed under the License is distributed on an "AS IS" BASIS,
  22. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  23. # See the License for the specific language governing permissions and
  24. # limitations under the License.
  25. """Inference-only Qwen2-VL model compatible with HuggingFace weights."""
  26. from array import array
  27. from functools import lru_cache, partial
  28. from typing import (Iterable, List, Mapping, Optional, Tuple, Type, TypedDict,
  29. Union)
  30. import torch
  31. import torch.nn as nn
  32. import torch.nn.functional as F
  33. from einops import rearrange, repeat
  34. from loguru import logger
  35. from PIL import Image
  36. from transformers.image_utils import (get_image_size,
  37. infer_channel_dimension_format,
  38. to_numpy_array)
  39. from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
  40. make_batched_images, make_batched_videos, smart_resize)
  41. import aphrodite.common.envs as envs
  42. from aphrodite.attention import AttentionMetadata
  43. from aphrodite.attention.selector import (_Backend, backend_name_to_enum,
  44. get_global_forced_attn_backend)
  45. from aphrodite.common.config import CacheConfig, MultiModalConfig
  46. from aphrodite.common.logger import log_once
  47. from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
  48. IntermediateTensors, SequenceData)
  49. from aphrodite.distributed import parallel_state
  50. from aphrodite.distributed import utils as dist_utils
  51. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  52. from aphrodite.modeling.layers.activation import QuickGELU
  53. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  54. RowParallelLinear)
  55. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  56. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  57. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  58. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  59. from aphrodite.modeling.models.interfaces import SupportsMultiModal
  60. from aphrodite.modeling.models.qwen2 import Qwen2Model
  61. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  62. from aphrodite.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
  63. MultiModalInputs)
  64. from aphrodite.multimodal.base import MultiModalData
  65. from aphrodite.multimodal.image import cached_get_image_processor
  66. from aphrodite.platforms import current_platform
  67. from aphrodite.quantization import QuantizationConfig
  68. from aphrodite.transformers_utils.configs import (Qwen2VLConfig,
  69. Qwen2VLVisionConfig)
  70. from aphrodite.transformers_utils.processor import get_processor
  71. # === Vision Inputs === #
  72. class Qwen2VLImageInputs(TypedDict):
  73. pixel_values: torch.Tensor
  74. """Shape:
  75. `(num_patches, num_channels * patch_size * patch_size)`
  76. """
  77. image_grid_thw: torch.Tensor
  78. """Shape: `(num_images, 3)`
  79. This should be in `(grid_t, grid_h, grid_w)` format.
  80. """
  81. class Qwen2VLVideoInputs(TypedDict):
  82. pixel_values_videos: torch.Tensor
  83. """Shape:
  84. `(num_patches,
  85. num_channels * temporal_patch_size * patch_size * patch_size)`
  86. """
  87. video_grid_thw: torch.Tensor
  88. """Shape: `(num_videos, 3)`
  89. This should be in `(grid_t, grid_h, grid_w)` format.
  90. """
  91. # === Vision Encoder === #
  92. class Qwen2VisionMLP(nn.Module):
  93. def __init__(
  94. self,
  95. in_features: int,
  96. hidden_features: int = None,
  97. act_layer: Type[nn.Module] = QuickGELU,
  98. quant_config: Optional[QuantizationConfig] = None,
  99. ):
  100. super().__init__()
  101. self.fc1 = ColumnParallelLinear(
  102. in_features, hidden_features, quant_config=quant_config
  103. )
  104. self.act = act_layer()
  105. self.fc2 = RowParallelLinear(
  106. hidden_features, in_features, quant_config=quant_config
  107. )
  108. def forward(self, x: torch.Tensor) -> torch.Tensor:
  109. x_parallel, _ = self.fc1(x)
  110. x_parallel = self.act(x_parallel)
  111. x, _ = self.fc2(x_parallel)
  112. return x
  113. def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
  114. if not interleaved:
  115. x1, x2 = x.chunk(2, dim=-1)
  116. return torch.cat((-x2, x1), dim=-1)
  117. else:
  118. x1, x2 = x[..., ::2], x[..., 1::2]
  119. return rearrange(
  120. torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
  121. )
  122. def apply_rotary_emb_torch(
  123. x: torch.Tensor,
  124. cos: torch.Tensor,
  125. sin: torch.Tensor,
  126. interleaved: bool = False,
  127. ) -> torch.Tensor:
  128. """
  129. x: (batch_size, seqlen, nheads, headdim)
  130. cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
  131. """
  132. ro_dim = cos.shape[-1] * 2
  133. assert ro_dim <= x.shape[-1]
  134. cos = repeat(
  135. cos,
  136. "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)",
  137. )
  138. sin = repeat(
  139. sin,
  140. "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)",
  141. )
  142. return torch.cat(
  143. [
  144. x[..., :ro_dim] * cos
  145. + rotate_half(x[..., :ro_dim], interleaved) * sin,
  146. x[..., ro_dim:],
  147. ],
  148. dim=-1,
  149. )
  150. def apply_rotary_pos_emb_vision(
  151. t: torch.Tensor, freqs: torch.Tensor
  152. ) -> torch.Tensor:
  153. t_ = t.float()
  154. cos = freqs.cos()
  155. sin = freqs.sin()
  156. output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
  157. return output
  158. class Qwen2VisionAttention(nn.Module):
  159. def __init__(
  160. self,
  161. embed_dim: Optional[int] = None,
  162. num_heads: Optional[int] = None,
  163. projection_size: Optional[int] = None,
  164. quant_config: Optional[QuantizationConfig] = None,
  165. ) -> None:
  166. super().__init__()
  167. # Per attention head and per partition values.
  168. world_size = parallel_state.get_tensor_model_parallel_world_size()
  169. self.hidden_size_per_attention_head = dist_utils.divide(
  170. projection_size, num_heads
  171. )
  172. self.num_attention_heads_per_partition = dist_utils.divide(
  173. num_heads, world_size
  174. )
  175. self.qkv = ColumnParallelLinear(
  176. input_size=embed_dim,
  177. output_size=3 * projection_size,
  178. quant_config=quant_config,
  179. )
  180. self.proj = RowParallelLinear(
  181. input_size=projection_size,
  182. output_size=embed_dim,
  183. quant_config=quant_config,
  184. )
  185. # Detect attention implementation.
  186. selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
  187. if selected_backend is None:
  188. backend_by_env_var: Optional[str] = envs.APHRODITE_ATTENTION_BACKEND
  189. if backend_by_env_var is not None:
  190. selected_backend = backend_name_to_enum(backend_by_env_var)
  191. if selected_backend is None:
  192. # For Volta and Turing GPUs, use xformers instead.
  193. device_available = current_platform.get_device_capability()[0] >= 8
  194. if device_available:
  195. from transformers.utils import is_flash_attn_2_available
  196. if is_flash_attn_2_available():
  197. self._use_flash_attn = True
  198. else:
  199. log_once(
  200. level="WARNING",
  201. message=
  202. "Current Qwen2-VL implementation has a bug with "
  203. "`aphrodite-flash-attn` inside vision module, so we use"
  204. " xformers backend instead. You can run `pip install "
  205. "flash-attn to use flash-attention backend."
  206. )
  207. self._use_flash_attn = False
  208. else:
  209. self._use_flash_attn = False
  210. else:
  211. if selected_backend == _Backend.FLASH_ATTN:
  212. self._use_flash_attn = True
  213. elif selected_backend == _Backend.XFORMERS:
  214. self._use_flash_attn = False
  215. else:
  216. raise RuntimeError(
  217. f"Qwen2-VL does not support {selected_backend} backend now."
  218. )
  219. def forward(
  220. self,
  221. x: torch.Tensor,
  222. cu_seqlens: torch.Tensor,
  223. rotary_pos_emb: torch.Tensor = None,
  224. ) -> torch.Tensor:
  225. # [s, b, c] --> [s, b, head * 3 * head_dim]
  226. x, _ = self.qkv(x)
  227. # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
  228. new_x_shape = x.size()[:-1] + (
  229. self.num_attention_heads_per_partition,
  230. 3 * self.hidden_size_per_attention_head,
  231. )
  232. x = x.view(*new_x_shape)
  233. # [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
  234. q, k, v = dist_utils.split_tensor_along_last_dim(x, 3)
  235. batch_size = q.shape[1]
  236. q, k, v = [
  237. rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
  238. ]
  239. if rotary_pos_emb is not None:
  240. q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
  241. k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
  242. if self._use_flash_attn:
  243. # from aphrodite_flash_attn.flash_attn_interface import (
  244. # flash_attn_varlen_func)
  245. from flash_attn import flash_attn_varlen_func
  246. q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
  247. max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
  248. output = flash_attn_varlen_func(
  249. q,
  250. k,
  251. v,
  252. cu_seqlens_q=cu_seqlens,
  253. cu_seqlens_k=cu_seqlens,
  254. max_seqlen_q=max_seqlen,
  255. max_seqlen_k=max_seqlen,
  256. dropout_p=0,
  257. causal=False,
  258. )
  259. context_layer = rearrange(
  260. output, "(b s) ... -> b s ...", b=batch_size
  261. )
  262. else:
  263. from xformers import ops as xops
  264. from xformers.ops.fmha.attn_bias import BlockDiagonalMask
  265. seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
  266. attn_bias = BlockDiagonalMask.from_seqlens(
  267. q_seqlen=seqlens, kv_seqlen=None
  268. )
  269. context_layer = xops.memory_efficient_attention_forward(
  270. q, k, v, attn_bias=attn_bias, p=0, scale=None
  271. )
  272. context_layer = rearrange(
  273. context_layer, "b s h d -> s b (h d)"
  274. ).contiguous()
  275. output, _ = self.proj(context_layer)
  276. return output
  277. class Qwen2VisionBlock(nn.Module):
  278. def __init__(
  279. self,
  280. dim: int,
  281. num_heads: int,
  282. mlp_ratio: float,
  283. act_layer: Type[nn.Module] = QuickGELU,
  284. norm_layer: Type[nn.Module] = None,
  285. quant_config: Optional[QuantizationConfig] = None,
  286. ) -> None:
  287. super().__init__()
  288. if norm_layer is None:
  289. norm_layer = partial(nn.LayerNorm, eps=1e-6)
  290. self.norm1 = norm_layer(dim)
  291. self.norm2 = norm_layer(dim)
  292. mlp_hidden_dim = int(dim * mlp_ratio)
  293. self.attn = Qwen2VisionAttention(
  294. embed_dim=dim,
  295. num_heads=num_heads,
  296. projection_size=dim,
  297. quant_config=quant_config,
  298. )
  299. self.mlp = Qwen2VisionMLP(
  300. dim, mlp_hidden_dim, act_layer=act_layer, quant_config=quant_config
  301. )
  302. def forward(
  303. self,
  304. x: torch.Tensor,
  305. cu_seqlens: torch.Tensor,
  306. rotary_pos_emb: torch.Tensor,
  307. ) -> torch.Tensor:
  308. x = x + self.attn(
  309. self.norm1(x), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
  310. )
  311. x = x + self.mlp(self.norm2(x))
  312. return x
  313. class Qwen2VisionPatchEmbed(nn.Module):
  314. def __init__(
  315. self,
  316. patch_size: int = 14,
  317. temporal_patch_size: int = 2,
  318. in_chans: int = 3,
  319. embed_dim: int = 1152,
  320. ) -> None:
  321. super().__init__()
  322. self.patch_size = patch_size
  323. self.temporal_patch_size = temporal_patch_size
  324. self.embed_dim = embed_dim
  325. kernel_size = [temporal_patch_size, patch_size, patch_size]
  326. self.proj = nn.Conv3d(
  327. in_chans,
  328. embed_dim,
  329. kernel_size=kernel_size,
  330. stride=kernel_size,
  331. bias=False,
  332. )
  333. def forward(self, x: torch.Tensor) -> torch.Tensor:
  334. L, C = x.shape
  335. x = x.view(
  336. L, -1, self.temporal_patch_size, self.patch_size, self.patch_size
  337. )
  338. x = self.proj(x).view(L, self.embed_dim)
  339. return x
  340. class Qwen2VisionPatchMerger(nn.Module):
  341. def __init__(
  342. self,
  343. d_model: int,
  344. context_dim: int,
  345. norm_layer: Type[nn.Module] = None,
  346. spatial_merge_size: int = 2,
  347. quant_config: Optional[QuantizationConfig] = None,
  348. ) -> None:
  349. super().__init__()
  350. self.hidden_size = context_dim * (spatial_merge_size**2)
  351. if norm_layer is None:
  352. norm_layer = partial(nn.LayerNorm, eps=1e-6)
  353. self.ln_q = norm_layer(context_dim)
  354. self.mlp = nn.ModuleList(
  355. [
  356. ColumnParallelLinear(
  357. self.hidden_size,
  358. self.hidden_size,
  359. bias=True,
  360. quant_config=quant_config,
  361. ),
  362. nn.GELU(),
  363. RowParallelLinear(
  364. self.hidden_size,
  365. d_model,
  366. bias=True,
  367. quant_config=quant_config,
  368. ),
  369. ]
  370. )
  371. def forward(self, x: torch.Tensor) -> torch.Tensor:
  372. x = self.ln_q(x)
  373. x = x.view(-1, self.hidden_size)
  374. mlp_fc1, mlp_act, mlp_fc2 = self.mlp
  375. x_parallel, _ = mlp_fc1(x)
  376. x_parallel = mlp_act(x_parallel)
  377. out, _ = mlp_fc2(x_parallel)
  378. return out
  379. class Qwen2VisionRotaryEmbedding(nn.Module):
  380. def __init__(self, dim: int, theta: float = 10000.0) -> None:
  381. super().__init__()
  382. self.dim = dim
  383. self.theta = theta
  384. inv_freq = 1.0 / (
  385. theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)
  386. )
  387. self.register_buffer("inv_freq", inv_freq, persistent=False)
  388. self._seq_len_cached = 0
  389. self._freqs_cached = None
  390. def update_freqs_cache(self, seqlen: int) -> None:
  391. if seqlen > self._seq_len_cached:
  392. seqlen *= 2
  393. self._seq_len_cached = seqlen
  394. self.inv_freq = 1.0 / (
  395. self.theta
  396. ** (
  397. torch.arange(
  398. 0,
  399. self.dim,
  400. 2,
  401. dtype=torch.float,
  402. device=self.inv_freq.device,
  403. )
  404. / self.dim
  405. )
  406. )
  407. seq = torch.arange(
  408. seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
  409. )
  410. freqs = torch.outer(seq, self.inv_freq)
  411. self._freqs_cached = freqs
  412. def forward(self, seqlen: int) -> torch.Tensor:
  413. self.update_freqs_cache(seqlen)
  414. return self._freqs_cached[:seqlen]
  415. class Qwen2VisionTransformer(nn.Module):
  416. def __init__(
  417. self,
  418. vision_config: Qwen2VLVisionConfig,
  419. norm_eps: float = 1e-6,
  420. quant_config: Optional[QuantizationConfig] = None,
  421. ) -> None:
  422. super().__init__()
  423. patch_size: int = vision_config.patch_size
  424. temporal_patch_size: int = vision_config.temporal_patch_size
  425. spatial_merge_size: int = vision_config.spatial_merge_size
  426. in_chans: int = vision_config.in_chans
  427. hidden_size: int = vision_config.hidden_size
  428. embed_dim: int = vision_config.embed_dim
  429. depth: int = vision_config.depth
  430. num_heads: int = vision_config.num_heads
  431. mlp_ratio: float = vision_config.mlp_ratio
  432. self.spatial_merge_size = spatial_merge_size
  433. self.patch_embed = Qwen2VisionPatchEmbed(
  434. patch_size=patch_size,
  435. temporal_patch_size=temporal_patch_size,
  436. in_chans=in_chans,
  437. embed_dim=embed_dim,
  438. )
  439. norm_layer = partial(nn.LayerNorm, eps=norm_eps)
  440. head_dim = embed_dim // num_heads
  441. self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)
  442. self.blocks = nn.ModuleList(
  443. [
  444. Qwen2VisionBlock(
  445. dim=embed_dim,
  446. num_heads=num_heads,
  447. mlp_ratio=mlp_ratio,
  448. norm_layer=norm_layer,
  449. quant_config=quant_config,
  450. )
  451. for _ in range(depth)
  452. ]
  453. )
  454. self.merger = Qwen2VisionPatchMerger(
  455. d_model=hidden_size,
  456. context_dim=embed_dim,
  457. norm_layer=norm_layer,
  458. quant_config=quant_config,
  459. )
  460. @property
  461. def dtype(self) -> torch.dtype:
  462. return self.blocks[0].mlp.fc2.weight.dtype
  463. @property
  464. def device(self) -> torch.device:
  465. return self.blocks[0].mlp.fc2.weight.device
  466. def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
  467. pos_ids = []
  468. for t, h, w in grid_thw:
  469. hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
  470. wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
  471. hpos_ids = (
  472. hpos_ids.reshape(
  473. h // self.spatial_merge_size,
  474. self.spatial_merge_size,
  475. w // self.spatial_merge_size,
  476. self.spatial_merge_size,
  477. )
  478. .permute(0, 2, 1, 3)
  479. .flatten()
  480. )
  481. wpos_ids = (
  482. wpos_ids.reshape(
  483. h // self.spatial_merge_size,
  484. self.spatial_merge_size,
  485. w // self.spatial_merge_size,
  486. self.spatial_merge_size,
  487. )
  488. .permute(0, 2, 1, 3)
  489. .flatten()
  490. )
  491. pos_ids.append(
  492. torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)
  493. )
  494. pos_ids = torch.cat(pos_ids, dim=0)
  495. max_grid_size = grid_thw[:, 1:].max()
  496. rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
  497. rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
  498. return rotary_pos_emb
  499. def forward(
  500. self,
  501. x: torch.Tensor,
  502. grid_thw: torch.Tensor,
  503. ) -> torch.Tensor:
  504. # patchify
  505. x = x.to(device=self.device, dtype=self.dtype)
  506. x = self.patch_embed(x)
  507. # compute position embedding
  508. rotary_pos_emb = self.rot_pos_emb(grid_thw)
  509. # compute cu_seqlens
  510. cu_seqlens = torch.repeat_interleave(
  511. grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
  512. ).cumsum(dim=0, dtype=torch.int32)
  513. cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
  514. # transformers
  515. x = x.unsqueeze(1)
  516. for blk in self.blocks:
  517. x = blk(x, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
  518. # adapter
  519. x = self.merger(x)
  520. return x
  521. # === Vision input helpers === #
  522. cached_get_processor = lru_cache(get_processor)
  523. def mm_input_mapper_for_qwen2_vl(
  524. ctx: InputContext,
  525. data: MultiModalData[object],
  526. data_type_key: str,
  527. ) -> MultiModalInputs:
  528. """Input mapper for Qwen2-VL."""
  529. model_config = ctx.model_config
  530. image_processor = cached_get_image_processor(
  531. model_config.model, trust_remote_code=model_config.trust_remote_code
  532. )
  533. if image_processor is None:
  534. raise RuntimeError(
  535. "No HuggingFace processor is available "
  536. "to process the image object"
  537. )
  538. images = None
  539. videos = None
  540. if data_type_key == "image":
  541. images = data
  542. else:
  543. assert data_type_key == "video"
  544. videos = data
  545. try:
  546. batch_data = image_processor.preprocess(
  547. images=images, videos=videos, return_tensors="pt"
  548. ).data
  549. except Exception:
  550. logger.error(f"Failed to process image ({data})")
  551. raise
  552. return MultiModalInputs(batch_data)
  553. image_input_mapper_for_qwen2_vl = partial(
  554. mm_input_mapper_for_qwen2_vl, data_type_key="image"
  555. )
  556. video_input_mapper_for_qwen2_vl = partial(
  557. mm_input_mapper_for_qwen2_vl, data_type_key="video"
  558. )
  559. def _get_vision_info(
  560. image_processor,
  561. height: int,
  562. width: int,
  563. min_pixels: int,
  564. max_pixels: int,
  565. do_resize: bool = True,
  566. data_type_key: str = "image",
  567. mm_count: int = 1,
  568. ):
  569. """Get information (resized height / width and number of vision tokens)
  570. of input image / video frame."""
  571. if do_resize:
  572. resized_height, resized_width = smart_resize(
  573. height=height,
  574. width=width,
  575. factor=image_processor.patch_size * image_processor.merge_size,
  576. min_pixels=min_pixels,
  577. max_pixels=max_pixels,
  578. )
  579. else:
  580. resized_height, resized_width = height, width
  581. if data_type_key == "image":
  582. grid_t = mm_count
  583. else:
  584. assert data_type_key == "video"
  585. grid_t = max(mm_count // image_processor.temporal_patch_size, 1)
  586. grid_h = resized_height // image_processor.patch_size
  587. grid_w = resized_width // image_processor.patch_size
  588. vision_tokens = grid_t * grid_h * grid_w
  589. llm_num_vision_tokens = (
  590. vision_tokens
  591. // image_processor.merge_size
  592. // image_processor.merge_size
  593. )
  594. return resized_height, resized_width, llm_num_vision_tokens
  595. def _get_max_image_info(
  596. image_processor,
  597. data_type_key: str = "image",
  598. mm_count: int = 1,
  599. ):
  600. return _get_vision_info(
  601. image_processor,
  602. height=9999999,
  603. width=9999999,
  604. # Limit min / max pixels.
  605. min_pixels=max(image_processor.min_pixels, 28 * 28),
  606. max_pixels=min(image_processor.max_pixels, 1280 * 28 * 28),
  607. data_type_key=data_type_key,
  608. mm_count=mm_count,
  609. )
  610. def get_max_qwen2_vl_mm_tokens(ctx: InputContext, data_type_key: str) -> int:
  611. image_processor = cached_get_image_processor(ctx.model_config.model)
  612. (
  613. max_resized_height,
  614. max_resized_width,
  615. max_llm_image_tokens,
  616. ) = _get_max_image_info(
  617. image_processor, data_type_key=data_type_key, mm_count=1
  618. )
  619. return max_llm_image_tokens
  620. get_max_qwen2_vl_image_tokens = partial(
  621. get_max_qwen2_vl_mm_tokens, data_type_key="image"
  622. )
  623. get_max_qwen2_vl_video_tokens = partial(
  624. get_max_qwen2_vl_mm_tokens, data_type_key="video"
  625. )
  626. def dummy_data_for_qwen2_vl(
  627. ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]
  628. ) -> Tuple[SequenceData, Optional[MultiModalDataDict]]:
  629. image_processor = cached_get_image_processor(ctx.model_config.model)
  630. num_images = mm_counts["image"]
  631. (
  632. max_resized_height,
  633. max_resized_width,
  634. max_llm_image_tokens,
  635. ) = _get_max_image_info(
  636. image_processor, data_type_key="image", mm_count=num_images
  637. )
  638. if seq_len - max_llm_image_tokens - 2 < 0:
  639. raise RuntimeError(
  640. f"Qwen2-VL cannot process {num_images} images in a prompt, "
  641. "please increase max_model_len or reduce image limit by "
  642. "--limit-mm-per-prompt."
  643. )
  644. # Check video counts.
  645. num_videos = mm_counts["video"]
  646. (
  647. max_resized_height,
  648. max_resized_width,
  649. max_llm_video_tokens,
  650. ) = _get_max_image_info(
  651. image_processor, data_type_key="video", mm_count=num_videos
  652. )
  653. if seq_len - max_llm_video_tokens - 2 < 0:
  654. raise RuntimeError(
  655. f"Qwen2-VL cannot process {num_images} videos in a prompt, "
  656. "please increase max_model_len or reduce video limit by "
  657. "--limit-mm-per-prompt."
  658. )
  659. hf_config = ctx.get_hf_config(Qwen2VLConfig)
  660. token_ids = array(
  661. APHRODITE_TOKEN_ID_ARRAY_TYPE, [hf_config.vision_start_token_id]
  662. )
  663. token_ids += (
  664. array(APHRODITE_TOKEN_ID_ARRAY_TYPE, [hf_config.image_token_id])
  665. * max_llm_image_tokens
  666. )
  667. token_ids += array(
  668. APHRODITE_TOKEN_ID_ARRAY_TYPE, [hf_config.vision_end_token_id]
  669. )
  670. token_ids += array(APHRODITE_TOKEN_ID_ARRAY_TYPE, [0]) * (
  671. seq_len - max_llm_image_tokens - 2
  672. )
  673. dummy_seqdata = SequenceData(token_ids)
  674. dummy_image = Image.new(
  675. "RGB", (max_resized_width, max_resized_height), color=0
  676. )
  677. return dummy_seqdata, {
  678. "image": dummy_image if num_images == 1 else [dummy_image] * num_images
  679. }
  680. def _get_llm_num_vision_tokens(
  681. mm_inputs: list,
  682. data_type_key: str,
  683. image_processor,
  684. ):
  685. """Get number of vision tokens of multimodal inputs.
  686. This method is derived from `transformers.models.qwen2_vl.
  687. image_processing_qwen2_vl.Qwen2VLImageProcessor._preprocess`.
  688. """
  689. image = to_numpy_array(mm_inputs[0])
  690. input_data_format = infer_channel_dimension_format(image)
  691. height, width = get_image_size(image, channel_dim=input_data_format)
  692. _, _, llm_num_vision_tokens = _get_vision_info(
  693. image_processor,
  694. height=height,
  695. width=width,
  696. min_pixels=image_processor.min_pixels,
  697. max_pixels=image_processor.max_pixels,
  698. do_resize=image_processor.do_resize,
  699. data_type_key=data_type_key,
  700. mm_count=len(mm_inputs),
  701. )
  702. return llm_num_vision_tokens
  703. def input_processor_for_qwen2_vl(
  704. ctx: InputContext, llm_inputs: LLMInputs
  705. ) -> LLMInputs:
  706. multi_modal_data = llm_inputs.get("multi_modal_data", None)
  707. if multi_modal_data is None:
  708. return llm_inputs
  709. image_inputs = multi_modal_data.get("image", None)
  710. video_inputs = multi_modal_data.get("video", None)
  711. processor = cached_get_processor(ctx.model_config.model)
  712. image_processor = processor.image_processor
  713. hf_config = ctx.get_hf_config(Qwen2VLConfig)
  714. # To avoid redundant processing of vision objects (resize, rescale, etc.),
  715. # we extract code of calculating number of vision tokens from
  716. # `transformers.models.qwen2_vl.processing_qwen2_vl.Qwen2VLProcessor`.
  717. #
  718. # The following code is equivalent to:
  719. # prompt = llm_inputs["prompt"]
  720. # inputs = processor(text=[prompt],
  721. # images=image_inputs,
  722. # videos=video_inputs,
  723. # padding=True,
  724. # return_tensors="pt")
  725. # prompt_token_ids = inputs["input_ids"][0].tolist()
  726. prompt_token_ids = llm_inputs.get("prompt_token_ids", None)
  727. if prompt_token_ids is None:
  728. prompt = llm_inputs["prompt"]
  729. prompt_token_ids = processor.tokenizer(
  730. prompt,
  731. padding=True,
  732. return_tensors=None,
  733. )["input_ids"]
  734. # Expand image pad tokens.
  735. if image_inputs is not None:
  736. image_indices = [
  737. idx
  738. for idx, token in enumerate(prompt_token_ids)
  739. if token == hf_config.image_token_id
  740. ]
  741. image_inputs = make_batched_images(image_inputs)
  742. assert len(image_indices) == len(image_inputs)
  743. prompt_token_ids_with_image = []
  744. for image_cnt, image in enumerate(image_inputs):
  745. num_image_tokens = _get_llm_num_vision_tokens(
  746. [image],
  747. data_type_key="image",
  748. image_processor=image_processor,
  749. )
  750. if image_cnt == 0:
  751. non_image_tokens = prompt_token_ids[: image_indices[image_cnt]]
  752. else:
  753. non_image_tokens = prompt_token_ids[
  754. image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
  755. ]
  756. prompt_token_ids_with_image.extend(non_image_tokens)
  757. prompt_token_ids_with_image.extend(
  758. hf_config.image_token_id for _ in range(num_image_tokens)
  759. )
  760. prompt_token_ids_with_image.extend(
  761. prompt_token_ids[image_indices[-1] + 1 :]
  762. )
  763. prompt_token_ids = prompt_token_ids_with_image
  764. # Expand video pad tokens.
  765. if video_inputs is not None:
  766. video_indices = [
  767. idx
  768. for idx, token in enumerate(prompt_token_ids)
  769. if token == hf_config.video_token_id
  770. ]
  771. video_inputs = make_batched_videos(video_inputs)
  772. assert len(video_indices) == len(video_inputs)
  773. prompt_token_ids_with_video = []
  774. for video_cnt, video in enumerate(video_inputs):
  775. num_video_tokens = _get_llm_num_vision_tokens(
  776. video,
  777. data_type_key="video",
  778. image_processor=image_processor,
  779. )
  780. if video_cnt == 0:
  781. non_video_tokens = prompt_token_ids[: video_indices[video_cnt]]
  782. else:
  783. non_video_tokens = prompt_token_ids[
  784. video_indices[video_cnt - 1] + 1 : video_indices[video_cnt]
  785. ]
  786. prompt_token_ids_with_video.extend(non_video_tokens)
  787. prompt_token_ids_with_video.extend(
  788. hf_config.video_token_id for _ in range(num_video_tokens)
  789. )
  790. prompt_token_ids_with_video.extend(
  791. prompt_token_ids[video_indices[-1] + 1 :]
  792. )
  793. prompt_token_ids = prompt_token_ids_with_video
  794. return LLMInputs(
  795. prompt_token_ids=prompt_token_ids,
  796. prompt=llm_inputs["prompt"],
  797. multi_modal_data=multi_modal_data,
  798. )
  799. @MULTIMODAL_REGISTRY.register_image_input_mapper(
  800. image_input_mapper_for_qwen2_vl
  801. )
  802. @MULTIMODAL_REGISTRY.register_input_mapper(
  803. "video", video_input_mapper_for_qwen2_vl
  804. )
  805. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_qwen2_vl_image_tokens)
  806. @MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
  807. "video", get_max_qwen2_vl_video_tokens
  808. )
  809. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_vl)
  810. @INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_vl)
  811. class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
  812. def __init__(
  813. self,
  814. config: Qwen2VLConfig,
  815. multimodal_config: MultiModalConfig,
  816. cache_config: Optional[CacheConfig] = None,
  817. quant_config: Optional[QuantizationConfig] = None,
  818. ) -> None:
  819. super().__init__()
  820. assert (
  821. not cache_config.enable_prefix_caching
  822. ), "Qwen2-VL currently does not support prefix caching"
  823. self.config = config
  824. self.multimodal_config = multimodal_config
  825. self.visual = Qwen2VisionTransformer(
  826. config.vision_config,
  827. norm_eps=getattr(config, "rms_norm_eps", 1e-6),
  828. # NOTE: Qwen2-VL vision encoder does not support any
  829. # quantization method now.
  830. quant_config=None,
  831. )
  832. self.model = Qwen2Model(config, cache_config, quant_config)
  833. if config.tie_word_embeddings:
  834. self.lm_head = self.model.embed_tokens
  835. else:
  836. self.lm_head = ParallelLMHead(
  837. config.vocab_size, config.hidden_size, quant_config=quant_config
  838. )
  839. self.logits_processor = LogitsProcessor(config.vocab_size)
  840. self.sampler = Sampler()
  841. def _validate_and_reshape_mm_tensor(
  842. self, mm_input: Union[torch.Tensor, List[torch.Tensor]], name: str
  843. ) -> torch.Tensor:
  844. if not isinstance(mm_input, (torch.Tensor, list)):
  845. raise ValueError(
  846. f"Incorrect type of {name}. " f"Got type: {type(mm_input)}"
  847. )
  848. if isinstance(mm_input, torch.Tensor):
  849. if mm_input.ndim == 2:
  850. return mm_input
  851. if mm_input.ndim != 3:
  852. raise ValueError(
  853. f"{name} should be 2D or batched 3D tensor. "
  854. f"Got ndim: {mm_input.ndim}"
  855. )
  856. return torch.concat(list(mm_input))
  857. else:
  858. return torch.concat(mm_input)
  859. def _parse_and_validate_image_input(
  860. self, **kwargs: object
  861. ) -> Optional[Qwen2VLImageInputs]:
  862. pixel_values = kwargs.pop("pixel_values", None)
  863. image_grid_thw = kwargs.pop("image_grid_thw", None)
  864. if pixel_values is None:
  865. return None
  866. pixel_values = self._validate_and_reshape_mm_tensor(
  867. pixel_values, "image pixel values"
  868. )
  869. image_grid_thw = self._validate_and_reshape_mm_tensor(
  870. image_grid_thw, "image grid_thw"
  871. )
  872. if not isinstance(pixel_values, (torch.Tensor, list)):
  873. raise ValueError(
  874. "Incorrect type of image pixel values. "
  875. f"Got type: {type(pixel_values)}"
  876. )
  877. return Qwen2VLImageInputs(
  878. pixel_values=pixel_values, image_grid_thw=image_grid_thw
  879. )
  880. def _parse_and_validate_video_input(
  881. self, **kwargs: object
  882. ) -> Optional[Qwen2VLVideoInputs]:
  883. pixel_values_videos = kwargs.pop("pixel_values_videos", None)
  884. video_grid_thw = kwargs.pop("video_grid_thw", None)
  885. if pixel_values_videos is None:
  886. return None
  887. pixel_values_videos = self._validate_and_reshape_mm_tensor(
  888. pixel_values_videos, "video pixel values"
  889. )
  890. video_grid_thw = self._validate_and_reshape_mm_tensor(
  891. video_grid_thw, "video grid_thw"
  892. )
  893. return Qwen2VLVideoInputs(
  894. pixel_values_videos=pixel_values_videos,
  895. video_grid_thw=video_grid_thw,
  896. )
  897. def _process_image_input(
  898. self, image_input: Qwen2VLImageInputs
  899. ) -> torch.Tensor:
  900. pixel_values = image_input["pixel_values"].type(self.visual.dtype)
  901. image_embeds = self.visual(
  902. pixel_values, grid_thw=image_input["image_grid_thw"]
  903. )
  904. return image_embeds
  905. def _process_video_input(
  906. self, video_input: Qwen2VLVideoInputs
  907. ) -> torch.Tensor:
  908. pixel_values_videos = video_input["pixel_values_videos"].type(
  909. self.visual.dtype
  910. )
  911. video_embeds = self.visual(
  912. pixel_values_videos, grid_thw=video_input["video_grid_thw"]
  913. )
  914. return video_embeds
  915. def _merge_multimodal_embeddings(
  916. self,
  917. input_ids: torch.Tensor,
  918. inputs_embeds: torch.Tensor,
  919. multimodal_embeddings: torch.Tensor,
  920. placeholder_token_id: int,
  921. ) -> torch.Tensor:
  922. mask = input_ids == placeholder_token_id
  923. inputs_embeds[mask, :] = multimodal_embeddings
  924. return inputs_embeds
  925. def forward(
  926. self,
  927. input_ids: torch.Tensor,
  928. positions: torch.Tensor,
  929. kv_caches: List[torch.Tensor],
  930. attn_metadata: AttentionMetadata,
  931. intermediate_tensors: Optional[IntermediateTensors] = None,
  932. **kwargs: object,
  933. ) -> SamplerOutput:
  934. """Run forward pass for Qwen2-VL.
  935. Args:
  936. input_ids: Flattened (concatenated) input_ids corresponding to a
  937. batch.
  938. positions: Flattened (concatenated) position ids corresponding to a
  939. batch.
  940. **NOTE**: If mrope is enabled (default setting for Qwen2-VL
  941. opensource models), the shape will be `(3, seq_len)`,
  942. otherwise it will be `(seq_len,).
  943. pixel_values: Pixel values to be fed to a model.
  944. `None` if no images are passed.
  945. image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
  946. `None` if no images are passed.
  947. pixel_values_videos: Pixel values of videos to be fed to a model.
  948. `None` if no videos are passed.
  949. video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
  950. `None` if no videos are passed.
  951. """
  952. image_input = self._parse_and_validate_image_input(**kwargs)
  953. video_input = self._parse_and_validate_video_input(**kwargs)
  954. if image_input is None and video_input is None:
  955. inputs_embeds = None
  956. else:
  957. if (
  958. getattr(self.config, "rope_scaling", {}).get("type", None)
  959. == "mrope"
  960. ):
  961. assert positions.ndim == 2 and positions.size(0) == 3, (
  962. "multimodal section rotary embedding requires "
  963. f"(3, seq_len) positions, but got {positions.size()}"
  964. )
  965. inputs_embeds = self.model.embed_tokens(input_ids)
  966. if image_input is not None:
  967. image_embeds = self._process_image_input(image_input)
  968. inputs_embeds = self._merge_multimodal_embeddings(
  969. input_ids,
  970. inputs_embeds,
  971. image_embeds,
  972. placeholder_token_id=self.config.image_token_id,
  973. )
  974. if video_input is not None:
  975. video_embeds = self._process_video_input(video_input)
  976. inputs_embeds = self._merge_multimodal_embeddings(
  977. input_ids,
  978. inputs_embeds,
  979. video_embeds,
  980. placeholder_token_id=self.config.video_token_id,
  981. )
  982. input_ids = None
  983. hidden_states = self.model(
  984. input_ids=input_ids,
  985. positions=positions,
  986. kv_caches=kv_caches,
  987. attn_metadata=attn_metadata,
  988. inputs_embeds=inputs_embeds,
  989. )
  990. return hidden_states
  991. def compute_logits(
  992. self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata
  993. ) -> torch.Tensor:
  994. logits = self.logits_processor(
  995. self.lm_head, hidden_states, sampling_metadata
  996. )
  997. return logits
  998. def sample(
  999. self,
  1000. logits: torch.Tensor,
  1001. sampling_metadata: SamplingMetadata,
  1002. ) -> Optional[SamplerOutput]:
  1003. next_tokens = self.sampler(logits, sampling_metadata)
  1004. return next_tokens
  1005. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  1006. stacked_params_mapping = [
  1007. # (param_name, shard_name, shard_id)
  1008. ("qkv_proj", "q_proj", "q"),
  1009. ("qkv_proj", "k_proj", "k"),
  1010. ("qkv_proj", "v_proj", "v"),
  1011. ("gate_up_proj", "up_proj", 1),
  1012. ("gate_up_proj", "gate_proj", 0),
  1013. ]
  1014. params_dict = dict(self.named_parameters(remove_duplicate=False))
  1015. for name, loaded_weight in weights:
  1016. if "rotary_emb.inv_freq" in name:
  1017. continue
  1018. if self.config.tie_word_embeddings and "lm_head.weight" in name:
  1019. continue
  1020. for param_name, weight_name, shard_id in stacked_params_mapping:
  1021. if weight_name not in name:
  1022. continue
  1023. name = name.replace(weight_name, param_name)
  1024. # Skip loading extra bias for GPTQ models.
  1025. if name.endswith(".bias") and name not in params_dict:
  1026. continue
  1027. param = params_dict[name]
  1028. weight_loader = param.weight_loader
  1029. weight_loader(param, loaded_weight, shard_id)
  1030. break
  1031. else:
  1032. if "visual" in name and "qkv.weight" in name:
  1033. visual_num_heads = self.config.vision_config.num_heads
  1034. visual_embed_dim = self.config.vision_config.embed_dim
  1035. head_size = visual_embed_dim // visual_num_heads
  1036. loaded_weight = loaded_weight.view(
  1037. 3, visual_num_heads, head_size, visual_embed_dim
  1038. )
  1039. loaded_weight = loaded_weight.transpose(0, 1)
  1040. loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
  1041. elif "visual" in name and "qkv.bias" in name:
  1042. visual_num_heads = self.config.vision_config.num_heads
  1043. visual_embed_dim = self.config.vision_config.embed_dim
  1044. head_size = visual_embed_dim // visual_num_heads
  1045. loaded_weight = loaded_weight.view(
  1046. 3, visual_num_heads, head_size
  1047. )
  1048. loaded_weight = loaded_weight.transpose(0, 1)
  1049. loaded_weight = loaded_weight.reshape(-1)
  1050. try:
  1051. # Skip loading extra bias for GPTQ models.
  1052. if name.endswith(".bias") and name not in params_dict:
  1053. continue
  1054. param = params_dict[name]
  1055. except KeyError:
  1056. raise ValueError(f"Unexpected weight: {name}") from None
  1057. weight_loader = getattr(
  1058. param, "weight_loader", default_weight_loader
  1059. )
  1060. weight_loader(param, loaded_weight)