deepseek.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only Deepseek model."""
  25. from typing import Any, Dict, List, Optional, Tuple
  26. import torch
  27. from torch import nn
  28. from transformers import PretrainedConfig
  29. from aphrodite.modeling.metadata import InputMetadata
  30. from aphrodite.modeling.layers.activation import SiluAndMul
  31. from aphrodite.modeling.layers.attention import PagedAttention
  32. from aphrodite.modeling.layers.triton_kernel.fused_moe import fused_moe
  33. from aphrodite.modeling.layers.layernorm import RMSNorm
  34. from aphrodite.modeling.layers.linear import (
  35. LinearMethodBase,
  36. MergedColumnParallelLinear,
  37. ReplicatedLinear,
  38. QKVParallelLinear,
  39. RowParallelLinear,
  40. )
  41. from aphrodite.modeling.layers.rotary_embedding import get_rope
  42. from aphrodite.modeling.layers.sampler import Sampler, QuantSampler
  43. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  44. VocabParallelEmbedding,
  45. ParallelLMHead,
  46. )
  47. from aphrodite.modeling.megatron.communication_op import (
  48. tensor_model_parallel_all_reduce, )
  49. from aphrodite.modeling.megatron.parallel_state import (
  50. get_tensor_model_parallel_rank,
  51. get_tensor_model_parallel_world_size,
  52. )
  53. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  54. from aphrodite.modeling.hf_downloader import (
  55. default_weight_loader,
  56. hf_model_weights_iterator,
  57. )
  58. from aphrodite.common.sequence import SamplerOutput
  59. KVCache = Tuple[torch.Tensor, torch.Tensor]
  60. class DeepseekMLP(nn.Module):
  61. def __init__(
  62. self,
  63. hidden_size: int,
  64. intermediate_size: int,
  65. hidden_act: str,
  66. linear_method: Optional[LinearMethodBase] = None,
  67. reduce_results: bool = True,
  68. ) -> None:
  69. super().__init__()
  70. self.gate_up_proj = MergedColumnParallelLinear(
  71. hidden_size,
  72. [intermediate_size] * 2,
  73. bias=False,
  74. linear_method=linear_method,
  75. )
  76. self.down_proj = RowParallelLinear(
  77. intermediate_size,
  78. hidden_size,
  79. bias=False,
  80. linear_method=linear_method,
  81. reduce_results=reduce_results,
  82. )
  83. if hidden_act != "silu":
  84. raise ValueError(f"Unsupported activation: {hidden_act}. "
  85. "Only silu is supported for now.")
  86. self.act_fn = SiluAndMul()
  87. def forward(self, x):
  88. gate_up, _ = self.gate_up_proj(x)
  89. x = self.act_fn(gate_up)
  90. x, _ = self.down_proj(x)
  91. return x
  92. class DeepseekMoE(nn.Module):
  93. def __init__(
  94. self,
  95. config: PretrainedConfig,
  96. linear_method: Optional[LinearMethodBase] = None,
  97. ):
  98. super().__init__()
  99. self.config = config
  100. self.rank = get_tensor_model_parallel_rank()
  101. self.tp_size = get_tensor_model_parallel_world_size()
  102. self.n_routed_experts = config.n_routed_experts
  103. self.top_k = config.num_experts_per_tok
  104. if self.tp_size > self.n_routed_experts:
  105. raise ValueError(
  106. f"Tensor parallel size {self.tp_size} is greater than "
  107. f"the number of experts {self.n_routed_experts}.")
  108. self.experts = nn.ModuleList([
  109. DeepseekMLP(
  110. hidden_size=config.hidden_size,
  111. intermediate_size=config.moe_intermediate_size,
  112. hidden_act=config.hidden_act,
  113. linear_method=linear_method,
  114. reduce_results=False,
  115. ) for idx in range(self.n_routed_experts)
  116. ])
  117. self.pack_params()
  118. self.gate = ReplicatedLinear(
  119. config.hidden_size,
  120. self.n_routed_experts,
  121. bias=False,
  122. linear_method=None,
  123. )
  124. if config.n_shared_experts is not None:
  125. intermediate_size = (config.moe_intermediate_size *
  126. config.n_shared_experts)
  127. self.shared_experts = DeepseekMLP(
  128. hidden_size=config.hidden_size,
  129. intermediate_size=intermediate_size,
  130. hidden_act=config.hidden_act,
  131. linear_method=linear_method,
  132. reduce_results=False,
  133. )
  134. def pack_params(self):
  135. w1 = []
  136. w2 = []
  137. for expert in self.experts:
  138. w1.append(expert.gate_up_proj.weight)
  139. w2.append(expert.down_proj.weight)
  140. self.w1 = torch._utils._flatten_dense_tensors(w1)
  141. w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
  142. for data, param in zip(w1s, w1):
  143. param.data = data
  144. self.w1 = self.w1.view(len(w1), *w1s[0].shape)
  145. self.w2 = torch._utils._flatten_dense_tensors(w2)
  146. w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
  147. for data, param in zip(w2s, w2):
  148. param.data = data
  149. self.w2 = self.w2.view(len(w2), *w2s[0].shape)
  150. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  151. batch_size, sequence_length, hidden_dim = hidden_states.shape
  152. hidden_states = hidden_states.view(-1, hidden_dim)
  153. if self.config.n_shared_experts is not None:
  154. shared_output = self.shared_experts(hidden_states)
  155. # router_logits: (batch * sequence_length, n_experts)
  156. router_logits, _ = self.gate(hidden_states)
  157. final_hidden_states = fused_moe(
  158. hidden_states,
  159. self.w1,
  160. self.w2,
  161. router_logits,
  162. self.top_k,
  163. renormalize=self.config.norm_topk_prob,
  164. inplace=True,
  165. )
  166. if self.config.n_shared_experts is not None:
  167. final_hidden_states = final_hidden_states + shared_output
  168. final_hidden_states = tensor_model_parallel_all_reduce(
  169. final_hidden_states)
  170. return final_hidden_states.view(batch_size, sequence_length,
  171. hidden_dim)
  172. class DeepseekAttention(nn.Module):
  173. def __init__(
  174. self,
  175. hidden_size: int,
  176. num_heads: int,
  177. num_kv_heads: int,
  178. rope_theta: float = 10000,
  179. rope_scaling: Optional[Dict[str, Any]] = None,
  180. max_position_embeddings: int = 8192,
  181. linear_method: Optional[LinearMethodBase] = None,
  182. ) -> None:
  183. super().__init__()
  184. self.hidden_size = hidden_size
  185. tp_size = get_tensor_model_parallel_world_size()
  186. self.total_num_heads = num_heads
  187. assert self.total_num_heads % tp_size == 0
  188. self.num_heads = self.total_num_heads // tp_size
  189. self.total_num_kv_heads = num_kv_heads
  190. if self.total_num_kv_heads >= tp_size:
  191. # Number of KV heads is greater than TP size, so we partition
  192. # the KV heads across multiple tensor parallel GPUs.
  193. assert self.total_num_kv_heads % tp_size == 0
  194. else:
  195. # Number of KV heads is less than TP size, so we replicate
  196. # the KV heads across multiple tensor parallel GPUs.
  197. assert tp_size % self.total_num_kv_heads == 0
  198. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  199. self.head_dim = hidden_size // self.total_num_heads
  200. self.q_size = self.num_heads * self.head_dim
  201. self.kv_size = self.num_kv_heads * self.head_dim
  202. self.scaling = self.head_dim**-0.5
  203. self.rope_theta = rope_theta
  204. self.max_position_embeddings = max_position_embeddings
  205. self.qkv_proj = QKVParallelLinear(
  206. hidden_size,
  207. self.head_dim,
  208. self.total_num_heads,
  209. self.total_num_kv_heads,
  210. bias=False,
  211. linear_method=linear_method,
  212. )
  213. self.o_proj = RowParallelLinear(
  214. self.total_num_heads * self.head_dim,
  215. hidden_size,
  216. bias=False,
  217. linear_method=linear_method,
  218. )
  219. self.rotary_emb = get_rope(
  220. self.head_dim,
  221. rotary_dim=self.head_dim,
  222. max_position=max_position_embeddings,
  223. base=rope_theta,
  224. rope_scaling=rope_scaling,
  225. )
  226. self.attn = PagedAttention(
  227. self.num_heads,
  228. self.head_dim,
  229. self.scaling,
  230. num_kv_heads=self.num_kv_heads,
  231. )
  232. def forward(
  233. self,
  234. positions: torch.Tensor,
  235. hidden_states: torch.Tensor,
  236. kv_cache: KVCache,
  237. input_metadata: InputMetadata,
  238. ) -> torch.Tensor:
  239. qkv, _ = self.qkv_proj(hidden_states)
  240. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  241. q, k = self.rotary_emb(positions, q, k)
  242. k_cache, v_cache = kv_cache
  243. attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
  244. output, _ = self.o_proj(attn_output)
  245. return output
  246. class DeepseekDecoderLayer(nn.Module):
  247. def __init__(
  248. self,
  249. config: PretrainedConfig,
  250. layer_idx: int,
  251. linear_method: Optional[LinearMethodBase] = None,
  252. ) -> None:
  253. super().__init__()
  254. self.hidden_size = config.hidden_size
  255. rope_theta = getattr(config, "rope_theta", 10000)
  256. rope_scaling = getattr(config, "rope_scaling", None)
  257. max_position_embeddings = getattr(config, "max_position_embeddings",
  258. 8192)
  259. self.self_attn = DeepseekAttention(
  260. hidden_size=self.hidden_size,
  261. num_heads=config.num_attention_heads,
  262. num_kv_heads=config.num_key_value_heads,
  263. rope_theta=rope_theta,
  264. rope_scaling=rope_scaling,
  265. max_position_embeddings=max_position_embeddings,
  266. linear_method=linear_method,
  267. )
  268. if (config.n_routed_experts is not None
  269. and layer_idx >= config.first_k_dense_replace
  270. and layer_idx % config.moe_layer_freq == 0):
  271. self.mlp = DeepseekMoE(config=config, linear_method=linear_method)
  272. else:
  273. self.mlp = DeepseekMLP(
  274. hidden_size=config.hidden_size,
  275. intermediate_size=config.intermediate_size,
  276. hidden_act=config.hidden_act,
  277. linear_method=linear_method,
  278. )
  279. self.input_layernorm = RMSNorm(config.hidden_size,
  280. eps=config.rms_norm_eps)
  281. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  282. eps=config.rms_norm_eps)
  283. def forward(
  284. self,
  285. positions: torch.Tensor,
  286. hidden_states: torch.Tensor,
  287. kv_cache: KVCache,
  288. input_metadata: InputMetadata,
  289. residual: Optional[torch.Tensor],
  290. ) -> torch.Tensor:
  291. # Self Attention
  292. if residual is None:
  293. residual = hidden_states
  294. hidden_states = self.input_layernorm(hidden_states)
  295. else:
  296. hidden_states, residual = self.input_layernorm(
  297. hidden_states, residual)
  298. hidden_states = self.self_attn(
  299. positions=positions,
  300. hidden_states=hidden_states,
  301. kv_cache=kv_cache,
  302. input_metadata=input_metadata,
  303. )
  304. # Fully Connected
  305. hidden_states, residual = self.post_attention_layernorm(
  306. hidden_states, residual)
  307. hidden_states = self.mlp(hidden_states)
  308. return hidden_states, residual
  309. class DeepseekModel(nn.Module):
  310. def __init__(
  311. self,
  312. config: PretrainedConfig,
  313. linear_method: Optional[LinearMethodBase] = None,
  314. ) -> None:
  315. super().__init__()
  316. self.padding_idx = config.pad_token_id
  317. self.vocab_size = config.vocab_size
  318. self.embed_tokens = VocabParallelEmbedding(
  319. config.vocab_size,
  320. config.hidden_size,
  321. )
  322. self.layers = nn.ModuleList([
  323. DeepseekDecoderLayer(config,
  324. layer_idx,
  325. linear_method=linear_method)
  326. for layer_idx in range(config.num_hidden_layers)
  327. ])
  328. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  329. def forward(
  330. self,
  331. input_ids: torch.Tensor,
  332. positions: torch.Tensor,
  333. kv_caches: List[KVCache],
  334. input_metadata: InputMetadata,
  335. ) -> torch.Tensor:
  336. hidden_states = self.embed_tokens(input_ids)
  337. residual = None
  338. for i in range(len(self.layers)):
  339. layer = self.layers[i]
  340. hidden_states, residual = layer(positions, hidden_states,
  341. kv_caches[i], input_metadata,
  342. residual)
  343. hidden_states, _ = self.norm(hidden_states, residual)
  344. return hidden_states
  345. class DeepseekForCausalLM(nn.Module):
  346. def __init__(
  347. self,
  348. config: PretrainedConfig,
  349. linear_method: Optional[LinearMethodBase] = None,
  350. ) -> None:
  351. super().__init__()
  352. self.config = config
  353. self.linear_method = linear_method
  354. self.model = DeepseekModel(config, linear_method)
  355. self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
  356. self.sampler = Sampler(config.vocab_size)
  357. self.quant_sampler = QuantSampler(config.vocab_size)
  358. def forward(
  359. self,
  360. input_ids: torch.Tensor,
  361. positions: torch.Tensor,
  362. kv_caches: List[KVCache],
  363. input_metadata: InputMetadata,
  364. ) -> torch.Tensor:
  365. hidden_states = self.model(input_ids, positions, kv_caches,
  366. input_metadata)
  367. return hidden_states
  368. def sample(
  369. self,
  370. hidden_states: Optional[torch.Tensor],
  371. sampling_metadata: SamplingMetadata,
  372. ) -> Optional[SamplerOutput]:
  373. if (self.linear_method is not None
  374. and not self.linear_method.quant_config.merge_weight()):
  375. next_tokens = self.quant_sampler(self.lm_head(hidden_states),
  376. sampling_metadata)
  377. else:
  378. next_tokens = self.sampler(self.lm_head.weight, hidden_states,
  379. sampling_metadata)
  380. return next_tokens
  381. def load_weights(
  382. self,
  383. model_name_or_path: str,
  384. cache_dir: Optional[str] = None,
  385. load_format: str = "auto",
  386. revision: Optional[str] = None,
  387. ):
  388. stacked_params_mapping = [
  389. # (param_name, shard_name, shard_id)
  390. ("qkv_proj", "q_proj", "q"),
  391. ("qkv_proj", "k_proj", "k"),
  392. ("qkv_proj", "v_proj", "v"),
  393. ("gate_up_proj", "gate_proj", 0),
  394. ("gate_up_proj", "up_proj", 1),
  395. ]
  396. params_dict = dict(self.named_parameters())
  397. for name, loaded_weight in hf_model_weights_iterator(
  398. model_name_or_path,
  399. cache_dir,
  400. load_format,
  401. revision,
  402. fall_back_to_pt=False,
  403. ):
  404. if "rotary_emb.inv_freq" in name:
  405. continue
  406. for param_name, weight_name, shard_id in stacked_params_mapping:
  407. if weight_name not in name:
  408. continue
  409. name = name.replace(weight_name, param_name)
  410. # Skip loading extra bias for GPTQ models.
  411. if name.endswith(".bias") and name not in params_dict:
  412. continue
  413. # Skip experts that are not assigned to this worker.
  414. if ("mlp.experts." in name or "mlp.shared_experts."
  415. in name) and name not in params_dict:
  416. continue
  417. param = params_dict[name]
  418. weight_loader = param.weight_loader
  419. weight_loader(param, loaded_weight, shard_id)
  420. break
  421. else:
  422. # Skip loading extra bias for GPTQ models.
  423. if name.endswith(".bias") and name not in params_dict:
  424. continue
  425. # Skip experts that are not assigned to this worker.
  426. if ("mlp.experts." in name or "mlp.shared_experts."
  427. in name) and name not in params_dict:
  428. continue
  429. param = params_dict[name]
  430. weight_loader = getattr(param, "weight_loader",
  431. default_weight_loader)
  432. weight_loader(param, loaded_weight)