qwen.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. # coding=utf-8
  2. # Adapted from
  3. # https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
  4. # Copyright (c) Alibaba Cloud.
  5. # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
  6. """Inference-only QWen model compatible with HuggingFace weights."""
  7. from typing import Any, Dict, List, Optional, Tuple
  8. import torch
  9. from torch import nn
  10. from aphrodite.attention import Attention, AttentionMetadata
  11. from aphrodite.modeling.layers.activation import SiluAndMul
  12. from aphrodite.modeling.layers.layernorm import RMSNorm
  13. from aphrodite.modeling.layers.linear import (
  14. LinearMethodBase,
  15. MergedColumnParallelLinear,
  16. QKVParallelLinear,
  17. RowParallelLinear,
  18. ColumnParallelLinear,
  19. )
  20. from aphrodite.modeling.layers.rotary_embedding import get_rope
  21. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  22. from aphrodite.modeling.layers.sampler import Sampler
  23. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  24. VocabParallelEmbedding,
  25. ParallelLMHead,
  26. )
  27. from aphrodite.distributed import (
  28. get_tensor_model_parallel_world_size, )
  29. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  30. from aphrodite.modeling.hf_downloader import (
  31. default_weight_loader,
  32. hf_model_weights_iterator,
  33. )
  34. from aphrodite.common.sequence import SamplerOutput
  35. from aphrodite.transformers_utils.configs.qwen import QWenConfig
  36. class QWenMLP(nn.Module):
  37. def __init__(
  38. self,
  39. hidden_size: int,
  40. intermediate_size: int,
  41. hidden_act: str = "silu",
  42. linear_method: Optional[LinearMethodBase] = None,
  43. ):
  44. super().__init__()
  45. if (linear_method is not None
  46. and not linear_method.quant_config.merge_weight()):
  47. self.merge_weight = False
  48. self.w2 = ColumnParallelLinear(
  49. hidden_size,
  50. intermediate_size,
  51. bias=False,
  52. linear_method=linear_method,
  53. )
  54. self.w1 = ColumnParallelLinear(
  55. hidden_size,
  56. intermediate_size,
  57. bias=False,
  58. linear_method=linear_method,
  59. )
  60. else:
  61. self.merge_weight = True
  62. self.gate_up_proj = MergedColumnParallelLinear(
  63. hidden_size,
  64. [intermediate_size] * 2,
  65. bias=False,
  66. linear_method=linear_method,
  67. )
  68. self.c_proj = RowParallelLinear(
  69. intermediate_size,
  70. hidden_size,
  71. bias=False,
  72. linear_method=linear_method,
  73. )
  74. if hidden_act != "silu":
  75. raise ValueError(f"Unsupported activation: {hidden_act}. "
  76. "Only silu is supported for now.")
  77. self.act_fn = SiluAndMul()
  78. def forward(self, x):
  79. if self.merge_weight:
  80. gate_up, _ = self.gate_up_proj(x)
  81. else:
  82. up, _ = self.w1(x)
  83. gate, _ = self.w2(x)
  84. gate_up = torch.cat([gate, up], dim=-1)
  85. x = self.act_fn(gate_up)
  86. x, _ = self.c_proj(x)
  87. return x
  88. class QWenAttention(nn.Module):
  89. def __init__(
  90. self,
  91. hidden_size: int,
  92. num_heads: int,
  93. max_position_embeddings: int,
  94. rope_theta: float = 10000,
  95. rope_scaling: Optional[Dict[str, Any]] = None,
  96. linear_method: Optional[LinearMethodBase] = None,
  97. ):
  98. super().__init__()
  99. self.hidden_size = hidden_size
  100. tensor_model_parallel_world_size = (
  101. get_tensor_model_parallel_world_size())
  102. self.total_num_heads = num_heads
  103. assert self.total_num_heads % tensor_model_parallel_world_size == 0
  104. self.num_heads = (self.total_num_heads //
  105. tensor_model_parallel_world_size)
  106. self.head_dim = hidden_size // self.total_num_heads
  107. self.c_attn = QKVParallelLinear(
  108. hidden_size,
  109. self.head_dim,
  110. self.total_num_heads,
  111. bias=True,
  112. linear_method=linear_method,
  113. )
  114. self.c_proj = RowParallelLinear(
  115. self.total_num_heads * self.head_dim,
  116. hidden_size,
  117. bias=False,
  118. linear_method=linear_method,
  119. )
  120. self.scaling = self.head_dim**-0.5
  121. is_neox_style = (True if linear_method is None
  122. or linear_method.quant_config.rope_style() is None
  123. else linear_method.quant_config.rope_style())
  124. self.rotary_emb = get_rope(
  125. self.head_dim,
  126. rotary_dim=self.head_dim,
  127. max_position=max_position_embeddings,
  128. base=rope_theta,
  129. rope_scaling=rope_scaling,
  130. is_neox_style=is_neox_style,
  131. )
  132. self.attn = Attention(self.num_heads, self.head_dim, self.scaling)
  133. def forward(
  134. self,
  135. positions: torch.Tensor,
  136. hidden_states: torch.Tensor,
  137. kv_cache: torch.Tensor,
  138. attn_metadata: AttentionMetadata,
  139. ) -> torch.Tensor:
  140. qkv, _ = self.c_attn(hidden_states)
  141. q, k, v = qkv.chunk(chunks=3, dim=-1)
  142. q, k = self.rotary_emb(positions, q, k)
  143. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  144. output, _ = self.c_proj(attn_output)
  145. return output
  146. class QWenBlock(nn.Module):
  147. def __init__(
  148. self,
  149. config: QWenConfig,
  150. linear_method: Optional[LinearMethodBase] = None,
  151. ):
  152. super().__init__()
  153. self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  154. rope_theta = getattr(config, "rope_theta", 10000)
  155. rope_scaling = getattr(config, "rope_scaling", None)
  156. self.attn = QWenAttention(
  157. config.hidden_size,
  158. config.num_attention_heads,
  159. config.max_position_embeddings,
  160. rope_theta=rope_theta,
  161. rope_scaling=rope_scaling,
  162. linear_method=linear_method,
  163. )
  164. self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  165. self.mlp = QWenMLP(
  166. config.hidden_size,
  167. config.intermediate_size // 2,
  168. linear_method=linear_method,
  169. )
  170. def forward(
  171. self,
  172. positions: torch.Tensor,
  173. hidden_states: torch.Tensor,
  174. kv_cache: torch.Tensor,
  175. attn_metadata: AttentionMetadata,
  176. residual: Optional[torch.Tensor],
  177. ) -> Tuple[torch.Tensor, torch.Tensor]:
  178. # Self Attention
  179. if residual is None:
  180. residual = hidden_states
  181. hidden_states = self.ln_1(hidden_states)
  182. else:
  183. hidden_states, residual = self.ln_1(hidden_states, residual)
  184. hidden_states = self.attn(
  185. positions=positions,
  186. hidden_states=hidden_states,
  187. kv_cache=kv_cache,
  188. attn_metadata=attn_metadata,
  189. )
  190. # Fully Connected
  191. hidden_states, residual = self.ln_2(hidden_states, residual)
  192. hidden_states = self.mlp(hidden_states)
  193. return hidden_states, residual
  194. class QWenModel(nn.Module):
  195. def __init__(
  196. self,
  197. config: QWenConfig,
  198. linear_method: Optional[LinearMethodBase] = None,
  199. ):
  200. super().__init__()
  201. self.config = config
  202. self.vocab_size = config.vocab_size
  203. self.wte = VocabParallelEmbedding(config.vocab_size,
  204. config.hidden_size,
  205. linear_method=linear_method)
  206. self.h = nn.ModuleList([
  207. QWenBlock(config, linear_method)
  208. for _ in range(config.num_hidden_layers)
  209. ])
  210. self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  211. def forward(
  212. self,
  213. input_ids: torch.Tensor,
  214. positions: torch.Tensor,
  215. kv_caches: List[torch.Tensor],
  216. attn_metadata: AttentionMetadata,
  217. ) -> torch.Tensor:
  218. hidden_states = self.wte(input_ids)
  219. residual = None
  220. for i in range(len(self.h)):
  221. layer = self.h[i]
  222. hidden_states, residual = layer(
  223. positions,
  224. hidden_states,
  225. kv_caches[i],
  226. attn_metadata,
  227. residual,
  228. )
  229. hidden_states, _ = self.ln_f(hidden_states, residual)
  230. return hidden_states
  231. class QWenLMHeadModel(nn.Module):
  232. def __init__(
  233. self,
  234. config: QWenConfig,
  235. linear_method: Optional[LinearMethodBase] = None,
  236. ):
  237. super().__init__()
  238. self.config = config
  239. self.linear_method = linear_method
  240. self.transformer = QWenModel(config, linear_method)
  241. self.lm_head = ParallelLMHead(config.vocab_size,
  242. config.hidden_size,
  243. linear_method=linear_method)
  244. self.logits_processor = LogitsProcessor(config.vocab_size,
  245. config.tokenizer_vocab_size)
  246. self.sampler = Sampler()
  247. def forward(
  248. self,
  249. input_ids: torch.Tensor,
  250. positions: torch.Tensor,
  251. kv_caches: List[torch.Tensor],
  252. attn_metadata: AttentionMetadata,
  253. ) -> torch.Tensor:
  254. hidden_states = self.transformer(input_ids, positions, kv_caches,
  255. attn_metadata)
  256. return hidden_states
  257. def compute_logits(self, hidden_states: torch.Tensor,
  258. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  259. logits = self.logits_processor(self.lm_head, hidden_states,
  260. sampling_metadata)
  261. return logits
  262. def sample(
  263. self,
  264. logits: torch.Tensor,
  265. sampling_metadata: SamplingMetadata,
  266. ) -> Optional[SamplerOutput]:
  267. next_tokens = self.sampler(logits, sampling_metadata)
  268. return next_tokens
  269. def load_weights(
  270. self,
  271. model_name_or_path: str,
  272. cache_dir: Optional[str] = None,
  273. load_format: str = "auto",
  274. revision: Optional[str] = None,
  275. ):
  276. stacked_params_mapping = [
  277. # (param_name, shard_name, shard_id)
  278. ("gate_up_proj", "w2", 0),
  279. ("gate_up_proj", "w1", 1),
  280. ]
  281. if (self.linear_method is not None
  282. and not self.linear_method.quant_config.merge_weight()):
  283. stacked_params_mapping = []
  284. params_dict = dict(self.named_parameters())
  285. for name, loaded_weight in hf_model_weights_iterator(
  286. model_name_or_path, cache_dir, load_format, revision,
  287. self.config):
  288. if "rotary_emb.inv_freq" in name:
  289. continue
  290. for param_name, weight_name, shard_id in stacked_params_mapping:
  291. if weight_name not in name:
  292. continue
  293. name = name.replace(weight_name, param_name)
  294. # Skip loading extra bias for GPTQ models.
  295. if name.endswith(".bias") and name not in params_dict:
  296. continue
  297. param = params_dict[name]
  298. weight_loader = param.weight_loader
  299. weight_loader(param, loaded_weight, shard_id)
  300. break
  301. else:
  302. # Skip loading extra bias for GPTQ models.
  303. if name.endswith(".bias") and name not in params_dict:
  304. continue
  305. param = params_dict[name]
  306. weight_loader = getattr(param, "weight_loader",
  307. default_weight_loader)
  308. weight_loader(param, loaded_weight)