deepseek.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only Deepseek model."""
  24. from typing import Any, Dict, Iterable, List, Optional, Tuple
  25. import torch
  26. from torch import nn
  27. from transformers import PretrainedConfig
  28. from aphrodite.attention import Attention, AttentionMetadata
  29. from aphrodite.common.config import CacheConfig
  30. from aphrodite.common.sequence import IntermediateTensors
  31. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  32. get_tensor_model_parallel_world_size,
  33. tensor_model_parallel_all_reduce)
  34. from aphrodite.modeling.layers.activation import SiluAndMul
  35. from aphrodite.modeling.layers.fused_moe import fused_moe
  36. from aphrodite.modeling.layers.layernorm import RMSNorm
  37. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  38. QKVParallelLinear,
  39. ReplicatedLinear,
  40. RowParallelLinear)
  41. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  42. from aphrodite.modeling.layers.rotary_embedding import get_rope
  43. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  44. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  45. ParallelLMHead, VocabParallelEmbedding)
  46. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  47. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  48. from aphrodite.quantization.base_config import QuantizationConfig
  49. class DeepseekMLP(nn.Module):
  50. def __init__(
  51. self,
  52. hidden_size: int,
  53. intermediate_size: int,
  54. hidden_act: str,
  55. quant_config: Optional[QuantizationConfig] = None,
  56. reduce_results: bool = True,
  57. ) -> None:
  58. super().__init__()
  59. self.gate_up_proj = MergedColumnParallelLinear(
  60. hidden_size, [intermediate_size] * 2,
  61. bias=False,
  62. quant_config=quant_config)
  63. self.down_proj = RowParallelLinear(intermediate_size,
  64. hidden_size,
  65. bias=False,
  66. quant_config=quant_config,
  67. reduce_results=reduce_results)
  68. if hidden_act != "silu":
  69. raise ValueError(f"Unsupported activation: {hidden_act}. "
  70. "Only silu is supported for now.")
  71. self.act_fn = SiluAndMul()
  72. def forward(self, x):
  73. gate_up, _ = self.gate_up_proj(x)
  74. x = self.act_fn(gate_up)
  75. x, _ = self.down_proj(x)
  76. return x
  77. class DeepseekMoE(nn.Module):
  78. def __init__(
  79. self,
  80. config: PretrainedConfig,
  81. quant_config: Optional[QuantizationConfig] = None,
  82. ):
  83. super().__init__()
  84. self.config = config
  85. self.rank = get_tensor_model_parallel_rank()
  86. self.tp_size = get_tensor_model_parallel_world_size()
  87. self.n_routed_experts = config.n_routed_experts
  88. self.top_k = config.num_experts_per_tok
  89. if self.tp_size > self.n_routed_experts:
  90. raise ValueError(
  91. f"Tensor parallel size {self.tp_size} is greater than "
  92. f"the number of experts {self.n_routed_experts}.")
  93. self.experts = nn.ModuleList([
  94. DeepseekMLP(hidden_size=config.hidden_size,
  95. intermediate_size=config.moe_intermediate_size,
  96. hidden_act=config.hidden_act,
  97. quant_config=quant_config,
  98. reduce_results=False)
  99. for idx in range(self.n_routed_experts)
  100. ])
  101. self.pack_params()
  102. self.gate = ReplicatedLinear(config.hidden_size,
  103. self.n_routed_experts,
  104. bias=False,
  105. quant_config=None)
  106. if config.n_shared_experts is not None:
  107. intermediate_size = (config.moe_intermediate_size *
  108. config.n_shared_experts)
  109. self.shared_experts = DeepseekMLP(
  110. hidden_size=config.hidden_size,
  111. intermediate_size=intermediate_size,
  112. hidden_act=config.hidden_act,
  113. quant_config=quant_config,
  114. reduce_results=False,
  115. )
  116. def pack_params(self):
  117. w1 = []
  118. w2 = []
  119. for expert in self.experts:
  120. w1.append(expert.gate_up_proj.weight)
  121. w2.append(expert.down_proj.weight)
  122. self.w1 = torch._utils._flatten_dense_tensors(w1)
  123. w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
  124. for data, param in zip(w1s, w1):
  125. param.data = data
  126. self.w1 = self.w1.view(len(w1), *w1s[0].shape)
  127. self.w2 = torch._utils._flatten_dense_tensors(w2)
  128. w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
  129. for data, param in zip(w2s, w2):
  130. param.data = data
  131. self.w2 = self.w2.view(len(w2), *w2s[0].shape)
  132. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  133. num_tokens, hidden_dim = hidden_states.shape
  134. hidden_states = hidden_states.view(-1, hidden_dim)
  135. if self.config.n_shared_experts is not None:
  136. shared_output = self.shared_experts(hidden_states)
  137. # router_logits: (num_tokens, n_experts)
  138. router_logits, _ = self.gate(hidden_states)
  139. final_hidden_states = fused_moe(hidden_states,
  140. self.w1,
  141. self.w2,
  142. router_logits,
  143. self.top_k,
  144. renormalize=self.config.norm_topk_prob,
  145. inplace=True)
  146. if self.config.n_shared_experts is not None:
  147. final_hidden_states = final_hidden_states + shared_output
  148. final_hidden_states = tensor_model_parallel_all_reduce(
  149. final_hidden_states)
  150. return final_hidden_states.view(num_tokens, hidden_dim)
  151. class DeepseekAttention(nn.Module):
  152. def __init__(
  153. self,
  154. hidden_size: int,
  155. num_heads: int,
  156. num_kv_heads: int,
  157. rope_theta: float = 10000,
  158. rope_scaling: Optional[Dict[str, Any]] = None,
  159. max_position_embeddings: int = 8192,
  160. cache_config: Optional[CacheConfig] = None,
  161. quant_config: Optional[QuantizationConfig] = None,
  162. ) -> None:
  163. super().__init__()
  164. self.hidden_size = hidden_size
  165. tp_size = get_tensor_model_parallel_world_size()
  166. self.total_num_heads = num_heads
  167. assert self.total_num_heads % tp_size == 0
  168. self.num_heads = self.total_num_heads // tp_size
  169. self.total_num_kv_heads = num_kv_heads
  170. if self.total_num_kv_heads >= tp_size:
  171. # Number of KV heads is greater than TP size, so we partition
  172. # the KV heads across multiple tensor parallel GPUs.
  173. assert self.total_num_kv_heads % tp_size == 0
  174. else:
  175. # Number of KV heads is less than TP size, so we replicate
  176. # the KV heads across multiple tensor parallel GPUs.
  177. assert tp_size % self.total_num_kv_heads == 0
  178. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  179. self.head_dim = hidden_size // self.total_num_heads
  180. self.q_size = self.num_heads * self.head_dim
  181. self.kv_size = self.num_kv_heads * self.head_dim
  182. self.scaling = self.head_dim**-0.5
  183. self.rope_theta = rope_theta
  184. self.max_position_embeddings = max_position_embeddings
  185. self.qkv_proj = QKVParallelLinear(
  186. hidden_size,
  187. self.head_dim,
  188. self.total_num_heads,
  189. self.total_num_kv_heads,
  190. bias=False,
  191. quant_config=quant_config,
  192. )
  193. self.o_proj = RowParallelLinear(
  194. self.total_num_heads * self.head_dim,
  195. hidden_size,
  196. bias=False,
  197. quant_config=quant_config,
  198. )
  199. self.rotary_emb = get_rope(
  200. self.head_dim,
  201. rotary_dim=self.head_dim,
  202. max_position=max_position_embeddings,
  203. base=rope_theta,
  204. rope_scaling=rope_scaling,
  205. )
  206. self.attn = Attention(self.num_heads,
  207. self.head_dim,
  208. self.scaling,
  209. num_kv_heads=self.num_kv_heads,
  210. cache_config=cache_config,
  211. quant_config=quant_config)
  212. def forward(
  213. self,
  214. positions: torch.Tensor,
  215. hidden_states: torch.Tensor,
  216. kv_cache: torch.Tensor,
  217. attn_metadata: AttentionMetadata,
  218. ) -> torch.Tensor:
  219. qkv, _ = self.qkv_proj(hidden_states)
  220. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  221. q, k = self.rotary_emb(positions, q, k)
  222. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  223. output, _ = self.o_proj(attn_output)
  224. return output
  225. class DeepseekDecoderLayer(nn.Module):
  226. def __init__(
  227. self,
  228. config: PretrainedConfig,
  229. layer_idx: int,
  230. cache_config: Optional[CacheConfig] = None,
  231. quant_config: Optional[QuantizationConfig] = None,
  232. ) -> None:
  233. super().__init__()
  234. self.hidden_size = config.hidden_size
  235. rope_theta = getattr(config, "rope_theta", 10000)
  236. rope_scaling = getattr(config, "rope_scaling", None)
  237. max_position_embeddings = getattr(config, "max_position_embeddings",
  238. 8192)
  239. self.self_attn = DeepseekAttention(
  240. hidden_size=self.hidden_size,
  241. num_heads=config.num_attention_heads,
  242. num_kv_heads=config.num_key_value_heads,
  243. rope_theta=rope_theta,
  244. rope_scaling=rope_scaling,
  245. max_position_embeddings=max_position_embeddings,
  246. cache_config=cache_config,
  247. quant_config=quant_config,
  248. )
  249. if (config.n_routed_experts is not None
  250. and layer_idx >= config.first_k_dense_replace
  251. and layer_idx % config.moe_layer_freq == 0):
  252. self.mlp = DeepseekMoE(config=config, quant_config=quant_config)
  253. else:
  254. self.mlp = DeepseekMLP(
  255. hidden_size=config.hidden_size,
  256. intermediate_size=config.intermediate_size,
  257. hidden_act=config.hidden_act,
  258. quant_config=quant_config,
  259. )
  260. self.input_layernorm = RMSNorm(config.hidden_size,
  261. eps=config.rms_norm_eps)
  262. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  263. eps=config.rms_norm_eps)
  264. def forward(
  265. self,
  266. positions: torch.Tensor,
  267. hidden_states: torch.Tensor,
  268. kv_cache: torch.Tensor,
  269. attn_metadata: AttentionMetadata,
  270. residual: Optional[torch.Tensor],
  271. ) -> torch.Tensor:
  272. # Self Attention
  273. if residual is None:
  274. residual = hidden_states
  275. hidden_states = self.input_layernorm(hidden_states)
  276. else:
  277. hidden_states, residual = self.input_layernorm(
  278. hidden_states, residual)
  279. hidden_states = self.self_attn(
  280. positions=positions,
  281. hidden_states=hidden_states,
  282. kv_cache=kv_cache,
  283. attn_metadata=attn_metadata,
  284. )
  285. # Fully Connected
  286. hidden_states, residual = self.post_attention_layernorm(
  287. hidden_states, residual)
  288. hidden_states = self.mlp(hidden_states)
  289. return hidden_states, residual
  290. class DeepseekModel(nn.Module):
  291. fall_back_to_pt_during_load = False
  292. def __init__(
  293. self,
  294. config: PretrainedConfig,
  295. cache_config: Optional[CacheConfig] = None,
  296. quant_config: Optional[QuantizationConfig] = None,
  297. ) -> None:
  298. super().__init__()
  299. self.padding_idx = config.pad_token_id
  300. self.vocab_size = config.vocab_size
  301. self.embed_tokens = VocabParallelEmbedding(
  302. config.vocab_size,
  303. config.hidden_size,
  304. )
  305. self.layers = nn.ModuleList([
  306. DeepseekDecoderLayer(config,
  307. layer_idx,
  308. cache_config,
  309. quant_config=quant_config)
  310. for layer_idx in range(config.num_hidden_layers)
  311. ])
  312. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  313. def forward(
  314. self,
  315. input_ids: torch.Tensor,
  316. positions: torch.Tensor,
  317. kv_caches: List[torch.Tensor],
  318. attn_metadata: AttentionMetadata,
  319. ) -> torch.Tensor:
  320. hidden_states = self.embed_tokens(input_ids)
  321. residual = None
  322. for i in range(len(self.layers)):
  323. layer = self.layers[i]
  324. hidden_states, residual = layer(positions, hidden_states,
  325. kv_caches[i], attn_metadata,
  326. residual)
  327. hidden_states, _ = self.norm(hidden_states, residual)
  328. return hidden_states
  329. class DeepseekForCausalLM(nn.Module):
  330. def __init__(
  331. self,
  332. config: PretrainedConfig,
  333. cache_config: Optional[CacheConfig] = None,
  334. quant_config: Optional[QuantizationConfig] = None,
  335. ) -> None:
  336. super().__init__()
  337. self.config = config
  338. self.quant_config = quant_config
  339. self.model = DeepseekModel(config, cache_config, quant_config)
  340. self.lm_head = ParallelLMHead(config.vocab_size,
  341. config.hidden_size,
  342. quant_config=quant_config)
  343. self.logits_processor = LogitsProcessor(config.vocab_size)
  344. self.sampler = Sampler()
  345. def forward(
  346. self,
  347. input_ids: torch.Tensor,
  348. positions: torch.Tensor,
  349. kv_caches: List[torch.Tensor],
  350. attn_metadata: AttentionMetadata,
  351. intermediate_tensors: Optional[IntermediateTensors] = None,
  352. ) -> torch.Tensor:
  353. hidden_states = self.model(input_ids, positions, kv_caches,
  354. attn_metadata)
  355. return hidden_states
  356. def compute_logits(
  357. self,
  358. hidden_states: torch.Tensor,
  359. sampling_metadata: SamplingMetadata,
  360. ) -> Optional[torch.Tensor]:
  361. logits = self.logits_processor(self.lm_head, hidden_states,
  362. sampling_metadata)
  363. return logits
  364. def sample(
  365. self,
  366. logits: Optional[torch.Tensor],
  367. sampling_metadata: SamplingMetadata,
  368. ) -> Optional[SamplerOutput]:
  369. next_tokens = self.sampler(logits, sampling_metadata)
  370. return next_tokens
  371. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  372. stacked_params_mapping = [
  373. # (param_name, shard_name, shard_id)
  374. ("qkv_proj", "q_proj", "q"),
  375. ("qkv_proj", "k_proj", "k"),
  376. ("qkv_proj", "v_proj", "v"),
  377. ("gate_up_proj", "gate_proj", 0),
  378. ("gate_up_proj", "up_proj", 1),
  379. ]
  380. params_dict = dict(self.named_parameters())
  381. for name, loaded_weight in weights:
  382. if "rotary_emb.inv_freq" in name:
  383. continue
  384. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  385. if weight_name not in name:
  386. continue
  387. name = name.replace(weight_name, param_name)
  388. # Skip loading extra bias for GPTQ models.
  389. if name.endswith(".bias") and name not in params_dict:
  390. continue
  391. # Skip experts that are not assigned to this worker.
  392. if (("mlp.experts." in name or "mlp.shared_experts." in name)
  393. and name not in params_dict):
  394. continue
  395. param = params_dict[name]
  396. weight_loader = param.weight_loader
  397. weight_loader(param, loaded_weight, shard_id)
  398. break
  399. else:
  400. # Skip loading extra bias for GPTQ models.
  401. if name.endswith(".bias") and name not in params_dict:
  402. continue
  403. # Skip experts that are not assigned to this worker.
  404. if (("mlp.experts." in name or "mlp.shared_experts." in name)
  405. and name not in params_dict):
  406. continue
  407. param = params_dict[name]
  408. weight_loader = getattr(param, "weight_loader",
  409. default_weight_loader)
  410. weight_loader(param, loaded_weight)