qwen2.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
  4. # Copyright 2024 The Qwen team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only Qwen2 model compatible with HuggingFace weights."""
  25. from typing import Iterable, List, Optional, Tuple
  26. import torch
  27. from torch import nn
  28. from transformers import Qwen2Config
  29. from aphrodite.attention import Attention, AttentionMetadata
  30. from aphrodite.common.config import LoRAConfig
  31. from aphrodite.common.sequence import SamplerOutput
  32. from aphrodite.distributed import get_tensor_model_parallel_world_size
  33. from aphrodite.modeling.layers.activation import SiluAndMul
  34. from aphrodite.modeling.layers.layernorm import RMSNorm
  35. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  36. QKVParallelLinear,
  37. RowParallelLinear)
  38. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  39. from aphrodite.modeling.layers.rotary_embedding import get_rope
  40. from aphrodite.modeling.layers.sampler import Sampler
  41. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  42. ParallelLMHead, VocabParallelEmbedding)
  43. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  44. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  45. from aphrodite.quantization.base_config import QuantizationConfig
  46. class Qwen2MLP(nn.Module):
  47. def __init__(
  48. self,
  49. hidden_size: int,
  50. intermediate_size: int,
  51. hidden_act: str,
  52. quant_config: Optional[QuantizationConfig] = None,
  53. ) -> None:
  54. super().__init__()
  55. self.gate_up_proj = MergedColumnParallelLinear(
  56. hidden_size, [intermediate_size] * 2,
  57. bias=False,
  58. quant_config=quant_config)
  59. self.down_proj = RowParallelLinear(intermediate_size,
  60. hidden_size,
  61. bias=False,
  62. quant_config=quant_config)
  63. if hidden_act != "silu":
  64. raise ValueError(f"Unsupported activation: {hidden_act}. "
  65. "Only silu is supported for now.")
  66. self.act_fn = SiluAndMul()
  67. def forward(self, x):
  68. gate_up, _ = self.gate_up_proj(x)
  69. x = self.act_fn(gate_up)
  70. x, _ = self.down_proj(x)
  71. return x
  72. class Qwen2Attention(nn.Module):
  73. def __init__(self,
  74. hidden_size: int,
  75. num_heads: int,
  76. num_kv_heads: int,
  77. max_position: int = 4096 * 32,
  78. rope_theta: float = 10000,
  79. use_sliding_window: bool = False,
  80. quant_config: Optional[QuantizationConfig] = None,
  81. sliding_window: Optional[int] = None) -> None:
  82. super().__init__()
  83. self.hidden_size = hidden_size
  84. tp_size = get_tensor_model_parallel_world_size()
  85. self.total_num_heads = num_heads
  86. assert self.total_num_heads % tp_size == 0
  87. self.num_heads = self.total_num_heads // tp_size
  88. self.total_num_kv_heads = num_kv_heads
  89. if self.total_num_kv_heads >= tp_size:
  90. # Number of KV heads is greater than TP size, so we partition
  91. # the KV heads across multiple tensor parallel GPUs.
  92. assert self.total_num_kv_heads % tp_size == 0
  93. else:
  94. # Number of KV heads is less than TP size, so we replicate
  95. # the KV heads across multiple tensor parallel GPUs.
  96. assert tp_size % self.total_num_kv_heads == 0
  97. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  98. self.head_dim = hidden_size // self.total_num_heads
  99. self.q_size = self.num_heads * self.head_dim
  100. self.kv_size = self.num_kv_heads * self.head_dim
  101. self.scaling = self.head_dim**-0.5
  102. self.rope_theta = rope_theta
  103. self.sliding_window = sliding_window if use_sliding_window else None
  104. self.qkv_proj = QKVParallelLinear(
  105. hidden_size,
  106. self.head_dim,
  107. self.total_num_heads,
  108. self.total_num_kv_heads,
  109. bias=True,
  110. quant_config=quant_config,
  111. )
  112. self.o_proj = RowParallelLinear(
  113. self.total_num_heads * self.head_dim,
  114. hidden_size,
  115. bias=False,
  116. quant_config=quant_config,
  117. )
  118. self.rotary_emb = get_rope(
  119. self.head_dim,
  120. rotary_dim=self.head_dim,
  121. max_position=max_position,
  122. base=self.rope_theta,
  123. )
  124. self.attn = Attention(self.num_heads,
  125. self.head_dim,
  126. self.scaling,
  127. num_kv_heads=self.num_kv_heads,
  128. sliding_window=self.sliding_window)
  129. def forward(
  130. self,
  131. positions: torch.Tensor,
  132. hidden_states: torch.Tensor,
  133. kv_cache: torch.Tensor,
  134. attn_metadata: AttentionMetadata,
  135. ) -> torch.Tensor:
  136. qkv, _ = self.qkv_proj(hidden_states)
  137. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  138. q, k = self.rotary_emb(positions, q, k)
  139. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  140. output, _ = self.o_proj(attn_output)
  141. return output
  142. class Qwen2DecoderLayer(nn.Module):
  143. def __init__(
  144. self,
  145. config: Qwen2Config,
  146. layer_idx: int,
  147. quant_config: Optional[QuantizationConfig] = None,
  148. ) -> None:
  149. super().__init__()
  150. self.hidden_size = config.hidden_size
  151. # Requires transformers > 4.32.0
  152. rope_theta = getattr(config, "rope_theta", 1000000)
  153. use_sliding_window = (config.use_sliding_window
  154. and layer_idx < config.max_window_layers)
  155. self.self_attn = Qwen2Attention(
  156. hidden_size=self.hidden_size,
  157. num_heads=config.num_attention_heads,
  158. max_position=config.max_position_embeddings,
  159. num_kv_heads=config.num_key_value_heads,
  160. rope_theta=rope_theta,
  161. use_sliding_window=use_sliding_window,
  162. quant_config=quant_config,
  163. sliding_window=config.sliding_window)
  164. self.mlp = Qwen2MLP(
  165. hidden_size=self.hidden_size,
  166. intermediate_size=config.intermediate_size,
  167. hidden_act=config.hidden_act,
  168. quant_config=quant_config,
  169. )
  170. self.input_layernorm = RMSNorm(config.hidden_size,
  171. eps=config.rms_norm_eps)
  172. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  173. eps=config.rms_norm_eps)
  174. def forward(
  175. self,
  176. positions: torch.Tensor,
  177. hidden_states: torch.Tensor,
  178. kv_cache: torch.Tensor,
  179. attn_metadata: AttentionMetadata,
  180. residual: Optional[torch.Tensor],
  181. ) -> Tuple[torch.Tensor, torch.Tensor]:
  182. # Self Attention
  183. if residual is None:
  184. residual = hidden_states
  185. hidden_states = self.input_layernorm(hidden_states)
  186. else:
  187. hidden_states, residual = self.input_layernorm(
  188. hidden_states, residual)
  189. hidden_states = self.self_attn(
  190. positions=positions,
  191. hidden_states=hidden_states,
  192. kv_cache=kv_cache,
  193. attn_metadata=attn_metadata,
  194. )
  195. # Fully Connected
  196. hidden_states, residual = self.post_attention_layernorm(
  197. hidden_states, residual)
  198. hidden_states = self.mlp(hidden_states)
  199. return hidden_states, residual
  200. class Qwen2Model(nn.Module):
  201. def __init__(
  202. self,
  203. config: Qwen2Config,
  204. quant_config: Optional[QuantizationConfig] = None,
  205. ) -> None:
  206. super().__init__()
  207. self.config = config
  208. self.padding_idx = config.pad_token_id
  209. self.vocab_size = config.vocab_size
  210. self.embed_tokens = VocabParallelEmbedding(
  211. config.vocab_size,
  212. config.hidden_size,
  213. )
  214. self.layers = nn.ModuleList([
  215. Qwen2DecoderLayer(config, layer_idx, quant_config)
  216. for layer_idx in range(config.num_hidden_layers)
  217. ])
  218. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  219. def forward(
  220. self,
  221. input_ids: torch.Tensor,
  222. positions: torch.Tensor,
  223. kv_caches: List[torch.Tensor],
  224. attn_metadata: AttentionMetadata,
  225. ) -> torch.Tensor:
  226. hidden_states = self.embed_tokens(input_ids)
  227. residual = None
  228. for i in range(len(self.layers)):
  229. layer = self.layers[i]
  230. hidden_states, residual = layer(
  231. positions,
  232. hidden_states,
  233. kv_caches[i],
  234. attn_metadata,
  235. residual,
  236. )
  237. hidden_states, _ = self.norm(hidden_states, residual)
  238. return hidden_states
  239. class Qwen2ForCausalLM(nn.Module):
  240. packed_modules_mapping = {
  241. "qkv_proj": [
  242. "q_proj",
  243. "k_proj",
  244. "v_proj",
  245. ],
  246. "gate_up_proj": [
  247. "gate_proj",
  248. "up_proj",
  249. ],
  250. }
  251. # LoRA specific attributes
  252. supported_lora_modules = [
  253. "qkv_proj",
  254. "o_proj",
  255. "gate_up_proj",
  256. "down_proj",
  257. ]
  258. embedding_modules = {}
  259. embedding_padding_modules = []
  260. def __init__(
  261. self,
  262. config: Qwen2Config,
  263. quant_config: Optional[QuantizationConfig] = None,
  264. lora_config: Optional[LoRAConfig] = None,
  265. ) -> None:
  266. del lora_config
  267. super().__init__()
  268. self.config = config
  269. self.quant_config = quant_config
  270. self.model = Qwen2Model(config, quant_config)
  271. if config.tie_word_embeddings:
  272. self.lm_head_weight = self.model.embed_tokens.weight
  273. else:
  274. self.lm_head = ParallelLMHead(config.vocab_size,
  275. config.hidden_size)
  276. self.lm_head_weight = self.lm_head.weight
  277. self.logits_processor = LogitsProcessor(config.vocab_size)
  278. self.sampler = Sampler()
  279. def forward(
  280. self,
  281. input_ids: torch.Tensor,
  282. positions: torch.Tensor,
  283. kv_caches: List[torch.Tensor],
  284. attn_metadata: AttentionMetadata,
  285. ) -> torch.Tensor:
  286. hidden_states = self.model(input_ids, positions, kv_caches,
  287. attn_metadata)
  288. return hidden_states
  289. def compute_logits(self, hidden_states: torch.Tensor,
  290. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  291. logits = self.logits_processor(self.lm_head_weight, hidden_states,
  292. sampling_metadata)
  293. return logits
  294. def sample(
  295. self,
  296. logits: torch.Tensor,
  297. sampling_metadata: SamplingMetadata,
  298. ) -> Optional[SamplerOutput]:
  299. next_tokens = self.sampler(logits, sampling_metadata)
  300. return next_tokens
  301. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  302. stacked_params_mapping = [
  303. # (param_name, shard_name, shard_id)
  304. ("qkv_proj", "q_proj", "q"),
  305. ("qkv_proj", "k_proj", "k"),
  306. ("qkv_proj", "v_proj", "v"),
  307. ("gate_up_proj", "gate_proj", 0),
  308. ("gate_up_proj", "up_proj", 1),
  309. ]
  310. params_dict = dict(self.named_parameters(remove_duplicate=False))
  311. for name, loaded_weight in weights:
  312. if "rotary_emb.inv_freq" in name:
  313. continue
  314. if self.config.tie_word_embeddings and "lm_head.weight" in name:
  315. continue
  316. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  317. if weight_name not in name:
  318. continue
  319. name = name.replace(weight_name, param_name)
  320. # Skip loading extra bias for GPTQ models.
  321. if name.endswith(".bias") and name not in params_dict:
  322. continue
  323. param = params_dict[name]
  324. weight_loader = param.weight_loader
  325. weight_loader(param, loaded_weight, shard_id)
  326. break
  327. else:
  328. # Skip loading extra bias for GPTQ models.
  329. if name.endswith(".bias") and name not in params_dict:
  330. continue
  331. param = params_dict[name]
  332. weight_loader = getattr(param, "weight_loader",
  333. default_weight_loader)
  334. weight_loader(param, loaded_weight)