deepseek.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only Deepseek model."""
  24. from typing import Any, Dict, Iterable, List, Optional, Tuple
  25. import torch
  26. from torch import nn
  27. from transformers import PretrainedConfig
  28. from aphrodite.attention import Attention, AttentionMetadata
  29. from aphrodite.common.config import CacheConfig
  30. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  31. from aphrodite.common.utils import progress_bar
  32. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  33. get_tensor_model_parallel_world_size,
  34. tensor_model_parallel_all_reduce)
  35. from aphrodite.modeling.layers.activation import SiluAndMul
  36. from aphrodite.modeling.layers.fused_moe import fused_moe
  37. from aphrodite.modeling.layers.layernorm import RMSNorm
  38. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  39. QKVParallelLinear,
  40. ReplicatedLinear,
  41. RowParallelLinear)
  42. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  43. from aphrodite.modeling.layers.rotary_embedding import get_rope
  44. from aphrodite.modeling.layers.sampler import Sampler
  45. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  46. ParallelLMHead, VocabParallelEmbedding)
  47. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  48. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  49. from aphrodite.quantization.base_config import QuantizationConfig
  50. class DeepseekMLP(nn.Module):
  51. def __init__(
  52. self,
  53. hidden_size: int,
  54. intermediate_size: int,
  55. hidden_act: str,
  56. quant_config: Optional[QuantizationConfig] = None,
  57. reduce_results: bool = True,
  58. ) -> None:
  59. super().__init__()
  60. self.gate_up_proj = MergedColumnParallelLinear(
  61. hidden_size, [intermediate_size] * 2,
  62. bias=False,
  63. quant_config=quant_config)
  64. self.down_proj = RowParallelLinear(intermediate_size,
  65. hidden_size,
  66. bias=False,
  67. quant_config=quant_config,
  68. reduce_results=reduce_results)
  69. if hidden_act != "silu":
  70. raise ValueError(f"Unsupported activation: {hidden_act}. "
  71. "Only silu is supported for now.")
  72. self.act_fn = SiluAndMul()
  73. def forward(self, x):
  74. gate_up, _ = self.gate_up_proj(x)
  75. x = self.act_fn(gate_up)
  76. x, _ = self.down_proj(x)
  77. return x
  78. class DeepseekMoE(nn.Module):
  79. def __init__(
  80. self,
  81. config: PretrainedConfig,
  82. quant_config: Optional[QuantizationConfig] = None,
  83. ):
  84. super().__init__()
  85. self.config = config
  86. self.rank = get_tensor_model_parallel_rank()
  87. self.tp_size = get_tensor_model_parallel_world_size()
  88. self.n_routed_experts = config.n_routed_experts
  89. self.top_k = config.num_experts_per_tok
  90. if self.tp_size > self.n_routed_experts:
  91. raise ValueError(
  92. f"Tensor parallel size {self.tp_size} is greater than "
  93. f"the number of experts {self.n_routed_experts}.")
  94. self.experts = nn.ModuleList([
  95. DeepseekMLP(hidden_size=config.hidden_size,
  96. intermediate_size=config.moe_intermediate_size,
  97. hidden_act=config.hidden_act,
  98. quant_config=quant_config,
  99. reduce_results=False)
  100. for idx in range(self.n_routed_experts)
  101. ])
  102. self.pack_params()
  103. self.gate = ReplicatedLinear(config.hidden_size,
  104. self.n_routed_experts,
  105. bias=False,
  106. quant_config=None)
  107. if config.n_shared_experts is not None:
  108. intermediate_size = (config.moe_intermediate_size *
  109. config.n_shared_experts)
  110. self.shared_experts = DeepseekMLP(
  111. hidden_size=config.hidden_size,
  112. intermediate_size=intermediate_size,
  113. hidden_act=config.hidden_act,
  114. quant_config=quant_config,
  115. reduce_results=False,
  116. )
  117. def pack_params(self):
  118. w1 = []
  119. w2 = []
  120. for expert in self.experts:
  121. w1.append(expert.gate_up_proj.weight)
  122. w2.append(expert.down_proj.weight)
  123. self.w1 = torch._utils._flatten_dense_tensors(w1)
  124. w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
  125. for data, param in zip(w1s, w1):
  126. param.data = data
  127. self.w1 = self.w1.view(len(w1), *w1s[0].shape)
  128. self.w2 = torch._utils._flatten_dense_tensors(w2)
  129. w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
  130. for data, param in zip(w2s, w2):
  131. param.data = data
  132. self.w2 = self.w2.view(len(w2), *w2s[0].shape)
  133. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  134. num_tokens, hidden_dim = hidden_states.shape
  135. hidden_states = hidden_states.view(-1, hidden_dim)
  136. if self.config.n_shared_experts is not None:
  137. shared_output = self.shared_experts(hidden_states)
  138. # router_logits: (num_tokens, n_experts)
  139. router_logits, _ = self.gate(hidden_states)
  140. final_hidden_states = fused_moe(hidden_states,
  141. self.w1,
  142. self.w2,
  143. router_logits,
  144. self.top_k,
  145. renormalize=self.config.norm_topk_prob,
  146. inplace=True)
  147. if self.config.n_shared_experts is not None:
  148. final_hidden_states = final_hidden_states + shared_output
  149. final_hidden_states = tensor_model_parallel_all_reduce(
  150. final_hidden_states)
  151. return final_hidden_states.view(num_tokens, hidden_dim)
  152. class DeepseekAttention(nn.Module):
  153. def __init__(
  154. self,
  155. hidden_size: int,
  156. num_heads: int,
  157. num_kv_heads: int,
  158. rope_theta: float = 10000,
  159. rope_scaling: Optional[Dict[str, Any]] = None,
  160. max_position_embeddings: int = 8192,
  161. cache_config: Optional[CacheConfig] = None,
  162. quant_config: Optional[QuantizationConfig] = None,
  163. ) -> None:
  164. super().__init__()
  165. self.hidden_size = hidden_size
  166. tp_size = get_tensor_model_parallel_world_size()
  167. self.total_num_heads = num_heads
  168. assert self.total_num_heads % tp_size == 0
  169. self.num_heads = self.total_num_heads // tp_size
  170. self.total_num_kv_heads = num_kv_heads
  171. if self.total_num_kv_heads >= tp_size:
  172. # Number of KV heads is greater than TP size, so we partition
  173. # the KV heads across multiple tensor parallel GPUs.
  174. assert self.total_num_kv_heads % tp_size == 0
  175. else:
  176. # Number of KV heads is less than TP size, so we replicate
  177. # the KV heads across multiple tensor parallel GPUs.
  178. assert tp_size % self.total_num_kv_heads == 0
  179. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  180. self.head_dim = hidden_size // self.total_num_heads
  181. self.q_size = self.num_heads * self.head_dim
  182. self.kv_size = self.num_kv_heads * self.head_dim
  183. self.scaling = self.head_dim**-0.5
  184. self.rope_theta = rope_theta
  185. self.max_position_embeddings = max_position_embeddings
  186. self.qkv_proj = QKVParallelLinear(
  187. hidden_size,
  188. self.head_dim,
  189. self.total_num_heads,
  190. self.total_num_kv_heads,
  191. bias=False,
  192. quant_config=quant_config,
  193. )
  194. self.o_proj = RowParallelLinear(
  195. self.total_num_heads * self.head_dim,
  196. hidden_size,
  197. bias=False,
  198. quant_config=quant_config,
  199. )
  200. self.rotary_emb = get_rope(
  201. self.head_dim,
  202. rotary_dim=self.head_dim,
  203. max_position=max_position_embeddings,
  204. base=rope_theta,
  205. rope_scaling=rope_scaling,
  206. )
  207. self.attn = Attention(self.num_heads,
  208. self.head_dim,
  209. self.scaling,
  210. num_kv_heads=self.num_kv_heads,
  211. cache_config=cache_config,
  212. quant_config=quant_config)
  213. def forward(
  214. self,
  215. positions: torch.Tensor,
  216. hidden_states: torch.Tensor,
  217. kv_cache: torch.Tensor,
  218. attn_metadata: AttentionMetadata,
  219. ) -> torch.Tensor:
  220. qkv, _ = self.qkv_proj(hidden_states)
  221. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  222. q, k = self.rotary_emb(positions, q, k)
  223. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  224. output, _ = self.o_proj(attn_output)
  225. return output
  226. class DeepseekDecoderLayer(nn.Module):
  227. def __init__(
  228. self,
  229. config: PretrainedConfig,
  230. layer_idx: int,
  231. cache_config: Optional[CacheConfig] = None,
  232. quant_config: Optional[QuantizationConfig] = None,
  233. ) -> None:
  234. super().__init__()
  235. self.hidden_size = config.hidden_size
  236. rope_theta = getattr(config, "rope_theta", 10000)
  237. rope_scaling = getattr(config, "rope_scaling", None)
  238. max_position_embeddings = getattr(config, "max_position_embeddings",
  239. 8192)
  240. self.self_attn = DeepseekAttention(
  241. hidden_size=self.hidden_size,
  242. num_heads=config.num_attention_heads,
  243. num_kv_heads=config.num_key_value_heads,
  244. rope_theta=rope_theta,
  245. rope_scaling=rope_scaling,
  246. max_position_embeddings=max_position_embeddings,
  247. cache_config=cache_config,
  248. quant_config=quant_config,
  249. )
  250. if (config.n_routed_experts is not None
  251. and layer_idx >= config.first_k_dense_replace
  252. and layer_idx % config.moe_layer_freq == 0):
  253. self.mlp = DeepseekMoE(config=config, quant_config=quant_config)
  254. else:
  255. self.mlp = DeepseekMLP(
  256. hidden_size=config.hidden_size,
  257. intermediate_size=config.intermediate_size,
  258. hidden_act=config.hidden_act,
  259. quant_config=quant_config,
  260. )
  261. self.input_layernorm = RMSNorm(config.hidden_size,
  262. eps=config.rms_norm_eps)
  263. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  264. eps=config.rms_norm_eps)
  265. def forward(
  266. self,
  267. positions: torch.Tensor,
  268. hidden_states: torch.Tensor,
  269. kv_cache: torch.Tensor,
  270. attn_metadata: AttentionMetadata,
  271. residual: Optional[torch.Tensor],
  272. ) -> torch.Tensor:
  273. # Self Attention
  274. if residual is None:
  275. residual = hidden_states
  276. hidden_states = self.input_layernorm(hidden_states)
  277. else:
  278. hidden_states, residual = self.input_layernorm(
  279. hidden_states, residual)
  280. hidden_states = self.self_attn(
  281. positions=positions,
  282. hidden_states=hidden_states,
  283. kv_cache=kv_cache,
  284. attn_metadata=attn_metadata,
  285. )
  286. # Fully Connected
  287. hidden_states, residual = self.post_attention_layernorm(
  288. hidden_states, residual)
  289. hidden_states = self.mlp(hidden_states)
  290. return hidden_states, residual
  291. class DeepseekModel(nn.Module):
  292. fall_back_to_pt_during_load = False
  293. def __init__(
  294. self,
  295. config: PretrainedConfig,
  296. cache_config: Optional[CacheConfig] = None,
  297. quant_config: Optional[QuantizationConfig] = None,
  298. ) -> None:
  299. super().__init__()
  300. self.padding_idx = config.pad_token_id
  301. self.vocab_size = config.vocab_size
  302. self.embed_tokens = VocabParallelEmbedding(
  303. config.vocab_size,
  304. config.hidden_size,
  305. )
  306. self.layers = nn.ModuleList([
  307. DeepseekDecoderLayer(config,
  308. layer_idx,
  309. cache_config,
  310. quant_config=quant_config)
  311. for layer_idx in range(config.num_hidden_layers)
  312. ])
  313. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  314. def forward(
  315. self,
  316. input_ids: torch.Tensor,
  317. positions: torch.Tensor,
  318. kv_caches: List[torch.Tensor],
  319. attn_metadata: AttentionMetadata,
  320. ) -> torch.Tensor:
  321. hidden_states = self.embed_tokens(input_ids)
  322. residual = None
  323. for i in range(len(self.layers)):
  324. layer = self.layers[i]
  325. hidden_states, residual = layer(positions, hidden_states,
  326. kv_caches[i], attn_metadata,
  327. residual)
  328. hidden_states, _ = self.norm(hidden_states, residual)
  329. return hidden_states
  330. class DeepseekForCausalLM(nn.Module):
  331. def __init__(
  332. self,
  333. config: PretrainedConfig,
  334. cache_config: Optional[CacheConfig] = None,
  335. quant_config: Optional[QuantizationConfig] = None,
  336. ) -> None:
  337. super().__init__()
  338. self.config = config
  339. self.quant_config = quant_config
  340. self.model = DeepseekModel(config, cache_config, quant_config)
  341. self.lm_head = ParallelLMHead(config.vocab_size,
  342. config.hidden_size,
  343. quant_config=quant_config)
  344. self.logits_processor = LogitsProcessor(config.vocab_size)
  345. self.sampler = Sampler()
  346. def forward(
  347. self,
  348. input_ids: torch.Tensor,
  349. positions: torch.Tensor,
  350. kv_caches: List[torch.Tensor],
  351. attn_metadata: AttentionMetadata,
  352. intermediate_tensors: Optional[IntermediateTensors] = None,
  353. ) -> torch.Tensor:
  354. hidden_states = self.model(input_ids, positions, kv_caches,
  355. attn_metadata)
  356. return hidden_states
  357. def compute_logits(
  358. self,
  359. hidden_states: torch.Tensor,
  360. sampling_metadata: SamplingMetadata,
  361. ) -> Optional[torch.Tensor]:
  362. logits = self.logits_processor(self.lm_head, hidden_states,
  363. sampling_metadata)
  364. return logits
  365. def sample(
  366. self,
  367. logits: Optional[torch.Tensor],
  368. sampling_metadata: SamplingMetadata,
  369. ) -> Optional[SamplerOutput]:
  370. next_tokens = self.sampler(logits, sampling_metadata)
  371. return next_tokens
  372. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  373. stacked_params_mapping = [
  374. # (param_name, shard_name, shard_id)
  375. ("qkv_proj", "q_proj", "q"),
  376. ("qkv_proj", "k_proj", "k"),
  377. ("qkv_proj", "v_proj", "v"),
  378. ("gate_up_proj", "gate_proj", 0),
  379. ("gate_up_proj", "up_proj", 1),
  380. ]
  381. params_dict = dict(self.named_parameters())
  382. weights_list = list(weights)
  383. for name, loaded_weight in progress_bar(weights_list,
  384. desc="Loading modules..."):
  385. if "rotary_emb.inv_freq" in name:
  386. continue
  387. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  388. if weight_name not in name:
  389. continue
  390. name = name.replace(weight_name, param_name)
  391. # Skip loading extra bias for GPTQ models.
  392. if name.endswith(".bias") and name not in params_dict:
  393. continue
  394. # Skip experts that are not assigned to this worker.
  395. if (("mlp.experts." in name or "mlp.shared_experts." in name)
  396. and name not in params_dict):
  397. continue
  398. param = params_dict[name]
  399. weight_loader = param.weight_loader
  400. weight_loader(param, loaded_weight, shard_id)
  401. break
  402. else:
  403. # Skip loading extra bias for GPTQ models.
  404. if name.endswith(".bias") and name not in params_dict:
  405. continue
  406. # Skip experts that are not assigned to this worker.
  407. if (("mlp.experts." in name or "mlp.shared_experts." in name)
  408. and name not in params_dict):
  409. continue
  410. param = params_dict[name]
  411. weight_loader = getattr(param, "weight_loader",
  412. default_weight_loader)
  413. weight_loader(param, loaded_weight)