qwen2.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
  4. # Copyright 2024 The Qwen team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only Qwen2 model compatible with HuggingFace weights."""
  25. from typing import Iterable, List, Optional, Tuple
  26. import torch
  27. from torch import nn
  28. from transformers import Qwen2Config
  29. from aphrodite.attention import Attention, AttentionMetadata
  30. from aphrodite.common.config import CacheConfig, LoRAConfig
  31. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  32. from aphrodite.common.utils import progress_bar
  33. from aphrodite.distributed import (get_current_tp_rank_partition_size,
  34. get_pp_group,
  35. get_tensor_model_parallel_rank,
  36. get_tensor_model_parallel_world_size)
  37. from aphrodite.modeling.layers.activation import SiluAndMul
  38. from aphrodite.modeling.layers.layernorm import RMSNorm
  39. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  40. QKVParallelLinear,
  41. RowParallelLinear)
  42. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  43. from aphrodite.modeling.layers.rotary_embedding import get_rope
  44. from aphrodite.modeling.layers.sampler import Sampler
  45. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  46. ParallelLMHead, VocabParallelEmbedding)
  47. from aphrodite.modeling.model_loader.weight_utils import (
  48. default_weight_loader, maybe_remap_kv_scale_name)
  49. from aphrodite.modeling.models.interfaces import SupportsLoRA
  50. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  51. from aphrodite.quantization.base_config import QuantizationConfig
  52. from .utils import is_pp_missing_parameter, make_layers
  53. class Qwen2MLP(nn.Module):
  54. def __init__(
  55. self,
  56. hidden_size: int,
  57. intermediate_size: int,
  58. hidden_act: str,
  59. quant_config: Optional[QuantizationConfig] = None,
  60. ) -> None:
  61. super().__init__()
  62. self.gate_up_proj = MergedColumnParallelLinear(
  63. hidden_size, [intermediate_size] * 2,
  64. bias=False,
  65. quant_config=quant_config)
  66. self.down_proj = RowParallelLinear(intermediate_size,
  67. hidden_size,
  68. bias=False,
  69. quant_config=quant_config)
  70. if hidden_act != "silu":
  71. raise ValueError(f"Unsupported activation: {hidden_act}. "
  72. "Only silu is supported for now.")
  73. self.act_fn = SiluAndMul()
  74. def forward(self, x):
  75. gate_up, _ = self.gate_up_proj(x)
  76. x = self.act_fn(gate_up)
  77. x, _ = self.down_proj(x)
  78. return x
  79. class Qwen2Attention(nn.Module):
  80. def __init__(self,
  81. hidden_size: int,
  82. num_heads: int,
  83. num_kv_heads: int,
  84. max_position: int = 4096 * 32,
  85. rope_theta: float = 10000,
  86. cache_config: Optional[CacheConfig] = None,
  87. quant_config: Optional[QuantizationConfig] = None,
  88. rope_scaling: Optional[Tuple] = None) -> None:
  89. super().__init__()
  90. self.hidden_size = hidden_size
  91. tp_size = get_tensor_model_parallel_world_size()
  92. tp_rank = get_tensor_model_parallel_rank()
  93. self.total_num_heads = num_heads
  94. self.total_num_kv_heads = num_kv_heads
  95. self.num_kv_heads = max(
  96. 1,
  97. get_current_tp_rank_partition_size(self.total_num_kv_heads,
  98. tp_rank, tp_size))
  99. num_heads_per_kv_head = self.total_num_heads // self.total_num_kv_heads
  100. self.num_heads = self.num_kv_heads * num_heads_per_kv_head
  101. self.head_dim = hidden_size // self.total_num_heads
  102. self.q_size = self.num_heads * self.head_dim
  103. self.kv_size = self.num_kv_heads * self.head_dim
  104. self.scaling = self.head_dim**-0.5
  105. self.rope_theta = rope_theta
  106. self.qkv_proj = QKVParallelLinear(
  107. hidden_size,
  108. self.head_dim,
  109. self.total_num_heads,
  110. self.total_num_kv_heads,
  111. bias=True,
  112. quant_config=quant_config,
  113. )
  114. self.o_proj = RowParallelLinear(
  115. self.total_num_heads * self.head_dim,
  116. hidden_size,
  117. bias=False,
  118. quant_config=quant_config,
  119. partition_multiple_of=num_heads_per_kv_head * self.head_dim,
  120. )
  121. self.rotary_emb = get_rope(
  122. self.head_dim,
  123. rotary_dim=self.head_dim,
  124. max_position=max_position,
  125. base=self.rope_theta,
  126. rope_scaling=rope_scaling,
  127. )
  128. self.attn = Attention(self.num_heads,
  129. self.head_dim,
  130. self.scaling,
  131. num_kv_heads=self.num_kv_heads,
  132. cache_config=cache_config,
  133. quant_config=quant_config)
  134. def forward(
  135. self,
  136. positions: torch.Tensor,
  137. hidden_states: torch.Tensor,
  138. kv_cache: torch.Tensor,
  139. attn_metadata: AttentionMetadata,
  140. ) -> torch.Tensor:
  141. qkv, _ = self.qkv_proj(hidden_states)
  142. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  143. q, k = self.rotary_emb(positions, q, k)
  144. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  145. output, _ = self.o_proj(attn_output)
  146. return output
  147. class Qwen2DecoderLayer(nn.Module):
  148. def __init__(
  149. self,
  150. config: Qwen2Config,
  151. cache_config: Optional[CacheConfig] = None,
  152. quant_config: Optional[QuantizationConfig] = None,
  153. ) -> None:
  154. super().__init__()
  155. self.hidden_size = config.hidden_size
  156. # Requires transformers > 4.32.0
  157. rope_theta = getattr(config, "rope_theta", 1000000)
  158. rope_scaling = getattr(config, "rope_scaling", None)
  159. self.self_attn = Qwen2Attention(
  160. hidden_size=self.hidden_size,
  161. num_heads=config.num_attention_heads,
  162. max_position=config.max_position_embeddings,
  163. num_kv_heads=config.num_key_value_heads,
  164. rope_theta=rope_theta,
  165. cache_config=cache_config,
  166. quant_config=quant_config,
  167. rope_scaling=rope_scaling)
  168. self.mlp = Qwen2MLP(
  169. hidden_size=self.hidden_size,
  170. intermediate_size=config.intermediate_size,
  171. hidden_act=config.hidden_act,
  172. quant_config=quant_config,
  173. )
  174. self.input_layernorm = RMSNorm(config.hidden_size,
  175. eps=config.rms_norm_eps)
  176. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  177. eps=config.rms_norm_eps)
  178. def forward(
  179. self,
  180. positions: torch.Tensor,
  181. hidden_states: torch.Tensor,
  182. kv_cache: torch.Tensor,
  183. attn_metadata: AttentionMetadata,
  184. residual: Optional[torch.Tensor],
  185. ) -> Tuple[torch.Tensor, torch.Tensor]:
  186. # Self Attention
  187. if residual is None:
  188. residual = hidden_states
  189. hidden_states = self.input_layernorm(hidden_states)
  190. else:
  191. hidden_states, residual = self.input_layernorm(
  192. hidden_states, residual)
  193. hidden_states = self.self_attn(
  194. positions=positions,
  195. hidden_states=hidden_states,
  196. kv_cache=kv_cache,
  197. attn_metadata=attn_metadata,
  198. )
  199. # Fully Connected
  200. hidden_states, residual = self.post_attention_layernorm(
  201. hidden_states, residual)
  202. hidden_states = self.mlp(hidden_states)
  203. return hidden_states, residual
  204. class Qwen2Model(nn.Module):
  205. def __init__(
  206. self,
  207. config: Qwen2Config,
  208. cache_config: Optional[CacheConfig] = None,
  209. quant_config: Optional[QuantizationConfig] = None,
  210. prefix: str = "",
  211. ) -> None:
  212. super().__init__()
  213. self.config = config
  214. self.padding_idx = config.pad_token_id
  215. self.vocab_size = config.vocab_size
  216. self.embed_tokens = VocabParallelEmbedding(
  217. config.vocab_size,
  218. config.hidden_size,
  219. )
  220. self.start_layer, self.end_layer, self.layers = make_layers(
  221. config.num_hidden_layers,
  222. lambda prefix: Qwen2DecoderLayer(config=config,
  223. cache_config=cache_config,
  224. quant_config=quant_config),
  225. prefix=f"{prefix}.layers",
  226. )
  227. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  228. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  229. return self.embed_tokens(input_ids)
  230. def forward(
  231. self,
  232. input_ids: torch.Tensor,
  233. positions: torch.Tensor,
  234. kv_caches: List[torch.Tensor],
  235. attn_metadata: AttentionMetadata,
  236. intermediate_tensors: Optional[IntermediateTensors] = None,
  237. inputs_embeds: Optional[torch.Tensor] = None,
  238. ) -> torch.Tensor:
  239. if get_pp_group().is_first_rank:
  240. if inputs_embeds is not None:
  241. hidden_states = inputs_embeds
  242. else:
  243. hidden_states = self.embed_tokens(input_ids)
  244. residual = None
  245. else:
  246. assert intermediate_tensors is not None
  247. hidden_states = intermediate_tensors["hidden_states"]
  248. residual = intermediate_tensors["residual"]
  249. for i in range(self.start_layer, self.end_layer):
  250. layer = self.layers[i]
  251. hidden_states, residual = layer(
  252. positions,
  253. hidden_states,
  254. kv_caches[i - self.start_layer],
  255. attn_metadata,
  256. residual,
  257. )
  258. if not get_pp_group().is_last_rank:
  259. return IntermediateTensors({
  260. "hidden_states": hidden_states,
  261. "residual": residual
  262. })
  263. hidden_states, _ = self.norm(hidden_states, residual)
  264. return hidden_states
  265. class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
  266. packed_modules_mapping = {
  267. "qkv_proj": [
  268. "q_proj",
  269. "k_proj",
  270. "v_proj",
  271. ],
  272. "gate_up_proj": [
  273. "gate_proj",
  274. "up_proj",
  275. ],
  276. }
  277. # LoRA specific attributes
  278. supported_lora_modules = [
  279. "qkv_proj",
  280. "o_proj",
  281. "gate_up_proj",
  282. "down_proj",
  283. ]
  284. embedding_modules = {}
  285. embedding_padding_modules = []
  286. def __init__(
  287. self,
  288. config: Qwen2Config,
  289. cache_config: Optional[CacheConfig] = None,
  290. quant_config: Optional[QuantizationConfig] = None,
  291. lora_config: Optional[LoRAConfig] = None,
  292. ) -> None:
  293. # TODO (: see if this can be moved out
  294. if (cache_config.sliding_window is not None
  295. and hasattr(config, "max_window_layers")):
  296. raise ValueError(
  297. "Sliding window for some but all layers is not "
  298. "supported. This model uses sliding window "
  299. "but `max_window_layers` = "
  300. f"{config.max_window_layers} is less than "
  301. "`num_hidden_layers` = "
  302. f"{config.num_hidden_layers}. Please open an issue"
  303. " to discuss this feature.")
  304. super().__init__()
  305. self.config = config
  306. self.lora_config = lora_config
  307. self.quant_config = quant_config
  308. self.model = Qwen2Model(config, cache_config, quant_config)
  309. if config.tie_word_embeddings:
  310. self.lm_head = self.model.embed_tokens
  311. else:
  312. self.lm_head = ParallelLMHead(config.vocab_size,
  313. config.hidden_size,
  314. quant_config=quant_config)
  315. self.logits_processor = LogitsProcessor(config.vocab_size)
  316. self.sampler = Sampler()
  317. def forward(
  318. self,
  319. input_ids: torch.Tensor,
  320. positions: torch.Tensor,
  321. kv_caches: List[torch.Tensor],
  322. attn_metadata: AttentionMetadata,
  323. intermediate_tensors: Optional[IntermediateTensors] = None,
  324. ) -> torch.Tensor:
  325. hidden_states = self.model(input_ids, positions, kv_caches,
  326. attn_metadata, intermediate_tensors)
  327. return hidden_states
  328. def compute_logits(
  329. self,
  330. hidden_states: torch.Tensor,
  331. sampling_metadata: SamplingMetadata,
  332. ) -> Optional[torch.Tensor]:
  333. logits = self.logits_processor(self.lm_head, hidden_states,
  334. sampling_metadata)
  335. return logits
  336. def make_empty_intermediate_tensors(
  337. self, batch_size: int, dtype: torch.dtype,
  338. device: torch.device) -> IntermediateTensors:
  339. return IntermediateTensors({
  340. "hidden_states":
  341. torch.zeros((batch_size, self.config.hidden_size),
  342. dtype=dtype,
  343. device=device),
  344. "residual":
  345. torch.zeros((batch_size, self.config.hidden_size),
  346. dtype=dtype,
  347. device=device),
  348. })
  349. def sample(
  350. self,
  351. logits: torch.Tensor,
  352. sampling_metadata: SamplingMetadata,
  353. ) -> Optional[SamplerOutput]:
  354. next_tokens = self.sampler(logits, sampling_metadata)
  355. return next_tokens
  356. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  357. stacked_params_mapping = [
  358. # (param_name, shard_name, shard_id)
  359. ("qkv_proj", "q_proj", "q"),
  360. ("qkv_proj", "k_proj", "k"),
  361. ("qkv_proj", "v_proj", "v"),
  362. ("gate_up_proj", "gate_proj", 0),
  363. ("gate_up_proj", "up_proj", 1),
  364. ]
  365. params_dict = dict(self.named_parameters(remove_duplicate=False))
  366. weights_list = list(weights)
  367. for name, loaded_weight in progress_bar(weights_list,
  368. desc="Loading modules..."):
  369. if "rotary_emb.inv_freq" in name:
  370. continue
  371. if self.config.tie_word_embeddings and "lm_head.weight" in name:
  372. continue
  373. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  374. if weight_name not in name:
  375. continue
  376. name = name.replace(weight_name, param_name)
  377. # Skip loading extra bias for GPTQ models.
  378. if name.endswith(".bias") and name not in params_dict:
  379. continue
  380. if is_pp_missing_parameter(name, self):
  381. continue
  382. param = params_dict[name]
  383. weight_loader = param.weight_loader
  384. weight_loader(param, loaded_weight, shard_id)
  385. break
  386. else:
  387. # Skip loading extra bias for GPTQ models.
  388. if name.endswith(".bias") and name not in params_dict:
  389. continue
  390. # Remapping the name of FP8 kv-scale.
  391. name = maybe_remap_kv_scale_name(name, params_dict)
  392. if name is None:
  393. continue
  394. if is_pp_missing_parameter(name, self):
  395. continue
  396. param = params_dict[name]
  397. weight_loader = getattr(param, "weight_loader",
  398. default_weight_loader)
  399. weight_loader(param, loaded_weight)