qwen2.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2024 The Qwen team.
  6. # Copyright 2023 The vLLM team.
  7. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  8. #
  9. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  10. # and OPT implementations in this library. It has been modified from its
  11. # original forms to accommodate minor architectural differences compared
  12. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  13. #
  14. # Licensed under the Apache License, Version 2.0 (the "License");
  15. # you may not use this file except in compliance with the License.
  16. # You may obtain a copy of the License at
  17. #
  18. # http://www.apache.org/licenses/LICENSE-2.0
  19. #
  20. # Unless required by applicable law or agreed to in writing, software
  21. # distributed under the License is distributed on an "AS IS" BASIS,
  22. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  23. # See the License for the specific language governing permissions and
  24. # limitations under the License.
  25. """Inference-only Qwen2 model compatible with HuggingFace weights."""
  26. from typing import List, Optional, Tuple
  27. import torch
  28. from torch import nn
  29. from transformers import Qwen2Config
  30. from aphrodite.attention import Attention, AttentionMetadata
  31. from aphrodite.modeling.layers.activation import SiluAndMul
  32. from aphrodite.modeling.layers.layernorm import RMSNorm
  33. from aphrodite.modeling.layers.linear import (
  34. LinearMethodBase,
  35. ColumnParallelLinear,
  36. MergedColumnParallelLinear,
  37. QKVParallelLinear,
  38. RowParallelLinear,
  39. )
  40. from aphrodite.modeling.layers.rotary_embedding import get_rope
  41. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  42. from aphrodite.modeling.layers.sampler import Sampler
  43. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  44. VocabParallelEmbedding,
  45. ParallelLMHead,
  46. )
  47. from aphrodite.distributed import (
  48. get_tensor_model_parallel_world_size, )
  49. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  50. from aphrodite.modeling.hf_downloader import (
  51. default_weight_loader,
  52. hf_model_weights_iterator,
  53. )
  54. from aphrodite.common.sequence import SamplerOutput
  55. class Qwen2MLP(nn.Module):
  56. def __init__(
  57. self,
  58. hidden_size: int,
  59. intermediate_size: int,
  60. hidden_act: str,
  61. linear_method: Optional[LinearMethodBase] = None,
  62. ) -> None:
  63. super().__init__()
  64. if (linear_method is not None
  65. and not linear_method.quant_config.merge_weight()):
  66. self.merge_weight = False
  67. self.gate_proj = ColumnParallelLinear(
  68. hidden_size,
  69. intermediate_size,
  70. bias=False,
  71. linear_method=linear_method,
  72. )
  73. self.up_proj = ColumnParallelLinear(
  74. hidden_size,
  75. intermediate_size,
  76. bias=False,
  77. linear_method=linear_method,
  78. )
  79. else:
  80. self.merge_weight = True
  81. self.gate_up_proj = MergedColumnParallelLinear(
  82. hidden_size,
  83. [intermediate_size] * 2,
  84. bias=False,
  85. linear_method=linear_method,
  86. )
  87. self.down_proj = RowParallelLinear(
  88. intermediate_size,
  89. hidden_size,
  90. bias=False,
  91. linear_method=linear_method,
  92. )
  93. if hidden_act != "silu":
  94. raise ValueError(f"Unsupported activation: {hidden_act}. "
  95. "Only silu is supported for now.")
  96. self.act_fn = SiluAndMul()
  97. def forward(self, x):
  98. if self.merge_weight:
  99. gate_up, _ = self.gate_up_proj(x)
  100. else:
  101. up, _ = self.up_proj(x)
  102. gate, _ = self.gate_proj(x)
  103. gate_up = torch.cat([gate, up], dim=-1)
  104. x = self.act_fn(gate_up)
  105. x, _ = self.down_proj(x)
  106. return x
  107. class Qwen2Attention(nn.Module):
  108. def __init__(
  109. self,
  110. hidden_size: int,
  111. num_heads: int,
  112. num_kv_heads: int,
  113. max_position: int = 4096 * 32,
  114. rope_theta: float = 10000,
  115. use_sliding_window: bool = False,
  116. linear_method: Optional[LinearMethodBase] = None,
  117. sliding_window: Optional[int] = None,
  118. ) -> None:
  119. super().__init__()
  120. self.hidden_size = hidden_size
  121. tp_size = get_tensor_model_parallel_world_size()
  122. self.total_num_heads = num_heads
  123. assert self.total_num_heads % tp_size == 0
  124. self.num_heads = self.total_num_heads // tp_size
  125. self.total_num_kv_heads = num_kv_heads
  126. if self.total_num_kv_heads >= tp_size:
  127. # Number of KV heads is greater than TP size, so we partition
  128. # the KV heads across multiple tensor parallel GPUs.
  129. assert self.total_num_kv_heads % tp_size == 0
  130. else:
  131. # Number of KV heads is less than TP size, so we replicate
  132. # the KV heads across multiple tensor parallel GPUs.
  133. assert tp_size % self.total_num_kv_heads == 0
  134. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  135. self.head_dim = hidden_size // self.total_num_heads
  136. self.q_size = self.num_heads * self.head_dim
  137. self.kv_size = self.num_kv_heads * self.head_dim
  138. self.scaling = self.head_dim**-0.5
  139. self.rope_theta = rope_theta
  140. self.sliding_window = sliding_window if use_sliding_window else None
  141. if (linear_method is not None
  142. and not linear_method.quant_config.merge_weight()):
  143. self.merge_weight = False
  144. self.q_proj = ColumnParallelLinear(hidden_size,
  145. self.q_size,
  146. bias=True,
  147. linear_method=linear_method)
  148. self.k_proj = ColumnParallelLinear(
  149. hidden_size,
  150. self.kv_size,
  151. bias=True,
  152. linear_method=linear_method,
  153. )
  154. self.v_proj = ColumnParallelLinear(
  155. hidden_size,
  156. self.kv_size,
  157. bias=True,
  158. linear_method=linear_method,
  159. )
  160. else:
  161. self.merge_weight = True
  162. self.qkv_proj = QKVParallelLinear(
  163. hidden_size,
  164. self.head_dim,
  165. self.total_num_heads,
  166. self.total_num_kv_heads,
  167. bias=True,
  168. linear_method=linear_method,
  169. )
  170. self.o_proj = RowParallelLinear(
  171. self.total_num_heads * self.head_dim,
  172. hidden_size,
  173. bias=False,
  174. linear_method=linear_method,
  175. )
  176. self.rotary_emb = get_rope(
  177. self.head_dim,
  178. rotary_dim=self.head_dim,
  179. max_position=max_position,
  180. base=self.rope_theta,
  181. )
  182. self.attn = Attention(
  183. self.num_heads,
  184. self.head_dim,
  185. self.scaling,
  186. num_kv_heads=self.num_kv_heads,
  187. sliding_window=self.sliding_window,
  188. )
  189. def forward(
  190. self,
  191. positions: torch.Tensor,
  192. hidden_states: torch.Tensor,
  193. kv_cache: torch.Tensor,
  194. attn_metadata: AttentionMetadata,
  195. ) -> torch.Tensor:
  196. if self.merge_weight:
  197. qkv, _ = self.qkv_proj(hidden_states)
  198. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
  199. dim=-1)
  200. else:
  201. q, _ = self.q_proj(hidden_states)
  202. k, _ = self.k_proj(hidden_states)
  203. v, _ = self.v_proj(hidden_states)
  204. q, k = self.rotary_emb(positions, q, k)
  205. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  206. output, _ = self.o_proj(attn_output)
  207. return output
  208. class Qwen2DecoderLayer(nn.Module):
  209. def __init__(
  210. self,
  211. config: Qwen2Config,
  212. layer_idx: int,
  213. linear_method: Optional[LinearMethodBase] = None,
  214. ) -> None:
  215. super().__init__()
  216. self.hidden_size = config.hidden_size
  217. # Requires transformers > 4.32.0
  218. rope_theta = getattr(config, "rope_theta", 1000000)
  219. use_sliding_window = (config.use_sliding_window
  220. and layer_idx < config.max_window_layers)
  221. self.self_attn = Qwen2Attention(
  222. hidden_size=self.hidden_size,
  223. num_heads=config.num_attention_heads,
  224. max_position=config.max_position_embeddings,
  225. num_kv_heads=config.num_key_value_heads,
  226. rope_theta=rope_theta,
  227. use_sliding_window=use_sliding_window,
  228. linear_method=linear_method,
  229. sliding_window=config.sliding_window,
  230. )
  231. self.mlp = Qwen2MLP(
  232. hidden_size=self.hidden_size,
  233. intermediate_size=config.intermediate_size,
  234. hidden_act=config.hidden_act,
  235. linear_method=linear_method,
  236. )
  237. self.input_layernorm = RMSNorm(config.hidden_size,
  238. eps=config.rms_norm_eps)
  239. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  240. eps=config.rms_norm_eps)
  241. def forward(
  242. self,
  243. positions: torch.Tensor,
  244. hidden_states: torch.Tensor,
  245. kv_cache: torch.Tensor,
  246. attn_metadata: AttentionMetadata,
  247. residual: Optional[torch.Tensor],
  248. ) -> Tuple[torch.Tensor, torch.Tensor]:
  249. # Self Attention
  250. if residual is None:
  251. residual = hidden_states
  252. hidden_states = self.input_layernorm(hidden_states)
  253. else:
  254. hidden_states, residual = self.input_layernorm(
  255. hidden_states, residual)
  256. hidden_states = self.self_attn(
  257. positions=positions,
  258. hidden_states=hidden_states,
  259. kv_cache=kv_cache,
  260. attn_metadata=attn_metadata,
  261. )
  262. # Fully Connected
  263. hidden_states, residual = self.post_attention_layernorm(
  264. hidden_states, residual)
  265. hidden_states = self.mlp(hidden_states)
  266. return hidden_states, residual
  267. class Qwen2Model(nn.Module):
  268. def __init__(
  269. self,
  270. config: Qwen2Config,
  271. linear_method: Optional[LinearMethodBase] = None,
  272. ) -> None:
  273. super().__init__()
  274. self.config = config
  275. self.padding_idx = config.pad_token_id
  276. self.vocab_size = config.vocab_size
  277. self.embed_tokens = VocabParallelEmbedding(
  278. config.vocab_size,
  279. config.hidden_size,
  280. linear_method=linear_method,
  281. )
  282. self.layers = nn.ModuleList([
  283. Qwen2DecoderLayer(config, layer_idx, linear_method)
  284. for layer_idx in range(config.num_hidden_layers)
  285. ])
  286. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  287. def forward(
  288. self,
  289. input_ids: torch.Tensor,
  290. positions: torch.Tensor,
  291. kv_caches: List[torch.Tensor],
  292. attn_metadata: AttentionMetadata,
  293. ) -> torch.Tensor:
  294. hidden_states = self.embed_tokens(input_ids)
  295. residual = None
  296. for i in range(len(self.layers)):
  297. layer = self.layers[i]
  298. hidden_states, residual = layer(
  299. positions,
  300. hidden_states,
  301. kv_caches[i],
  302. attn_metadata,
  303. residual,
  304. )
  305. hidden_states, _ = self.norm(hidden_states, residual)
  306. return hidden_states
  307. class Qwen2ForCausalLM(nn.Module):
  308. def __init__(
  309. self,
  310. config: Qwen2Config,
  311. linear_method: Optional[LinearMethodBase] = None,
  312. ) -> None:
  313. super().__init__()
  314. self.config = config
  315. self.linear_method = linear_method
  316. self.model = Qwen2Model(config, linear_method)
  317. self.lm_head = ParallelLMHead(
  318. config.vocab_size,
  319. config.hidden_size,
  320. linear_method=linear_method,
  321. )
  322. self.logits_processor = LogitsProcessor(config.vocab_size)
  323. self.sampler = Sampler()
  324. def forward(
  325. self,
  326. input_ids: torch.Tensor,
  327. positions: torch.Tensor,
  328. kv_caches: List[torch.Tensor],
  329. attn_metadata: AttentionMetadata,
  330. ) -> torch.Tensor:
  331. hidden_states = self.model(input_ids, positions, kv_caches,
  332. attn_metadata)
  333. return hidden_states
  334. def compute_logits(self, hidden_states: torch.Tensor,
  335. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  336. logits = self.logits_processor(self.lm_head, hidden_states,
  337. sampling_metadata)
  338. return logits
  339. def sample(
  340. self,
  341. logits: torch.Tensor,
  342. sampling_metadata: SamplingMetadata,
  343. ) -> Optional[SamplerOutput]:
  344. next_tokens = self.sampler(logits, sampling_metadata)
  345. return next_tokens
  346. def load_weights(
  347. self,
  348. model_name_or_path: str,
  349. cache_dir: Optional[str] = None,
  350. load_format: str = "auto",
  351. revision: Optional[str] = None,
  352. ):
  353. stacked_params_mapping = [
  354. # (param_name, shard_name, shard_id)
  355. ("qkv_proj", "q_proj", "q"),
  356. ("qkv_proj", "k_proj", "k"),
  357. ("qkv_proj", "v_proj", "v"),
  358. ("gate_up_proj", "gate_proj", 0),
  359. ("gate_up_proj", "up_proj", 1),
  360. ]
  361. if (self.linear_method is not None
  362. and not self.linear_method.quant_config.merge_weight()):
  363. stacked_params_mapping = []
  364. params_dict = dict(self.named_parameters(remove_duplicate=False))
  365. for name, loaded_weight in hf_model_weights_iterator(
  366. model_name_or_path, cache_dir, load_format, revision):
  367. if "rotary_emb.inv_freq" in name:
  368. continue
  369. for param_name, weight_name, shard_id in stacked_params_mapping:
  370. if weight_name not in name:
  371. continue
  372. name = name.replace(weight_name, param_name)
  373. # Skip loading extra bias for GPTQ models.
  374. if name.endswith(".bias") and name not in params_dict:
  375. continue
  376. param = params_dict[name]
  377. weight_loader = param.weight_loader
  378. weight_loader(param, loaded_weight, shard_id)
  379. break
  380. else:
  381. # Skip loading extra bias for GPTQ models.
  382. if name.endswith(".bias") and name not in params_dict:
  383. continue
  384. param = params_dict[name]
  385. weight_loader = getattr(param, "weight_loader",
  386. default_weight_loader)
  387. weight_loader(param, loaded_weight)