deepseek.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only Deepseek model."""
  25. from typing import Any, Dict, List, Optional
  26. import numpy as np
  27. import torch
  28. from torch import nn
  29. from transformers import PretrainedConfig
  30. from aphrodite.attention import Attention, AttentionMetadata
  31. from aphrodite.modeling.layers.activation import SiluAndMul
  32. from aphrodite.modeling.layers.fused_moe import fused_topk
  33. from aphrodite.modeling.layers.layernorm import RMSNorm
  34. from aphrodite.modeling.layers.linear import (
  35. LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear,
  36. ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod)
  37. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  38. from aphrodite.modeling.layers.rotary_embedding import get_rope
  39. from aphrodite.modeling.layers.sampler import Sampler
  40. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  41. ParallelLMHead, VocabParallelEmbedding)
  42. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  43. get_tensor_model_parallel_world_size,
  44. tensor_model_parallel_all_reduce)
  45. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  46. from aphrodite.modeling.hf_downloader import (default_weight_loader,
  47. hf_model_weights_iterator)
  48. from aphrodite.common.sequence import SamplerOutput
  49. class DeepseekMLP(nn.Module):
  50. def __init__(
  51. self,
  52. hidden_size: int,
  53. intermediate_size: int,
  54. hidden_act: str,
  55. linear_method: Optional[LinearMethodBase] = None,
  56. reduce_results: bool = True,
  57. ) -> None:
  58. super().__init__()
  59. self.gate_up_proj = MergedColumnParallelLinear(
  60. hidden_size, [intermediate_size] * 2,
  61. bias=False,
  62. linear_method=linear_method)
  63. self.down_proj = RowParallelLinear(intermediate_size,
  64. hidden_size,
  65. bias=False,
  66. linear_method=linear_method,
  67. reduce_results=reduce_results)
  68. if hidden_act != "silu":
  69. raise ValueError(f"Unsupported activation: {hidden_act}. "
  70. "Only silu is supported for now.")
  71. self.act_fn = SiluAndMul()
  72. def forward(self, x):
  73. gate_up, _ = self.gate_up_proj(x)
  74. x = self.act_fn(gate_up)
  75. x, _ = self.down_proj(x)
  76. return x
  77. class DeepseekExpertMLP(nn.Module):
  78. def __init__(
  79. self,
  80. hidden_size: int,
  81. intermediate_size: int,
  82. hidden_act: str,
  83. linear_method: Optional[LinearMethodBase] = None,
  84. ) -> None:
  85. super().__init__()
  86. self.gate_proj = ReplicatedLinear(hidden_size,
  87. intermediate_size,
  88. bias=False,
  89. linear_method=linear_method)
  90. self.up_proj = ReplicatedLinear(hidden_size,
  91. intermediate_size,
  92. bias=False,
  93. linear_method=linear_method)
  94. self.down_proj = ReplicatedLinear(intermediate_size,
  95. hidden_size,
  96. bias=False,
  97. linear_method=linear_method)
  98. self.act_fn = nn.SiLU()
  99. def forward(self, hidden_states):
  100. gate_out, _ = self.gate_proj(hidden_states)
  101. gate_out = self.act_fn(gate_out)
  102. up_out, _ = self.up_proj(hidden_states)
  103. current_hidden_states = gate_out * up_out
  104. current_hidden_states, _ = self.down_proj(current_hidden_states)
  105. return current_hidden_states
  106. class DeepseekMoE(nn.Module):
  107. def __init__(
  108. self,
  109. config: PretrainedConfig,
  110. linear_method: Optional[LinearMethodBase] = None,
  111. ):
  112. super().__init__()
  113. self.config = config
  114. self.rank = get_tensor_model_parallel_rank()
  115. self.tp_size = get_tensor_model_parallel_world_size()
  116. self.n_routed_experts = config.n_routed_experts
  117. self.top_k = config.num_experts_per_tok
  118. self.linear_method = linear_method
  119. if self.linear_method is None:
  120. self.linear_method = UnquantizedLinearMethod()
  121. if not isinstance(
  122. self.linear_method, UnquantizedLinearMethod
  123. ) and not self.linear_method.quant_config.support_fused_moe():
  124. if self.tp_size > self.n_routed_experts:
  125. raise ValueError(
  126. f"Tensor parallel size {self.tp_size} is greater than "
  127. f"the number of experts {self.n_routed_experts}.")
  128. # Split experts equally between ranks
  129. self.expert_indicies = np.array_split(range(
  130. self.n_routed_experts), self.tp_size)[self.rank].tolist()
  131. if not self.expert_indicies:
  132. raise ValueError(
  133. f"Rank {self.rank} has no experts assigned to it.")
  134. self.experts = nn.ModuleList([
  135. DeepseekExpertMLP(
  136. hidden_size=config.hidden_size,
  137. intermediate_size=config.moe_intermediate_size,
  138. hidden_act=config.hidden_act,
  139. linear_method=linear_method,
  140. ) if idx in self.expert_indicies else None
  141. for idx in range(self.n_routed_experts)
  142. ])
  143. else:
  144. self.w1 = MergedColumnParallelLinear(
  145. config.hidden_size, [config.moe_intermediate_size] * 2,
  146. bias=False,
  147. linear_method=linear_method,
  148. num_experts=self.n_routed_experts)
  149. self.w2 = RowParallelLinear(config.moe_intermediate_size,
  150. config.hidden_size,
  151. bias=False,
  152. linear_method=linear_method,
  153. num_experts=self.n_routed_experts)
  154. self.gate = ReplicatedLinear(config.hidden_size,
  155. self.n_routed_experts,
  156. bias=False,
  157. linear_method=None)
  158. if config.n_shared_experts is not None:
  159. intermediate_size = (config.moe_intermediate_size *
  160. config.n_shared_experts)
  161. self.shared_experts = DeepseekMLP(
  162. hidden_size=config.hidden_size,
  163. intermediate_size=intermediate_size,
  164. hidden_act=config.hidden_act,
  165. linear_method=linear_method,
  166. reduce_results=False,
  167. )
  168. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  169. num_tokens, hidden_dim = hidden_states.shape
  170. hidden_states = hidden_states.view(-1, hidden_dim)
  171. if self.config.n_shared_experts is not None:
  172. shared_output = self.shared_experts(hidden_states)
  173. # router_logits: (num_tokens, n_experts)
  174. router_logits, _ = self.gate(hidden_states)
  175. if not isinstance(
  176. self.linear_method, UnquantizedLinearMethod
  177. ) and not self.linear_method.quant_config.support_fused_moe():
  178. routing_weights, selected_experts = fused_topk(
  179. router_logits,
  180. self.top_k,
  181. renormalize=self.config.norm_topk_prob)
  182. final_hidden_states = None
  183. for expert_idx in self.expert_indicies:
  184. expert_layer = self.experts[expert_idx]
  185. expert_mask = (selected_experts == expert_idx)
  186. expert_weights = (routing_weights * expert_mask).sum(
  187. dim=-1, keepdim=True)
  188. current_hidden_states = expert_layer(hidden_states).mul_(
  189. expert_weights)
  190. if final_hidden_states is None:
  191. final_hidden_states = current_hidden_states
  192. else:
  193. final_hidden_states.add_(current_hidden_states)
  194. else:
  195. final_hidden_states = self.linear_method.apply_moe_weights(
  196. self.w1.linear_weights,
  197. self.w2.linear_weights,
  198. hidden_states,
  199. router_logits,
  200. self.top_k,
  201. renormalize=self.config.norm_topk_prob,
  202. )
  203. if self.config.n_shared_experts is not None:
  204. final_hidden_states = final_hidden_states + shared_output
  205. final_hidden_states = tensor_model_parallel_all_reduce(
  206. final_hidden_states)
  207. return final_hidden_states.view(num_tokens, hidden_dim)
  208. class DeepseekAttention(nn.Module):
  209. def __init__(
  210. self,
  211. hidden_size: int,
  212. num_heads: int,
  213. num_kv_heads: int,
  214. rope_theta: float = 10000,
  215. rope_scaling: Optional[Dict[str, Any]] = None,
  216. max_position_embeddings: int = 8192,
  217. linear_method: Optional[LinearMethodBase] = None,
  218. ) -> None:
  219. super().__init__()
  220. self.hidden_size = hidden_size
  221. tp_size = get_tensor_model_parallel_world_size()
  222. self.total_num_heads = num_heads
  223. assert self.total_num_heads % tp_size == 0
  224. self.num_heads = self.total_num_heads // tp_size
  225. self.total_num_kv_heads = num_kv_heads
  226. if self.total_num_kv_heads >= tp_size:
  227. # Number of KV heads is greater than TP size, so we partition
  228. # the KV heads across multiple tensor parallel GPUs.
  229. assert self.total_num_kv_heads % tp_size == 0
  230. else:
  231. # Number of KV heads is less than TP size, so we replicate
  232. # the KV heads across multiple tensor parallel GPUs.
  233. assert tp_size % self.total_num_kv_heads == 0
  234. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  235. self.head_dim = hidden_size // self.total_num_heads
  236. self.q_size = self.num_heads * self.head_dim
  237. self.kv_size = self.num_kv_heads * self.head_dim
  238. self.scaling = self.head_dim**-0.5
  239. self.rope_theta = rope_theta
  240. self.max_position_embeddings = max_position_embeddings
  241. self.qkv_proj = QKVParallelLinear(
  242. hidden_size,
  243. self.head_dim,
  244. self.total_num_heads,
  245. self.total_num_kv_heads,
  246. bias=False,
  247. linear_method=linear_method,
  248. )
  249. self.o_proj = RowParallelLinear(
  250. self.total_num_heads * self.head_dim,
  251. hidden_size,
  252. bias=False,
  253. linear_method=linear_method,
  254. )
  255. self.rotary_emb = get_rope(
  256. self.head_dim,
  257. rotary_dim=self.head_dim,
  258. max_position=max_position_embeddings,
  259. base=rope_theta,
  260. rope_scaling=rope_scaling,
  261. )
  262. self.attn = Attention(self.num_heads,
  263. self.head_dim,
  264. self.scaling,
  265. num_kv_heads=self.num_kv_heads)
  266. def forward(
  267. self,
  268. positions: torch.Tensor,
  269. hidden_states: torch.Tensor,
  270. kv_cache: torch.Tensor,
  271. attn_metadata: AttentionMetadata,
  272. ) -> torch.Tensor:
  273. qkv, _ = self.qkv_proj(hidden_states)
  274. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  275. q, k = self.rotary_emb(positions, q, k)
  276. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  277. output, _ = self.o_proj(attn_output)
  278. return output
  279. class DeepseekDecoderLayer(nn.Module):
  280. def __init__(
  281. self,
  282. config: PretrainedConfig,
  283. layer_idx: int,
  284. linear_method: Optional[LinearMethodBase] = None,
  285. ) -> None:
  286. super().__init__()
  287. self.hidden_size = config.hidden_size
  288. rope_theta = getattr(config, "rope_theta", 10000)
  289. rope_scaling = getattr(config, "rope_scaling", None)
  290. max_position_embeddings = getattr(config, "max_position_embeddings",
  291. 8192)
  292. self.self_attn = DeepseekAttention(
  293. hidden_size=self.hidden_size,
  294. num_heads=config.num_attention_heads,
  295. num_kv_heads=config.num_key_value_heads,
  296. rope_theta=rope_theta,
  297. rope_scaling=rope_scaling,
  298. max_position_embeddings=max_position_embeddings,
  299. linear_method=linear_method,
  300. )
  301. if (config.n_routed_experts is not None
  302. and layer_idx >= config.first_k_dense_replace
  303. and layer_idx % config.moe_layer_freq == 0):
  304. self.mlp = DeepseekMoE(config=config, linear_method=linear_method)
  305. else:
  306. self.mlp = DeepseekMLP(
  307. hidden_size=config.hidden_size,
  308. intermediate_size=config.intermediate_size,
  309. hidden_act=config.hidden_act,
  310. linear_method=linear_method,
  311. )
  312. self.input_layernorm = RMSNorm(config.hidden_size,
  313. eps=config.rms_norm_eps)
  314. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  315. eps=config.rms_norm_eps)
  316. def forward(
  317. self,
  318. positions: torch.Tensor,
  319. hidden_states: torch.Tensor,
  320. kv_cache: torch.Tensor,
  321. attn_metadata: AttentionMetadata,
  322. residual: Optional[torch.Tensor],
  323. ) -> torch.Tensor:
  324. # Self Attention
  325. if residual is None:
  326. residual = hidden_states
  327. hidden_states = self.input_layernorm(hidden_states)
  328. else:
  329. hidden_states, residual = self.input_layernorm(
  330. hidden_states, residual)
  331. hidden_states = self.self_attn(
  332. positions=positions,
  333. hidden_states=hidden_states,
  334. kv_cache=kv_cache,
  335. attn_metadata=attn_metadata,
  336. )
  337. # Fully Connected
  338. hidden_states, residual = self.post_attention_layernorm(
  339. hidden_states, residual)
  340. hidden_states = self.mlp(hidden_states)
  341. return hidden_states, residual
  342. class DeepseekModel(nn.Module):
  343. def __init__(
  344. self,
  345. config: PretrainedConfig,
  346. linear_method: Optional[LinearMethodBase] = None,
  347. ) -> None:
  348. super().__init__()
  349. self.padding_idx = config.pad_token_id
  350. self.vocab_size = config.vocab_size
  351. self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
  352. config.hidden_size,
  353. linear_method=linear_method)
  354. self.layers = nn.ModuleList([
  355. DeepseekDecoderLayer(config,
  356. layer_idx,
  357. linear_method=linear_method)
  358. for layer_idx in range(config.num_hidden_layers)
  359. ])
  360. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  361. def forward(
  362. self,
  363. input_ids: torch.Tensor,
  364. positions: torch.Tensor,
  365. kv_caches: List[torch.Tensor],
  366. attn_metadata: AttentionMetadata,
  367. ) -> torch.Tensor:
  368. hidden_states = self.embed_tokens(input_ids)
  369. residual = None
  370. for i in range(len(self.layers)):
  371. layer = self.layers[i]
  372. hidden_states, residual = layer(positions, hidden_states,
  373. kv_caches[i], attn_metadata,
  374. residual)
  375. hidden_states, _ = self.norm(hidden_states, residual)
  376. return hidden_states
  377. class DeepseekForCausalLM(nn.Module):
  378. def __init__(
  379. self,
  380. config: PretrainedConfig,
  381. linear_method: Optional[LinearMethodBase] = None,
  382. ) -> None:
  383. super().__init__()
  384. self.config = config
  385. self.linear_method = linear_method
  386. self.model = DeepseekModel(config, linear_method)
  387. self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
  388. self.logits_processor = LogitsProcessor(config.vocab_size)
  389. self.sampler = Sampler()
  390. def forward(
  391. self,
  392. input_ids: torch.Tensor,
  393. positions: torch.Tensor,
  394. kv_caches: List[torch.Tensor],
  395. attn_metadata: AttentionMetadata,
  396. ) -> torch.Tensor:
  397. hidden_states = self.model(input_ids, positions, kv_caches,
  398. attn_metadata)
  399. return hidden_states
  400. def compute_logits(self, hidden_states: torch.Tensor,
  401. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  402. logits = self.logits_processor(self.lm_head, hidden_states,
  403. sampling_metadata)
  404. return logits
  405. def sample(
  406. self,
  407. logits: Optional[torch.Tensor],
  408. sampling_metadata: SamplingMetadata,
  409. ) -> Optional[SamplerOutput]:
  410. next_tokens = self.sampler(logits, sampling_metadata)
  411. return next_tokens
  412. def load_weights(self,
  413. model_name_or_path: str,
  414. cache_dir: Optional[str] = None,
  415. load_format: str = "auto",
  416. revision: Optional[str] = None):
  417. stacked_params_mapping = [
  418. # (param_name, shard_name, shard_id)
  419. ("qkv_proj", "q_proj", "q"),
  420. ("qkv_proj", "k_proj", "k"),
  421. ("qkv_proj", "v_proj", "v"),
  422. ("mlp.gate_up_proj", "mlp.gate_proj", 0),
  423. ("mlp.gate_up_proj", "mlp.up_proj", 1),
  424. ("shared_experts.gate_up_proj", "shared_experts.gate_proj", 0),
  425. ("shared_experts.gate_up_proj", "shared_experts.up_proj", 1),
  426. ]
  427. expert_params_mapping = [
  428. # (param_name, weight_name, shard_id, expert_id)
  429. ("w1" if weight_name in ["gate_proj", "up_proj"] else "w2",
  430. f"experts.{expert_id}.{weight_name}", shard_id, expert_id)
  431. for expert_id in range(self.config.n_routed_experts)
  432. for weight_name, shard_id in [("gate_proj",
  433. 0), ("up_proj",
  434. 1), ("down_proj", None)]
  435. ] if self.linear_method is None or (
  436. self.linear_method.quant_config.support_fused_moe()) else []
  437. params_dict = dict(self.named_parameters())
  438. for name, loaded_weight in hf_model_weights_iterator(
  439. model_name_or_path,
  440. cache_dir,
  441. load_format,
  442. revision,
  443. self.config,
  444. fall_back_to_pt=False):
  445. if "rotary_emb.inv_freq" in name:
  446. continue
  447. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  448. if weight_name not in name:
  449. continue
  450. name = name.replace(weight_name, param_name)
  451. # Skip loading extra bias for GPTQ models.
  452. if name.endswith(".bias") and name not in params_dict:
  453. continue
  454. # Skip experts that are not assigned to this worker.
  455. if (("mlp.experts." in name or "mlp.shared_experts." in name)
  456. and name not in params_dict):
  457. continue
  458. param = params_dict[name]
  459. weight_loader = param.weight_loader
  460. weight_loader(param, loaded_weight, shard_id)
  461. break
  462. else:
  463. for (param_name, weight_name, shard_id,
  464. expert_id) in expert_params_mapping:
  465. if weight_name not in name:
  466. continue
  467. name = name.replace(weight_name, param_name)
  468. if name.endswith(".bias") and name not in params_dict:
  469. continue
  470. param = params_dict[name]
  471. weight_loader = param.weight_loader
  472. if shard_id is None:
  473. weight_loader(param,
  474. loaded_weight,
  475. expert_id=expert_id)
  476. else:
  477. weight_loader(param,
  478. loaded_weight,
  479. shard_id,
  480. expert_id=expert_id)
  481. break
  482. else:
  483. # Skip loading extra bias for GPTQ models.
  484. if name.endswith(".bias") and name not in params_dict:
  485. continue
  486. # Skip experts that are not assigned to this worker.
  487. if (("mlp.experts." in name
  488. or "mlp.shared_experts." in name)
  489. and name not in params_dict):
  490. continue
  491. param = params_dict[name]
  492. weight_loader = getattr(param, "weight_loader",
  493. default_weight_loader)
  494. weight_loader(param, loaded_weight)