1
0

qwen.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. # coding=utf-8
  2. # Adapted from
  3. # https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
  4. # Copyright (c) Alibaba Cloud.
  5. # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
  6. """Inference-only QWen model compatible with HuggingFace weights."""
  7. from typing import Any, Dict, List, Optional, Tuple
  8. import torch
  9. from torch import nn
  10. from aphrodite.modeling.metadata import InputMetadata
  11. from aphrodite.modeling.layers.activation import SiluAndMul
  12. from aphrodite.modeling.layers.attention import PagedAttention
  13. from aphrodite.modeling.layers.layernorm import RMSNorm
  14. from aphrodite.modeling.layers.linear import (
  15. LinearMethodBase,
  16. MergedColumnParallelLinear,
  17. QKVParallelLinear,
  18. RowParallelLinear,
  19. ColumnParallelLinear,
  20. )
  21. from aphrodite.modeling.layers.rotary_embedding import get_rope
  22. from aphrodite.modeling.layers.sampler import Sampler, QuantSampler
  23. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  24. VocabParallelEmbedding,
  25. ParallelLMHead,
  26. )
  27. from aphrodite.modeling.megatron.parallel_state import (
  28. get_tensor_model_parallel_world_size, )
  29. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  30. from aphrodite.modeling.hf_downloader import (
  31. default_weight_loader,
  32. hf_model_weights_iterator,
  33. )
  34. from aphrodite.common.sequence import SamplerOutput
  35. from aphrodite.transformers_utils.configs.qwen import QWenConfig
  36. KVCache = Tuple[torch.Tensor, torch.Tensor]
  37. class QWenMLP(nn.Module):
  38. def __init__(
  39. self,
  40. hidden_size: int,
  41. intermediate_size: int,
  42. hidden_act: str = "silu",
  43. linear_method: Optional[LinearMethodBase] = None,
  44. ):
  45. super().__init__()
  46. if (linear_method is not None
  47. and not linear_method.quant_config.merge_weight()):
  48. self.merge_weight = False
  49. self.w2 = ColumnParallelLinear(
  50. hidden_size,
  51. intermediate_size,
  52. bias=False,
  53. linear_method=linear_method,
  54. )
  55. self.w1 = ColumnParallelLinear(
  56. hidden_size,
  57. intermediate_size,
  58. bias=False,
  59. linear_method=linear_method,
  60. )
  61. else:
  62. self.merge_weight = True
  63. self.gate_up_proj = MergedColumnParallelLinear(
  64. hidden_size,
  65. [intermediate_size] * 2,
  66. bias=False,
  67. linear_method=linear_method,
  68. )
  69. self.c_proj = RowParallelLinear(
  70. intermediate_size,
  71. hidden_size,
  72. bias=False,
  73. linear_method=linear_method,
  74. )
  75. if hidden_act != "silu":
  76. raise ValueError(f"Unsupported activation: {hidden_act}. "
  77. "Only silu is supported for now.")
  78. self.act_fn = SiluAndMul()
  79. def forward(self, x):
  80. if self.merge_weight:
  81. gate_up, _ = self.gate_up_proj(x)
  82. else:
  83. up, _ = self.w1(x)
  84. gate, _ = self.w2(x)
  85. gate_up = torch.cat([gate, up], dim=-1)
  86. x = self.act_fn(gate_up)
  87. x, _ = self.c_proj(x)
  88. return x
  89. class QWenAttention(nn.Module):
  90. def __init__(
  91. self,
  92. hidden_size: int,
  93. num_heads: int,
  94. max_position_embeddings: int,
  95. rope_theta: float = 10000,
  96. rope_scaling: Optional[Dict[str, Any]] = None,
  97. linear_method: Optional[LinearMethodBase] = None,
  98. ):
  99. super().__init__()
  100. self.hidden_size = hidden_size
  101. tensor_model_parallel_world_size = (
  102. get_tensor_model_parallel_world_size())
  103. self.total_num_heads = num_heads
  104. assert self.total_num_heads % tensor_model_parallel_world_size == 0
  105. self.num_heads = (self.total_num_heads //
  106. tensor_model_parallel_world_size)
  107. self.head_dim = hidden_size // self.total_num_heads
  108. self.c_attn = QKVParallelLinear(
  109. hidden_size,
  110. self.head_dim,
  111. self.total_num_heads,
  112. bias=True,
  113. linear_method=linear_method,
  114. )
  115. self.c_proj = RowParallelLinear(
  116. self.total_num_heads * self.head_dim,
  117. hidden_size,
  118. bias=False,
  119. linear_method=linear_method,
  120. )
  121. self.scaling = self.head_dim**-0.5
  122. is_neox_style = (True if linear_method is None
  123. or linear_method.quant_config.rope_style() is None
  124. else linear_method.quant_config.rope_style())
  125. self.rotary_emb = get_rope(
  126. self.head_dim,
  127. rotary_dim=self.head_dim,
  128. max_position=max_position_embeddings,
  129. base=rope_theta,
  130. rope_scaling=rope_scaling,
  131. is_neox_style=is_neox_style,
  132. )
  133. self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling)
  134. def forward(
  135. self,
  136. positions: torch.Tensor,
  137. hidden_states: torch.Tensor,
  138. kv_cache: KVCache,
  139. input_metadata: InputMetadata,
  140. ) -> torch.Tensor:
  141. qkv, _ = self.c_attn(hidden_states)
  142. q, k, v = qkv.chunk(chunks=3, dim=-1)
  143. q, k = self.rotary_emb(positions, q, k)
  144. k_cache, v_cache = kv_cache
  145. attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
  146. output, _ = self.c_proj(attn_output)
  147. return output
  148. class QWenBlock(nn.Module):
  149. def __init__(
  150. self,
  151. config: QWenConfig,
  152. linear_method: Optional[LinearMethodBase] = None,
  153. ):
  154. super().__init__()
  155. self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  156. rope_theta = getattr(config, "rope_theta", 10000)
  157. rope_scaling = getattr(config, "rope_scaling", None)
  158. self.attn = QWenAttention(
  159. config.hidden_size,
  160. config.num_attention_heads,
  161. config.max_position_embeddings,
  162. rope_theta=rope_theta,
  163. rope_scaling=rope_scaling,
  164. linear_method=linear_method,
  165. )
  166. self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  167. self.mlp = QWenMLP(
  168. config.hidden_size,
  169. config.intermediate_size // 2,
  170. linear_method=linear_method,
  171. )
  172. def forward(
  173. self,
  174. positions: torch.Tensor,
  175. hidden_states: torch.Tensor,
  176. kv_cache: KVCache,
  177. input_metadata: InputMetadata,
  178. residual: Optional[torch.Tensor],
  179. ) -> Tuple[torch.Tensor, torch.Tensor]:
  180. # Self Attention
  181. if residual is None:
  182. residual = hidden_states
  183. hidden_states = self.ln_1(hidden_states)
  184. else:
  185. hidden_states, residual = self.ln_1(hidden_states, residual)
  186. hidden_states = self.attn(
  187. positions=positions,
  188. hidden_states=hidden_states,
  189. kv_cache=kv_cache,
  190. input_metadata=input_metadata,
  191. )
  192. # Fully Connected
  193. hidden_states, residual = self.ln_2(hidden_states, residual)
  194. hidden_states = self.mlp(hidden_states)
  195. return hidden_states, residual
  196. class QWenModel(nn.Module):
  197. def __init__(
  198. self,
  199. config: QWenConfig,
  200. linear_method: Optional[LinearMethodBase] = None,
  201. ):
  202. super().__init__()
  203. self.config = config
  204. self.vocab_size = config.vocab_size
  205. self.wte = VocabParallelEmbedding(config.vocab_size,
  206. config.hidden_size,
  207. linear_method=linear_method)
  208. self.h = nn.ModuleList([
  209. QWenBlock(config, linear_method)
  210. for _ in range(config.num_hidden_layers)
  211. ])
  212. self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  213. def forward(
  214. self,
  215. input_ids: torch.Tensor,
  216. positions: torch.Tensor,
  217. kv_caches: List[KVCache],
  218. input_metadata: InputMetadata,
  219. ) -> torch.Tensor:
  220. hidden_states = self.wte(input_ids)
  221. residual = None
  222. for i in range(len(self.h)):
  223. layer = self.h[i]
  224. hidden_states, residual = layer(
  225. positions,
  226. hidden_states,
  227. kv_caches[i],
  228. input_metadata,
  229. residual,
  230. )
  231. hidden_states, _ = self.ln_f(hidden_states, residual)
  232. return hidden_states
  233. class QWenLMHeadModel(nn.Module):
  234. def __init__(
  235. self,
  236. config: QWenConfig,
  237. linear_method: Optional[LinearMethodBase] = None,
  238. ):
  239. super().__init__()
  240. self.config = config
  241. self.linear_method = linear_method
  242. self.transformer = QWenModel(config, linear_method)
  243. self.lm_head = ParallelLMHead(config.vocab_size,
  244. config.hidden_size,
  245. linear_method=linear_method)
  246. self.sampler = Sampler(config.vocab_size)
  247. self.quant_sampler = QuantSampler(config.vocab_size)
  248. def forward(
  249. self,
  250. input_ids: torch.Tensor,
  251. positions: torch.Tensor,
  252. kv_caches: List[KVCache],
  253. input_metadata: InputMetadata,
  254. ) -> torch.Tensor:
  255. hidden_states = self.transformer(input_ids, positions, kv_caches,
  256. input_metadata)
  257. return hidden_states
  258. def sample(
  259. self,
  260. hidden_states: torch.Tensor,
  261. sampling_metadata: SamplingMetadata,
  262. ) -> Optional[SamplerOutput]:
  263. if (self.linear_method is not None
  264. and not self.linear_method.quant_config.merge_weight()):
  265. next_tokens = self.quant_sampler(self.lm_head(hidden_states),
  266. sampling_metadata)
  267. else:
  268. next_tokens = self.sampler(self.lm_head.weight, hidden_states,
  269. sampling_metadata)
  270. return next_tokens
  271. def load_weights(
  272. self,
  273. model_name_or_path: str,
  274. cache_dir: Optional[str] = None,
  275. load_format: str = "auto",
  276. revision: Optional[str] = None,
  277. ):
  278. stacked_params_mapping = [
  279. # (param_name, shard_name, shard_id)
  280. ("gate_up_proj", "w2", 0),
  281. ("gate_up_proj", "w1", 1),
  282. ]
  283. if (self.linear_method is not None
  284. and not self.linear_method.quant_config.merge_weight()):
  285. stacked_params_mapping = []
  286. params_dict = dict(self.named_parameters())
  287. for name, loaded_weight in hf_model_weights_iterator(
  288. model_name_or_path, cache_dir, load_format, revision,
  289. self.config):
  290. if "rotary_emb.inv_freq" in name:
  291. continue
  292. for param_name, weight_name, shard_id in stacked_params_mapping:
  293. if weight_name not in name:
  294. continue
  295. name = name.replace(weight_name, param_name)
  296. # Skip loading extra bias for GPTQ models.
  297. if name.endswith(".bias") and name not in params_dict:
  298. continue
  299. param = params_dict[name]
  300. weight_loader = param.weight_loader
  301. weight_loader(param, loaded_weight, shard_id)
  302. break
  303. else:
  304. # Skip loading extra bias for GPTQ models.
  305. if name.endswith(".bias") and name not in params_dict:
  306. continue
  307. param = params_dict[name]
  308. weight_loader = getattr(param, "weight_loader",
  309. default_weight_loader)
  310. weight_loader(param, loaded_weight)