qwen2.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
  4. # Copyright 2024 The Qwen team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only Qwen2 model compatible with HuggingFace weights."""
  25. from typing import Iterable, List, Optional, Tuple
  26. import torch
  27. from torch import nn
  28. from transformers import Qwen2Config
  29. from aphrodite.attention import Attention, AttentionMetadata
  30. from aphrodite.common.config import CacheConfig, LoRAConfig
  31. from aphrodite.common.sequence import IntermediateTensors
  32. from aphrodite.distributed import (get_current_tp_rank_partition_size,
  33. get_pp_group,
  34. get_tensor_model_parallel_rank,
  35. get_tensor_model_parallel_world_size)
  36. from aphrodite.modeling.layers.activation import SiluAndMul
  37. from aphrodite.modeling.layers.layernorm import RMSNorm
  38. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  39. QKVParallelLinear,
  40. RowParallelLinear)
  41. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  42. from aphrodite.modeling.layers.rotary_embedding import get_rope
  43. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  44. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  45. ParallelLMHead, VocabParallelEmbedding)
  46. from aphrodite.modeling.model_loader.weight_utils import (
  47. default_weight_loader, maybe_remap_kv_scale_name)
  48. from aphrodite.modeling.models.interfaces import SupportsLoRA
  49. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  50. from aphrodite.quantization.base_config import QuantizationConfig
  51. from .utils import is_pp_missing_parameter, make_layers
  52. class Qwen2MLP(nn.Module):
  53. def __init__(
  54. self,
  55. hidden_size: int,
  56. intermediate_size: int,
  57. hidden_act: str,
  58. quant_config: Optional[QuantizationConfig] = None,
  59. ) -> None:
  60. super().__init__()
  61. self.gate_up_proj = MergedColumnParallelLinear(
  62. hidden_size, [intermediate_size] * 2,
  63. bias=False,
  64. quant_config=quant_config)
  65. self.down_proj = RowParallelLinear(intermediate_size,
  66. hidden_size,
  67. bias=False,
  68. quant_config=quant_config)
  69. if hidden_act != "silu":
  70. raise ValueError(f"Unsupported activation: {hidden_act}. "
  71. "Only silu is supported for now.")
  72. self.act_fn = SiluAndMul()
  73. def forward(self, x):
  74. gate_up, _ = self.gate_up_proj(x)
  75. x = self.act_fn(gate_up)
  76. x, _ = self.down_proj(x)
  77. return x
  78. class Qwen2Attention(nn.Module):
  79. def __init__(self,
  80. hidden_size: int,
  81. num_heads: int,
  82. num_kv_heads: int,
  83. max_position: int = 4096 * 32,
  84. rope_theta: float = 10000,
  85. cache_config: Optional[CacheConfig] = None,
  86. quant_config: Optional[QuantizationConfig] = None,
  87. rope_scaling: Optional[Tuple] = None) -> None:
  88. super().__init__()
  89. self.hidden_size = hidden_size
  90. tp_size = get_tensor_model_parallel_world_size()
  91. tp_rank = get_tensor_model_parallel_rank()
  92. self.total_num_heads = num_heads
  93. self.total_num_kv_heads = num_kv_heads
  94. self.num_kv_heads = max(
  95. 1,
  96. get_current_tp_rank_partition_size(self.total_num_kv_heads,
  97. tp_rank, tp_size))
  98. num_heads_per_kv_head = self.total_num_heads // self.total_num_kv_heads
  99. self.num_heads = self.num_kv_heads * num_heads_per_kv_head
  100. self.head_dim = hidden_size // self.total_num_heads
  101. self.q_size = self.num_heads * self.head_dim
  102. self.kv_size = self.num_kv_heads * self.head_dim
  103. self.scaling = self.head_dim**-0.5
  104. self.rope_theta = rope_theta
  105. self.qkv_proj = QKVParallelLinear(
  106. hidden_size,
  107. self.head_dim,
  108. self.total_num_heads,
  109. self.total_num_kv_heads,
  110. bias=True,
  111. quant_config=quant_config,
  112. )
  113. self.o_proj = RowParallelLinear(
  114. self.total_num_heads * self.head_dim,
  115. hidden_size,
  116. bias=False,
  117. quant_config=quant_config,
  118. partition_multiple_of=num_heads_per_kv_head * self.head_dim,
  119. )
  120. self.rotary_emb = get_rope(
  121. self.head_dim,
  122. rotary_dim=self.head_dim,
  123. max_position=max_position,
  124. base=self.rope_theta,
  125. rope_scaling=rope_scaling,
  126. )
  127. self.attn = Attention(self.num_heads,
  128. self.head_dim,
  129. self.scaling,
  130. num_kv_heads=self.num_kv_heads,
  131. cache_config=cache_config,
  132. quant_config=quant_config)
  133. def forward(
  134. self,
  135. positions: torch.Tensor,
  136. hidden_states: torch.Tensor,
  137. kv_cache: torch.Tensor,
  138. attn_metadata: AttentionMetadata,
  139. ) -> torch.Tensor:
  140. qkv, _ = self.qkv_proj(hidden_states)
  141. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  142. q, k = self.rotary_emb(positions, q, k)
  143. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  144. output, _ = self.o_proj(attn_output)
  145. return output
  146. class Qwen2DecoderLayer(nn.Module):
  147. def __init__(
  148. self,
  149. config: Qwen2Config,
  150. cache_config: Optional[CacheConfig] = None,
  151. quant_config: Optional[QuantizationConfig] = None,
  152. ) -> None:
  153. super().__init__()
  154. self.hidden_size = config.hidden_size
  155. # Requires transformers > 4.32.0
  156. rope_theta = getattr(config, "rope_theta", 1000000)
  157. rope_scaling = getattr(config, "rope_scaling", None)
  158. self.self_attn = Qwen2Attention(
  159. hidden_size=self.hidden_size,
  160. num_heads=config.num_attention_heads,
  161. max_position=config.max_position_embeddings,
  162. num_kv_heads=config.num_key_value_heads,
  163. rope_theta=rope_theta,
  164. cache_config=cache_config,
  165. quant_config=quant_config,
  166. rope_scaling=rope_scaling)
  167. self.mlp = Qwen2MLP(
  168. hidden_size=self.hidden_size,
  169. intermediate_size=config.intermediate_size,
  170. hidden_act=config.hidden_act,
  171. quant_config=quant_config,
  172. )
  173. self.input_layernorm = RMSNorm(config.hidden_size,
  174. eps=config.rms_norm_eps)
  175. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  176. eps=config.rms_norm_eps)
  177. def forward(
  178. self,
  179. positions: torch.Tensor,
  180. hidden_states: torch.Tensor,
  181. kv_cache: torch.Tensor,
  182. attn_metadata: AttentionMetadata,
  183. residual: Optional[torch.Tensor],
  184. ) -> Tuple[torch.Tensor, torch.Tensor]:
  185. # Self Attention
  186. if residual is None:
  187. residual = hidden_states
  188. hidden_states = self.input_layernorm(hidden_states)
  189. else:
  190. hidden_states, residual = self.input_layernorm(
  191. hidden_states, residual)
  192. hidden_states = self.self_attn(
  193. positions=positions,
  194. hidden_states=hidden_states,
  195. kv_cache=kv_cache,
  196. attn_metadata=attn_metadata,
  197. )
  198. # Fully Connected
  199. hidden_states, residual = self.post_attention_layernorm(
  200. hidden_states, residual)
  201. hidden_states = self.mlp(hidden_states)
  202. return hidden_states, residual
  203. class Qwen2Model(nn.Module):
  204. def __init__(
  205. self,
  206. config: Qwen2Config,
  207. cache_config: Optional[CacheConfig] = None,
  208. quant_config: Optional[QuantizationConfig] = None,
  209. prefix: str = "",
  210. ) -> None:
  211. super().__init__()
  212. self.config = config
  213. self.padding_idx = config.pad_token_id
  214. self.vocab_size = config.vocab_size
  215. self.embed_tokens = VocabParallelEmbedding(
  216. config.vocab_size,
  217. config.hidden_size,
  218. )
  219. self.start_layer, self.end_layer, self.layers = make_layers(
  220. config.num_hidden_layers,
  221. lambda prefix: Qwen2DecoderLayer(config=config,
  222. cache_config=cache_config,
  223. quant_config=quant_config),
  224. prefix=f"{prefix}.layers",
  225. )
  226. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  227. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  228. return self.embed_tokens(input_ids)
  229. def forward(
  230. self,
  231. input_ids: torch.Tensor,
  232. positions: torch.Tensor,
  233. kv_caches: List[torch.Tensor],
  234. attn_metadata: AttentionMetadata,
  235. intermediate_tensors: Optional[IntermediateTensors] = None,
  236. inputs_embeds: Optional[torch.Tensor] = None,
  237. ) -> torch.Tensor:
  238. if get_pp_group().is_first_rank:
  239. if inputs_embeds is not None:
  240. hidden_states = inputs_embeds
  241. else:
  242. hidden_states = self.embed_tokens(input_ids)
  243. residual = None
  244. else:
  245. assert intermediate_tensors is not None
  246. hidden_states = intermediate_tensors["hidden_states"]
  247. residual = intermediate_tensors["residual"]
  248. for i in range(self.start_layer, self.end_layer):
  249. layer = self.layers[i]
  250. hidden_states, residual = layer(
  251. positions,
  252. hidden_states,
  253. kv_caches[i - self.start_layer],
  254. attn_metadata,
  255. residual,
  256. )
  257. if not get_pp_group().is_last_rank:
  258. return IntermediateTensors({
  259. "hidden_states": hidden_states,
  260. "residual": residual
  261. })
  262. hidden_states, _ = self.norm(hidden_states, residual)
  263. return hidden_states
  264. class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
  265. packed_modules_mapping = {
  266. "qkv_proj": [
  267. "q_proj",
  268. "k_proj",
  269. "v_proj",
  270. ],
  271. "gate_up_proj": [
  272. "gate_proj",
  273. "up_proj",
  274. ],
  275. }
  276. # LoRA specific attributes
  277. supported_lora_modules = [
  278. "qkv_proj",
  279. "o_proj",
  280. "gate_up_proj",
  281. "down_proj",
  282. ]
  283. embedding_modules = {}
  284. embedding_padding_modules = []
  285. def __init__(
  286. self,
  287. config: Qwen2Config,
  288. cache_config: Optional[CacheConfig] = None,
  289. quant_config: Optional[QuantizationConfig] = None,
  290. lora_config: Optional[LoRAConfig] = None,
  291. ) -> None:
  292. # TODO (: see if this can be moved out
  293. if (cache_config.sliding_window is not None
  294. and hasattr(config, "max_window_layers")):
  295. raise ValueError(
  296. "Sliding window for some but all layers is not "
  297. "supported. This model uses sliding window "
  298. "but `max_window_layers` = "
  299. f"{config.max_window_layers} is less than "
  300. "`num_hidden_layers` = "
  301. f"{config.num_hidden_layers}. Please open an issue"
  302. " to discuss this feature.")
  303. super().__init__()
  304. self.config = config
  305. self.lora_config = lora_config
  306. self.quant_config = quant_config
  307. self.model = Qwen2Model(config, cache_config, quant_config)
  308. if config.tie_word_embeddings:
  309. self.lm_head = self.model.embed_tokens
  310. else:
  311. self.lm_head = ParallelLMHead(config.vocab_size,
  312. config.hidden_size,
  313. quant_config=quant_config)
  314. self.logits_processor = LogitsProcessor(config.vocab_size)
  315. self.sampler = Sampler()
  316. def forward(
  317. self,
  318. input_ids: torch.Tensor,
  319. positions: torch.Tensor,
  320. kv_caches: List[torch.Tensor],
  321. attn_metadata: AttentionMetadata,
  322. intermediate_tensors: Optional[IntermediateTensors] = None,
  323. ) -> torch.Tensor:
  324. hidden_states = self.model(input_ids, positions, kv_caches,
  325. attn_metadata, intermediate_tensors)
  326. return hidden_states
  327. def compute_logits(
  328. self,
  329. hidden_states: torch.Tensor,
  330. sampling_metadata: SamplingMetadata,
  331. ) -> Optional[torch.Tensor]:
  332. logits = self.logits_processor(self.lm_head, hidden_states,
  333. sampling_metadata)
  334. return logits
  335. def make_empty_intermediate_tensors(
  336. self, batch_size: int, dtype: torch.dtype,
  337. device: torch.device) -> IntermediateTensors:
  338. return IntermediateTensors({
  339. "hidden_states":
  340. torch.zeros((batch_size, self.config.hidden_size),
  341. dtype=dtype,
  342. device=device),
  343. "residual":
  344. torch.zeros((batch_size, self.config.hidden_size),
  345. dtype=dtype,
  346. device=device),
  347. })
  348. def sample(
  349. self,
  350. logits: torch.Tensor,
  351. sampling_metadata: SamplingMetadata,
  352. ) -> Optional[SamplerOutput]:
  353. next_tokens = self.sampler(logits, sampling_metadata)
  354. return next_tokens
  355. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  356. stacked_params_mapping = [
  357. # (param_name, shard_name, shard_id)
  358. ("qkv_proj", "q_proj", "q"),
  359. ("qkv_proj", "k_proj", "k"),
  360. ("qkv_proj", "v_proj", "v"),
  361. ("gate_up_proj", "gate_proj", 0),
  362. ("gate_up_proj", "up_proj", 1),
  363. ]
  364. params_dict = dict(self.named_parameters(remove_duplicate=False))
  365. for name, loaded_weight in weights:
  366. if "rotary_emb.inv_freq" in name:
  367. continue
  368. if self.config.tie_word_embeddings and "lm_head.weight" in name:
  369. continue
  370. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  371. if weight_name not in name:
  372. continue
  373. name = name.replace(weight_name, param_name)
  374. # Skip loading extra bias for GPTQ models.
  375. if name.endswith(".bias") and name not in params_dict:
  376. continue
  377. if is_pp_missing_parameter(name, self):
  378. continue
  379. param = params_dict[name]
  380. weight_loader = param.weight_loader
  381. weight_loader(param, loaded_weight, shard_id)
  382. break
  383. else:
  384. # Skip loading extra bias for GPTQ models.
  385. if name.endswith(".bias") and name not in params_dict:
  386. continue
  387. # Remapping the name of FP8 kv-scale.
  388. name = maybe_remap_kv_scale_name(name, params_dict)
  389. if name is None:
  390. continue
  391. if is_pp_missing_parameter(name, self):
  392. continue
  393. param = params_dict[name]
  394. weight_loader = getattr(param, "weight_loader",
  395. default_weight_loader)
  396. weight_loader(param, loaded_weight)