qwen.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991
  1. # coding=utf-8
  2. # Adapted from
  3. # https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
  4. # Copyright (c) Alibaba Cloud.
  5. # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
  6. """Inference-only QWen model compatible with HuggingFace weights."""
  7. import math
  8. import re
  9. from array import array
  10. from functools import partial
  11. from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
  12. Optional, Tuple, TypedDict, Union)
  13. import numpy as np
  14. import torch
  15. from loguru import logger
  16. from PIL import Image
  17. from torch import nn
  18. from torchvision import transforms
  19. from torchvision.transforms import InterpolationMode
  20. from transformers import PretrainedConfig
  21. from aphrodite.attention import Attention, AttentionMetadata
  22. from aphrodite.common.config import CacheConfig, MultiModalConfig
  23. from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
  24. IntermediateTensors, SequenceData)
  25. from aphrodite.common.utils import is_list_of
  26. from aphrodite.distributed import (get_pp_group,
  27. get_tensor_model_parallel_world_size)
  28. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  29. from aphrodite.modeling.layers.activation import SiluAndMul, get_act_fn
  30. from aphrodite.modeling.layers.layernorm import RMSNorm
  31. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  32. MergedColumnParallelLinear,
  33. QKVParallelLinear,
  34. RowParallelLinear)
  35. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  36. from aphrodite.modeling.layers.resampler import Resampler2, get_abs_pos
  37. from aphrodite.modeling.layers.rotary_embedding import get_rope
  38. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  39. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  40. ParallelLMHead, VocabParallelEmbedding)
  41. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  42. from aphrodite.modeling.models.interfaces import SupportsMultiModal
  43. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  44. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  45. from aphrodite.multimodal.base import MultiModalInputs
  46. from aphrodite.multimodal.utils import cached_get_tokenizer
  47. from aphrodite.quantization.base_config import QuantizationConfig
  48. from .utils import flatten_bn, is_pp_missing_parameter, make_layers
  49. # NOTE: Qwen models have a few other special tags, e.g., ref, bbox, quad;
  50. # for the time being, these tags are not considered as special at encoding
  51. # time. This may change as APHRODITEs multimodal API changes in the future.
  52. IMG_START = "<img>"
  53. IMG_END = "</img>"
  54. IMG_PAD = "<imgpad>"
  55. # Image context is fixed at 256 for all images
  56. MAX_QWEN_IMG_TOKENS = 256
  57. # Image normalization params
  58. CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
  59. CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
  60. class QwenImagePixelInputs(TypedDict):
  61. type: Literal["pixel_values"]
  62. data: torch.Tensor
  63. """
  64. Shape: `(batch_size * num_images, 3, image_size, image_size)`
  65. Note that image_size is the value in the vision config to which we resize
  66. the image to in the normalization transform. Currently multi-image support
  67. can only be leveraged by passing image embeddings directly.
  68. """
  69. class QwenImageEmbeddingInputs(TypedDict):
  70. type: Literal["image_embeds"]
  71. data: torch.Tensor
  72. """Shape: `(batch_size * num_images, 256, hidden_size)`
  73. `hidden_size` must match the hidden size of the language model backbone
  74. and is stored in the visual config of the model if we have one.
  75. """
  76. QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs]
  77. class VisualAttention(nn.Module):
  78. """self-attention layer class.
  79. Self-attention layer takes input with size [s, b, h]
  80. and returns output of the same size.
  81. """
  82. def __init__(
  83. self,
  84. embed_dim: int,
  85. num_heads: int,
  86. bias: bool = True,
  87. kdim: Optional[int] = None,
  88. vdim: Optional[int] = None,
  89. ):
  90. super().__init__()
  91. self.embed_dim = embed_dim
  92. self.kdim = kdim if kdim is not None else embed_dim
  93. self.vdim = vdim if vdim is not None else embed_dim
  94. self._qkv_same_embed_dim = self.kdim == embed_dim \
  95. and self.vdim == embed_dim
  96. self.num_heads = num_heads
  97. # Per attention head and per partition values.
  98. assert embed_dim % num_heads == 0
  99. self.hidden_size_per_attention_head = embed_dim // num_heads
  100. self.num_attention_heads_per_partition = num_heads
  101. self.hidden_size_per_partition = embed_dim
  102. # Strided linear layer.
  103. assert self._qkv_same_embed_dim, \
  104. 'Visual Attention implementation only supports self-attention'
  105. self.in_proj = nn.Linear(embed_dim, 3 * embed_dim)
  106. self.out_proj = nn.Linear(embed_dim, embed_dim)
  107. self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
  108. def forward(
  109. self,
  110. x: torch.Tensor,
  111. attn_mask: Optional[torch.Tensor] = None,
  112. ) -> torch.Tensor:
  113. # query/key/value: [sq, b, h]
  114. sq, b, _ = x.size()
  115. mixed_x_layer = self.in_proj(x)
  116. # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
  117. new_tensor_shape = mixed_x_layer.size()[:-1] + \
  118. (self.num_attention_heads_per_partition,
  119. 3 * self.hidden_size_per_attention_head)
  120. mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
  121. # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
  122. query_layer, key_layer, value_layer = mixed_x_layer.split(
  123. self.hidden_size_per_attention_head, dim=-1)
  124. # [sq, b, np, hn] -> [sq, b * np, hn]
  125. query_layer = query_layer.view(
  126. sq, b * self.num_attention_heads_per_partition,
  127. self.hidden_size_per_attention_head).transpose(0, 1)
  128. # [sk, b, np, hn] -> [sk, b * np, hn]
  129. key_layer = key_layer.view(
  130. sq, b * self.num_attention_heads_per_partition,
  131. self.hidden_size_per_attention_head).transpose(0, 1)
  132. q_scaled = query_layer / self.norm_factor
  133. if attn_mask is not None:
  134. attention_probs = torch.baddbmm(attn_mask, q_scaled,
  135. key_layer.transpose(-2, -1))
  136. else:
  137. attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1))
  138. attention_probs = attention_probs.softmax(dim=-1)
  139. value_layer = value_layer.view(
  140. sq, b * self.num_attention_heads_per_partition,
  141. self.hidden_size_per_attention_head).transpose(0, 1)
  142. # matmul: [b * np, sq, hn]
  143. context_layer = torch.bmm(attention_probs, value_layer)
  144. # change view [b, np, sq, hn]
  145. context_layer = context_layer.view(
  146. b, self.num_attention_heads_per_partition, sq,
  147. self.hidden_size_per_attention_head)
  148. # [b, np, sq, hn] --> [sq, b, np, hn]
  149. context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
  150. # [sq, b, np, hn] --> [sq, b, hp]
  151. new_context_layer_shape = context_layer.size()[:-2] + \
  152. (self.hidden_size_per_partition,)
  153. context_layer = context_layer.view(*new_context_layer_shape)
  154. output = self.out_proj(context_layer)
  155. return output
  156. class QwenVMLP(nn.Module):
  157. """MLP for the visual component of the Qwen model."""
  158. def __init__(
  159. self,
  160. hidden_size: int,
  161. intermediate_size: int,
  162. quant_config: Optional[QuantizationConfig] = None,
  163. ):
  164. super().__init__()
  165. self.c_fc = ColumnParallelLinear(hidden_size,
  166. intermediate_size,
  167. bias=True,
  168. quant_config=quant_config)
  169. self.act_fn = get_act_fn("gelu", quant_config, intermediate_size)
  170. self.c_proj = RowParallelLinear(
  171. intermediate_size,
  172. hidden_size,
  173. bias=True,
  174. quant_config=quant_config,
  175. )
  176. def forward(self, x):
  177. x, _ = self.c_fc(x)
  178. x = self.act_fn(x)
  179. x, _ = self.c_proj(x)
  180. return x
  181. class VisualAttentionBlock(nn.Module):
  182. def __init__(
  183. self,
  184. d_model: int,
  185. n_head: int,
  186. mlp_ratio: float = 4.0,
  187. norm_layer: Callable = nn.LayerNorm,
  188. quant_config: Optional[QuantizationConfig] = None,
  189. ):
  190. super().__init__()
  191. self.ln_1 = norm_layer(d_model)
  192. self.ln_2 = norm_layer(d_model)
  193. mlp_width = int(d_model * mlp_ratio)
  194. self.attn = VisualAttention(d_model, n_head)
  195. self.mlp = QwenVMLP(
  196. hidden_size=d_model,
  197. intermediate_size=mlp_width,
  198. quant_config=quant_config,
  199. )
  200. def attention(
  201. self,
  202. x: torch.Tensor,
  203. attn_mask: Optional[torch.Tensor] = None,
  204. ) -> torch.Tensor:
  205. attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
  206. return self.attn(x, attn_mask=attn_mask)
  207. def forward(
  208. self,
  209. x: torch.Tensor,
  210. attn_mask: Optional[torch.Tensor] = None,
  211. ) -> torch.Tensor:
  212. x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
  213. x = x + self.mlp(self.ln_2(x))
  214. return x
  215. class TransformerBlock(nn.Module):
  216. def __init__(
  217. self,
  218. width: int,
  219. layers: int,
  220. heads: int,
  221. mlp_ratio: float = 4.0,
  222. norm_layer: Callable = nn.LayerNorm,
  223. quant_config: Optional[QuantizationConfig] = None,
  224. ):
  225. super().__init__()
  226. self.width = width
  227. self.layers = layers
  228. self.resblocks = nn.ModuleList([
  229. VisualAttentionBlock(width,
  230. heads,
  231. mlp_ratio,
  232. norm_layer=norm_layer,
  233. quant_config=quant_config)
  234. for _ in range(layers)
  235. ])
  236. def get_cast_dtype(self) -> torch.dtype:
  237. return self.resblocks[0].mlp.c_fc.weight.dtype
  238. def get_cast_device(self) -> torch.device:
  239. return self.resblocks[0].mlp.c_fc.weight.device
  240. def forward(self,
  241. x: torch.Tensor,
  242. attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
  243. for r in self.resblocks:
  244. x = r(x, attn_mask=attn_mask)
  245. return x
  246. class VisionTransformer(nn.Module):
  247. def __init__(self,
  248. image_size: int,
  249. patch_size: int,
  250. width: int,
  251. layers: int,
  252. heads: int,
  253. mlp_ratio: float,
  254. n_queries: int = 256,
  255. output_dim: int = 512,
  256. image_start_id: int = 151857,
  257. quant_config: Optional[QuantizationConfig] = None,
  258. **kwargs):
  259. super().__init__()
  260. image_height, image_width = self.image_size = (image_size, image_size)
  261. patch_height, patch_width = self.patch_size = (patch_size, patch_size)
  262. self.grid_size = (image_height // patch_height,
  263. image_width // patch_width)
  264. self.output_dim = output_dim
  265. self.conv1 = nn.Conv2d(in_channels=3,
  266. out_channels=width,
  267. kernel_size=patch_size,
  268. stride=patch_size,
  269. bias=False)
  270. # class embeddings and positional embeddings
  271. scale = width**-0.5
  272. self.positional_embedding = nn.Parameter(scale *
  273. torch.randn(256, width))
  274. norm_layer = partial(nn.LayerNorm, eps=1e-6)
  275. self.ln_pre = norm_layer(width)
  276. self.transformer = TransformerBlock(width,
  277. layers,
  278. heads,
  279. mlp_ratio,
  280. norm_layer=norm_layer,
  281. quant_config=quant_config)
  282. self.attn_pool = Resampler2(
  283. grid_size=int(math.sqrt(n_queries)),
  284. embed_dim=output_dim,
  285. num_heads=output_dim // 128,
  286. kv_dim=width,
  287. norm_layer=norm_layer,
  288. adaptive=False,
  289. do_post_projection=False,
  290. ).to(
  291. device=self.positional_embedding.device,
  292. dtype=self.positional_embedding.dtype,
  293. )
  294. self.ln_post = norm_layer(output_dim)
  295. self.proj = nn.Parameter(
  296. (output_dim**-0.5) * torch.randn(output_dim, output_dim))
  297. self.image_start_id = image_start_id
  298. self.image_end_id = image_start_id + 1
  299. def forward(self, x: torch.Tensor) -> torch.Tensor:
  300. x = x.to(
  301. dtype=self.transformer.get_cast_dtype(),
  302. device=self.transformer.get_cast_device(),
  303. )
  304. # to patches
  305. x = self.conv1(x) # shape = [*, width, grid, grid]
  306. x = x.reshape(x.shape[0], x.shape[1],
  307. -1) # shape = [*, width, grid ** 2]
  308. x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
  309. x = x + get_abs_pos(self.positional_embedding, int(math.sqrt(
  310. x.size(1))))
  311. x = self.ln_pre(x)
  312. x = x.permute(1, 0, 2) # NLD -> LND
  313. x = self.transformer(x)
  314. x = x.permute(1, 0, 2) # LND -> NLD
  315. x = self.attn_pool(x)
  316. x = self.ln_post(x)
  317. x = x @ self.proj
  318. return x
  319. def get_image_positions(self,
  320. input_ids: torch.Tensor) -> Optional[torch.Tensor]:
  321. """Given the input IDs, extracts start/stop points corresponding to
  322. images.
  323. args:
  324. Returns:
  325. Optional torch tensor corresponding to start/stop pairs of images.
  326. """
  327. if torch.any(input_ids == self.image_start_id):
  328. bos_pos = torch.where(input_ids == self.image_start_id)
  329. eos_pos = torch.where(input_ids == self.image_end_id)
  330. return torch.stack((bos_pos[0], eos_pos[0]), dim=1)
  331. return None
  332. class QWenMLP(nn.Module):
  333. """MLP for the language component of the Qwen model, which contains a
  334. MergedColumnParallelLinear merging 2 outputs via silu activation."""
  335. def __init__(
  336. self,
  337. hidden_size: int,
  338. intermediate_size: int,
  339. hidden_act: str = "silu",
  340. quant_config: Optional[QuantizationConfig] = None,
  341. ):
  342. super().__init__()
  343. self.gate_up_proj = MergedColumnParallelLinear(
  344. hidden_size, [intermediate_size] * 2,
  345. bias=False,
  346. quant_config=quant_config)
  347. self.c_proj = RowParallelLinear(intermediate_size,
  348. hidden_size,
  349. bias=False,
  350. quant_config=quant_config)
  351. if hidden_act != "silu":
  352. raise ValueError(f"Unsupported activation: {hidden_act}. "
  353. "Only silu is supported for now.")
  354. self.act_fn = SiluAndMul()
  355. def forward(self, x: torch.Tensor) -> torch.Tensor:
  356. gate_up, _ = self.gate_up_proj(x)
  357. x = self.act_fn(gate_up)
  358. x, _ = self.c_proj(x)
  359. return x
  360. class QWenAttention(nn.Module):
  361. def __init__(
  362. self,
  363. hidden_size: int,
  364. num_heads: int,
  365. max_position_embeddings: int,
  366. rope_theta: float = 10000,
  367. rope_scaling: Optional[Dict[str, Any]] = None,
  368. cache_config: Optional[CacheConfig] = None,
  369. quant_config: Optional[QuantizationConfig] = None,
  370. ):
  371. super().__init__()
  372. self.hidden_size = hidden_size
  373. tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
  374. )
  375. self.total_num_heads = num_heads
  376. assert self.total_num_heads % tensor_model_parallel_world_size == 0
  377. self.num_heads = (self.total_num_heads //
  378. tensor_model_parallel_world_size)
  379. self.head_dim = hidden_size // self.total_num_heads
  380. self.c_attn = QKVParallelLinear(
  381. hidden_size,
  382. self.head_dim,
  383. self.total_num_heads,
  384. bias=True,
  385. quant_config=quant_config,
  386. )
  387. self.c_proj = RowParallelLinear(
  388. self.total_num_heads * self.head_dim,
  389. hidden_size,
  390. bias=False,
  391. quant_config=quant_config,
  392. )
  393. self.scaling = self.head_dim**-0.5
  394. self.rotary_emb = get_rope(
  395. self.head_dim,
  396. rotary_dim=self.head_dim,
  397. max_position=max_position_embeddings,
  398. base=rope_theta,
  399. rope_scaling=rope_scaling,
  400. )
  401. self.attn = Attention(self.num_heads,
  402. self.head_dim,
  403. self.scaling,
  404. cache_config=cache_config,
  405. quant_config=quant_config)
  406. def forward(
  407. self,
  408. positions: torch.Tensor,
  409. hidden_states: torch.Tensor,
  410. kv_cache: torch.Tensor,
  411. attn_metadata: AttentionMetadata,
  412. ) -> torch.Tensor:
  413. qkv, _ = self.c_attn(hidden_states)
  414. q, k, v = qkv.chunk(chunks=3, dim=-1)
  415. q, k = self.rotary_emb(positions, q, k)
  416. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  417. output, _ = self.c_proj(attn_output)
  418. return output
  419. class QWenBlock(nn.Module):
  420. def __init__(
  421. self,
  422. config: PretrainedConfig,
  423. cache_config: Optional[CacheConfig] = None,
  424. quant_config: Optional[QuantizationConfig] = None,
  425. ):
  426. super().__init__()
  427. self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  428. rope_theta = getattr(config, "rope_theta", 10000)
  429. rope_scaling = getattr(config, "rope_scaling", None)
  430. self.attn = QWenAttention(config.hidden_size,
  431. config.num_attention_heads,
  432. config.max_position_embeddings,
  433. rope_theta=rope_theta,
  434. rope_scaling=rope_scaling,
  435. cache_config=cache_config,
  436. quant_config=quant_config)
  437. self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  438. self.mlp = QWenMLP(config.hidden_size,
  439. config.intermediate_size // 2,
  440. quant_config=quant_config)
  441. def forward(
  442. self,
  443. positions: torch.Tensor,
  444. hidden_states: torch.Tensor,
  445. kv_cache: torch.Tensor,
  446. attn_metadata: AttentionMetadata,
  447. residual: Optional[torch.Tensor],
  448. ) -> Tuple[torch.Tensor, torch.Tensor]:
  449. # Self Attention
  450. if residual is None:
  451. residual = hidden_states
  452. hidden_states = self.ln_1(hidden_states)
  453. else:
  454. hidden_states, residual = self.ln_1(hidden_states, residual)
  455. hidden_states = self.attn(
  456. positions=positions,
  457. hidden_states=hidden_states,
  458. kv_cache=kv_cache,
  459. attn_metadata=attn_metadata,
  460. )
  461. # Fully Connected
  462. hidden_states, residual = self.ln_2(hidden_states, residual)
  463. hidden_states = self.mlp(hidden_states)
  464. return hidden_states, residual
  465. class QWenModel(nn.Module):
  466. def __init__(
  467. self,
  468. config: PretrainedConfig,
  469. cache_config: Optional[CacheConfig] = None,
  470. quant_config: Optional[QuantizationConfig] = None,
  471. prefix: str = "",
  472. ):
  473. super().__init__()
  474. self.config = config
  475. self.vocab_size = config.vocab_size
  476. self.wte = VocabParallelEmbedding(
  477. config.vocab_size,
  478. config.hidden_size,
  479. )
  480. self.start_layer, self.end_layer, self.h = make_layers(
  481. config.num_hidden_layers,
  482. lambda prefix: QWenBlock(config, cache_config, quant_config),
  483. prefix=f"{prefix}.h")
  484. self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  485. self.visual = VisionTransformer(**config.visual,
  486. quant_config=quant_config) if hasattr(
  487. config, "visual") else None
  488. def forward(
  489. self,
  490. input_ids: torch.Tensor,
  491. positions: torch.Tensor,
  492. kv_caches: List[torch.Tensor],
  493. attn_metadata: AttentionMetadata,
  494. intermediate_tensors: Optional[IntermediateTensors],
  495. pixel_values: Optional[QwenImageInputs],
  496. ) -> torch.Tensor:
  497. img_pos = None
  498. # If pixel / visual embeddings are provided, this is a visual model
  499. if pixel_values is not None and self.visual is not None:
  500. if pixel_values["type"] != "image_embeds":
  501. image_embeds = self.visual(pixel_values["data"])
  502. else:
  503. image_embeds = pixel_values["data"]
  504. # features should be of shape (# images, 256, hidden_dim)
  505. img_pos = self.visual.get_image_positions(input_ids)
  506. if isinstance(
  507. img_pos,
  508. np.ndarray) and img_pos.shape[0] != image_embeds.shape[0]:
  509. raise ValueError(
  510. f"Number of placeholders: {img_pos.shape[0]} "
  511. f"does not match number of images {image_embeds.shape[0]}."
  512. )
  513. if get_pp_group().is_first_rank:
  514. hidden_states = self.wte(input_ids)
  515. # Merge the image embeddings into the hidden states if actually have
  516. # visual features and the corresponding image tokens
  517. if img_pos is not None:
  518. for idx, (img_bos, img_eos) in enumerate(img_pos):
  519. hidden_states[img_bos + 1:img_eos] = image_embeds[idx]
  520. residual = None
  521. else:
  522. assert intermediate_tensors is not None
  523. hidden_states = intermediate_tensors["hidden_states"]
  524. residual = intermediate_tensors["residual"]
  525. for i in range(self.start_layer, self.end_layer):
  526. layer = self.h[i]
  527. hidden_states, residual = layer(
  528. positions,
  529. hidden_states,
  530. kv_caches[i - self.start_layer],
  531. attn_metadata,
  532. residual,
  533. )
  534. if not get_pp_group().is_last_rank:
  535. return IntermediateTensors({
  536. "hidden_states": hidden_states,
  537. "residual": residual
  538. })
  539. hidden_states, _ = self.ln_f(hidden_states, residual)
  540. return hidden_states
  541. def get_image_text(image_num: int, padding: bool) -> str:
  542. """Retrieves a placeholder text that when tokenized, will be expanded with
  543. image pads.
  544. Args:
  545. image_num: The number of the image that we want a text prompt for.
  546. Images should be indexed starting at 1.
  547. padding: Whether or not padding should be manually added.
  548. Returns:
  549. Text placeholder prompt for the image being considered.
  550. """
  551. image_start = f"Picture {image_num}: {IMG_START}"
  552. image_end = f"{IMG_END}\n"
  553. if not padding:
  554. return f"{image_start}{image_end}"
  555. return f"{image_start}{MAX_QWEN_IMG_TOKENS * IMG_PAD}{image_end}"
  556. def input_processor_for_qwen(ctx: InputContext,
  557. llm_inputs: LLMInputs) -> LLMInputs:
  558. """Processes the inputs, which may or may not be multimodal.
  559. Multimodal inputs will only be processed if the model has a "visual"
  560. component in its model config, otherwise they'll be ignored.
  561. Args:
  562. ctx: Context of the loaded model.
  563. llm_inputs: LLM inputs which may have a multi_modal_data attribute.
  564. Returns:
  565. If the model is language only or not multimodal inputs were provided,
  566. returns llm_inputs unmodified. Otherwise, processes the multimodal
  567. images / image embeddings and adds the fixed-length image placeholders.
  568. """
  569. multi_modal_data = llm_inputs.get("multi_modal_data")
  570. # Only process images if we have multimodal data and a visual config
  571. hf_config = ctx.get_hf_config()
  572. if (multi_modal_data is None or "image" not in multi_modal_data
  573. or not hasattr(hf_config, "visual")):
  574. return llm_inputs
  575. prompt = llm_inputs.get("prompt")
  576. prompt_token_ids = llm_inputs["prompt_token_ids"]
  577. model_config = ctx.model_config
  578. tokenizer = cached_get_tokenizer(model_config.tokenizer,
  579. trust_remote_code=True)
  580. image_data = multi_modal_data["image"]
  581. if isinstance(image_data, torch.Tensor):
  582. num_dims = len(image_data.shape)
  583. if num_dims < 2 or num_dims > 3:
  584. raise ValueError(
  585. f"Expected img embeds to be have 3 dimensions, got {num_dims}")
  586. num_images = 1 if num_dims == 2 else image_data.shape[0]
  587. elif isinstance(image_data, Image.Image):
  588. num_images = 1
  589. elif is_list_of(image_data, Image.Image):
  590. num_images = len(image_data)
  591. else:
  592. raise TypeError(f"Invalid image type: {type(image_data)}")
  593. if prompt is None:
  594. prompt = tokenizer.decode(prompt_token_ids)
  595. # Drops anything between <img>/</img> tags; encoding with the tokenizer
  596. # will automatically add the image pads for the context.
  597. new_prompt, num_matched_images = re.subn(
  598. r"(Picture \d*: <img>).*?(<\/img>\n)",
  599. r"\1\2",
  600. prompt,
  601. )
  602. if num_matched_images != num_images:
  603. logger.warning(
  604. f"Number of matched image placeholders {num_matched_images} "
  605. f"doesn't match the number of expected images {num_images}; "
  606. "check your placeholder formatting."
  607. )
  608. new_prompt_token_ids = tokenizer.encode(new_prompt)
  609. return LLMInputs(prompt=new_prompt,
  610. prompt_token_ids=new_prompt_token_ids,
  611. multi_modal_data=multi_modal_data)
  612. def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs:
  613. """Maps the input data to its MultiModalInputs (if any).
  614. Args:
  615. ctx: Context of the loaded model.
  616. data: data potentially containing image/image embeddings to be mapped
  617. to pixel_values in .forward() for a visual QWenLMHeadModel model.
  618. Returns:
  619. MultiModalInputs containing the stacked normalized images tensor or
  620. image embeddings.
  621. """
  622. # Early exit if we have provided an image to a language only Qwen model
  623. hf_config = ctx.get_hf_config()
  624. if not hasattr(hf_config, "visual"):
  625. logger.warning(
  626. "Images were provided but this model has no visual config; "
  627. "multimodal inputs will not be forwarded to the model.")
  628. return MultiModalInputs()
  629. model_config = ctx.model_config
  630. tokenizer = cached_get_tokenizer(model_config.tokenizer,
  631. trust_remote_code=True)
  632. image_pair_tok = tokenizer.encode(IMG_START + IMG_END,
  633. add_special_tokens=False,
  634. return_tensors="pt").squeeze()
  635. image_start_id = image_pair_tok[0]
  636. image_end_id = image_pair_tok[-1]
  637. if (image_start_id + 1) != image_end_id:
  638. raise ValueError(
  639. f"Found image end ID {image_end_id}, but expected {IMG_START} + 1")
  640. if len(image_pair_tok) != (MAX_QWEN_IMG_TOKENS + 2):
  641. raise ValueError(
  642. f"Expected image context length of {MAX_QWEN_IMG_TOKENS}, "
  643. f"but got {image_pair_tok - 2}")
  644. hf_config = ctx.get_hf_config()
  645. image_size = hf_config.visual["image_size"]
  646. img_emb_size = hf_config.visual["output_dim"]
  647. if isinstance(data, torch.Tensor):
  648. # It's expected that our values have already been processed
  649. # by the visual transformer; shape is expected to be:
  650. # (# images, 256, hidden_size)
  651. if len(data.shape) == 2:
  652. # Assume only one image embed was provided; unsqueeze the extra dim
  653. data = data.unsqueeze(0)
  654. if len(data.shape) != 3 or data.shape[
  655. 1] != MAX_QWEN_IMG_TOKENS or data.shape[2] != img_emb_size:
  656. raise ValueError(
  657. "Expected image embeds to be a tensor of shape"
  658. f"[# images, {MAX_QWEN_IMG_TOKENS}, {img_emb_size}], but "
  659. f"received shape [{data.shape}]")
  660. pixel_values = data
  661. else:
  662. transform = build_normalization_transform(image_size)
  663. if not isinstance(data, (list, tuple)):
  664. data = [data]
  665. transformed_images = [transform(datum) for datum in data]
  666. pixel_values = torch.stack(transformed_images, dim=0)
  667. return MultiModalInputs({"pixel_values": pixel_values})
  668. def build_normalization_transform(image_size: int) -> transforms.Compose:
  669. """Builds a normalization transform which can be applied to one or
  670. more input images from which we want to extract visual features.
  671. Args:
  672. image_size: size of the image to be processed for visual embeddings.
  673. Returns:
  674. Callable transform for normalizing and resizing one RGB image.
  675. """
  676. return transforms.Compose([
  677. transforms.Resize((image_size, image_size),
  678. interpolation=InterpolationMode.BICUBIC),
  679. transforms.ToTensor(),
  680. transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
  681. ])
  682. def dummy_data_for_qwen(
  683. ctx: InputContext,
  684. seq_len: int,
  685. mm_counts: Mapping[str, int],
  686. ) -> Tuple[SequenceData, Optional[Dict]]:
  687. """Build dummy data for warming up Qwen models; this will only contain text
  688. matching the defaults for APHRODITE unless the model has a visual config.
  689. Args:
  690. ctx: Context of the loaded model.
  691. seq_len: Number of tokens in the text sequence.
  692. mm_counts: multimodal data counts.
  693. Returns:
  694. Tuple containing sequential and multimodal data.
  695. """
  696. hf_config = ctx.get_hf_config()
  697. # The presence of a visual config indicates this is a multimodal model.
  698. # If we don't have it, the model is considered an LLM for warmup purposes.
  699. if not hasattr(hf_config, "visual"):
  700. seq_data = SequenceData(array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  701. [0] * seq_len))
  702. mm_data = None
  703. return seq_data, mm_data
  704. # We have a visual component - use images to warm up
  705. num_images = mm_counts["image"]
  706. model_config = ctx.model_config
  707. tokenizer = cached_get_tokenizer(model_config.tokenizer,
  708. trust_remote_code=True)
  709. # Build the image prompts with no imgpads; the tokenizer will add img pads
  710. image_prompt = ''.join(
  711. [get_image_text(idx, False) for idx in range(1, num_images + 1)])
  712. toks = tokenizer.encode(image_prompt, add_special_tokens=False)
  713. # Make sure we actually get the fixed context size per tok padding
  714. num_pads = toks.count(tokenizer.encode(IMG_PAD)[0])
  715. if num_pads != (num_images * MAX_QWEN_IMG_TOKENS):
  716. raise ValueError(
  717. f"Tokenized dummy data should encode {MAX_QWEN_IMG_TOKENS} pads"
  718. f" per image, but got {num_pads} pads for {num_images} image(s)"
  719. " in total. Are you using a qwen tokenizer?")
  720. # Ensure the number of tokens is at minimum the sequence length provided
  721. if len(toks) < seq_len:
  722. toks += [0] * (seq_len - len(toks))
  723. # Build the input images; width/height doesn't actually matter here since
  724. # the data will get resized and the # of tokens per image is constant
  725. image = Image.new("RGB", (224, 224), color=0)
  726. mm_data = {"image": image if num_images == 1 else [image] * num_images}
  727. return SequenceData(array(APHRODITE_TOKEN_ID_ARRAY_TYPE, toks)), mm_data
  728. @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen)
  729. @MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS)
  730. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen)
  731. @INPUT_REGISTRY.register_input_processor(input_processor_for_qwen)
  732. class QWenLMHeadModel(nn.Module, SupportsMultiModal):
  733. def __init__(
  734. self,
  735. config: PretrainedConfig,
  736. multimodal_config: MultiModalConfig,
  737. cache_config: Optional[CacheConfig] = None,
  738. quant_config: Optional[QuantizationConfig] = None,
  739. ):
  740. super().__init__()
  741. self.config = config
  742. self.multimodal_config = multimodal_config
  743. self.quant_config = quant_config
  744. self.transformer = QWenModel(config, cache_config, quant_config)
  745. self.lm_head = ParallelLMHead(config.vocab_size,
  746. config.hidden_size,
  747. quant_config=quant_config)
  748. if self.config.tie_word_embeddings:
  749. self.lm_head.weight = self.transformer.wte.weight
  750. self.logits_processor = LogitsProcessor(config.vocab_size)
  751. self.sampler = Sampler()
  752. def _get_image_input_type(
  753. self,
  754. pixel_values: Optional[torch.Tensor]) -> Optional[QwenImageInputs]:
  755. """Determines if the provided pixel_values are normalized pixel values
  756. or image embeddings.
  757. Args:
  758. pixel_values: Optional data to processed into visual embeddings.
  759. Returns:
  760. None of the QwenImageInputs type used to determine whether or not
  761. the visual transformer needs to process the pixel_values.
  762. """
  763. if pixel_values is not None and self.transformer.visual is not None:
  764. pixel_values = flatten_bn(pixel_values)
  765. if len(pixel_values.shape) == 3 and pixel_values.shape[
  766. 1] == MAX_QWEN_IMG_TOKENS and pixel_values.shape[
  767. 2] == self.config.visual["output_dim"]:
  768. return QwenImageEmbeddingInputs(
  769. type="image_embeds",
  770. data=pixel_values,
  771. )
  772. else:
  773. # If we have the wrong shape, assume we still need to process
  774. return QwenImagePixelInputs(
  775. type="pixel_values",
  776. data=pixel_values,
  777. )
  778. return None
  779. def forward(self,
  780. input_ids: torch.Tensor,
  781. positions: torch.Tensor,
  782. kv_caches: List[torch.Tensor],
  783. attn_metadata: AttentionMetadata,
  784. intermediate_tensors: Optional[IntermediateTensors] = None,
  785. pixel_values: Optional[torch.Tensor] = None) -> torch.Tensor:
  786. pixel_values = self._get_image_input_type(pixel_values)
  787. hidden_states = self.transformer(input_ids, positions, kv_caches,
  788. attn_metadata, intermediate_tensors,
  789. pixel_values)
  790. return hidden_states
  791. def make_empty_intermediate_tensors(
  792. self, batch_size: int, dtype: torch.dtype,
  793. device: torch.device) -> IntermediateTensors:
  794. return IntermediateTensors({
  795. "hidden_states":
  796. torch.zeros((batch_size, self.config.hidden_size),
  797. dtype=dtype,
  798. device=device),
  799. "residual":
  800. torch.zeros((batch_size, self.config.hidden_size),
  801. dtype=dtype,
  802. device=device),
  803. })
  804. def compute_logits(
  805. self,
  806. hidden_states: torch.Tensor,
  807. sampling_metadata: SamplingMetadata,
  808. ) -> Optional[torch.Tensor]:
  809. logits = self.logits_processor(self.lm_head, hidden_states,
  810. sampling_metadata)
  811. return logits
  812. def sample(
  813. self,
  814. logits: torch.Tensor,
  815. sampling_metadata: SamplingMetadata,
  816. ) -> Optional[SamplerOutput]:
  817. next_tokens = self.sampler(logits, sampling_metadata)
  818. return next_tokens
  819. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  820. stacked_params_mapping = [
  821. # (param_name, shard_name, shard_id)
  822. ("gate_up_proj", "w2", 0),
  823. ("gate_up_proj", "w1", 1),
  824. ]
  825. params_dict = dict(self.named_parameters())
  826. for name, loaded_weight in weights:
  827. if "rotary_emb.inv_freq" in name:
  828. continue
  829. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  830. if weight_name not in name:
  831. continue
  832. name = name.replace(weight_name, param_name)
  833. # Skip loading extra bias for GPTQ models.
  834. if name.endswith(".bias") and name not in params_dict:
  835. continue
  836. # Skip layers on other devices.
  837. if is_pp_missing_parameter(name, self):
  838. continue
  839. param = params_dict[name]
  840. weight_loader = param.weight_loader
  841. weight_loader(param, loaded_weight, shard_id)
  842. break
  843. else:
  844. # Skip loading extra bias for GPTQ models.
  845. if name.endswith(".bias") and name not in params_dict:
  846. continue
  847. # Skip layers on other devices.
  848. if is_pp_missing_parameter(name, self):
  849. continue
  850. param = params_dict[name]
  851. weight_loader = getattr(param, "weight_loader",
  852. default_weight_loader)
  853. weight_loader(param, loaded_weight)