qwen.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990
  1. # coding=utf-8
  2. # Adapted from
  3. # https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
  4. # Copyright (c) Alibaba Cloud.
  5. # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
  6. """Inference-only QWen model compatible with HuggingFace weights."""
  7. import math
  8. import re
  9. from functools import partial
  10. from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
  11. Optional, Tuple, TypedDict, Union)
  12. import numpy as np
  13. import torch
  14. from loguru import logger
  15. from PIL import Image
  16. from torch import nn
  17. from torchvision import transforms
  18. from torchvision.transforms import InterpolationMode
  19. from transformers import PretrainedConfig
  20. from aphrodite.attention import Attention, AttentionMetadata
  21. from aphrodite.common.config import CacheConfig, MultiModalConfig
  22. from aphrodite.common.sequence import IntermediateTensors, SequenceData
  23. from aphrodite.common.utils import is_list_of
  24. from aphrodite.distributed import (get_pp_group,
  25. get_tensor_model_parallel_world_size)
  26. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  27. from aphrodite.modeling.layers.activation import SiluAndMul, get_act_fn
  28. from aphrodite.modeling.layers.layernorm import RMSNorm
  29. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  30. MergedColumnParallelLinear,
  31. QKVParallelLinear,
  32. RowParallelLinear)
  33. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  34. from aphrodite.modeling.layers.resampler import Resampler2, get_abs_pos
  35. from aphrodite.modeling.layers.rotary_embedding import get_rope
  36. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  37. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  38. ParallelLMHead, VocabParallelEmbedding)
  39. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  40. from aphrodite.modeling.models.interfaces import SupportsMultiModal
  41. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  42. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  43. from aphrodite.multimodal.base import MultiModalInputs
  44. from aphrodite.multimodal.utils import cached_get_tokenizer
  45. from aphrodite.quantization.base_config import QuantizationConfig
  46. from .utils import flatten_bn, is_pp_missing_parameter, make_layers
  47. # NOTE: Qwen models have a few other special tags, e.g., ref, bbox, quad;
  48. # for the time being, these tags are not considered as special at encoding
  49. # time. This may change as APHRODITEs multimodal API changes in the future.
  50. IMG_START = "<img>"
  51. IMG_END = "</img>"
  52. IMG_PAD = "<imgpad>"
  53. # Image context is fixed at 256 for all images
  54. MAX_QWEN_IMG_TOKENS = 256
  55. # Image normalization params
  56. CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
  57. CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
  58. class QwenImagePixelInputs(TypedDict):
  59. type: Literal["pixel_values"]
  60. data: torch.Tensor
  61. """
  62. Shape: `(batch_size * num_images, 3, image_size, image_size)`
  63. Note that image_size is the value in the vision config to which we resize
  64. the image to in the normalization transform. Currently multi-image support
  65. can only be leveraged by passing image embeddings directly.
  66. """
  67. class QwenImageEmbeddingInputs(TypedDict):
  68. type: Literal["image_embeds"]
  69. data: torch.Tensor
  70. """Shape: `(batch_size * num_images, 256, hidden_size)`
  71. `hidden_size` must match the hidden size of the language model backbone
  72. and is stored in the visual config of the model if we have one.
  73. """
  74. QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs]
  75. class VisualAttention(nn.Module):
  76. """self-attention layer class.
  77. Self-attention layer takes input with size [s, b, h]
  78. and returns output of the same size.
  79. """
  80. def __init__(
  81. self,
  82. embed_dim: int,
  83. num_heads: int,
  84. bias: bool = True,
  85. kdim: Optional[int] = None,
  86. vdim: Optional[int] = None,
  87. ):
  88. super().__init__()
  89. self.embed_dim = embed_dim
  90. self.kdim = kdim if kdim is not None else embed_dim
  91. self.vdim = vdim if vdim is not None else embed_dim
  92. self._qkv_same_embed_dim = self.kdim == embed_dim \
  93. and self.vdim == embed_dim
  94. self.num_heads = num_heads
  95. # Per attention head and per partition values.
  96. assert embed_dim % num_heads == 0
  97. self.hidden_size_per_attention_head = embed_dim // num_heads
  98. self.num_attention_heads_per_partition = num_heads
  99. self.hidden_size_per_partition = embed_dim
  100. # Strided linear layer.
  101. assert self._qkv_same_embed_dim, \
  102. 'Visual Attention implementation only supports self-attention'
  103. self.in_proj = nn.Linear(embed_dim, 3 * embed_dim)
  104. self.out_proj = nn.Linear(embed_dim, embed_dim)
  105. self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
  106. def forward(
  107. self,
  108. x: torch.Tensor,
  109. attn_mask: Optional[torch.Tensor] = None,
  110. ) -> torch.Tensor:
  111. # query/key/value: [sq, b, h]
  112. sq, b, _ = x.size()
  113. mixed_x_layer = self.in_proj(x)
  114. # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
  115. new_tensor_shape = mixed_x_layer.size()[:-1] + \
  116. (self.num_attention_heads_per_partition,
  117. 3 * self.hidden_size_per_attention_head)
  118. mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
  119. # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
  120. query_layer, key_layer, value_layer = mixed_x_layer.split(
  121. self.hidden_size_per_attention_head, dim=-1)
  122. # [sq, b, np, hn] -> [sq, b * np, hn]
  123. query_layer = query_layer.view(
  124. sq, b * self.num_attention_heads_per_partition,
  125. self.hidden_size_per_attention_head).transpose(0, 1)
  126. # [sk, b, np, hn] -> [sk, b * np, hn]
  127. key_layer = key_layer.view(
  128. sq, b * self.num_attention_heads_per_partition,
  129. self.hidden_size_per_attention_head).transpose(0, 1)
  130. q_scaled = query_layer / self.norm_factor
  131. if attn_mask is not None:
  132. attention_probs = torch.baddbmm(attn_mask, q_scaled,
  133. key_layer.transpose(-2, -1))
  134. else:
  135. attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1))
  136. attention_probs = attention_probs.softmax(dim=-1)
  137. value_layer = value_layer.view(
  138. sq, b * self.num_attention_heads_per_partition,
  139. self.hidden_size_per_attention_head).transpose(0, 1)
  140. # matmul: [b * np, sq, hn]
  141. context_layer = torch.bmm(attention_probs, value_layer)
  142. # change view [b, np, sq, hn]
  143. context_layer = context_layer.view(
  144. b, self.num_attention_heads_per_partition, sq,
  145. self.hidden_size_per_attention_head)
  146. # [b, np, sq, hn] --> [sq, b, np, hn]
  147. context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
  148. # [sq, b, np, hn] --> [sq, b, hp]
  149. new_context_layer_shape = context_layer.size()[:-2] + \
  150. (self.hidden_size_per_partition,)
  151. context_layer = context_layer.view(*new_context_layer_shape)
  152. output = self.out_proj(context_layer)
  153. return output
  154. class QwenVMLP(nn.Module):
  155. """MLP for the visual component of the Qwen model."""
  156. def __init__(
  157. self,
  158. hidden_size: int,
  159. intermediate_size: int,
  160. quant_config: Optional[QuantizationConfig] = None,
  161. ):
  162. super().__init__()
  163. self.c_fc = ColumnParallelLinear(hidden_size,
  164. intermediate_size,
  165. bias=True,
  166. quant_config=quant_config)
  167. self.act_fn = get_act_fn("gelu", quant_config, intermediate_size)
  168. self.c_proj = RowParallelLinear(
  169. intermediate_size,
  170. hidden_size,
  171. bias=True,
  172. quant_config=quant_config,
  173. )
  174. def forward(self, x):
  175. x, _ = self.c_fc(x)
  176. x = self.act_fn(x)
  177. x, _ = self.c_proj(x)
  178. return x
  179. class VisualAttentionBlock(nn.Module):
  180. def __init__(
  181. self,
  182. d_model: int,
  183. n_head: int,
  184. mlp_ratio: float = 4.0,
  185. norm_layer: Callable = nn.LayerNorm,
  186. quant_config: Optional[QuantizationConfig] = None,
  187. ):
  188. super().__init__()
  189. self.ln_1 = norm_layer(d_model)
  190. self.ln_2 = norm_layer(d_model)
  191. mlp_width = int(d_model * mlp_ratio)
  192. self.attn = VisualAttention(d_model, n_head)
  193. self.mlp = QwenVMLP(
  194. hidden_size=d_model,
  195. intermediate_size=mlp_width,
  196. quant_config=quant_config,
  197. )
  198. def attention(
  199. self,
  200. x: torch.Tensor,
  201. attn_mask: Optional[torch.Tensor] = None,
  202. ) -> torch.Tensor:
  203. attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
  204. return self.attn(x, attn_mask=attn_mask)
  205. def forward(
  206. self,
  207. x: torch.Tensor,
  208. attn_mask: Optional[torch.Tensor] = None,
  209. ) -> torch.Tensor:
  210. x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
  211. x = x + self.mlp(self.ln_2(x))
  212. return x
  213. class TransformerBlock(nn.Module):
  214. def __init__(
  215. self,
  216. width: int,
  217. layers: int,
  218. heads: int,
  219. mlp_ratio: float = 4.0,
  220. norm_layer: Callable = nn.LayerNorm,
  221. quant_config: Optional[QuantizationConfig] = None,
  222. ):
  223. super().__init__()
  224. self.width = width
  225. self.layers = layers
  226. self.resblocks = nn.ModuleList([
  227. VisualAttentionBlock(width,
  228. heads,
  229. mlp_ratio,
  230. norm_layer=norm_layer,
  231. quant_config=quant_config)
  232. for _ in range(layers)
  233. ])
  234. def get_cast_dtype(self) -> torch.dtype:
  235. return self.resblocks[0].mlp.c_fc.weight.dtype
  236. def get_cast_device(self) -> torch.device:
  237. return self.resblocks[0].mlp.c_fc.weight.device
  238. def forward(self,
  239. x: torch.Tensor,
  240. attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
  241. for r in self.resblocks:
  242. x = r(x, attn_mask=attn_mask)
  243. return x
  244. class VisionTransformer(nn.Module):
  245. def __init__(self,
  246. image_size: int,
  247. patch_size: int,
  248. width: int,
  249. layers: int,
  250. heads: int,
  251. mlp_ratio: float,
  252. n_queries: int = 256,
  253. output_dim: int = 512,
  254. image_start_id: int = 151857,
  255. quant_config: Optional[QuantizationConfig] = None,
  256. **kwargs):
  257. super().__init__()
  258. image_height, image_width = self.image_size = (image_size, image_size)
  259. patch_height, patch_width = self.patch_size = (patch_size, patch_size)
  260. self.grid_size = (image_height // patch_height,
  261. image_width // patch_width)
  262. self.output_dim = output_dim
  263. self.conv1 = nn.Conv2d(in_channels=3,
  264. out_channels=width,
  265. kernel_size=patch_size,
  266. stride=patch_size,
  267. bias=False)
  268. # class embeddings and positional embeddings
  269. scale = width**-0.5
  270. self.positional_embedding = nn.Parameter(scale *
  271. torch.randn(256, width))
  272. norm_layer = partial(nn.LayerNorm, eps=1e-6)
  273. self.ln_pre = norm_layer(width)
  274. self.transformer = TransformerBlock(width,
  275. layers,
  276. heads,
  277. mlp_ratio,
  278. norm_layer=norm_layer,
  279. quant_config=quant_config)
  280. self.attn_pool = Resampler2(
  281. grid_size=int(math.sqrt(n_queries)),
  282. embed_dim=output_dim,
  283. num_heads=output_dim // 128,
  284. kv_dim=width,
  285. norm_layer=norm_layer,
  286. adaptive=False,
  287. do_post_projection=False,
  288. ).to(
  289. device=self.positional_embedding.device,
  290. dtype=self.positional_embedding.dtype,
  291. )
  292. self.ln_post = norm_layer(output_dim)
  293. self.proj = nn.Parameter(
  294. (output_dim**-0.5) * torch.randn(output_dim, output_dim))
  295. self.image_start_id = image_start_id
  296. self.image_end_id = image_start_id + 1
  297. def forward(self, x: torch.Tensor) -> torch.Tensor:
  298. x = x.to(
  299. dtype=self.transformer.get_cast_dtype(),
  300. device=self.transformer.get_cast_device(),
  301. )
  302. # to patches
  303. x = self.conv1(x) # shape = [*, width, grid, grid]
  304. x = x.reshape(x.shape[0], x.shape[1],
  305. -1) # shape = [*, width, grid ** 2]
  306. x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
  307. x = x + get_abs_pos(self.positional_embedding, int(math.sqrt(
  308. x.size(1))))
  309. x = self.ln_pre(x)
  310. x = x.permute(1, 0, 2) # NLD -> LND
  311. x = self.transformer(x)
  312. x = x.permute(1, 0, 2) # LND -> NLD
  313. x = self.attn_pool(x)
  314. x = self.ln_post(x)
  315. x = x @ self.proj
  316. return x
  317. def get_image_positions(self,
  318. input_ids: torch.Tensor) -> Optional[torch.Tensor]:
  319. """Given the input IDs, extracts start/stop points corresponding to
  320. images.
  321. args:
  322. Returns:
  323. Optional torch tensor corresponding to start/stop pairs of images.
  324. """
  325. if torch.any(input_ids == self.image_start_id):
  326. bos_pos = torch.where(input_ids == self.image_start_id)
  327. eos_pos = torch.where(input_ids == self.image_end_id)
  328. return torch.stack((bos_pos[0], eos_pos[0]), dim=1)
  329. return None
  330. class QWenMLP(nn.Module):
  331. """MLP for the language component of the Qwen model, which contains a
  332. MergedColumnParallelLinear merging 2 outputs via silu activation."""
  333. def __init__(
  334. self,
  335. hidden_size: int,
  336. intermediate_size: int,
  337. hidden_act: str = "silu",
  338. quant_config: Optional[QuantizationConfig] = None,
  339. ):
  340. super().__init__()
  341. self.gate_up_proj = MergedColumnParallelLinear(
  342. hidden_size, [intermediate_size] * 2,
  343. bias=False,
  344. quant_config=quant_config)
  345. self.c_proj = RowParallelLinear(intermediate_size,
  346. hidden_size,
  347. bias=False,
  348. quant_config=quant_config)
  349. if hidden_act != "silu":
  350. raise ValueError(f"Unsupported activation: {hidden_act}. "
  351. "Only silu is supported for now.")
  352. self.act_fn = SiluAndMul()
  353. def forward(self, x: torch.Tensor) -> torch.Tensor:
  354. gate_up, _ = self.gate_up_proj(x)
  355. x = self.act_fn(gate_up)
  356. x, _ = self.c_proj(x)
  357. return x
  358. class QWenAttention(nn.Module):
  359. def __init__(
  360. self,
  361. hidden_size: int,
  362. num_heads: int,
  363. max_position_embeddings: int,
  364. rope_theta: float = 10000,
  365. rope_scaling: Optional[Dict[str, Any]] = None,
  366. cache_config: Optional[CacheConfig] = None,
  367. quant_config: Optional[QuantizationConfig] = None,
  368. ):
  369. super().__init__()
  370. self.hidden_size = hidden_size
  371. tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
  372. )
  373. self.total_num_heads = num_heads
  374. assert self.total_num_heads % tensor_model_parallel_world_size == 0
  375. self.num_heads = (self.total_num_heads //
  376. tensor_model_parallel_world_size)
  377. self.head_dim = hidden_size // self.total_num_heads
  378. self.c_attn = QKVParallelLinear(
  379. hidden_size,
  380. self.head_dim,
  381. self.total_num_heads,
  382. bias=True,
  383. quant_config=quant_config,
  384. )
  385. self.c_proj = RowParallelLinear(
  386. self.total_num_heads * self.head_dim,
  387. hidden_size,
  388. bias=False,
  389. quant_config=quant_config,
  390. )
  391. self.scaling = self.head_dim**-0.5
  392. self.rotary_emb = get_rope(
  393. self.head_dim,
  394. rotary_dim=self.head_dim,
  395. max_position=max_position_embeddings,
  396. base=rope_theta,
  397. rope_scaling=rope_scaling,
  398. )
  399. self.attn = Attention(self.num_heads,
  400. self.head_dim,
  401. self.scaling,
  402. cache_config=cache_config,
  403. quant_config=quant_config)
  404. def forward(
  405. self,
  406. positions: torch.Tensor,
  407. hidden_states: torch.Tensor,
  408. kv_cache: torch.Tensor,
  409. attn_metadata: AttentionMetadata,
  410. ) -> torch.Tensor:
  411. qkv, _ = self.c_attn(hidden_states)
  412. q, k, v = qkv.chunk(chunks=3, dim=-1)
  413. q, k = self.rotary_emb(positions, q, k)
  414. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  415. output, _ = self.c_proj(attn_output)
  416. return output
  417. class QWenBlock(nn.Module):
  418. def __init__(
  419. self,
  420. config: PretrainedConfig,
  421. cache_config: Optional[CacheConfig] = None,
  422. quant_config: Optional[QuantizationConfig] = None,
  423. ):
  424. super().__init__()
  425. self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  426. rope_theta = getattr(config, "rope_theta", 10000)
  427. rope_scaling = getattr(config, "rope_scaling", None)
  428. self.attn = QWenAttention(config.hidden_size,
  429. config.num_attention_heads,
  430. config.max_position_embeddings,
  431. rope_theta=rope_theta,
  432. rope_scaling=rope_scaling,
  433. cache_config=cache_config,
  434. quant_config=quant_config)
  435. self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  436. self.mlp = QWenMLP(config.hidden_size,
  437. config.intermediate_size // 2,
  438. quant_config=quant_config)
  439. def forward(
  440. self,
  441. positions: torch.Tensor,
  442. hidden_states: torch.Tensor,
  443. kv_cache: torch.Tensor,
  444. attn_metadata: AttentionMetadata,
  445. residual: Optional[torch.Tensor],
  446. ) -> Tuple[torch.Tensor, torch.Tensor]:
  447. # Self Attention
  448. if residual is None:
  449. residual = hidden_states
  450. hidden_states = self.ln_1(hidden_states)
  451. else:
  452. hidden_states, residual = self.ln_1(hidden_states, residual)
  453. hidden_states = self.attn(
  454. positions=positions,
  455. hidden_states=hidden_states,
  456. kv_cache=kv_cache,
  457. attn_metadata=attn_metadata,
  458. )
  459. # Fully Connected
  460. hidden_states, residual = self.ln_2(hidden_states, residual)
  461. hidden_states = self.mlp(hidden_states)
  462. return hidden_states, residual
  463. class QWenModel(nn.Module):
  464. def __init__(
  465. self,
  466. config: PretrainedConfig,
  467. cache_config: Optional[CacheConfig] = None,
  468. quant_config: Optional[QuantizationConfig] = None,
  469. prefix: str = "",
  470. ):
  471. super().__init__()
  472. self.config = config
  473. self.vocab_size = config.vocab_size
  474. self.wte = VocabParallelEmbedding(
  475. config.vocab_size,
  476. config.hidden_size,
  477. )
  478. self.start_layer, self.end_layer, self.h = make_layers(
  479. config.num_hidden_layers,
  480. lambda prefix: QWenBlock(config, cache_config, quant_config),
  481. prefix=f"{prefix}.h")
  482. self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  483. self.visual = VisionTransformer(**config.visual,
  484. quant_config=quant_config) if hasattr(
  485. config, "visual") else None
  486. def forward(
  487. self,
  488. input_ids: torch.Tensor,
  489. positions: torch.Tensor,
  490. kv_caches: List[torch.Tensor],
  491. attn_metadata: AttentionMetadata,
  492. intermediate_tensors: Optional[IntermediateTensors],
  493. pixel_values: Optional[QwenImageInputs],
  494. ) -> torch.Tensor:
  495. img_pos = None
  496. # If pixel / visual embeddings are provided, this is a visual model
  497. if pixel_values is not None and self.visual is not None:
  498. if pixel_values["type"] != "image_embeds":
  499. image_embeds = self.visual(pixel_values["data"])
  500. else:
  501. image_embeds = pixel_values["data"]
  502. # features should be of shape (# images, 256, hidden_dim)
  503. img_pos = self.visual.get_image_positions(input_ids)
  504. if isinstance(
  505. img_pos,
  506. np.ndarray) and img_pos.shape[0] != image_embeds.shape[0]:
  507. raise ValueError(
  508. f"Number of placeholders: {img_pos.shape[0]} "
  509. f"does not match number of images {image_embeds.shape[0]}."
  510. )
  511. if get_pp_group().is_first_rank:
  512. hidden_states = self.wte(input_ids)
  513. # Merge the image embeddings into the hidden states if actually have
  514. # visual features and the corresponding image tokens
  515. if img_pos is not None:
  516. for idx, (img_bos, img_eos) in enumerate(img_pos):
  517. hidden_states[img_bos + 1:img_eos] = image_embeds[idx]
  518. residual = None
  519. else:
  520. assert intermediate_tensors is not None
  521. hidden_states = intermediate_tensors["hidden_states"]
  522. residual = intermediate_tensors["residual"]
  523. for i in range(self.start_layer, self.end_layer):
  524. layer = self.h[i]
  525. hidden_states, residual = layer(
  526. positions,
  527. hidden_states,
  528. kv_caches[i - self.start_layer],
  529. attn_metadata,
  530. residual,
  531. )
  532. if not get_pp_group().is_last_rank:
  533. return IntermediateTensors({
  534. "hidden_states": hidden_states,
  535. "residual": residual
  536. })
  537. hidden_states, _ = self.ln_f(hidden_states, residual)
  538. return hidden_states
  539. def get_image_text(image_num: int, padding: bool) -> str:
  540. """Retrieves a placeholder text that when tokenized, will be expanded with
  541. image pads.
  542. Args:
  543. image_num: The number of the image that we want a text prompt for.
  544. Images should be indexed starting at 1.
  545. padding: Whether or not padding should be manually added.
  546. Returns:
  547. Text placeholder prompt for the image being considered.
  548. """
  549. image_start = f"Picture {image_num}: {IMG_START}"
  550. image_end = f"{IMG_END}\n"
  551. if not padding:
  552. return f"{image_start}{image_end}"
  553. return f"{image_start}{MAX_QWEN_IMG_TOKENS * IMG_PAD}{image_end}"
  554. def input_processor_for_qwen(ctx: InputContext,
  555. llm_inputs: LLMInputs) -> LLMInputs:
  556. """Processes the inputs, which may or may not be multimodal.
  557. Multimodal inputs will only be processed if the model has a "visual"
  558. component in its model config, otherwise they'll be ignored.
  559. Args:
  560. ctx: Context of the loaded model.
  561. llm_inputs: LLM inputs which may have a multi_modal_data attribute.
  562. Returns:
  563. If the model is language only or not multimodal inputs were provided,
  564. returns llm_inputs unmodified. Otherwise, processes the multimodal
  565. images / image embeddings and adds the fixed-length image placeholders.
  566. """
  567. multi_modal_data = llm_inputs.get("multi_modal_data")
  568. # Only process images if we have multimodal data and a visual config
  569. hf_config = ctx.get_hf_config()
  570. if (multi_modal_data is None or "image" not in multi_modal_data
  571. or not hasattr(hf_config, "visual")):
  572. return llm_inputs
  573. prompt = llm_inputs.get("prompt")
  574. prompt_token_ids = llm_inputs["prompt_token_ids"]
  575. model_config = ctx.model_config
  576. tokenizer = cached_get_tokenizer(model_config.tokenizer,
  577. trust_remote_code=True)
  578. image_data = multi_modal_data["image"]
  579. if isinstance(image_data, torch.Tensor):
  580. num_dims = len(image_data.shape)
  581. if num_dims < 2 or num_dims > 3:
  582. raise ValueError(
  583. f"Expected img embeds to be have 3 dimensions, got {num_dims}")
  584. num_images = 1 if num_dims == 2 else image_data.shape[0]
  585. elif isinstance(image_data, Image.Image):
  586. num_images = 1
  587. elif is_list_of(image_data, Image.Image):
  588. num_images = len(image_data)
  589. else:
  590. raise TypeError(f"Invalid image type: {type(image_data)}")
  591. if prompt is None:
  592. prompt = tokenizer.decode(prompt_token_ids)
  593. # Drops anything between <img>/</img> tags; encoding with the tokenizer
  594. # will automatically add the image pads for the context.
  595. new_prompt, num_matched_images = re.subn(
  596. r"(Picture \d*: <img>).*?(<\/img>\n)",
  597. r"\1\2",
  598. prompt,
  599. )
  600. if num_matched_images != num_images:
  601. logger.warning(
  602. f"Number of matched image placeholders {num_matched_images} "
  603. f"doesn't match the number of expected images {num_images}; "
  604. "check your placeholder formatting."
  605. )
  606. new_prompt_token_ids = tokenizer.encode(new_prompt)
  607. return LLMInputs(prompt=new_prompt,
  608. prompt_token_ids=new_prompt_token_ids,
  609. multi_modal_data=multi_modal_data)
  610. def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs:
  611. """Maps the input data to its MultiModalInputs (if any).
  612. Args:
  613. ctx: Context of the loaded model.
  614. data: data potentially containing image/image embeddings to be mapped
  615. to pixel_values in .forward() for a visual QWenLMHeadModel model.
  616. Returns:
  617. MultiModalInputs containing the stacked normalized images tensor or
  618. image embeddings.
  619. """
  620. # Early exit if we have provided an image to a language only Qwen model
  621. hf_config = ctx.get_hf_config()
  622. if not hasattr(hf_config, "visual"):
  623. logger.warning(
  624. "Images were provided but this model has no visual config; "
  625. "multimodal inputs will not be forwarded to the model.")
  626. return MultiModalInputs()
  627. model_config = ctx.model_config
  628. tokenizer = cached_get_tokenizer(model_config.tokenizer,
  629. trust_remote_code=True)
  630. image_pair_tok = tokenizer.encode(IMG_START + IMG_END,
  631. add_special_tokens=False,
  632. return_tensors="pt").squeeze()
  633. image_start_id = image_pair_tok[0]
  634. image_end_id = image_pair_tok[-1]
  635. if (image_start_id + 1) != image_end_id:
  636. raise ValueError(
  637. f"Found image end ID {image_end_id}, but expected {IMG_START} + 1")
  638. if len(image_pair_tok) != (MAX_QWEN_IMG_TOKENS + 2):
  639. raise ValueError(
  640. f"Expected image context length of {MAX_QWEN_IMG_TOKENS}, "
  641. f"but got {image_pair_tok - 2}")
  642. hf_config = ctx.get_hf_config()
  643. image_size = hf_config.visual["image_size"]
  644. img_emb_size = hf_config.visual["output_dim"]
  645. if isinstance(data, torch.Tensor):
  646. # It's expected that our values have already been processed
  647. # by the visual transformer; shape is expected to be:
  648. # (# images, 256, hidden_size)
  649. if len(data.shape) == 2:
  650. # Assume only one image embed was provided; unsqueeze the extra dim
  651. data = data.unsqueeze(0)
  652. if len(data.shape) != 3 or data.shape[
  653. 1] != MAX_QWEN_IMG_TOKENS or data.shape[2] != img_emb_size:
  654. raise ValueError(
  655. "Expected image embeds to be a tensor of shape"
  656. f"[# images, {MAX_QWEN_IMG_TOKENS}, {img_emb_size}], but "
  657. f"received shape [{data.shape}]")
  658. pixel_values = data
  659. else:
  660. transform = build_normalization_transform(image_size)
  661. if not isinstance(data, (list, tuple)):
  662. data = [data]
  663. transformed_images = [transform(datum) for datum in data]
  664. pixel_values = torch.stack(transformed_images, dim=0)
  665. return MultiModalInputs({"pixel_values": pixel_values})
  666. def build_normalization_transform(image_size: int) -> transforms.Compose:
  667. """Builds a normalization transform which can be applied to one or
  668. more input images from which we want to extract visual features.
  669. Args:
  670. image_size: size of the image to be processed for visual embeddings.
  671. Returns:
  672. Callable transform for normalizing and resizing one RGB image.
  673. """
  674. return transforms.Compose([
  675. transforms.Resize((image_size, image_size),
  676. interpolation=InterpolationMode.BICUBIC),
  677. transforms.ToTensor(),
  678. transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
  679. ])
  680. def dummy_data_for_qwen(
  681. ctx: InputContext,
  682. seq_len: int,
  683. mm_counts: Mapping[str, int],
  684. ) -> Tuple[SequenceData, Optional[Dict]]:
  685. """Build dummy data for warming up Qwen models; this will only contain text
  686. matching the defaults for APHRODITE unless the model has a visual config.
  687. Args:
  688. ctx: Context of the loaded model.
  689. seq_len: Number of tokens in the text sequence.
  690. mm_counts: multimodal data counts.
  691. Returns:
  692. Tuple containing sequential and multimodal data.
  693. """
  694. hf_config = ctx.get_hf_config()
  695. # The presence of a visual config indicates this is a multimodal model.
  696. # If we don't have it, the model is considered an LLM for warmup purposes.
  697. if not hasattr(hf_config, "visual"):
  698. seq_data = SequenceData.from_token_counts((0, seq_len))
  699. mm_data = None
  700. return seq_data, mm_data
  701. # We have a visual component - use images to warm up
  702. num_images = mm_counts["image"]
  703. model_config = ctx.model_config
  704. tokenizer = cached_get_tokenizer(model_config.tokenizer,
  705. trust_remote_code=True)
  706. # Build the image prompts with no imgpads; the tokenizer will add img pads
  707. image_prompt = ''.join(
  708. [get_image_text(idx, False) for idx in range(1, num_images + 1)])
  709. toks = tokenizer.encode(image_prompt, add_special_tokens=False)
  710. # Make sure we actually get the fixed context size per tok padding
  711. num_pads = toks.count(tokenizer.encode(IMG_PAD)[0])
  712. if num_pads != (num_images * MAX_QWEN_IMG_TOKENS):
  713. raise ValueError(
  714. f"Tokenized dummy data should encode {MAX_QWEN_IMG_TOKENS} pads"
  715. f" per image, but got {num_pads} pads for {num_images} image(s)"
  716. " in total. Are you using a qwen tokenizer?")
  717. # Ensure the number of tokens is at minimum the sequence length provided
  718. if len(toks) < seq_len:
  719. toks += [0] * (seq_len - len(toks))
  720. seq_data = SequenceData.from_seqs(toks)
  721. # Build the input images; width/height doesn't actually matter here since
  722. # the data will get resized and the # of tokens per image is constant
  723. image = Image.new("RGB", (224, 224), color=0)
  724. mm_data = {"image": image if num_images == 1 else [image] * num_images}
  725. return seq_data, mm_data
  726. @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen)
  727. @MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS)
  728. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen)
  729. @INPUT_REGISTRY.register_input_processor(input_processor_for_qwen)
  730. class QWenLMHeadModel(nn.Module, SupportsMultiModal):
  731. def __init__(
  732. self,
  733. config: PretrainedConfig,
  734. multimodal_config: MultiModalConfig,
  735. cache_config: Optional[CacheConfig] = None,
  736. quant_config: Optional[QuantizationConfig] = None,
  737. ):
  738. super().__init__()
  739. self.config = config
  740. self.multimodal_config = multimodal_config
  741. self.quant_config = quant_config
  742. self.transformer = QWenModel(config, cache_config, quant_config)
  743. self.lm_head = ParallelLMHead(config.vocab_size,
  744. config.hidden_size,
  745. quant_config=quant_config)
  746. if self.config.tie_word_embeddings:
  747. self.lm_head.weight = self.transformer.wte.weight
  748. self.logits_processor = LogitsProcessor(config.vocab_size)
  749. self.sampler = Sampler()
  750. def _get_image_input_type(
  751. self,
  752. pixel_values: Optional[torch.Tensor]) -> Optional[QwenImageInputs]:
  753. """Determines if the provided pixel_values are normalized pixel values
  754. or image embeddings.
  755. Args:
  756. pixel_values: Optional data to processed into visual embeddings.
  757. Returns:
  758. None of the QwenImageInputs type used to determine whether or not
  759. the visual transformer needs to process the pixel_values.
  760. """
  761. if pixel_values is not None and self.transformer.visual is not None:
  762. pixel_values = flatten_bn(pixel_values)
  763. if len(pixel_values.shape) == 3 and pixel_values.shape[
  764. 1] == MAX_QWEN_IMG_TOKENS and pixel_values.shape[
  765. 2] == self.config.visual["output_dim"]:
  766. return QwenImageEmbeddingInputs(
  767. type="image_embeds",
  768. data=pixel_values,
  769. )
  770. else:
  771. # If we have the wrong shape, assume we still need to process
  772. return QwenImagePixelInputs(
  773. type="pixel_values",
  774. data=pixel_values,
  775. )
  776. return None
  777. def forward(self,
  778. input_ids: torch.Tensor,
  779. positions: torch.Tensor,
  780. kv_caches: List[torch.Tensor],
  781. attn_metadata: AttentionMetadata,
  782. intermediate_tensors: Optional[IntermediateTensors] = None,
  783. pixel_values: Optional[torch.Tensor] = None) -> torch.Tensor:
  784. pixel_values = self._get_image_input_type(pixel_values)
  785. hidden_states = self.transformer(input_ids, positions, kv_caches,
  786. attn_metadata, intermediate_tensors,
  787. pixel_values)
  788. return hidden_states
  789. def make_empty_intermediate_tensors(
  790. self, batch_size: int, dtype: torch.dtype,
  791. device: torch.device) -> IntermediateTensors:
  792. return IntermediateTensors({
  793. "hidden_states":
  794. torch.zeros((batch_size, self.config.hidden_size),
  795. dtype=dtype,
  796. device=device),
  797. "residual":
  798. torch.zeros((batch_size, self.config.hidden_size),
  799. dtype=dtype,
  800. device=device),
  801. })
  802. def compute_logits(
  803. self,
  804. hidden_states: torch.Tensor,
  805. sampling_metadata: SamplingMetadata,
  806. ) -> Optional[torch.Tensor]:
  807. logits = self.logits_processor(self.lm_head, hidden_states,
  808. sampling_metadata)
  809. return logits
  810. def sample(
  811. self,
  812. logits: torch.Tensor,
  813. sampling_metadata: SamplingMetadata,
  814. ) -> Optional[SamplerOutput]:
  815. next_tokens = self.sampler(logits, sampling_metadata)
  816. return next_tokens
  817. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  818. stacked_params_mapping = [
  819. # (param_name, shard_name, shard_id)
  820. ("gate_up_proj", "w2", 0),
  821. ("gate_up_proj", "w1", 1),
  822. ]
  823. params_dict = dict(self.named_parameters())
  824. for name, loaded_weight in weights:
  825. if "rotary_emb.inv_freq" in name:
  826. continue
  827. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  828. if weight_name not in name:
  829. continue
  830. name = name.replace(weight_name, param_name)
  831. # Skip loading extra bias for GPTQ models.
  832. if name.endswith(".bias") and name not in params_dict:
  833. continue
  834. # Skip layers on other devices.
  835. if is_pp_missing_parameter(name, self):
  836. continue
  837. param = params_dict[name]
  838. weight_loader = param.weight_loader
  839. weight_loader(param, loaded_weight, shard_id)
  840. break
  841. else:
  842. # Skip loading extra bias for GPTQ models.
  843. if name.endswith(".bias") and name not in params_dict:
  844. continue
  845. # Skip layers on other devices.
  846. if is_pp_missing_parameter(name, self):
  847. continue
  848. param = params_dict[name]
  849. weight_loader = getattr(param, "weight_loader",
  850. default_weight_loader)
  851. weight_loader(param, loaded_weight)