qwen2.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2024 The Qwen team.
  6. # Copyright 2023 The vLLM team.
  7. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  8. #
  9. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  10. # and OPT implementations in this library. It has been modified from its
  11. # original forms to accommodate minor architectural differences compared
  12. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  13. #
  14. # Licensed under the Apache License, Version 2.0 (the "License");
  15. # you may not use this file except in compliance with the License.
  16. # You may obtain a copy of the License at
  17. #
  18. # http://www.apache.org/licenses/LICENSE-2.0
  19. #
  20. # Unless required by applicable law or agreed to in writing, software
  21. # distributed under the License is distributed on an "AS IS" BASIS,
  22. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  23. # See the License for the specific language governing permissions and
  24. # limitations under the License.
  25. """Inference-only Qwen2 model compatible with HuggingFace weights."""
  26. from typing import List, Optional, Tuple
  27. import torch
  28. from torch import nn
  29. from transformers import Qwen2Config
  30. from aphrodite.modeling.metadata import InputMetadata
  31. from aphrodite.modeling.layers.activation import SiluAndMul
  32. from aphrodite.modeling.layers.attention import PagedAttention
  33. from aphrodite.modeling.layers.layernorm import RMSNorm
  34. from aphrodite.modeling.layers.linear import (
  35. LinearMethodBase,
  36. ColumnParallelLinear,
  37. MergedColumnParallelLinear,
  38. QKVParallelLinear,
  39. RowParallelLinear,
  40. )
  41. from aphrodite.modeling.layers.rotary_embedding import get_rope
  42. from aphrodite.modeling.layers.sampler import Sampler, QuantSampler
  43. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  44. VocabParallelEmbedding,
  45. ParallelLMHead,
  46. )
  47. from aphrodite.modeling.megatron.parallel_state import (
  48. get_tensor_model_parallel_world_size, )
  49. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  50. from aphrodite.modeling.hf_downloader import (
  51. default_weight_loader,
  52. hf_model_weights_iterator,
  53. )
  54. from aphrodite.common.sequence import SamplerOutput
  55. KVCache = Tuple[torch.Tensor, torch.Tensor]
  56. class Qwen2MLP(nn.Module):
  57. def __init__(
  58. self,
  59. hidden_size: int,
  60. intermediate_size: int,
  61. hidden_act: str,
  62. linear_method: Optional[LinearMethodBase] = None,
  63. ) -> None:
  64. super().__init__()
  65. if (linear_method is not None
  66. and not linear_method.quant_config.merge_weight()):
  67. self.merge_weight = False
  68. self.gate_proj = ColumnParallelLinear(
  69. hidden_size,
  70. intermediate_size,
  71. bias=False,
  72. linear_method=linear_method,
  73. )
  74. self.up_proj = ColumnParallelLinear(
  75. hidden_size,
  76. intermediate_size,
  77. bias=False,
  78. linear_method=linear_method,
  79. )
  80. else:
  81. self.merge_weight = True
  82. self.gate_up_proj = MergedColumnParallelLinear(
  83. hidden_size,
  84. [intermediate_size] * 2,
  85. bias=False,
  86. linear_method=linear_method,
  87. )
  88. self.down_proj = RowParallelLinear(
  89. intermediate_size,
  90. hidden_size,
  91. bias=False,
  92. linear_method=linear_method,
  93. )
  94. if hidden_act != "silu":
  95. raise ValueError(f"Unsupported activation: {hidden_act}. "
  96. "Only silu is supported for now.")
  97. self.act_fn = SiluAndMul()
  98. def forward(self, x):
  99. if self.merge_weight:
  100. gate_up, _ = self.gate_up_proj(x)
  101. else:
  102. up, _ = self.up_proj(x)
  103. gate, _ = self.gate_proj(x)
  104. gate_up = torch.cat([gate, up], dim=-1)
  105. x = self.act_fn(gate_up)
  106. x, _ = self.down_proj(x)
  107. return x
  108. class Qwen2Attention(nn.Module):
  109. def __init__(
  110. self,
  111. hidden_size: int,
  112. num_heads: int,
  113. num_kv_heads: int,
  114. max_position: int = 4096 * 32,
  115. rope_theta: float = 10000,
  116. use_sliding_window: bool = False,
  117. linear_method: Optional[LinearMethodBase] = None,
  118. sliding_window: Optional[int] = None,
  119. ) -> None:
  120. super().__init__()
  121. self.hidden_size = hidden_size
  122. tp_size = get_tensor_model_parallel_world_size()
  123. self.total_num_heads = num_heads
  124. assert self.total_num_heads % tp_size == 0
  125. self.num_heads = self.total_num_heads // tp_size
  126. self.total_num_kv_heads = num_kv_heads
  127. if self.total_num_kv_heads >= tp_size:
  128. # Number of KV heads is greater than TP size, so we partition
  129. # the KV heads across multiple tensor parallel GPUs.
  130. assert self.total_num_kv_heads % tp_size == 0
  131. else:
  132. # Number of KV heads is less than TP size, so we replicate
  133. # the KV heads across multiple tensor parallel GPUs.
  134. assert tp_size % self.total_num_kv_heads == 0
  135. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  136. self.head_dim = hidden_size // self.total_num_heads
  137. self.q_size = self.num_heads * self.head_dim
  138. self.kv_size = self.num_kv_heads * self.head_dim
  139. self.scaling = self.head_dim**-0.5
  140. self.rope_theta = rope_theta
  141. self.sliding_window = sliding_window if use_sliding_window else None
  142. if (linear_method is not None
  143. and not linear_method.quant_config.merge_weight()):
  144. self.merge_weight = False
  145. self.q_proj = ColumnParallelLinear(hidden_size,
  146. self.q_size,
  147. bias=True,
  148. linear_method=linear_method)
  149. self.k_proj = ColumnParallelLinear(
  150. hidden_size,
  151. self.kv_size,
  152. bias=True,
  153. linear_method=linear_method,
  154. )
  155. self.v_proj = ColumnParallelLinear(
  156. hidden_size,
  157. self.kv_size,
  158. bias=True,
  159. linear_method=linear_method,
  160. )
  161. else:
  162. self.merge_weight = True
  163. self.qkv_proj = QKVParallelLinear(
  164. hidden_size,
  165. self.head_dim,
  166. self.total_num_heads,
  167. self.total_num_kv_heads,
  168. bias=True,
  169. linear_method=linear_method,
  170. )
  171. self.o_proj = RowParallelLinear(
  172. self.total_num_heads * self.head_dim,
  173. hidden_size,
  174. bias=False,
  175. linear_method=linear_method,
  176. )
  177. self.rotary_emb = get_rope(
  178. self.head_dim,
  179. rotary_dim=self.head_dim,
  180. max_position=max_position,
  181. base=self.rope_theta,
  182. )
  183. self.attn = PagedAttention(
  184. self.num_heads,
  185. self.head_dim,
  186. self.scaling,
  187. num_kv_heads=self.num_kv_heads,
  188. sliding_window=self.sliding_window,
  189. )
  190. def forward(
  191. self,
  192. positions: torch.Tensor,
  193. hidden_states: torch.Tensor,
  194. kv_cache: KVCache,
  195. input_metadata: InputMetadata,
  196. ) -> torch.Tensor:
  197. if self.merge_weight:
  198. qkv, _ = self.qkv_proj(hidden_states)
  199. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
  200. dim=-1)
  201. else:
  202. q, _ = self.q_proj(hidden_states)
  203. k, _ = self.k_proj(hidden_states)
  204. v, _ = self.v_proj(hidden_states)
  205. q, k = self.rotary_emb(positions, q, k)
  206. k_cache, v_cache = kv_cache
  207. attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
  208. output, _ = self.o_proj(attn_output)
  209. return output
  210. class Qwen2DecoderLayer(nn.Module):
  211. def __init__(
  212. self,
  213. config: Qwen2Config,
  214. layer_idx: int,
  215. linear_method: Optional[LinearMethodBase] = None,
  216. ) -> None:
  217. super().__init__()
  218. self.hidden_size = config.hidden_size
  219. # Requires transformers > 4.32.0
  220. rope_theta = getattr(config, "rope_theta", 1000000)
  221. use_sliding_window = (config.use_sliding_window
  222. and layer_idx < config.max_window_layers)
  223. self.self_attn = Qwen2Attention(
  224. hidden_size=self.hidden_size,
  225. num_heads=config.num_attention_heads,
  226. max_position=config.max_position_embeddings,
  227. num_kv_heads=config.num_key_value_heads,
  228. rope_theta=rope_theta,
  229. use_sliding_window=use_sliding_window,
  230. linear_method=linear_method,
  231. sliding_window=config.sliding_window,
  232. )
  233. self.mlp = Qwen2MLP(
  234. hidden_size=self.hidden_size,
  235. intermediate_size=config.intermediate_size,
  236. hidden_act=config.hidden_act,
  237. linear_method=linear_method,
  238. )
  239. self.input_layernorm = RMSNorm(config.hidden_size,
  240. eps=config.rms_norm_eps)
  241. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  242. eps=config.rms_norm_eps)
  243. def forward(
  244. self,
  245. positions: torch.Tensor,
  246. hidden_states: torch.Tensor,
  247. kv_cache: KVCache,
  248. input_metadata: InputMetadata,
  249. residual: Optional[torch.Tensor],
  250. ) -> Tuple[torch.Tensor, torch.Tensor]:
  251. # Self Attention
  252. if residual is None:
  253. residual = hidden_states
  254. hidden_states = self.input_layernorm(hidden_states)
  255. else:
  256. hidden_states, residual = self.input_layernorm(
  257. hidden_states, residual)
  258. hidden_states = self.self_attn(
  259. positions=positions,
  260. hidden_states=hidden_states,
  261. kv_cache=kv_cache,
  262. input_metadata=input_metadata,
  263. )
  264. # Fully Connected
  265. hidden_states, residual = self.post_attention_layernorm(
  266. hidden_states, residual)
  267. hidden_states = self.mlp(hidden_states)
  268. return hidden_states, residual
  269. class Qwen2Model(nn.Module):
  270. def __init__(
  271. self,
  272. config: Qwen2Config,
  273. linear_method: Optional[LinearMethodBase] = None,
  274. ) -> None:
  275. super().__init__()
  276. self.config = config
  277. self.padding_idx = config.pad_token_id
  278. self.vocab_size = config.vocab_size
  279. self.embed_tokens = VocabParallelEmbedding(
  280. config.vocab_size,
  281. config.hidden_size,
  282. linear_method=linear_method,
  283. )
  284. self.layers = nn.ModuleList([
  285. Qwen2DecoderLayer(config, layer_idx, linear_method)
  286. for layer_idx in range(config.num_hidden_layers)
  287. ])
  288. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  289. def forward(
  290. self,
  291. input_ids: torch.Tensor,
  292. positions: torch.Tensor,
  293. kv_caches: List[KVCache],
  294. input_metadata: InputMetadata,
  295. ) -> torch.Tensor:
  296. hidden_states = self.embed_tokens(input_ids)
  297. residual = None
  298. for i in range(len(self.layers)):
  299. layer = self.layers[i]
  300. hidden_states, residual = layer(
  301. positions,
  302. hidden_states,
  303. kv_caches[i],
  304. input_metadata,
  305. residual,
  306. )
  307. hidden_states, _ = self.norm(hidden_states, residual)
  308. return hidden_states
  309. class Qwen2ForCausalLM(nn.Module):
  310. def __init__(
  311. self,
  312. config: Qwen2Config,
  313. linear_method: Optional[LinearMethodBase] = None,
  314. ) -> None:
  315. super().__init__()
  316. self.config = config
  317. self.linear_method = linear_method
  318. self.model = Qwen2Model(config, linear_method)
  319. self.lm_head = ParallelLMHead(
  320. config.vocab_size,
  321. config.hidden_size,
  322. linear_method=linear_method,
  323. )
  324. self.sampler = Sampler(config.vocab_size)
  325. self.quant_sampler = QuantSampler(config.vocab_size)
  326. def forward(
  327. self,
  328. input_ids: torch.Tensor,
  329. positions: torch.Tensor,
  330. kv_caches: List[KVCache],
  331. input_metadata: InputMetadata,
  332. ) -> torch.Tensor:
  333. hidden_states = self.model(input_ids, positions, kv_caches,
  334. input_metadata)
  335. return hidden_states
  336. def sample(
  337. self,
  338. hidden_states: torch.Tensor,
  339. sampling_metadata: SamplingMetadata,
  340. ) -> Optional[SamplerOutput]:
  341. if (self.linear_method is not None
  342. and not self.linear_method.quant_config.merge_weight()):
  343. next_tokens = self.quant_sampler(self.lm_head(hidden_states),
  344. sampling_metadata)
  345. else:
  346. next_tokens = self.sampler(self.lm_head.weight, hidden_states,
  347. sampling_metadata)
  348. return next_tokens
  349. def load_weights(
  350. self,
  351. model_name_or_path: str,
  352. cache_dir: Optional[str] = None,
  353. load_format: str = "auto",
  354. revision: Optional[str] = None,
  355. ):
  356. stacked_params_mapping = [
  357. # (param_name, shard_name, shard_id)
  358. ("qkv_proj", "q_proj", "q"),
  359. ("qkv_proj", "k_proj", "k"),
  360. ("qkv_proj", "v_proj", "v"),
  361. ("gate_up_proj", "gate_proj", 0),
  362. ("gate_up_proj", "up_proj", 1),
  363. ]
  364. if (self.linear_method is not None
  365. and not self.linear_method.quant_config.merge_weight()):
  366. stacked_params_mapping = []
  367. params_dict = dict(self.named_parameters())
  368. for name, loaded_weight in hf_model_weights_iterator(
  369. model_name_or_path, cache_dir, load_format, revision):
  370. if "rotary_emb.inv_freq" in name:
  371. continue
  372. for param_name, weight_name, shard_id in stacked_params_mapping:
  373. if weight_name not in name:
  374. continue
  375. name = name.replace(weight_name, param_name)
  376. # Skip loading extra bias for GPTQ models.
  377. if name.endswith(".bias") and name not in params_dict:
  378. continue
  379. param = params_dict[name]
  380. weight_loader = param.weight_loader
  381. weight_loader(param, loaded_weight, shard_id)
  382. break
  383. else:
  384. # Skip loading extra bias for GPTQ models.
  385. if name.endswith(".bias") and name not in params_dict:
  386. continue
  387. param = params_dict[name]
  388. weight_loader = getattr(param, "weight_loader",
  389. default_weight_loader)
  390. weight_loader(param, loaded_weight)