1
0

qwen.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. # coding=utf-8
  2. # Adapted from
  3. # https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
  4. # Copyright (c) Alibaba Cloud.
  5. # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
  6. """Inference-only QWen model compatible with HuggingFace weights."""
  7. from typing import Any, Dict, Iterable, List, Optional, Tuple
  8. import torch
  9. from torch import nn
  10. from transformers import PretrainedConfig
  11. from aphrodite.attention import Attention, AttentionMetadata
  12. from aphrodite.common.config import CacheConfig
  13. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  14. from aphrodite.distributed import get_tensor_model_parallel_world_size
  15. from aphrodite.modeling.layers.activation import SiluAndMul
  16. from aphrodite.modeling.layers.layernorm import RMSNorm
  17. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  18. QKVParallelLinear,
  19. RowParallelLinear)
  20. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  21. from aphrodite.modeling.layers.rotary_embedding import get_rope
  22. from aphrodite.modeling.layers.sampler import Sampler
  23. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  24. ParallelLMHead, VocabParallelEmbedding)
  25. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  26. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  27. from aphrodite.quantization.base_config import QuantizationConfig
  28. class QWenMLP(nn.Module):
  29. def __init__(
  30. self,
  31. hidden_size: int,
  32. intermediate_size: int,
  33. hidden_act: str = "silu",
  34. quant_config: Optional[QuantizationConfig] = None,
  35. ):
  36. super().__init__()
  37. self.gate_up_proj = MergedColumnParallelLinear(
  38. hidden_size, [intermediate_size] * 2,
  39. bias=False,
  40. quant_config=quant_config)
  41. self.c_proj = RowParallelLinear(intermediate_size,
  42. hidden_size,
  43. bias=False,
  44. quant_config=quant_config)
  45. if hidden_act != "silu":
  46. raise ValueError(f"Unsupported activation: {hidden_act}. "
  47. "Only silu is supported for now.")
  48. self.act_fn = SiluAndMul()
  49. def forward(self, x):
  50. gate_up, _ = self.gate_up_proj(x)
  51. x = self.act_fn(gate_up)
  52. x, _ = self.c_proj(x)
  53. return x
  54. class QWenAttention(nn.Module):
  55. def __init__(
  56. self,
  57. hidden_size: int,
  58. num_heads: int,
  59. max_position_embeddings: int,
  60. rope_theta: float = 10000,
  61. rope_scaling: Optional[Dict[str, Any]] = None,
  62. cache_config: Optional[CacheConfig] = None,
  63. quant_config: Optional[QuantizationConfig] = None,
  64. ):
  65. super().__init__()
  66. self.hidden_size = hidden_size
  67. tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
  68. )
  69. self.total_num_heads = num_heads
  70. assert self.total_num_heads % tensor_model_parallel_world_size == 0
  71. self.num_heads = (self.total_num_heads //
  72. tensor_model_parallel_world_size)
  73. self.head_dim = hidden_size // self.total_num_heads
  74. self.c_attn = QKVParallelLinear(
  75. hidden_size,
  76. self.head_dim,
  77. self.total_num_heads,
  78. bias=True,
  79. quant_config=quant_config,
  80. )
  81. self.c_proj = RowParallelLinear(
  82. self.total_num_heads * self.head_dim,
  83. hidden_size,
  84. bias=False,
  85. quant_config=quant_config,
  86. )
  87. self.scaling = self.head_dim**-0.5
  88. self.rotary_emb = get_rope(
  89. self.head_dim,
  90. rotary_dim=self.head_dim,
  91. max_position=max_position_embeddings,
  92. base=rope_theta,
  93. rope_scaling=rope_scaling,
  94. )
  95. self.attn = Attention(self.num_heads,
  96. self.head_dim,
  97. self.scaling,
  98. cache_config=cache_config,
  99. quant_config=quant_config)
  100. def forward(
  101. self,
  102. positions: torch.Tensor,
  103. hidden_states: torch.Tensor,
  104. kv_cache: torch.Tensor,
  105. attn_metadata: AttentionMetadata,
  106. ) -> torch.Tensor:
  107. qkv, _ = self.c_attn(hidden_states)
  108. q, k, v = qkv.chunk(chunks=3, dim=-1)
  109. q, k = self.rotary_emb(positions, q, k)
  110. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  111. output, _ = self.c_proj(attn_output)
  112. return output
  113. class QWenBlock(nn.Module):
  114. def __init__(
  115. self,
  116. config: PretrainedConfig,
  117. cache_config: Optional[CacheConfig] = None,
  118. quant_config: Optional[QuantizationConfig] = None,
  119. ):
  120. super().__init__()
  121. self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  122. rope_theta = getattr(config, "rope_theta", 10000)
  123. rope_scaling = getattr(config, "rope_scaling", None)
  124. self.attn = QWenAttention(config.hidden_size,
  125. config.num_attention_heads,
  126. config.max_position_embeddings,
  127. rope_theta=rope_theta,
  128. rope_scaling=rope_scaling,
  129. cache_config=cache_config,
  130. quant_config=quant_config)
  131. self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  132. self.mlp = QWenMLP(config.hidden_size,
  133. config.intermediate_size // 2,
  134. quant_config=quant_config)
  135. def forward(
  136. self,
  137. positions: torch.Tensor,
  138. hidden_states: torch.Tensor,
  139. kv_cache: torch.Tensor,
  140. attn_metadata: AttentionMetadata,
  141. residual: Optional[torch.Tensor],
  142. ) -> Tuple[torch.Tensor, torch.Tensor]:
  143. # Self Attention
  144. if residual is None:
  145. residual = hidden_states
  146. hidden_states = self.ln_1(hidden_states)
  147. else:
  148. hidden_states, residual = self.ln_1(hidden_states, residual)
  149. hidden_states = self.attn(
  150. positions=positions,
  151. hidden_states=hidden_states,
  152. kv_cache=kv_cache,
  153. attn_metadata=attn_metadata,
  154. )
  155. # Fully Connected
  156. hidden_states, residual = self.ln_2(hidden_states, residual)
  157. hidden_states = self.mlp(hidden_states)
  158. return hidden_states, residual
  159. class QWenModel(nn.Module):
  160. def __init__(
  161. self,
  162. config: PretrainedConfig,
  163. cache_config: Optional[CacheConfig] = None,
  164. quant_config: Optional[QuantizationConfig] = None,
  165. ):
  166. super().__init__()
  167. self.config = config
  168. self.vocab_size = config.vocab_size
  169. self.wte = VocabParallelEmbedding(
  170. config.vocab_size,
  171. config.hidden_size,
  172. )
  173. self.h = nn.ModuleList([
  174. QWenBlock(config, cache_config, quant_config)
  175. for _ in range(config.num_hidden_layers)
  176. ])
  177. self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  178. def forward(
  179. self,
  180. input_ids: torch.Tensor,
  181. positions: torch.Tensor,
  182. kv_caches: List[torch.Tensor],
  183. attn_metadata: AttentionMetadata,
  184. ) -> torch.Tensor:
  185. hidden_states = self.wte(input_ids)
  186. residual = None
  187. for i in range(len(self.h)):
  188. layer = self.h[i]
  189. hidden_states, residual = layer(
  190. positions,
  191. hidden_states,
  192. kv_caches[i],
  193. attn_metadata,
  194. residual,
  195. )
  196. hidden_states, _ = self.ln_f(hidden_states, residual)
  197. return hidden_states
  198. class QWenLMHeadModel(nn.Module):
  199. def __init__(
  200. self,
  201. config: PretrainedConfig,
  202. cache_config: Optional[CacheConfig] = None,
  203. quant_config: Optional[QuantizationConfig] = None,
  204. ):
  205. super().__init__()
  206. self.config = config
  207. self.quant_config = quant_config
  208. self.transformer = QWenModel(config, cache_config, quant_config)
  209. self.lm_head = ParallelLMHead(config.vocab_size,
  210. config.hidden_size,
  211. quant_config=quant_config)
  212. self.logits_processor = LogitsProcessor(config.vocab_size)
  213. self.sampler = Sampler()
  214. def forward(
  215. self,
  216. input_ids: torch.Tensor,
  217. positions: torch.Tensor,
  218. kv_caches: List[torch.Tensor],
  219. attn_metadata: AttentionMetadata,
  220. intermediate_tensors: Optional[IntermediateTensors] = None,
  221. ) -> torch.Tensor:
  222. hidden_states = self.transformer(input_ids, positions, kv_caches,
  223. attn_metadata)
  224. return hidden_states
  225. def compute_logits(
  226. self,
  227. hidden_states: torch.Tensor,
  228. sampling_metadata: SamplingMetadata,
  229. ) -> Optional[torch.Tensor]:
  230. logits = self.logits_processor(self.lm_head, hidden_states,
  231. sampling_metadata)
  232. return logits
  233. def sample(
  234. self,
  235. logits: torch.Tensor,
  236. sampling_metadata: SamplingMetadata,
  237. ) -> Optional[SamplerOutput]:
  238. next_tokens = self.sampler(logits, sampling_metadata)
  239. return next_tokens
  240. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  241. stacked_params_mapping = [
  242. # (param_name, shard_name, shard_id)
  243. ("gate_up_proj", "w2", 0),
  244. ("gate_up_proj", "w1", 1),
  245. ]
  246. params_dict = dict(self.named_parameters())
  247. for name, loaded_weight in weights:
  248. if "rotary_emb.inv_freq" in name:
  249. continue
  250. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  251. if weight_name not in name:
  252. continue
  253. name = name.replace(weight_name, param_name)
  254. # Skip loading extra bias for GPTQ models.
  255. if name.endswith(".bias") and name not in params_dict:
  256. continue
  257. param = params_dict[name]
  258. weight_loader = param.weight_loader
  259. weight_loader(param, loaded_weight, shard_id)
  260. break
  261. else:
  262. # Skip loading extra bias for GPTQ models.
  263. if name.endswith(".bias") and name not in params_dict:
  264. continue
  265. param = params_dict[name]
  266. weight_loader = getattr(param, "weight_loader",
  267. default_weight_loader)
  268. weight_loader(param, loaded_weight)