deepseek.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only Deepseek model."""
  25. from typing import Any, Dict, Iterable, List, Optional, Tuple
  26. import torch
  27. from torch import nn
  28. from transformers import PretrainedConfig
  29. from aphrodite.attention import Attention, AttentionMetadata
  30. from aphrodite.common.sequence import SamplerOutput
  31. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  32. get_tensor_model_parallel_world_size,
  33. tensor_model_parallel_all_reduce)
  34. from aphrodite.modeling.layers.activation import SiluAndMul
  35. from aphrodite.modeling.layers.fused_moe import fused_moe
  36. from aphrodite.modeling.layers.layernorm import RMSNorm
  37. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  38. MergedColumnParallelLinear,
  39. QKVParallelLinear,
  40. ReplicatedLinear,
  41. RowParallelLinear)
  42. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  43. from aphrodite.modeling.layers.rotary_embedding import get_rope
  44. from aphrodite.modeling.layers.sampler import Sampler
  45. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  46. ParallelLMHead, VocabParallelEmbedding)
  47. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  48. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  49. class DeepseekMLP(nn.Module):
  50. def __init__(
  51. self,
  52. hidden_size: int,
  53. intermediate_size: int,
  54. hidden_act: str,
  55. linear_method: Optional[LinearMethodBase] = None,
  56. reduce_results: bool = True,
  57. ) -> None:
  58. super().__init__()
  59. self.gate_up_proj = MergedColumnParallelLinear(
  60. hidden_size, [intermediate_size] * 2,
  61. bias=False,
  62. linear_method=linear_method)
  63. self.down_proj = RowParallelLinear(intermediate_size,
  64. hidden_size,
  65. bias=False,
  66. linear_method=linear_method,
  67. reduce_results=reduce_results)
  68. if hidden_act != "silu":
  69. raise ValueError(f"Unsupported activation: {hidden_act}. "
  70. "Only silu is supported for now.")
  71. self.act_fn = SiluAndMul()
  72. def forward(self, x):
  73. gate_up, _ = self.gate_up_proj(x)
  74. x = self.act_fn(gate_up)
  75. x, _ = self.down_proj(x)
  76. return x
  77. class DeepseekMoE(nn.Module):
  78. def __init__(
  79. self,
  80. config: PretrainedConfig,
  81. linear_method: Optional[LinearMethodBase] = None,
  82. ):
  83. super().__init__()
  84. self.config = config
  85. self.rank = get_tensor_model_parallel_rank()
  86. self.tp_size = get_tensor_model_parallel_world_size()
  87. self.n_routed_experts = config.n_routed_experts
  88. self.top_k = config.num_experts_per_tok
  89. if self.tp_size > self.n_routed_experts:
  90. raise ValueError(
  91. f"Tensor parallel size {self.tp_size} is greater than "
  92. f"the number of experts {self.n_routed_experts}.")
  93. self.experts = nn.ModuleList([
  94. DeepseekMLP(hidden_size=config.hidden_size,
  95. intermediate_size=config.moe_intermediate_size,
  96. hidden_act=config.hidden_act,
  97. linear_method=linear_method,
  98. reduce_results=False)
  99. for idx in range(self.n_routed_experts)
  100. ])
  101. self.pack_params()
  102. self.gate = ReplicatedLinear(config.hidden_size,
  103. self.n_routed_experts,
  104. bias=False,
  105. linear_method=None)
  106. if config.n_shared_experts is not None:
  107. intermediate_size = (config.moe_intermediate_size *
  108. config.n_shared_experts)
  109. self.shared_experts = DeepseekMLP(
  110. hidden_size=config.hidden_size,
  111. intermediate_size=intermediate_size,
  112. hidden_act=config.hidden_act,
  113. linear_method=linear_method,
  114. reduce_results=False,
  115. )
  116. def pack_params(self):
  117. w1 = []
  118. w2 = []
  119. for expert in self.experts:
  120. w1.append(expert.gate_up_proj.weight)
  121. w2.append(expert.down_proj.weight)
  122. self.w1 = torch._utils._flatten_dense_tensors(w1)
  123. w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
  124. for data, param in zip(w1s, w1):
  125. param.data = data
  126. self.w1 = self.w1.view(len(w1), *w1s[0].shape)
  127. self.w2 = torch._utils._flatten_dense_tensors(w2)
  128. w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
  129. for data, param in zip(w2s, w2):
  130. param.data = data
  131. self.w2 = self.w2.view(len(w2), *w2s[0].shape)
  132. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  133. num_tokens, hidden_dim = hidden_states.shape
  134. hidden_states = hidden_states.view(-1, hidden_dim)
  135. if self.config.n_shared_experts is not None:
  136. shared_output = self.shared_experts(hidden_states)
  137. # router_logits: (num_tokens, n_experts)
  138. router_logits, _ = self.gate(hidden_states)
  139. final_hidden_states = fused_moe(hidden_states,
  140. self.w1,
  141. self.w2,
  142. router_logits,
  143. self.top_k,
  144. renormalize=self.config.norm_topk_prob,
  145. inplace=True)
  146. if self.config.n_shared_experts is not None:
  147. final_hidden_states = final_hidden_states + shared_output
  148. final_hidden_states = tensor_model_parallel_all_reduce(
  149. final_hidden_states)
  150. return final_hidden_states.view(num_tokens, hidden_dim)
  151. class DeepseekAttention(nn.Module):
  152. def __init__(
  153. self,
  154. hidden_size: int,
  155. num_heads: int,
  156. num_kv_heads: int,
  157. rope_theta: float = 10000,
  158. rope_scaling: Optional[Dict[str, Any]] = None,
  159. max_position_embeddings: int = 8192,
  160. linear_method: Optional[LinearMethodBase] = None,
  161. ) -> None:
  162. super().__init__()
  163. self.hidden_size = hidden_size
  164. tp_size = get_tensor_model_parallel_world_size()
  165. self.total_num_heads = num_heads
  166. assert self.total_num_heads % tp_size == 0
  167. self.num_heads = self.total_num_heads // tp_size
  168. self.total_num_kv_heads = num_kv_heads
  169. if self.total_num_kv_heads >= tp_size:
  170. # Number of KV heads is greater than TP size, so we partition
  171. # the KV heads across multiple tensor parallel GPUs.
  172. assert self.total_num_kv_heads % tp_size == 0
  173. else:
  174. # Number of KV heads is less than TP size, so we replicate
  175. # the KV heads across multiple tensor parallel GPUs.
  176. assert tp_size % self.total_num_kv_heads == 0
  177. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  178. self.head_dim = hidden_size // self.total_num_heads
  179. self.q_size = self.num_heads * self.head_dim
  180. self.kv_size = self.num_kv_heads * self.head_dim
  181. self.scaling = self.head_dim**-0.5
  182. self.rope_theta = rope_theta
  183. self.max_position_embeddings = max_position_embeddings
  184. self.qkv_proj = QKVParallelLinear(
  185. hidden_size,
  186. self.head_dim,
  187. self.total_num_heads,
  188. self.total_num_kv_heads,
  189. bias=False,
  190. linear_method=linear_method,
  191. )
  192. self.o_proj = RowParallelLinear(
  193. self.total_num_heads * self.head_dim,
  194. hidden_size,
  195. bias=False,
  196. linear_method=linear_method,
  197. )
  198. self.rotary_emb = get_rope(
  199. self.head_dim,
  200. rotary_dim=self.head_dim,
  201. max_position=max_position_embeddings,
  202. base=rope_theta,
  203. rope_scaling=rope_scaling,
  204. )
  205. self.attn = Attention(self.num_heads,
  206. self.head_dim,
  207. self.scaling,
  208. num_kv_heads=self.num_kv_heads)
  209. def forward(
  210. self,
  211. positions: torch.Tensor,
  212. hidden_states: torch.Tensor,
  213. kv_cache: torch.Tensor,
  214. attn_metadata: AttentionMetadata,
  215. ) -> torch.Tensor:
  216. qkv, _ = self.qkv_proj(hidden_states)
  217. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  218. q, k = self.rotary_emb(positions, q, k)
  219. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  220. output, _ = self.o_proj(attn_output)
  221. return output
  222. class DeepseekDecoderLayer(nn.Module):
  223. def __init__(
  224. self,
  225. config: PretrainedConfig,
  226. layer_idx: int,
  227. linear_method: Optional[LinearMethodBase] = None,
  228. ) -> None:
  229. super().__init__()
  230. self.hidden_size = config.hidden_size
  231. rope_theta = getattr(config, "rope_theta", 10000)
  232. rope_scaling = getattr(config, "rope_scaling", None)
  233. max_position_embeddings = getattr(config, "max_position_embeddings",
  234. 8192)
  235. self.self_attn = DeepseekAttention(
  236. hidden_size=self.hidden_size,
  237. num_heads=config.num_attention_heads,
  238. num_kv_heads=config.num_key_value_heads,
  239. rope_theta=rope_theta,
  240. rope_scaling=rope_scaling,
  241. max_position_embeddings=max_position_embeddings,
  242. linear_method=linear_method,
  243. )
  244. if (config.n_routed_experts is not None
  245. and layer_idx >= config.first_k_dense_replace
  246. and layer_idx % config.moe_layer_freq == 0):
  247. self.mlp = DeepseekMoE(config=config, linear_method=linear_method)
  248. else:
  249. self.mlp = DeepseekMLP(
  250. hidden_size=config.hidden_size,
  251. intermediate_size=config.intermediate_size,
  252. hidden_act=config.hidden_act,
  253. linear_method=linear_method,
  254. )
  255. self.input_layernorm = RMSNorm(config.hidden_size,
  256. eps=config.rms_norm_eps)
  257. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  258. eps=config.rms_norm_eps)
  259. def forward(
  260. self,
  261. positions: torch.Tensor,
  262. hidden_states: torch.Tensor,
  263. kv_cache: torch.Tensor,
  264. attn_metadata: AttentionMetadata,
  265. residual: Optional[torch.Tensor],
  266. ) -> torch.Tensor:
  267. # Self Attention
  268. if residual is None:
  269. residual = hidden_states
  270. hidden_states = self.input_layernorm(hidden_states)
  271. else:
  272. hidden_states, residual = self.input_layernorm(
  273. hidden_states, residual)
  274. hidden_states = self.self_attn(
  275. positions=positions,
  276. hidden_states=hidden_states,
  277. kv_cache=kv_cache,
  278. attn_metadata=attn_metadata,
  279. )
  280. # Fully Connected
  281. hidden_states, residual = self.post_attention_layernorm(
  282. hidden_states, residual)
  283. hidden_states = self.mlp(hidden_states)
  284. return hidden_states, residual
  285. class DeepseekModel(nn.Module):
  286. fall_back_to_pt_during_load = False
  287. def __init__(
  288. self,
  289. config: PretrainedConfig,
  290. linear_method: Optional[LinearMethodBase] = None,
  291. ) -> None:
  292. super().__init__()
  293. self.padding_idx = config.pad_token_id
  294. self.vocab_size = config.vocab_size
  295. self.embed_tokens = VocabParallelEmbedding(
  296. config.vocab_size,
  297. config.hidden_size,
  298. )
  299. self.layers = nn.ModuleList([
  300. DeepseekDecoderLayer(config,
  301. layer_idx,
  302. linear_method=linear_method)
  303. for layer_idx in range(config.num_hidden_layers)
  304. ])
  305. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  306. def forward(
  307. self,
  308. input_ids: torch.Tensor,
  309. positions: torch.Tensor,
  310. kv_caches: List[torch.Tensor],
  311. attn_metadata: AttentionMetadata,
  312. ) -> torch.Tensor:
  313. hidden_states = self.embed_tokens(input_ids)
  314. residual = None
  315. for i in range(len(self.layers)):
  316. layer = self.layers[i]
  317. hidden_states, residual = layer(positions, hidden_states,
  318. kv_caches[i], attn_metadata,
  319. residual)
  320. hidden_states, _ = self.norm(hidden_states, residual)
  321. return hidden_states
  322. class DeepseekForCausalLM(nn.Module):
  323. def __init__(
  324. self,
  325. config: PretrainedConfig,
  326. linear_method: Optional[LinearMethodBase] = None,
  327. ) -> None:
  328. super().__init__()
  329. self.config = config
  330. self.linear_method = linear_method
  331. self.model = DeepseekModel(config, linear_method)
  332. self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
  333. self.logits_processor = LogitsProcessor(config.vocab_size)
  334. self.sampler = Sampler()
  335. def forward(
  336. self,
  337. input_ids: torch.Tensor,
  338. positions: torch.Tensor,
  339. kv_caches: List[torch.Tensor],
  340. attn_metadata: AttentionMetadata,
  341. ) -> torch.Tensor:
  342. hidden_states = self.model(input_ids, positions, kv_caches,
  343. attn_metadata)
  344. return hidden_states
  345. def compute_logits(self, hidden_states: torch.Tensor,
  346. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  347. logits = self.logits_processor(self.lm_head.weight, hidden_states,
  348. sampling_metadata)
  349. return logits
  350. def sample(
  351. self,
  352. logits: Optional[torch.Tensor],
  353. sampling_metadata: SamplingMetadata,
  354. ) -> Optional[SamplerOutput]:
  355. next_tokens = self.sampler(logits, sampling_metadata)
  356. return next_tokens
  357. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  358. stacked_params_mapping = [
  359. # (param_name, shard_name, shard_id)
  360. ("qkv_proj", "q_proj", "q"),
  361. ("qkv_proj", "k_proj", "k"),
  362. ("qkv_proj", "v_proj", "v"),
  363. ("gate_up_proj", "gate_proj", 0),
  364. ("gate_up_proj", "up_proj", 1),
  365. ]
  366. params_dict = dict(self.named_parameters())
  367. for name, loaded_weight in weights:
  368. if "rotary_emb.inv_freq" in name:
  369. continue
  370. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  371. if weight_name not in name:
  372. continue
  373. name = name.replace(weight_name, param_name)
  374. # Skip loading extra bias for GPTQ models.
  375. if name.endswith(".bias") and name not in params_dict:
  376. continue
  377. # Skip experts that are not assigned to this worker.
  378. if (("mlp.experts." in name or "mlp.shared_experts." in name)
  379. and name not in params_dict):
  380. continue
  381. param = params_dict[name]
  382. weight_loader = param.weight_loader
  383. weight_loader(param, loaded_weight, shard_id)
  384. break
  385. else:
  386. # Skip loading extra bias for GPTQ models.
  387. if name.endswith(".bias") and name not in params_dict:
  388. continue
  389. # Skip experts that are not assigned to this worker.
  390. if (("mlp.experts." in name or "mlp.shared_experts." in name)
  391. and name not in params_dict):
  392. continue
  393. param = params_dict[name]
  394. weight_loader = getattr(param, "weight_loader",
  395. default_weight_loader)
  396. weight_loader(param, loaded_weight)