qwen.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. # coding=utf-8
  2. # Adapted from
  3. # https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
  4. # Copyright (c) Alibaba Cloud.
  5. # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
  6. """Inference-only QWen model compatible with HuggingFace weights."""
  7. from typing import Any, Dict, Iterable, List, Optional, Tuple
  8. import torch
  9. from torch import nn
  10. from transformers import PretrainedConfig
  11. from aphrodite.attention import Attention, AttentionMetadata
  12. from aphrodite.common.config import CacheConfig
  13. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  14. from aphrodite.common.utils import progress_bar
  15. from aphrodite.distributed import get_tensor_model_parallel_world_size
  16. from aphrodite.modeling.layers.activation import SiluAndMul
  17. from aphrodite.modeling.layers.layernorm import RMSNorm
  18. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  19. QKVParallelLinear,
  20. RowParallelLinear)
  21. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  22. from aphrodite.modeling.layers.rotary_embedding import get_rope
  23. from aphrodite.modeling.layers.sampler import Sampler
  24. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  25. ParallelLMHead, VocabParallelEmbedding)
  26. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  27. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  28. from aphrodite.quantization.base_config import QuantizationConfig
  29. class QWenMLP(nn.Module):
  30. def __init__(
  31. self,
  32. hidden_size: int,
  33. intermediate_size: int,
  34. hidden_act: str = "silu",
  35. quant_config: Optional[QuantizationConfig] = None,
  36. ):
  37. super().__init__()
  38. self.gate_up_proj = MergedColumnParallelLinear(
  39. hidden_size, [intermediate_size] * 2,
  40. bias=False,
  41. quant_config=quant_config)
  42. self.c_proj = RowParallelLinear(intermediate_size,
  43. hidden_size,
  44. bias=False,
  45. quant_config=quant_config)
  46. if hidden_act != "silu":
  47. raise ValueError(f"Unsupported activation: {hidden_act}. "
  48. "Only silu is supported for now.")
  49. self.act_fn = SiluAndMul()
  50. def forward(self, x):
  51. gate_up, _ = self.gate_up_proj(x)
  52. x = self.act_fn(gate_up)
  53. x, _ = self.c_proj(x)
  54. return x
  55. class QWenAttention(nn.Module):
  56. def __init__(
  57. self,
  58. hidden_size: int,
  59. num_heads: int,
  60. max_position_embeddings: int,
  61. rope_theta: float = 10000,
  62. rope_scaling: Optional[Dict[str, Any]] = None,
  63. cache_config: Optional[CacheConfig] = None,
  64. quant_config: Optional[QuantizationConfig] = None,
  65. ):
  66. super().__init__()
  67. self.hidden_size = hidden_size
  68. tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
  69. )
  70. self.total_num_heads = num_heads
  71. assert self.total_num_heads % tensor_model_parallel_world_size == 0
  72. self.num_heads = (self.total_num_heads //
  73. tensor_model_parallel_world_size)
  74. self.head_dim = hidden_size // self.total_num_heads
  75. self.c_attn = QKVParallelLinear(
  76. hidden_size,
  77. self.head_dim,
  78. self.total_num_heads,
  79. bias=True,
  80. quant_config=quant_config,
  81. )
  82. self.c_proj = RowParallelLinear(
  83. self.total_num_heads * self.head_dim,
  84. hidden_size,
  85. bias=False,
  86. quant_config=quant_config,
  87. )
  88. self.scaling = self.head_dim**-0.5
  89. self.rotary_emb = get_rope(
  90. self.head_dim,
  91. rotary_dim=self.head_dim,
  92. max_position=max_position_embeddings,
  93. base=rope_theta,
  94. rope_scaling=rope_scaling,
  95. )
  96. self.attn = Attention(self.num_heads,
  97. self.head_dim,
  98. self.scaling,
  99. cache_config=cache_config,
  100. quant_config=quant_config)
  101. def forward(
  102. self,
  103. positions: torch.Tensor,
  104. hidden_states: torch.Tensor,
  105. kv_cache: torch.Tensor,
  106. attn_metadata: AttentionMetadata,
  107. ) -> torch.Tensor:
  108. qkv, _ = self.c_attn(hidden_states)
  109. q, k, v = qkv.chunk(chunks=3, dim=-1)
  110. q, k = self.rotary_emb(positions, q, k)
  111. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  112. output, _ = self.c_proj(attn_output)
  113. return output
  114. class QWenBlock(nn.Module):
  115. def __init__(
  116. self,
  117. config: PretrainedConfig,
  118. cache_config: Optional[CacheConfig] = None,
  119. quant_config: Optional[QuantizationConfig] = None,
  120. ):
  121. super().__init__()
  122. self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  123. rope_theta = getattr(config, "rope_theta", 10000)
  124. rope_scaling = getattr(config, "rope_scaling", None)
  125. self.attn = QWenAttention(config.hidden_size,
  126. config.num_attention_heads,
  127. config.max_position_embeddings,
  128. rope_theta=rope_theta,
  129. rope_scaling=rope_scaling,
  130. cache_config=cache_config,
  131. quant_config=quant_config)
  132. self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  133. self.mlp = QWenMLP(config.hidden_size,
  134. config.intermediate_size // 2,
  135. quant_config=quant_config)
  136. def forward(
  137. self,
  138. positions: torch.Tensor,
  139. hidden_states: torch.Tensor,
  140. kv_cache: torch.Tensor,
  141. attn_metadata: AttentionMetadata,
  142. residual: Optional[torch.Tensor],
  143. ) -> Tuple[torch.Tensor, torch.Tensor]:
  144. # Self Attention
  145. if residual is None:
  146. residual = hidden_states
  147. hidden_states = self.ln_1(hidden_states)
  148. else:
  149. hidden_states, residual = self.ln_1(hidden_states, residual)
  150. hidden_states = self.attn(
  151. positions=positions,
  152. hidden_states=hidden_states,
  153. kv_cache=kv_cache,
  154. attn_metadata=attn_metadata,
  155. )
  156. # Fully Connected
  157. hidden_states, residual = self.ln_2(hidden_states, residual)
  158. hidden_states = self.mlp(hidden_states)
  159. return hidden_states, residual
  160. class QWenModel(nn.Module):
  161. def __init__(
  162. self,
  163. config: PretrainedConfig,
  164. cache_config: Optional[CacheConfig] = None,
  165. quant_config: Optional[QuantizationConfig] = None,
  166. ):
  167. super().__init__()
  168. self.config = config
  169. self.vocab_size = config.vocab_size
  170. self.wte = VocabParallelEmbedding(
  171. config.vocab_size,
  172. config.hidden_size,
  173. )
  174. self.h = nn.ModuleList([
  175. QWenBlock(config, cache_config, quant_config)
  176. for _ in range(config.num_hidden_layers)
  177. ])
  178. self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  179. def forward(
  180. self,
  181. input_ids: torch.Tensor,
  182. positions: torch.Tensor,
  183. kv_caches: List[torch.Tensor],
  184. attn_metadata: AttentionMetadata,
  185. ) -> torch.Tensor:
  186. hidden_states = self.wte(input_ids)
  187. residual = None
  188. for i in range(len(self.h)):
  189. layer = self.h[i]
  190. hidden_states, residual = layer(
  191. positions,
  192. hidden_states,
  193. kv_caches[i],
  194. attn_metadata,
  195. residual,
  196. )
  197. hidden_states, _ = self.ln_f(hidden_states, residual)
  198. return hidden_states
  199. class QWenLMHeadModel(nn.Module):
  200. def __init__(
  201. self,
  202. config: PretrainedConfig,
  203. cache_config: Optional[CacheConfig] = None,
  204. quant_config: Optional[QuantizationConfig] = None,
  205. ):
  206. super().__init__()
  207. self.config = config
  208. self.quant_config = quant_config
  209. self.transformer = QWenModel(config, cache_config, quant_config)
  210. self.lm_head = ParallelLMHead(config.vocab_size,
  211. config.hidden_size,
  212. quant_config=quant_config)
  213. self.logits_processor = LogitsProcessor(config.vocab_size)
  214. self.sampler = Sampler()
  215. def forward(
  216. self,
  217. input_ids: torch.Tensor,
  218. positions: torch.Tensor,
  219. kv_caches: List[torch.Tensor],
  220. attn_metadata: AttentionMetadata,
  221. intermediate_tensors: Optional[IntermediateTensors] = None,
  222. ) -> torch.Tensor:
  223. hidden_states = self.transformer(input_ids, positions, kv_caches,
  224. attn_metadata)
  225. return hidden_states
  226. def compute_logits(
  227. self,
  228. hidden_states: torch.Tensor,
  229. sampling_metadata: SamplingMetadata,
  230. ) -> Optional[torch.Tensor]:
  231. logits = self.logits_processor(self.lm_head, hidden_states,
  232. sampling_metadata)
  233. return logits
  234. def sample(
  235. self,
  236. logits: torch.Tensor,
  237. sampling_metadata: SamplingMetadata,
  238. ) -> Optional[SamplerOutput]:
  239. next_tokens = self.sampler(logits, sampling_metadata)
  240. return next_tokens
  241. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  242. stacked_params_mapping = [
  243. # (param_name, shard_name, shard_id)
  244. ("gate_up_proj", "w2", 0),
  245. ("gate_up_proj", "w1", 1),
  246. ]
  247. params_dict = dict(self.named_parameters())
  248. weights_list = list(weights)
  249. for name, loaded_weight in progress_bar(weights_list,
  250. desc="Loading modules..."):
  251. if "rotary_emb.inv_freq" in name:
  252. continue
  253. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  254. if weight_name not in name:
  255. continue
  256. name = name.replace(weight_name, param_name)
  257. # Skip loading extra bias for GPTQ models.
  258. if name.endswith(".bias") and name not in params_dict:
  259. continue
  260. param = params_dict[name]
  261. weight_loader = param.weight_loader
  262. weight_loader(param, loaded_weight, shard_id)
  263. break
  264. else:
  265. # Skip loading extra bias for GPTQ models.
  266. if name.endswith(".bias") and name not in params_dict:
  267. continue
  268. param = params_dict[name]
  269. weight_loader = getattr(param, "weight_loader",
  270. default_weight_loader)
  271. weight_loader(param, loaded_weight)