qwen2.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
  4. # Copyright 2024 The Qwen team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only Qwen2 model compatible with HuggingFace weights."""
  25. from typing import Iterable, List, Optional, Tuple
  26. import torch
  27. from torch import nn
  28. from transformers import Qwen2Config
  29. from aphrodite.attention import Attention, AttentionMetadata
  30. from aphrodite.common.config import CacheConfig, LoRAConfig
  31. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  32. from aphrodite.distributed import (get_current_tp_rank_partition_size,
  33. get_tensor_model_parallel_world_size,
  34. get_tensor_model_parallel_rank)
  35. from aphrodite.modeling.layers.activation import SiluAndMul
  36. from aphrodite.modeling.layers.layernorm import RMSNorm
  37. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  38. QKVParallelLinear,
  39. RowParallelLinear)
  40. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  41. from aphrodite.modeling.layers.rotary_embedding import get_rope
  42. from aphrodite.modeling.layers.sampler import Sampler
  43. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  44. ParallelLMHead, VocabParallelEmbedding)
  45. from aphrodite.modeling.model_loader.weight_utils import (
  46. default_weight_loader, maybe_remap_kv_scale_name)
  47. from aphrodite.modeling.models.interfaces import SupportsLoRA
  48. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  49. from aphrodite.quantization.base_config import QuantizationConfig
  50. class Qwen2MLP(nn.Module):
  51. def __init__(
  52. self,
  53. hidden_size: int,
  54. intermediate_size: int,
  55. hidden_act: str,
  56. quant_config: Optional[QuantizationConfig] = None,
  57. ) -> None:
  58. super().__init__()
  59. self.gate_up_proj = MergedColumnParallelLinear(
  60. hidden_size, [intermediate_size] * 2,
  61. bias=False,
  62. quant_config=quant_config)
  63. self.down_proj = RowParallelLinear(intermediate_size,
  64. hidden_size,
  65. bias=False,
  66. quant_config=quant_config)
  67. if hidden_act != "silu":
  68. raise ValueError(f"Unsupported activation: {hidden_act}. "
  69. "Only silu is supported for now.")
  70. self.act_fn = SiluAndMul()
  71. def forward(self, x):
  72. gate_up, _ = self.gate_up_proj(x)
  73. x = self.act_fn(gate_up)
  74. x, _ = self.down_proj(x)
  75. return x
  76. class Qwen2Attention(nn.Module):
  77. def __init__(self,
  78. hidden_size: int,
  79. num_heads: int,
  80. num_kv_heads: int,
  81. max_position: int = 4096 * 32,
  82. rope_theta: float = 10000,
  83. cache_config: Optional[CacheConfig] = None,
  84. quant_config: Optional[QuantizationConfig] = None,
  85. rope_scaling: Optional[Tuple] = None) -> None:
  86. super().__init__()
  87. self.hidden_size = hidden_size
  88. tp_size = get_tensor_model_parallel_world_size()
  89. tp_rank = get_tensor_model_parallel_rank()
  90. self.total_num_heads = num_heads
  91. self.total_num_kv_heads = num_kv_heads
  92. self.num_kv_heads = max(
  93. 1,
  94. get_current_tp_rank_partition_size(self.total_num_kv_heads,
  95. tp_rank, tp_size))
  96. num_heads_per_kv_head = self.total_num_heads // self.total_num_kv_heads
  97. self.num_heads = self.num_kv_heads * num_heads_per_kv_head
  98. self.head_dim = hidden_size // self.total_num_heads
  99. self.q_size = self.num_heads * self.head_dim
  100. self.kv_size = self.num_kv_heads * self.head_dim
  101. self.scaling = self.head_dim**-0.5
  102. self.rope_theta = rope_theta
  103. self.qkv_proj = QKVParallelLinear(
  104. hidden_size,
  105. self.head_dim,
  106. self.total_num_heads,
  107. self.total_num_kv_heads,
  108. bias=True,
  109. quant_config=quant_config,
  110. )
  111. self.o_proj = RowParallelLinear(
  112. self.total_num_heads * self.head_dim,
  113. hidden_size,
  114. bias=False,
  115. quant_config=quant_config,
  116. partition_multiple_of=num_heads_per_kv_head * self.head_dim,
  117. )
  118. self.rotary_emb = get_rope(
  119. self.head_dim,
  120. rotary_dim=self.head_dim,
  121. max_position=max_position,
  122. base=self.rope_theta,
  123. rope_scaling=rope_scaling,
  124. )
  125. self.attn = Attention(self.num_heads,
  126. self.head_dim,
  127. self.scaling,
  128. num_kv_heads=self.num_kv_heads,
  129. cache_config=cache_config,
  130. quant_config=quant_config)
  131. def forward(
  132. self,
  133. positions: torch.Tensor,
  134. hidden_states: torch.Tensor,
  135. kv_cache: torch.Tensor,
  136. attn_metadata: AttentionMetadata,
  137. ) -> torch.Tensor:
  138. qkv, _ = self.qkv_proj(hidden_states)
  139. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  140. q, k = self.rotary_emb(positions, q, k)
  141. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  142. output, _ = self.o_proj(attn_output)
  143. return output
  144. class Qwen2DecoderLayer(nn.Module):
  145. def __init__(
  146. self,
  147. config: Qwen2Config,
  148. cache_config: Optional[CacheConfig] = None,
  149. quant_config: Optional[QuantizationConfig] = None,
  150. ) -> None:
  151. super().__init__()
  152. self.hidden_size = config.hidden_size
  153. # Requires transformers > 4.32.0
  154. rope_theta = getattr(config, "rope_theta", 1000000)
  155. rope_scaling = getattr(config, "rope_scaling", None)
  156. self.self_attn = Qwen2Attention(
  157. hidden_size=self.hidden_size,
  158. num_heads=config.num_attention_heads,
  159. max_position=config.max_position_embeddings,
  160. num_kv_heads=config.num_key_value_heads,
  161. rope_theta=rope_theta,
  162. cache_config=cache_config,
  163. quant_config=quant_config,
  164. rope_scaling=rope_scaling)
  165. self.mlp = Qwen2MLP(
  166. hidden_size=self.hidden_size,
  167. intermediate_size=config.intermediate_size,
  168. hidden_act=config.hidden_act,
  169. quant_config=quant_config,
  170. )
  171. self.input_layernorm = RMSNorm(config.hidden_size,
  172. eps=config.rms_norm_eps)
  173. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  174. eps=config.rms_norm_eps)
  175. def forward(
  176. self,
  177. positions: torch.Tensor,
  178. hidden_states: torch.Tensor,
  179. kv_cache: torch.Tensor,
  180. attn_metadata: AttentionMetadata,
  181. residual: Optional[torch.Tensor],
  182. ) -> Tuple[torch.Tensor, torch.Tensor]:
  183. # Self Attention
  184. if residual is None:
  185. residual = hidden_states
  186. hidden_states = self.input_layernorm(hidden_states)
  187. else:
  188. hidden_states, residual = self.input_layernorm(
  189. hidden_states, residual)
  190. hidden_states = self.self_attn(
  191. positions=positions,
  192. hidden_states=hidden_states,
  193. kv_cache=kv_cache,
  194. attn_metadata=attn_metadata,
  195. )
  196. # Fully Connected
  197. hidden_states, residual = self.post_attention_layernorm(
  198. hidden_states, residual)
  199. hidden_states = self.mlp(hidden_states)
  200. return hidden_states, residual
  201. class Qwen2Model(nn.Module):
  202. def __init__(
  203. self,
  204. config: Qwen2Config,
  205. cache_config: Optional[CacheConfig] = None,
  206. quant_config: Optional[QuantizationConfig] = None,
  207. ) -> None:
  208. super().__init__()
  209. self.config = config
  210. self.padding_idx = config.pad_token_id
  211. self.vocab_size = config.vocab_size
  212. self.embed_tokens = VocabParallelEmbedding(
  213. config.vocab_size,
  214. config.hidden_size,
  215. )
  216. self.layers = nn.ModuleList([
  217. Qwen2DecoderLayer(config, cache_config, quant_config)
  218. for _ in range(config.num_hidden_layers)
  219. ])
  220. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  221. def forward(
  222. self,
  223. input_ids: torch.Tensor,
  224. positions: torch.Tensor,
  225. kv_caches: List[torch.Tensor],
  226. attn_metadata: AttentionMetadata,
  227. ) -> torch.Tensor:
  228. hidden_states = self.embed_tokens(input_ids)
  229. residual = None
  230. for i in range(len(self.layers)):
  231. layer = self.layers[i]
  232. hidden_states, residual = layer(
  233. positions,
  234. hidden_states,
  235. kv_caches[i],
  236. attn_metadata,
  237. residual,
  238. )
  239. hidden_states, _ = self.norm(hidden_states, residual)
  240. return hidden_states
  241. class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
  242. packed_modules_mapping = {
  243. "qkv_proj": [
  244. "q_proj",
  245. "k_proj",
  246. "v_proj",
  247. ],
  248. "gate_up_proj": [
  249. "gate_proj",
  250. "up_proj",
  251. ],
  252. }
  253. # LoRA specific attributes
  254. supported_lora_modules = [
  255. "qkv_proj",
  256. "o_proj",
  257. "gate_up_proj",
  258. "down_proj",
  259. ]
  260. embedding_modules = {}
  261. embedding_padding_modules = []
  262. def __init__(
  263. self,
  264. config: Qwen2Config,
  265. cache_config: Optional[CacheConfig] = None,
  266. quant_config: Optional[QuantizationConfig] = None,
  267. lora_config: Optional[LoRAConfig] = None,
  268. ) -> None:
  269. # TODO (: see if this can be moved out
  270. if (cache_config.sliding_window is not None
  271. and hasattr(config, "max_window_layers")):
  272. raise ValueError(
  273. "Sliding window for some but all layers is not "
  274. "supported. This model uses sliding window "
  275. "but `max_window_layers` = "
  276. f"{config.max_window_layers} is less than "
  277. "`num_hidden_layers` = "
  278. f"{config.num_hidden_layers}. Please open an issue"
  279. " to discuss this feature.")
  280. super().__init__()
  281. self.config = config
  282. self.lora_config = lora_config
  283. self.quant_config = quant_config
  284. self.model = Qwen2Model(config, cache_config, quant_config)
  285. if config.tie_word_embeddings:
  286. self.lm_head = self.model.embed_tokens
  287. else:
  288. self.lm_head = ParallelLMHead(config.vocab_size,
  289. config.hidden_size,
  290. quant_config=quant_config)
  291. self.logits_processor = LogitsProcessor(config.vocab_size)
  292. self.sampler = Sampler()
  293. def forward(
  294. self,
  295. input_ids: torch.Tensor,
  296. positions: torch.Tensor,
  297. kv_caches: List[torch.Tensor],
  298. attn_metadata: AttentionMetadata,
  299. intermediate_tensors: Optional[IntermediateTensors] = None,
  300. ) -> torch.Tensor:
  301. hidden_states = self.model(input_ids, positions, kv_caches,
  302. attn_metadata)
  303. return hidden_states
  304. def compute_logits(self, hidden_states: torch.Tensor,
  305. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  306. logits = self.logits_processor(self.lm_head, hidden_states,
  307. sampling_metadata)
  308. return logits
  309. def sample(
  310. self,
  311. logits: torch.Tensor,
  312. sampling_metadata: SamplingMetadata,
  313. ) -> Optional[SamplerOutput]:
  314. next_tokens = self.sampler(logits, sampling_metadata)
  315. return next_tokens
  316. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  317. stacked_params_mapping = [
  318. # (param_name, shard_name, shard_id)
  319. ("qkv_proj", "q_proj", "q"),
  320. ("qkv_proj", "k_proj", "k"),
  321. ("qkv_proj", "v_proj", "v"),
  322. ("gate_up_proj", "gate_proj", 0),
  323. ("gate_up_proj", "up_proj", 1),
  324. ]
  325. params_dict = dict(self.named_parameters(remove_duplicate=False))
  326. for name, loaded_weight in weights:
  327. if "rotary_emb.inv_freq" in name:
  328. continue
  329. if self.config.tie_word_embeddings and "lm_head.weight" in name:
  330. continue
  331. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  332. if weight_name not in name:
  333. continue
  334. name = name.replace(weight_name, param_name)
  335. # Skip loading extra bias for GPTQ models.
  336. if name.endswith(".bias") and name not in params_dict:
  337. continue
  338. param = params_dict[name]
  339. weight_loader = param.weight_loader
  340. weight_loader(param, loaded_weight, shard_id)
  341. break
  342. else:
  343. # Skip loading extra bias for GPTQ models.
  344. if name.endswith(".bias") and name not in params_dict:
  345. continue
  346. # Remapping the name of FP8 kv-scale.
  347. name = maybe_remap_kv_scale_name(name, params_dict)
  348. if name is None:
  349. continue
  350. param = params_dict[name]
  351. weight_loader = getattr(param, "weight_loader",
  352. default_weight_loader)
  353. weight_loader(param, loaded_weight)