qwen2.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
  4. # Copyright 2024 The Qwen team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only Qwen2 model compatible with HuggingFace weights."""
  25. from typing import Iterable, List, Optional, Tuple
  26. import torch
  27. from torch import nn
  28. from transformers import Qwen2Config
  29. from aphrodite.attention import Attention, AttentionMetadata
  30. from aphrodite.common.config import CacheConfig, LoRAConfig
  31. from aphrodite.common.sequence import SamplerOutput
  32. from aphrodite.common.utils import print_warning_once
  33. from aphrodite.distributed import get_tensor_model_parallel_world_size
  34. from aphrodite.modeling.layers.activation import SiluAndMul
  35. from aphrodite.modeling.layers.layernorm import RMSNorm
  36. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  37. QKVParallelLinear,
  38. RowParallelLinear)
  39. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  40. from aphrodite.modeling.layers.rotary_embedding import get_rope
  41. from aphrodite.modeling.layers.sampler import Sampler
  42. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  43. ParallelLMHead, VocabParallelEmbedding)
  44. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  45. from aphrodite.modeling.models.interfaces import SupportsLoRA
  46. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  47. from aphrodite.quantization.base_config import QuantizationConfig
  48. class Qwen2MLP(nn.Module):
  49. def __init__(
  50. self,
  51. hidden_size: int,
  52. intermediate_size: int,
  53. hidden_act: str,
  54. quant_config: Optional[QuantizationConfig] = None,
  55. ) -> None:
  56. super().__init__()
  57. self.gate_up_proj = MergedColumnParallelLinear(
  58. hidden_size, [intermediate_size] * 2,
  59. bias=False,
  60. quant_config=quant_config)
  61. self.down_proj = RowParallelLinear(intermediate_size,
  62. hidden_size,
  63. bias=False,
  64. quant_config=quant_config)
  65. if hidden_act != "silu":
  66. raise ValueError(f"Unsupported activation: {hidden_act}. "
  67. "Only silu is supported for now.")
  68. self.act_fn = SiluAndMul()
  69. def forward(self, x):
  70. gate_up, _ = self.gate_up_proj(x)
  71. x = self.act_fn(gate_up)
  72. x, _ = self.down_proj(x)
  73. return x
  74. class Qwen2Attention(nn.Module):
  75. def __init__(self,
  76. hidden_size: int,
  77. num_heads: int,
  78. num_kv_heads: int,
  79. max_position: int = 4096 * 32,
  80. rope_theta: float = 10000,
  81. cache_config: Optional[CacheConfig] = None,
  82. quant_config: Optional[QuantizationConfig] = None,
  83. rope_scaling: Optional[Tuple] = None) -> None:
  84. super().__init__()
  85. self.hidden_size = hidden_size
  86. tp_size = get_tensor_model_parallel_world_size()
  87. self.total_num_heads = num_heads
  88. assert self.total_num_heads % tp_size == 0
  89. self.num_heads = self.total_num_heads // tp_size
  90. self.total_num_kv_heads = num_kv_heads
  91. if self.total_num_kv_heads >= tp_size:
  92. # Number of KV heads is greater than TP size, so we partition
  93. # the KV heads across multiple tensor parallel GPUs.
  94. assert self.total_num_kv_heads % tp_size == 0
  95. else:
  96. # Number of KV heads is less than TP size, so we replicate
  97. # the KV heads across multiple tensor parallel GPUs.
  98. assert tp_size % self.total_num_kv_heads == 0
  99. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  100. self.head_dim = hidden_size // self.total_num_heads
  101. self.q_size = self.num_heads * self.head_dim
  102. self.kv_size = self.num_kv_heads * self.head_dim
  103. self.scaling = self.head_dim**-0.5
  104. self.rope_theta = rope_theta
  105. self.qkv_proj = QKVParallelLinear(
  106. hidden_size,
  107. self.head_dim,
  108. self.total_num_heads,
  109. self.total_num_kv_heads,
  110. bias=True,
  111. quant_config=quant_config,
  112. )
  113. self.o_proj = RowParallelLinear(
  114. self.total_num_heads * self.head_dim,
  115. hidden_size,
  116. bias=False,
  117. quant_config=quant_config,
  118. )
  119. self.rotary_emb = get_rope(
  120. self.head_dim,
  121. rotary_dim=self.head_dim,
  122. max_position=max_position,
  123. base=self.rope_theta,
  124. rope_scaling=rope_scaling,
  125. )
  126. self.attn = Attention(self.num_heads,
  127. self.head_dim,
  128. self.scaling,
  129. num_kv_heads=self.num_kv_heads,
  130. cache_config=cache_config,
  131. quant_config=quant_config)
  132. def forward(
  133. self,
  134. positions: torch.Tensor,
  135. hidden_states: torch.Tensor,
  136. kv_cache: torch.Tensor,
  137. attn_metadata: AttentionMetadata,
  138. ) -> torch.Tensor:
  139. qkv, _ = self.qkv_proj(hidden_states)
  140. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  141. q, k = self.rotary_emb(positions, q, k)
  142. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  143. output, _ = self.o_proj(attn_output)
  144. return output
  145. class Qwen2DecoderLayer(nn.Module):
  146. def __init__(
  147. self,
  148. config: Qwen2Config,
  149. cache_config: Optional[CacheConfig] = None,
  150. quant_config: Optional[QuantizationConfig] = None,
  151. ) -> None:
  152. super().__init__()
  153. self.hidden_size = config.hidden_size
  154. # Requires transformers > 4.32.0
  155. rope_theta = getattr(config, "rope_theta", 1000000)
  156. rope_scaling = getattr(config, "rope_scaling", None)
  157. self.self_attn = Qwen2Attention(
  158. hidden_size=self.hidden_size,
  159. num_heads=config.num_attention_heads,
  160. max_position=config.max_position_embeddings,
  161. num_kv_heads=config.num_key_value_heads,
  162. rope_theta=rope_theta,
  163. cache_config=cache_config,
  164. quant_config=quant_config,
  165. rope_scaling=rope_scaling)
  166. self.mlp = Qwen2MLP(
  167. hidden_size=self.hidden_size,
  168. intermediate_size=config.intermediate_size,
  169. hidden_act=config.hidden_act,
  170. quant_config=quant_config,
  171. )
  172. self.input_layernorm = RMSNorm(config.hidden_size,
  173. eps=config.rms_norm_eps)
  174. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  175. eps=config.rms_norm_eps)
  176. def forward(
  177. self,
  178. positions: torch.Tensor,
  179. hidden_states: torch.Tensor,
  180. kv_cache: torch.Tensor,
  181. attn_metadata: AttentionMetadata,
  182. residual: Optional[torch.Tensor],
  183. ) -> Tuple[torch.Tensor, torch.Tensor]:
  184. # Self Attention
  185. if residual is None:
  186. residual = hidden_states
  187. hidden_states = self.input_layernorm(hidden_states)
  188. else:
  189. hidden_states, residual = self.input_layernorm(
  190. hidden_states, residual)
  191. hidden_states = self.self_attn(
  192. positions=positions,
  193. hidden_states=hidden_states,
  194. kv_cache=kv_cache,
  195. attn_metadata=attn_metadata,
  196. )
  197. # Fully Connected
  198. hidden_states, residual = self.post_attention_layernorm(
  199. hidden_states, residual)
  200. hidden_states = self.mlp(hidden_states)
  201. return hidden_states, residual
  202. class Qwen2Model(nn.Module):
  203. def __init__(
  204. self,
  205. config: Qwen2Config,
  206. cache_config: Optional[CacheConfig] = None,
  207. quant_config: Optional[QuantizationConfig] = None,
  208. ) -> None:
  209. super().__init__()
  210. self.config = config
  211. self.padding_idx = config.pad_token_id
  212. self.vocab_size = config.vocab_size
  213. self.embed_tokens = VocabParallelEmbedding(
  214. config.vocab_size,
  215. config.hidden_size,
  216. )
  217. self.layers = nn.ModuleList([
  218. Qwen2DecoderLayer(config, cache_config, quant_config)
  219. for _ in range(config.num_hidden_layers)
  220. ])
  221. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  222. def forward(
  223. self,
  224. input_ids: torch.Tensor,
  225. positions: torch.Tensor,
  226. kv_caches: List[torch.Tensor],
  227. attn_metadata: AttentionMetadata,
  228. ) -> torch.Tensor:
  229. hidden_states = self.embed_tokens(input_ids)
  230. residual = None
  231. for i in range(len(self.layers)):
  232. layer = self.layers[i]
  233. hidden_states, residual = layer(
  234. positions,
  235. hidden_states,
  236. kv_caches[i],
  237. attn_metadata,
  238. residual,
  239. )
  240. hidden_states, _ = self.norm(hidden_states, residual)
  241. return hidden_states
  242. class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
  243. packed_modules_mapping = {
  244. "qkv_proj": [
  245. "q_proj",
  246. "k_proj",
  247. "v_proj",
  248. ],
  249. "gate_up_proj": [
  250. "gate_proj",
  251. "up_proj",
  252. ],
  253. }
  254. # LoRA specific attributes
  255. supported_lora_modules = [
  256. "qkv_proj",
  257. "o_proj",
  258. "gate_up_proj",
  259. "down_proj",
  260. ]
  261. embedding_modules = {}
  262. embedding_padding_modules = []
  263. def __init__(
  264. self,
  265. config: Qwen2Config,
  266. cache_config: Optional[CacheConfig] = None,
  267. quant_config: Optional[QuantizationConfig] = None,
  268. lora_config: Optional[LoRAConfig] = None,
  269. ) -> None:
  270. # TODO (: see if this can be moved out
  271. if (cache_config.sliding_window is not None
  272. and hasattr(config, "max_window_layers")):
  273. raise ValueError(
  274. "Sliding window for some but all layers is not "
  275. "supported. This model uses sliding window "
  276. "but `max_window_layers` = "
  277. f"{config.max_window_layers} is less than "
  278. "`num_hidden_layers` = "
  279. f"{config.num_hidden_layers}. Please open an issue"
  280. " to discuss this feature.")
  281. super().__init__()
  282. self.config = config
  283. self.lora_config = lora_config
  284. self.quant_config = quant_config
  285. self.model = Qwen2Model(config, cache_config, quant_config)
  286. if config.tie_word_embeddings:
  287. self.lm_head_weight = self.model.embed_tokens.weight
  288. else:
  289. self.lm_head = ParallelLMHead(config.vocab_size,
  290. config.hidden_size)
  291. self.lm_head_weight = self.lm_head.weight
  292. self.logits_processor = LogitsProcessor(config.vocab_size)
  293. self.sampler = Sampler()
  294. def forward(
  295. self,
  296. input_ids: torch.Tensor,
  297. positions: torch.Tensor,
  298. kv_caches: List[torch.Tensor],
  299. attn_metadata: AttentionMetadata,
  300. ) -> torch.Tensor:
  301. hidden_states = self.model(input_ids, positions, kv_caches,
  302. attn_metadata)
  303. return hidden_states
  304. def compute_logits(self, hidden_states: torch.Tensor,
  305. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  306. logits = self.logits_processor(self.lm_head_weight, hidden_states,
  307. sampling_metadata)
  308. return logits
  309. def sample(
  310. self,
  311. logits: torch.Tensor,
  312. sampling_metadata: SamplingMetadata,
  313. ) -> Optional[SamplerOutput]:
  314. next_tokens = self.sampler(logits, sampling_metadata)
  315. return next_tokens
  316. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  317. stacked_params_mapping = [
  318. # (param_name, shard_name, shard_id)
  319. ("qkv_proj", "q_proj", "q"),
  320. ("qkv_proj", "k_proj", "k"),
  321. ("qkv_proj", "v_proj", "v"),
  322. ("gate_up_proj", "gate_proj", 0),
  323. ("gate_up_proj", "up_proj", 1),
  324. ]
  325. params_dict = dict(self.named_parameters(remove_duplicate=False))
  326. for name, loaded_weight in weights:
  327. if "rotary_emb.inv_freq" in name:
  328. continue
  329. if self.config.tie_word_embeddings and "lm_head.weight" in name:
  330. continue
  331. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  332. if weight_name not in name:
  333. continue
  334. name = name.replace(weight_name, param_name)
  335. # Skip loading extra bias for GPTQ models.
  336. if name.endswith(".bias") and name not in params_dict:
  337. continue
  338. param = params_dict[name]
  339. weight_loader = param.weight_loader
  340. weight_loader(param, loaded_weight, shard_id)
  341. break
  342. else:
  343. # Skip loading extra bias for GPTQ models.
  344. if name.endswith(".bias") and name not in params_dict:
  345. continue
  346. # Remapping the name of FP8 kv-scale.
  347. if name.endswith("kv_scale"):
  348. remapped_kv_scale_name = name.replace(
  349. ".kv_scale", ".attn.kv_scale")
  350. if remapped_kv_scale_name not in params_dict:
  351. print_warning_once(
  352. f"Found kv scale in the checkpoint (e.g. {name}), "
  353. "but not found the expected name in the model "
  354. f"(e.g. {remapped_kv_scale_name}). kv-scale is "
  355. "not loaded.")
  356. continue
  357. else:
  358. name = remapped_kv_scale_name
  359. param = params_dict[name]
  360. weight_loader = getattr(param, "weight_loader",
  361. default_weight_loader)
  362. weight_loader(param, loaded_weight)