qwen2_vl.py 41 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
  4. # Copyright 2024 The Qwen team.
  5. # Copyright 2023 The PygmalionAI team.
  6. # Copyright 2023 The vLLM team.
  7. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  8. #
  9. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  10. # and OPT implementations in this library. It has been modified from its
  11. # original forms to accommodate minor architectural differences compared
  12. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  13. #
  14. # Licensed under the Apache License, Version 2.0 (the "License");
  15. # you may not use this file except in compliance with the License.
  16. # You may obtain a copy of the License at
  17. #
  18. # http://www.apache.org/licenses/LICENSE-2.0
  19. #
  20. # Unless required by applicable law or agreed to in writing, software
  21. # distributed under the License is distributed on an "AS IS" BASIS,
  22. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  23. # See the License for the specific language governing permissions and
  24. # limitations under the License.
  25. """Inference-only Qwen2-VL model compatible with HuggingFace weights."""
  26. from functools import lru_cache, partial
  27. from typing import (Iterable, List, Mapping, Optional, Tuple, Type, TypedDict,
  28. Union)
  29. import torch
  30. import torch.nn as nn
  31. import torch.nn.functional as F
  32. from einops import rearrange, repeat
  33. from loguru import logger
  34. from PIL import Image
  35. from transformers.image_utils import (get_image_size,
  36. infer_channel_dimension_format,
  37. to_numpy_array)
  38. from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
  39. make_batched_images, make_batched_videos, smart_resize)
  40. import aphrodite.common.envs as envs
  41. from aphrodite.attention import AttentionMetadata
  42. from aphrodite.attention.selector import (_Backend, backend_name_to_enum,
  43. get_global_forced_attn_backend)
  44. from aphrodite.common.config import CacheConfig, MultiModalConfig
  45. from aphrodite.common.logger import log_once
  46. from aphrodite.common.sequence import IntermediateTensors, SequenceData
  47. from aphrodite.distributed import get_pp_group, parallel_state
  48. from aphrodite.distributed import utils as dist_utils
  49. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  50. from aphrodite.modeling.layers.activation import QuickGELU
  51. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  52. RowParallelLinear)
  53. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  54. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  55. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  56. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  57. from aphrodite.modeling.models.interfaces import SupportsMultiModal
  58. from aphrodite.modeling.models.qwen2 import Qwen2Model
  59. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  60. from aphrodite.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
  61. MultiModalInputs)
  62. from aphrodite.multimodal.base import MultiModalData
  63. from aphrodite.multimodal.image import cached_get_image_processor
  64. from aphrodite.platforms import current_platform
  65. from aphrodite.quantization import QuantizationConfig
  66. from aphrodite.transformers_utils.configs import (Qwen2VLConfig,
  67. Qwen2VLVisionConfig)
  68. from aphrodite.transformers_utils.processor import get_processor
  69. from .utils import (PPMissingLayer, is_pp_missing_parameter,
  70. make_empty_intermediate_tensors_factory)
  71. # === Vision Inputs === #
  72. class Qwen2VLImageInputs(TypedDict):
  73. pixel_values: torch.Tensor
  74. """Shape:
  75. `(num_patches, num_channels * patch_size * patch_size)`
  76. """
  77. image_grid_thw: torch.Tensor
  78. """Shape: `(num_images, 3)`
  79. This should be in `(grid_t, grid_h, grid_w)` format.
  80. """
  81. class Qwen2VLVideoInputs(TypedDict):
  82. pixel_values_videos: torch.Tensor
  83. """Shape:
  84. `(num_patches,
  85. num_channels * temporal_patch_size * patch_size * patch_size)`
  86. """
  87. video_grid_thw: torch.Tensor
  88. """Shape: `(num_videos, 3)`
  89. This should be in `(grid_t, grid_h, grid_w)` format.
  90. """
  91. # === Vision Encoder === #
  92. class Qwen2VisionMLP(nn.Module):
  93. def __init__(
  94. self,
  95. in_features: int,
  96. hidden_features: int = None,
  97. act_layer: Type[nn.Module] = QuickGELU,
  98. quant_config: Optional[QuantizationConfig] = None,
  99. ):
  100. super().__init__()
  101. self.fc1 = ColumnParallelLinear(
  102. in_features, hidden_features, quant_config=quant_config
  103. )
  104. self.act = act_layer()
  105. self.fc2 = RowParallelLinear(
  106. hidden_features, in_features, quant_config=quant_config
  107. )
  108. def forward(self, x: torch.Tensor) -> torch.Tensor:
  109. x_parallel, _ = self.fc1(x)
  110. x_parallel = self.act(x_parallel)
  111. x, _ = self.fc2(x_parallel)
  112. return x
  113. def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
  114. if not interleaved:
  115. x1, x2 = x.chunk(2, dim=-1)
  116. return torch.cat((-x2, x1), dim=-1)
  117. else:
  118. x1, x2 = x[..., ::2], x[..., 1::2]
  119. return rearrange(
  120. torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
  121. )
  122. def apply_rotary_emb_torch(
  123. x: torch.Tensor,
  124. cos: torch.Tensor,
  125. sin: torch.Tensor,
  126. interleaved: bool = False,
  127. ) -> torch.Tensor:
  128. """
  129. x: (batch_size, seqlen, nheads, headdim)
  130. cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
  131. """
  132. ro_dim = cos.shape[-1] * 2
  133. assert ro_dim <= x.shape[-1]
  134. cos = repeat(
  135. cos,
  136. "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)",
  137. )
  138. sin = repeat(
  139. sin,
  140. "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)",
  141. )
  142. return torch.cat(
  143. [
  144. x[..., :ro_dim] * cos
  145. + rotate_half(x[..., :ro_dim], interleaved) * sin,
  146. x[..., ro_dim:],
  147. ],
  148. dim=-1,
  149. )
  150. def apply_rotary_pos_emb_vision(
  151. t: torch.Tensor, freqs: torch.Tensor
  152. ) -> torch.Tensor:
  153. t_ = t.float()
  154. cos = freqs.cos()
  155. sin = freqs.sin()
  156. output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
  157. return output
  158. class Qwen2VisionAttention(nn.Module):
  159. def __init__(
  160. self,
  161. embed_dim: Optional[int] = None,
  162. num_heads: Optional[int] = None,
  163. projection_size: Optional[int] = None,
  164. quant_config: Optional[QuantizationConfig] = None,
  165. ) -> None:
  166. super().__init__()
  167. # Per attention head and per partition values.
  168. world_size = parallel_state.get_tensor_model_parallel_world_size()
  169. self.hidden_size_per_attention_head = dist_utils.divide(
  170. projection_size, num_heads
  171. )
  172. self.num_attention_heads_per_partition = dist_utils.divide(
  173. num_heads, world_size
  174. )
  175. self.qkv = ColumnParallelLinear(
  176. input_size=embed_dim,
  177. output_size=3 * projection_size,
  178. quant_config=quant_config,
  179. )
  180. self.proj = RowParallelLinear(
  181. input_size=projection_size,
  182. output_size=embed_dim,
  183. quant_config=quant_config,
  184. )
  185. # Detect attention implementation.
  186. selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
  187. if selected_backend is None:
  188. backend_by_env_var: Optional[str] = envs.APHRODITE_ATTENTION_BACKEND
  189. if backend_by_env_var is not None:
  190. selected_backend = backend_name_to_enum(backend_by_env_var)
  191. if selected_backend is None:
  192. # For Volta and Turing GPUs, use xformers instead.
  193. device_available = current_platform.get_device_capability()[0] >= 8
  194. if device_available:
  195. from transformers.utils import is_flash_attn_2_available
  196. if is_flash_attn_2_available():
  197. self._use_flash_attn = True
  198. else:
  199. log_once(
  200. level="WARNING",
  201. message=
  202. "Current Qwen2-VL implementation has a bug with "
  203. "`aphrodite-flash-attn` inside vision module, so we use"
  204. " xformers backend instead. You can run `pip install "
  205. "flash-attn to use flash-attention backend."
  206. )
  207. self._use_flash_attn = False
  208. else:
  209. self._use_flash_attn = False
  210. else:
  211. if selected_backend == _Backend.FLASH_ATTN:
  212. self._use_flash_attn = True
  213. elif selected_backend == _Backend.XFORMERS:
  214. self._use_flash_attn = False
  215. else:
  216. raise RuntimeError(
  217. f"Qwen2-VL does not support {selected_backend} backend now."
  218. )
  219. def forward(
  220. self,
  221. x: torch.Tensor,
  222. cu_seqlens: torch.Tensor,
  223. rotary_pos_emb: torch.Tensor = None,
  224. ) -> torch.Tensor:
  225. # [s, b, c] --> [s, b, head * 3 * head_dim]
  226. x, _ = self.qkv(x)
  227. # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
  228. new_x_shape = x.size()[:-1] + (
  229. self.num_attention_heads_per_partition,
  230. 3 * self.hidden_size_per_attention_head,
  231. )
  232. x = x.view(*new_x_shape)
  233. # [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
  234. q, k, v = dist_utils.split_tensor_along_last_dim(x, 3)
  235. batch_size = q.shape[1]
  236. q, k, v = [
  237. rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
  238. ]
  239. if rotary_pos_emb is not None:
  240. q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
  241. k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
  242. if self._use_flash_attn:
  243. # from aphrodite_flash_attn.flash_attn_interface import (
  244. # flash_attn_varlen_func)
  245. from flash_attn import flash_attn_varlen_func
  246. q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
  247. max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
  248. output = flash_attn_varlen_func(
  249. q,
  250. k,
  251. v,
  252. cu_seqlens_q=cu_seqlens,
  253. cu_seqlens_k=cu_seqlens,
  254. max_seqlen_q=max_seqlen,
  255. max_seqlen_k=max_seqlen,
  256. dropout_p=0,
  257. causal=False,
  258. )
  259. context_layer = rearrange(
  260. output, "(b s) ... -> b s ...", b=batch_size
  261. )
  262. else:
  263. from xformers import ops as xops
  264. from xformers.ops.fmha.attn_bias import BlockDiagonalMask
  265. seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
  266. attn_bias = BlockDiagonalMask.from_seqlens(
  267. q_seqlen=seqlens, kv_seqlen=None
  268. )
  269. context_layer = xops.memory_efficient_attention_forward(
  270. q, k, v, attn_bias=attn_bias, p=0, scale=None
  271. )
  272. context_layer = rearrange(
  273. context_layer, "b s h d -> s b (h d)"
  274. ).contiguous()
  275. output, _ = self.proj(context_layer)
  276. return output
  277. class Qwen2VisionBlock(nn.Module):
  278. def __init__(
  279. self,
  280. dim: int,
  281. num_heads: int,
  282. mlp_ratio: float,
  283. act_layer: Type[nn.Module] = QuickGELU,
  284. norm_layer: Type[nn.Module] = None,
  285. quant_config: Optional[QuantizationConfig] = None,
  286. ) -> None:
  287. super().__init__()
  288. if norm_layer is None:
  289. norm_layer = partial(nn.LayerNorm, eps=1e-6)
  290. self.norm1 = norm_layer(dim)
  291. self.norm2 = norm_layer(dim)
  292. mlp_hidden_dim = int(dim * mlp_ratio)
  293. self.attn = Qwen2VisionAttention(
  294. embed_dim=dim,
  295. num_heads=num_heads,
  296. projection_size=dim,
  297. quant_config=quant_config,
  298. )
  299. self.mlp = Qwen2VisionMLP(
  300. dim, mlp_hidden_dim, act_layer=act_layer, quant_config=quant_config
  301. )
  302. def forward(
  303. self,
  304. x: torch.Tensor,
  305. cu_seqlens: torch.Tensor,
  306. rotary_pos_emb: torch.Tensor,
  307. ) -> torch.Tensor:
  308. x = x + self.attn(
  309. self.norm1(x), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
  310. )
  311. x = x + self.mlp(self.norm2(x))
  312. return x
  313. class Qwen2VisionPatchEmbed(nn.Module):
  314. def __init__(
  315. self,
  316. patch_size: int = 14,
  317. temporal_patch_size: int = 2,
  318. in_chans: int = 3,
  319. embed_dim: int = 1152,
  320. ) -> None:
  321. super().__init__()
  322. self.patch_size = patch_size
  323. self.temporal_patch_size = temporal_patch_size
  324. self.embed_dim = embed_dim
  325. kernel_size = [temporal_patch_size, patch_size, patch_size]
  326. self.proj = nn.Conv3d(
  327. in_chans,
  328. embed_dim,
  329. kernel_size=kernel_size,
  330. stride=kernel_size,
  331. bias=False,
  332. )
  333. def forward(self, x: torch.Tensor) -> torch.Tensor:
  334. L, C = x.shape
  335. x = x.view(
  336. L, -1, self.temporal_patch_size, self.patch_size, self.patch_size
  337. )
  338. x = self.proj(x).view(L, self.embed_dim)
  339. return x
  340. class Qwen2VisionPatchMerger(nn.Module):
  341. def __init__(
  342. self,
  343. d_model: int,
  344. context_dim: int,
  345. norm_layer: Type[nn.Module] = None,
  346. spatial_merge_size: int = 2,
  347. quant_config: Optional[QuantizationConfig] = None,
  348. ) -> None:
  349. super().__init__()
  350. self.hidden_size = context_dim * (spatial_merge_size**2)
  351. if norm_layer is None:
  352. norm_layer = partial(nn.LayerNorm, eps=1e-6)
  353. self.ln_q = norm_layer(context_dim)
  354. self.mlp = nn.ModuleList(
  355. [
  356. ColumnParallelLinear(
  357. self.hidden_size,
  358. self.hidden_size,
  359. bias=True,
  360. quant_config=quant_config,
  361. ),
  362. nn.GELU(),
  363. RowParallelLinear(
  364. self.hidden_size,
  365. d_model,
  366. bias=True,
  367. quant_config=quant_config,
  368. ),
  369. ]
  370. )
  371. def forward(self, x: torch.Tensor) -> torch.Tensor:
  372. x = self.ln_q(x)
  373. x = x.view(-1, self.hidden_size)
  374. mlp_fc1, mlp_act, mlp_fc2 = self.mlp
  375. x_parallel, _ = mlp_fc1(x)
  376. x_parallel = mlp_act(x_parallel)
  377. out, _ = mlp_fc2(x_parallel)
  378. return out
  379. class Qwen2VisionRotaryEmbedding(nn.Module):
  380. def __init__(self, dim: int, theta: float = 10000.0) -> None:
  381. super().__init__()
  382. self.dim = dim
  383. self.theta = theta
  384. inv_freq = 1.0 / (
  385. theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)
  386. )
  387. self.register_buffer("inv_freq", inv_freq, persistent=False)
  388. self._seq_len_cached = 0
  389. self._freqs_cached = None
  390. def update_freqs_cache(self, seqlen: int) -> None:
  391. if seqlen > self._seq_len_cached:
  392. seqlen *= 2
  393. self._seq_len_cached = seqlen
  394. self.inv_freq = 1.0 / (
  395. self.theta
  396. ** (
  397. torch.arange(
  398. 0,
  399. self.dim,
  400. 2,
  401. dtype=torch.float,
  402. device=self.inv_freq.device,
  403. )
  404. / self.dim
  405. )
  406. )
  407. seq = torch.arange(
  408. seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
  409. )
  410. freqs = torch.outer(seq, self.inv_freq)
  411. self._freqs_cached = freqs
  412. def forward(self, seqlen: int) -> torch.Tensor:
  413. self.update_freqs_cache(seqlen)
  414. return self._freqs_cached[:seqlen]
  415. class Qwen2VisionTransformer(nn.Module):
  416. def __init__(
  417. self,
  418. vision_config: Qwen2VLVisionConfig,
  419. norm_eps: float = 1e-6,
  420. quant_config: Optional[QuantizationConfig] = None,
  421. ) -> None:
  422. super().__init__()
  423. patch_size: int = vision_config.patch_size
  424. temporal_patch_size: int = vision_config.temporal_patch_size
  425. spatial_merge_size: int = vision_config.spatial_merge_size
  426. in_chans: int = vision_config.in_chans
  427. hidden_size: int = vision_config.hidden_size
  428. embed_dim: int = vision_config.embed_dim
  429. depth: int = vision_config.depth
  430. num_heads: int = vision_config.num_heads
  431. mlp_ratio: float = vision_config.mlp_ratio
  432. self.spatial_merge_size = spatial_merge_size
  433. self.patch_embed = Qwen2VisionPatchEmbed(
  434. patch_size=patch_size,
  435. temporal_patch_size=temporal_patch_size,
  436. in_chans=in_chans,
  437. embed_dim=embed_dim,
  438. )
  439. norm_layer = partial(nn.LayerNorm, eps=norm_eps)
  440. head_dim = embed_dim // num_heads
  441. self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)
  442. self.blocks = nn.ModuleList(
  443. [
  444. Qwen2VisionBlock(
  445. dim=embed_dim,
  446. num_heads=num_heads,
  447. mlp_ratio=mlp_ratio,
  448. norm_layer=norm_layer,
  449. quant_config=quant_config,
  450. )
  451. for _ in range(depth)
  452. ]
  453. )
  454. self.merger = Qwen2VisionPatchMerger(
  455. d_model=hidden_size,
  456. context_dim=embed_dim,
  457. norm_layer=norm_layer,
  458. quant_config=quant_config,
  459. )
  460. @property
  461. def dtype(self) -> torch.dtype:
  462. return self.blocks[0].mlp.fc2.weight.dtype
  463. @property
  464. def device(self) -> torch.device:
  465. return self.blocks[0].mlp.fc2.weight.device
  466. def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
  467. pos_ids = []
  468. for t, h, w in grid_thw:
  469. hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
  470. wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
  471. hpos_ids = (
  472. hpos_ids.reshape(
  473. h // self.spatial_merge_size,
  474. self.spatial_merge_size,
  475. w // self.spatial_merge_size,
  476. self.spatial_merge_size,
  477. )
  478. .permute(0, 2, 1, 3)
  479. .flatten()
  480. )
  481. wpos_ids = (
  482. wpos_ids.reshape(
  483. h // self.spatial_merge_size,
  484. self.spatial_merge_size,
  485. w // self.spatial_merge_size,
  486. self.spatial_merge_size,
  487. )
  488. .permute(0, 2, 1, 3)
  489. .flatten()
  490. )
  491. pos_ids.append(
  492. torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)
  493. )
  494. pos_ids = torch.cat(pos_ids, dim=0)
  495. max_grid_size = grid_thw[:, 1:].max()
  496. rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
  497. rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
  498. return rotary_pos_emb
  499. def forward(
  500. self,
  501. x: torch.Tensor,
  502. grid_thw: torch.Tensor,
  503. ) -> torch.Tensor:
  504. # patchify
  505. x = x.to(device=self.device, dtype=self.dtype)
  506. x = self.patch_embed(x)
  507. # compute position embedding
  508. rotary_pos_emb = self.rot_pos_emb(grid_thw)
  509. # compute cu_seqlens
  510. cu_seqlens = torch.repeat_interleave(
  511. grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
  512. ).cumsum(dim=0, dtype=torch.int32)
  513. cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
  514. # transformers
  515. x = x.unsqueeze(1)
  516. for blk in self.blocks:
  517. x = blk(x, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
  518. # adapter
  519. x = self.merger(x)
  520. return x
  521. # === Vision input helpers === #
  522. cached_get_processor = lru_cache(get_processor)
  523. def mm_input_mapper_for_qwen2_vl(
  524. ctx: InputContext,
  525. data: MultiModalData[object],
  526. data_type_key: str,
  527. ) -> MultiModalInputs:
  528. """Input mapper for Qwen2-VL."""
  529. model_config = ctx.model_config
  530. image_processor = cached_get_image_processor(
  531. model_config.model, trust_remote_code=model_config.trust_remote_code
  532. )
  533. if image_processor is None:
  534. raise RuntimeError(
  535. "No HuggingFace processor is available "
  536. "to process the image object"
  537. )
  538. images = None
  539. videos = None
  540. if data_type_key == "image":
  541. images = data
  542. else:
  543. assert data_type_key == "video"
  544. videos = data
  545. try:
  546. batch_data = image_processor.preprocess(
  547. images=images, videos=videos, return_tensors="pt"
  548. ).data
  549. except Exception:
  550. logger.error(f"Failed to process image ({data})")
  551. raise
  552. return MultiModalInputs(batch_data)
  553. image_input_mapper_for_qwen2_vl = partial(
  554. mm_input_mapper_for_qwen2_vl, data_type_key="image"
  555. )
  556. video_input_mapper_for_qwen2_vl = partial(
  557. mm_input_mapper_for_qwen2_vl, data_type_key="video"
  558. )
  559. def _get_vision_info(
  560. image_processor,
  561. height: int,
  562. width: int,
  563. min_pixels: int,
  564. max_pixels: int,
  565. do_resize: bool = True,
  566. data_type_key: str = "image",
  567. mm_count: int = 1,
  568. ):
  569. """Get information (resized height / width and number of vision tokens)
  570. of input image / video frame."""
  571. if do_resize:
  572. resized_height, resized_width = smart_resize(
  573. height=height,
  574. width=width,
  575. factor=image_processor.patch_size * image_processor.merge_size,
  576. min_pixels=min_pixels,
  577. max_pixels=max_pixels,
  578. )
  579. else:
  580. resized_height, resized_width = height, width
  581. if data_type_key == "image":
  582. grid_t = mm_count
  583. else:
  584. assert data_type_key == "video"
  585. grid_t = max(mm_count // image_processor.temporal_patch_size, 1)
  586. grid_h = resized_height // image_processor.patch_size
  587. grid_w = resized_width // image_processor.patch_size
  588. vision_tokens = grid_t * grid_h * grid_w
  589. llm_num_vision_tokens = (
  590. vision_tokens
  591. // image_processor.merge_size
  592. // image_processor.merge_size
  593. )
  594. return resized_height, resized_width, llm_num_vision_tokens
  595. def _get_max_image_info(
  596. image_processor,
  597. data_type_key: str = "image",
  598. mm_count: int = 1,
  599. ):
  600. return _get_vision_info(
  601. image_processor,
  602. height=9999999,
  603. width=9999999,
  604. # Limit min / max pixels.
  605. min_pixels=max(image_processor.min_pixels, 28 * 28),
  606. max_pixels=min(image_processor.max_pixels, 1280 * 28 * 28),
  607. data_type_key=data_type_key,
  608. mm_count=mm_count,
  609. )
  610. def get_max_qwen2_vl_mm_tokens(ctx: InputContext, data_type_key: str) -> int:
  611. image_processor = cached_get_image_processor(ctx.model_config.model)
  612. (
  613. max_resized_height,
  614. max_resized_width,
  615. max_llm_image_tokens,
  616. ) = _get_max_image_info(
  617. image_processor, data_type_key=data_type_key, mm_count=1
  618. )
  619. return max_llm_image_tokens
  620. get_max_qwen2_vl_image_tokens = partial(
  621. get_max_qwen2_vl_mm_tokens, data_type_key="image"
  622. )
  623. get_max_qwen2_vl_video_tokens = partial(
  624. get_max_qwen2_vl_mm_tokens, data_type_key="video"
  625. )
  626. def dummy_data_for_qwen2_vl(
  627. ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]
  628. ) -> Tuple[SequenceData, Optional[MultiModalDataDict]]:
  629. image_processor = cached_get_image_processor(ctx.model_config.model)
  630. num_images = mm_counts["image"]
  631. (
  632. max_resized_height,
  633. max_resized_width,
  634. max_llm_image_tokens,
  635. ) = _get_max_image_info(
  636. image_processor, data_type_key="image", mm_count=num_images
  637. )
  638. if seq_len - max_llm_image_tokens - 2 < 0:
  639. raise RuntimeError(
  640. f"Qwen2-VL cannot process {num_images} images in a prompt, "
  641. "please increase max_model_len or reduce image limit by "
  642. "--limit-mm-per-prompt."
  643. )
  644. # Check video counts.
  645. num_videos = mm_counts["video"]
  646. (
  647. max_resized_height,
  648. max_resized_width,
  649. max_llm_video_tokens,
  650. ) = _get_max_image_info(
  651. image_processor, data_type_key="video", mm_count=num_videos
  652. )
  653. if seq_len - max_llm_video_tokens - 2 < 0:
  654. raise RuntimeError(
  655. f"Qwen2-VL cannot process {num_images} videos in a prompt, "
  656. "please increase max_model_len or reduce video limit by "
  657. "--limit-mm-per-prompt."
  658. )
  659. hf_config = ctx.get_hf_config(Qwen2VLConfig)
  660. dummy_seqdata = SequenceData.from_token_counts(
  661. (hf_config.vision_start_token_id, 1),
  662. (hf_config.image_token_id, max_llm_image_tokens),
  663. (hf_config.vision_end_token_id, 1),
  664. (0, seq_len - max_llm_image_tokens - 2),
  665. )
  666. dummy_image = Image.new(
  667. "RGB", (max_resized_width, max_resized_height), color=0
  668. )
  669. return dummy_seqdata, {
  670. "image": dummy_image if num_images == 1 else [dummy_image] * num_images
  671. }
  672. def _get_llm_num_vision_tokens(
  673. mm_inputs: list,
  674. data_type_key: str,
  675. image_processor,
  676. ):
  677. """Get number of vision tokens of multimodal inputs.
  678. This method is derived from `transformers.models.qwen2_vl.
  679. image_processing_qwen2_vl.Qwen2VLImageProcessor._preprocess`.
  680. """
  681. image = to_numpy_array(mm_inputs[0])
  682. input_data_format = infer_channel_dimension_format(image)
  683. height, width = get_image_size(image, channel_dim=input_data_format)
  684. _, _, llm_num_vision_tokens = _get_vision_info(
  685. image_processor,
  686. height=height,
  687. width=width,
  688. min_pixels=image_processor.min_pixels,
  689. max_pixels=image_processor.max_pixels,
  690. do_resize=image_processor.do_resize,
  691. data_type_key=data_type_key,
  692. mm_count=len(mm_inputs),
  693. )
  694. return llm_num_vision_tokens
  695. def input_processor_for_qwen2_vl(
  696. ctx: InputContext, llm_inputs: LLMInputs
  697. ) -> LLMInputs:
  698. multi_modal_data = llm_inputs.get("multi_modal_data", None)
  699. if multi_modal_data is None:
  700. return llm_inputs
  701. image_inputs = multi_modal_data.get("image", None)
  702. video_inputs = multi_modal_data.get("video", None)
  703. processor = cached_get_processor(ctx.model_config.model)
  704. image_processor = processor.image_processor
  705. hf_config = ctx.get_hf_config(Qwen2VLConfig)
  706. # To avoid redundant processing of vision objects (resize, rescale, etc.),
  707. # we extract code of calculating number of vision tokens from
  708. # `transformers.models.qwen2_vl.processing_qwen2_vl.Qwen2VLProcessor`.
  709. #
  710. # The following code is equivalent to:
  711. # prompt = llm_inputs["prompt"]
  712. # inputs = processor(text=[prompt],
  713. # images=image_inputs,
  714. # videos=video_inputs,
  715. # padding=True,
  716. # return_tensors="pt")
  717. # prompt_token_ids = inputs["input_ids"][0].tolist()
  718. prompt_token_ids = llm_inputs.get("prompt_token_ids", None)
  719. if prompt_token_ids is None:
  720. prompt = llm_inputs["prompt"]
  721. prompt_token_ids = processor.tokenizer(
  722. prompt,
  723. padding=True,
  724. return_tensors=None,
  725. )["input_ids"]
  726. # Expand image pad tokens.
  727. if image_inputs is not None:
  728. image_indices = [
  729. idx
  730. for idx, token in enumerate(prompt_token_ids)
  731. if token == hf_config.image_token_id
  732. ]
  733. image_inputs = make_batched_images(image_inputs)
  734. assert len(image_indices) == len(image_inputs)
  735. prompt_token_ids_with_image = []
  736. for image_cnt, image in enumerate(image_inputs):
  737. num_image_tokens = _get_llm_num_vision_tokens(
  738. [image],
  739. data_type_key="image",
  740. image_processor=image_processor,
  741. )
  742. if image_cnt == 0:
  743. non_image_tokens = prompt_token_ids[: image_indices[image_cnt]]
  744. else:
  745. non_image_tokens = prompt_token_ids[
  746. image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
  747. ]
  748. prompt_token_ids_with_image.extend(non_image_tokens)
  749. prompt_token_ids_with_image.extend(
  750. hf_config.image_token_id for _ in range(num_image_tokens)
  751. )
  752. prompt_token_ids_with_image.extend(
  753. prompt_token_ids[image_indices[-1] + 1 :]
  754. )
  755. prompt_token_ids = prompt_token_ids_with_image
  756. # Expand video pad tokens.
  757. if video_inputs is not None:
  758. video_indices = [
  759. idx
  760. for idx, token in enumerate(prompt_token_ids)
  761. if token == hf_config.video_token_id
  762. ]
  763. video_inputs = make_batched_videos(video_inputs)
  764. assert len(video_indices) == len(video_inputs)
  765. prompt_token_ids_with_video = []
  766. for video_cnt, video in enumerate(video_inputs):
  767. num_video_tokens = _get_llm_num_vision_tokens(
  768. video,
  769. data_type_key="video",
  770. image_processor=image_processor,
  771. )
  772. if video_cnt == 0:
  773. non_video_tokens = prompt_token_ids[: video_indices[video_cnt]]
  774. else:
  775. non_video_tokens = prompt_token_ids[
  776. video_indices[video_cnt - 1] + 1 : video_indices[video_cnt]
  777. ]
  778. prompt_token_ids_with_video.extend(non_video_tokens)
  779. prompt_token_ids_with_video.extend(
  780. hf_config.video_token_id for _ in range(num_video_tokens)
  781. )
  782. prompt_token_ids_with_video.extend(
  783. prompt_token_ids[video_indices[-1] + 1 :]
  784. )
  785. prompt_token_ids = prompt_token_ids_with_video
  786. return LLMInputs(
  787. prompt_token_ids=prompt_token_ids,
  788. prompt=llm_inputs["prompt"],
  789. multi_modal_data=multi_modal_data,
  790. )
  791. @MULTIMODAL_REGISTRY.register_image_input_mapper(
  792. image_input_mapper_for_qwen2_vl
  793. )
  794. @MULTIMODAL_REGISTRY.register_input_mapper(
  795. "video", video_input_mapper_for_qwen2_vl
  796. )
  797. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_qwen2_vl_image_tokens)
  798. @MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
  799. "video", get_max_qwen2_vl_video_tokens
  800. )
  801. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_vl)
  802. @INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_vl)
  803. class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
  804. def __init__(
  805. self,
  806. config: Qwen2VLConfig,
  807. multimodal_config: MultiModalConfig,
  808. cache_config: Optional[CacheConfig] = None,
  809. quant_config: Optional[QuantizationConfig] = None,
  810. ) -> None:
  811. super().__init__()
  812. assert (
  813. not cache_config.enable_prefix_caching
  814. ), "Qwen2-VL currently does not support prefix caching"
  815. self.config = config
  816. self.multimodal_config = multimodal_config
  817. self.visual = Qwen2VisionTransformer(
  818. config.vision_config,
  819. norm_eps=getattr(config, "rms_norm_eps", 1e-6),
  820. # NOTE: Qwen2-VL vision encoder does not support any
  821. # quantization method now.
  822. quant_config=None,
  823. )
  824. self.model = Qwen2Model(config, cache_config, quant_config)
  825. if get_pp_group().is_last_rank:
  826. if config.tie_word_embeddings:
  827. self.lm_head = self.model.embed_tokens
  828. else:
  829. self.lm_head = ParallelLMHead(config.vocab_size,
  830. config.hidden_size,
  831. quant_config=quant_config)
  832. else:
  833. self.lm_head = PPMissingLayer()
  834. self.logits_processor = LogitsProcessor(config.vocab_size)
  835. self.sampler = Sampler()
  836. self.make_empty_intermediate_tensors = (
  837. make_empty_intermediate_tensors_factory(
  838. ["hidden_states", "residual"], config.hidden_size))
  839. def _validate_and_reshape_mm_tensor(
  840. self, mm_input: Union[torch.Tensor, List[torch.Tensor]], name: str
  841. ) -> torch.Tensor:
  842. if not isinstance(mm_input, (torch.Tensor, list)):
  843. raise ValueError(
  844. f"Incorrect type of {name}. " f"Got type: {type(mm_input)}"
  845. )
  846. if isinstance(mm_input, torch.Tensor):
  847. if mm_input.ndim == 2:
  848. return mm_input
  849. if mm_input.ndim != 3:
  850. raise ValueError(
  851. f"{name} should be 2D or batched 3D tensor. "
  852. f"Got ndim: {mm_input.ndim}"
  853. )
  854. return torch.concat(list(mm_input))
  855. else:
  856. return torch.concat(mm_input)
  857. def _parse_and_validate_image_input(
  858. self, **kwargs: object
  859. ) -> Optional[Qwen2VLImageInputs]:
  860. pixel_values = kwargs.pop("pixel_values", None)
  861. image_grid_thw = kwargs.pop("image_grid_thw", None)
  862. if pixel_values is None:
  863. return None
  864. pixel_values = self._validate_and_reshape_mm_tensor(
  865. pixel_values, "image pixel values"
  866. )
  867. image_grid_thw = self._validate_and_reshape_mm_tensor(
  868. image_grid_thw, "image grid_thw"
  869. )
  870. if not isinstance(pixel_values, (torch.Tensor, list)):
  871. raise ValueError(
  872. "Incorrect type of image pixel values. "
  873. f"Got type: {type(pixel_values)}"
  874. )
  875. return Qwen2VLImageInputs(
  876. pixel_values=pixel_values, image_grid_thw=image_grid_thw
  877. )
  878. def _parse_and_validate_video_input(
  879. self, **kwargs: object
  880. ) -> Optional[Qwen2VLVideoInputs]:
  881. pixel_values_videos = kwargs.pop("pixel_values_videos", None)
  882. video_grid_thw = kwargs.pop("video_grid_thw", None)
  883. if pixel_values_videos is None:
  884. return None
  885. pixel_values_videos = self._validate_and_reshape_mm_tensor(
  886. pixel_values_videos, "video pixel values"
  887. )
  888. video_grid_thw = self._validate_and_reshape_mm_tensor(
  889. video_grid_thw, "video grid_thw"
  890. )
  891. return Qwen2VLVideoInputs(
  892. pixel_values_videos=pixel_values_videos,
  893. video_grid_thw=video_grid_thw,
  894. )
  895. def _process_image_input(
  896. self, image_input: Qwen2VLImageInputs
  897. ) -> torch.Tensor:
  898. pixel_values = image_input["pixel_values"].type(self.visual.dtype)
  899. image_embeds = self.visual(
  900. pixel_values, grid_thw=image_input["image_grid_thw"]
  901. )
  902. return image_embeds
  903. def _process_video_input(
  904. self, video_input: Qwen2VLVideoInputs
  905. ) -> torch.Tensor:
  906. pixel_values_videos = video_input["pixel_values_videos"].type(
  907. self.visual.dtype
  908. )
  909. video_embeds = self.visual(
  910. pixel_values_videos, grid_thw=video_input["video_grid_thw"]
  911. )
  912. return video_embeds
  913. def _merge_multimodal_embeddings(
  914. self,
  915. input_ids: torch.Tensor,
  916. inputs_embeds: torch.Tensor,
  917. multimodal_embeddings: torch.Tensor,
  918. placeholder_token_id: int,
  919. ) -> torch.Tensor:
  920. mask = input_ids == placeholder_token_id
  921. inputs_embeds[mask, :] = multimodal_embeddings
  922. return inputs_embeds
  923. def forward(
  924. self,
  925. input_ids: torch.Tensor,
  926. positions: torch.Tensor,
  927. kv_caches: List[torch.Tensor],
  928. attn_metadata: AttentionMetadata,
  929. intermediate_tensors: Optional[IntermediateTensors] = None,
  930. **kwargs: object,
  931. ) -> SamplerOutput:
  932. """Run forward pass for Qwen2-VL.
  933. Args:
  934. input_ids: Flattened (concatenated) input_ids corresponding to a
  935. batch.
  936. positions: Flattened (concatenated) position ids corresponding to a
  937. batch.
  938. **NOTE**: If mrope is enabled (default setting for Qwen2-VL
  939. opensource models), the shape will be `(3, seq_len)`,
  940. otherwise it will be `(seq_len,).
  941. pixel_values: Pixel values to be fed to a model.
  942. `None` if no images are passed.
  943. image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
  944. `None` if no images are passed.
  945. pixel_values_videos: Pixel values of videos to be fed to a model.
  946. `None` if no videos are passed.
  947. video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
  948. `None` if no videos are passed.
  949. """
  950. image_input = self._parse_and_validate_image_input(**kwargs)
  951. video_input = self._parse_and_validate_video_input(**kwargs)
  952. if (image_input is None
  953. and video_input is None) or not get_pp_group().is_first_rank:
  954. inputs_embeds = None
  955. else:
  956. if (
  957. getattr(self.config, "rope_scaling", {}).get("type", None)
  958. == "mrope"
  959. ):
  960. assert positions.ndim == 2 and positions.size(0) == 3, (
  961. "multimodal section rotary embedding requires "
  962. f"(3, seq_len) positions, but got {positions.size()}"
  963. )
  964. inputs_embeds = self.model.embed_tokens(input_ids)
  965. if image_input is not None:
  966. image_embeds = self._process_image_input(image_input)
  967. inputs_embeds = self._merge_multimodal_embeddings(
  968. input_ids,
  969. inputs_embeds,
  970. image_embeds,
  971. placeholder_token_id=self.config.image_token_id,
  972. )
  973. if video_input is not None:
  974. video_embeds = self._process_video_input(video_input)
  975. inputs_embeds = self._merge_multimodal_embeddings(
  976. input_ids,
  977. inputs_embeds,
  978. video_embeds,
  979. placeholder_token_id=self.config.video_token_id,
  980. )
  981. input_ids = None
  982. hidden_states = self.model(
  983. input_ids=input_ids,
  984. positions=positions,
  985. kv_caches=kv_caches,
  986. attn_metadata=attn_metadata,
  987. intermediate_tensors=intermediate_tensors,
  988. inputs_embeds=inputs_embeds,
  989. )
  990. return hidden_states
  991. def compute_logits(
  992. self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata
  993. ) -> torch.Tensor:
  994. logits = self.logits_processor(
  995. self.lm_head, hidden_states, sampling_metadata
  996. )
  997. return logits
  998. def sample(
  999. self,
  1000. logits: torch.Tensor,
  1001. sampling_metadata: SamplingMetadata,
  1002. ) -> Optional[SamplerOutput]:
  1003. next_tokens = self.sampler(logits, sampling_metadata)
  1004. return next_tokens
  1005. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  1006. stacked_params_mapping = [
  1007. # (param_name, shard_name, shard_id)
  1008. ("qkv_proj", "q_proj", "q"),
  1009. ("qkv_proj", "k_proj", "k"),
  1010. ("qkv_proj", "v_proj", "v"),
  1011. ("gate_up_proj", "up_proj", 1),
  1012. ("gate_up_proj", "gate_proj", 0),
  1013. ]
  1014. params_dict = dict(self.named_parameters(remove_duplicate=False))
  1015. for name, loaded_weight in weights:
  1016. if "rotary_emb.inv_freq" in name:
  1017. continue
  1018. if self.config.tie_word_embeddings and "lm_head.weight" in name:
  1019. continue
  1020. for param_name, weight_name, shard_id in stacked_params_mapping:
  1021. if weight_name not in name:
  1022. continue
  1023. name = name.replace(weight_name, param_name)
  1024. # Skip loading extra bias for GPTQ models.
  1025. if name.endswith(".bias") and name not in params_dict:
  1026. continue
  1027. if is_pp_missing_parameter(name, self):
  1028. continue
  1029. param = params_dict[name]
  1030. weight_loader = param.weight_loader
  1031. weight_loader(param, loaded_weight, shard_id)
  1032. break
  1033. else:
  1034. if "visual" in name and "qkv.weight" in name:
  1035. visual_num_heads = self.config.vision_config.num_heads
  1036. visual_embed_dim = self.config.vision_config.embed_dim
  1037. head_size = visual_embed_dim // visual_num_heads
  1038. loaded_weight = loaded_weight.view(
  1039. 3, visual_num_heads, head_size, visual_embed_dim
  1040. )
  1041. loaded_weight = loaded_weight.transpose(0, 1)
  1042. loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
  1043. elif "visual" in name and "qkv.bias" in name:
  1044. visual_num_heads = self.config.vision_config.num_heads
  1045. visual_embed_dim = self.config.vision_config.embed_dim
  1046. head_size = visual_embed_dim // visual_num_heads
  1047. loaded_weight = loaded_weight.view(
  1048. 3, visual_num_heads, head_size
  1049. )
  1050. loaded_weight = loaded_weight.transpose(0, 1)
  1051. loaded_weight = loaded_weight.reshape(-1)
  1052. try:
  1053. # Skip loading extra bias for GPTQ models.
  1054. if name.endswith(".bias") and name not in params_dict:
  1055. continue
  1056. if is_pp_missing_parameter(name, self):
  1057. continue
  1058. param = params_dict[name]
  1059. except KeyError:
  1060. raise ValueError(f"Unexpected weight: {name}") from None
  1061. weight_loader = getattr(
  1062. param, "weight_loader", default_weight_loader
  1063. )
  1064. weight_loader(param, loaded_weight)