qwen.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. # coding=utf-8
  2. # Adapted from
  3. # https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
  4. # Copyright (c) Alibaba Cloud.
  5. # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
  6. """Inference-only QWen model compatible with HuggingFace weights."""
  7. from typing import Any, Dict, Iterable, List, Optional, Tuple
  8. import torch
  9. from torch import nn
  10. from transformers import PretrainedConfig
  11. from aphrodite.attention import Attention, AttentionMetadata
  12. from aphrodite.common.sequence import SamplerOutput
  13. from aphrodite.distributed import get_tensor_model_parallel_world_size
  14. from aphrodite.modeling.layers.activation import SiluAndMul
  15. from aphrodite.modeling.layers.layernorm import RMSNorm
  16. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  17. QKVParallelLinear,
  18. RowParallelLinear)
  19. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  20. from aphrodite.modeling.layers.rotary_embedding import get_rope
  21. from aphrodite.modeling.layers.sampler import Sampler
  22. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  23. ParallelLMHead, VocabParallelEmbedding)
  24. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  25. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  26. from aphrodite.quantization.base_config import QuantizationConfig
  27. class QWenMLP(nn.Module):
  28. def __init__(
  29. self,
  30. hidden_size: int,
  31. intermediate_size: int,
  32. hidden_act: str = "silu",
  33. quant_config: Optional[QuantizationConfig] = None,
  34. ):
  35. super().__init__()
  36. self.gate_up_proj = MergedColumnParallelLinear(
  37. hidden_size, [intermediate_size] * 2,
  38. bias=False,
  39. quant_config=quant_config)
  40. self.c_proj = RowParallelLinear(intermediate_size,
  41. hidden_size,
  42. bias=False,
  43. quant_config=quant_config)
  44. if hidden_act != "silu":
  45. raise ValueError(f"Unsupported activation: {hidden_act}. "
  46. "Only silu is supported for now.")
  47. self.act_fn = SiluAndMul()
  48. def forward(self, x):
  49. gate_up, _ = self.gate_up_proj(x)
  50. x = self.act_fn(gate_up)
  51. x, _ = self.c_proj(x)
  52. return x
  53. class QWenAttention(nn.Module):
  54. def __init__(
  55. self,
  56. hidden_size: int,
  57. num_heads: int,
  58. max_position_embeddings: int,
  59. rope_theta: float = 10000,
  60. rope_scaling: Optional[Dict[str, Any]] = None,
  61. quant_config: Optional[QuantizationConfig] = None,
  62. ):
  63. super().__init__()
  64. self.hidden_size = hidden_size
  65. tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
  66. )
  67. self.total_num_heads = num_heads
  68. assert self.total_num_heads % tensor_model_parallel_world_size == 0
  69. self.num_heads = (self.total_num_heads //
  70. tensor_model_parallel_world_size)
  71. self.head_dim = hidden_size // self.total_num_heads
  72. self.c_attn = QKVParallelLinear(
  73. hidden_size,
  74. self.head_dim,
  75. self.total_num_heads,
  76. bias=True,
  77. quant_config=quant_config,
  78. )
  79. self.c_proj = RowParallelLinear(
  80. self.total_num_heads * self.head_dim,
  81. hidden_size,
  82. bias=False,
  83. quant_config=quant_config,
  84. )
  85. self.scaling = self.head_dim**-0.5
  86. self.rotary_emb = get_rope(
  87. self.head_dim,
  88. rotary_dim=self.head_dim,
  89. max_position=max_position_embeddings,
  90. base=rope_theta,
  91. rope_scaling=rope_scaling,
  92. )
  93. self.attn = Attention(self.num_heads, self.head_dim, self.scaling)
  94. def forward(
  95. self,
  96. positions: torch.Tensor,
  97. hidden_states: torch.Tensor,
  98. kv_cache: torch.Tensor,
  99. attn_metadata: AttentionMetadata,
  100. ) -> torch.Tensor:
  101. qkv, _ = self.c_attn(hidden_states)
  102. q, k, v = qkv.chunk(chunks=3, dim=-1)
  103. q, k = self.rotary_emb(positions, q, k)
  104. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  105. output, _ = self.c_proj(attn_output)
  106. return output
  107. class QWenBlock(nn.Module):
  108. def __init__(
  109. self,
  110. config: PretrainedConfig,
  111. quant_config: Optional[QuantizationConfig] = None,
  112. ):
  113. super().__init__()
  114. self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  115. rope_theta = getattr(config, "rope_theta", 10000)
  116. rope_scaling = getattr(config, "rope_scaling", None)
  117. self.attn = QWenAttention(config.hidden_size,
  118. config.num_attention_heads,
  119. config.max_position_embeddings,
  120. rope_theta=rope_theta,
  121. rope_scaling=rope_scaling,
  122. quant_config=quant_config)
  123. self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  124. self.mlp = QWenMLP(config.hidden_size,
  125. config.intermediate_size // 2,
  126. quant_config=quant_config)
  127. def forward(
  128. self,
  129. positions: torch.Tensor,
  130. hidden_states: torch.Tensor,
  131. kv_cache: torch.Tensor,
  132. attn_metadata: AttentionMetadata,
  133. residual: Optional[torch.Tensor],
  134. ) -> Tuple[torch.Tensor, torch.Tensor]:
  135. # Self Attention
  136. if residual is None:
  137. residual = hidden_states
  138. hidden_states = self.ln_1(hidden_states)
  139. else:
  140. hidden_states, residual = self.ln_1(hidden_states, residual)
  141. hidden_states = self.attn(
  142. positions=positions,
  143. hidden_states=hidden_states,
  144. kv_cache=kv_cache,
  145. attn_metadata=attn_metadata,
  146. )
  147. # Fully Connected
  148. hidden_states, residual = self.ln_2(hidden_states, residual)
  149. hidden_states = self.mlp(hidden_states)
  150. return hidden_states, residual
  151. class QWenModel(nn.Module):
  152. def __init__(
  153. self,
  154. config: PretrainedConfig,
  155. quant_config: Optional[QuantizationConfig] = None,
  156. ):
  157. super().__init__()
  158. self.config = config
  159. self.vocab_size = config.vocab_size
  160. self.wte = VocabParallelEmbedding(
  161. config.vocab_size,
  162. config.hidden_size,
  163. )
  164. self.h = nn.ModuleList([
  165. QWenBlock(config, quant_config)
  166. for _ in range(config.num_hidden_layers)
  167. ])
  168. self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  169. def forward(
  170. self,
  171. input_ids: torch.Tensor,
  172. positions: torch.Tensor,
  173. kv_caches: List[torch.Tensor],
  174. attn_metadata: AttentionMetadata,
  175. ) -> torch.Tensor:
  176. hidden_states = self.wte(input_ids)
  177. residual = None
  178. for i in range(len(self.h)):
  179. layer = self.h[i]
  180. hidden_states, residual = layer(
  181. positions,
  182. hidden_states,
  183. kv_caches[i],
  184. attn_metadata,
  185. residual,
  186. )
  187. hidden_states, _ = self.ln_f(hidden_states, residual)
  188. return hidden_states
  189. class QWenLMHeadModel(nn.Module):
  190. def __init__(
  191. self,
  192. config: PretrainedConfig,
  193. quant_config: Optional[QuantizationConfig] = None,
  194. ):
  195. super().__init__()
  196. self.config = config
  197. self.quant_config = quant_config
  198. self.transformer = QWenModel(config, quant_config)
  199. self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
  200. self.logits_processor = LogitsProcessor(config.vocab_size)
  201. self.sampler = Sampler()
  202. def forward(
  203. self,
  204. input_ids: torch.Tensor,
  205. positions: torch.Tensor,
  206. kv_caches: List[torch.Tensor],
  207. attn_metadata: AttentionMetadata,
  208. ) -> torch.Tensor:
  209. hidden_states = self.transformer(input_ids, positions, kv_caches,
  210. attn_metadata)
  211. return hidden_states
  212. def compute_logits(self, hidden_states: torch.Tensor,
  213. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  214. logits = self.logits_processor(self.lm_head.weight, hidden_states,
  215. sampling_metadata)
  216. return logits
  217. def sample(
  218. self,
  219. logits: torch.Tensor,
  220. sampling_metadata: SamplingMetadata,
  221. ) -> Optional[SamplerOutput]:
  222. next_tokens = self.sampler(logits, sampling_metadata)
  223. return next_tokens
  224. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  225. stacked_params_mapping = [
  226. # (param_name, shard_name, shard_id)
  227. ("gate_up_proj", "w2", 0),
  228. ("gate_up_proj", "w1", 1),
  229. ]
  230. params_dict = dict(self.named_parameters())
  231. for name, loaded_weight in weights:
  232. if "rotary_emb.inv_freq" in name:
  233. continue
  234. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  235. if weight_name not in name:
  236. continue
  237. name = name.replace(weight_name, param_name)
  238. # Skip loading extra bias for GPTQ models.
  239. if name.endswith(".bias") and name not in params_dict:
  240. continue
  241. param = params_dict[name]
  242. weight_loader = param.weight_loader
  243. weight_loader(param, loaded_weight, shard_id)
  244. break
  245. else:
  246. # Skip loading extra bias for GPTQ models.
  247. if name.endswith(".bias") and name not in params_dict:
  248. continue
  249. param = params_dict[name]
  250. weight_loader = getattr(param, "weight_loader",
  251. default_weight_loader)
  252. weight_loader(param, loaded_weight)