deepseek.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only Deepseek model."""
  25. from typing import Any, Dict, List, Optional, Tuple
  26. import torch
  27. from torch import nn
  28. import torch.nn.functional as F
  29. from transformers import PretrainedConfig
  30. from aphrodite.modeling.metadata import InputMetadata
  31. from aphrodite.modeling.layers.activation import SiluAndMul
  32. from aphrodite.modeling.layers.attention import PagedAttention
  33. from aphrodite.modeling.layers.triton_kernel.fused_moe import fused_moe
  34. from aphrodite.modeling.layers.layernorm import RMSNorm
  35. from aphrodite.modeling.layers.linear import (
  36. LinearMethodBase, MergedColumnParallelLinear, ReplicatedLinear,
  37. QKVParallelLinear, RowParallelLinear, ColumnParallelLinear)
  38. from aphrodite.modeling.layers.rotary_embedding import get_rope
  39. from aphrodite.modeling.layers.sampler import Sampler
  40. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  41. VocabParallelEmbedding, ParallelLMHead)
  42. from aphrodite.modeling.megatron.communication_op import (
  43. tensor_model_parallel_all_reduce)
  44. from aphrodite.modeling.megatron.parallel_state import (
  45. get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
  46. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  47. from aphrodite.modeling.hf_downloader import (default_weight_loader,
  48. hf_model_weights_iterator)
  49. from aphrodite.common.sequence import SamplerOutput
  50. KVCache = Tuple[torch.Tensor, torch.Tensor]
  51. class DeepseekMLP(nn.Module):
  52. def __init__(
  53. self,
  54. hidden_size: int,
  55. intermediate_size: int,
  56. hidden_act: str,
  57. linear_method: Optional[LinearMethodBase] = None,
  58. reduce_results: bool = True,
  59. ) -> None:
  60. super().__init__()
  61. if linear_method is not None and not linear_method.quant_config.merge_weight(
  62. ):
  63. self.merge_weight = False
  64. self.gate_proj = ColumnParallelLinear(hidden_size,
  65. intermediate_size,
  66. bias=False,
  67. linear_method=linear_method)
  68. self.up_proj = ColumnParallelLinear(hidden_size,
  69. intermediate_size,
  70. bias=False,
  71. linear_method=linear_method)
  72. else:
  73. self.merge_weight = True
  74. self.gate_up_proj = MergedColumnParallelLinear(
  75. hidden_size, [intermediate_size] * 2,
  76. bias=False,
  77. linear_method=linear_method)
  78. self.down_proj = RowParallelLinear(intermediate_size,
  79. hidden_size,
  80. bias=False,
  81. linear_method=linear_method,
  82. reduce_results=reduce_results)
  83. if hidden_act != "silu":
  84. raise ValueError(f"Unsupported activation: {hidden_act}. "
  85. "Only silu is supported for now.")
  86. self.act_fn = SiluAndMul()
  87. def forward(self, x):
  88. if self.merge_weight:
  89. gate_up, _ = self.gate_up_proj(x)
  90. else:
  91. up, _ = self.up_proj(x)
  92. gate, _ = self.gate_proj(x)
  93. gate_up = torch.cat([gate, up], dim=-1)
  94. x = self.act_fn(gate_up)
  95. x, _ = self.down_proj(x)
  96. return x
  97. class DeepseekMoE(nn.Module):
  98. def __init__(
  99. self,
  100. config: PretrainedConfig,
  101. linear_method: Optional[LinearMethodBase] = None,
  102. ):
  103. super().__init__()
  104. self.config = config
  105. self.rank = get_tensor_model_parallel_rank()
  106. self.tp_size = get_tensor_model_parallel_world_size()
  107. self.n_routed_experts = config.n_routed_experts
  108. self.top_k = config.num_experts_per_tok
  109. if self.tp_size > self.n_routed_experts:
  110. raise ValueError(
  111. f"Tensor parallel size {self.tp_size} is greater than "
  112. f"the number of experts {self.n_routed_experts}.")
  113. self.experts = nn.ModuleList([
  114. DeepseekMLP(hidden_size=config.hidden_size,
  115. intermediate_size=config.moe_intermediate_size,
  116. hidden_act=config.hidden_act,
  117. linear_method=linear_method,
  118. reduce_results=False)
  119. for idx in range(self.n_routed_experts)
  120. ])
  121. self.pack_params()
  122. self.gate = ReplicatedLinear(config.hidden_size,
  123. self.n_routed_experts,
  124. bias=False,
  125. linear_method=None)
  126. if config.n_shared_experts is not None:
  127. intermediate_size = config.moe_intermediate_size * config.n_shared_experts
  128. self.shared_experts = DeepseekMLP(
  129. hidden_size=config.hidden_size,
  130. intermediate_size=intermediate_size,
  131. hidden_act=config.hidden_act,
  132. linear_method=linear_method,
  133. reduce_results=False,
  134. )
  135. def pack_params(self):
  136. w1 = []
  137. w2 = []
  138. for expert in self.experts:
  139. w1.append(expert.gate_up_proj.weight)
  140. w2.append(expert.down_proj.weight)
  141. self.w1 = torch._utils._flatten_dense_tensors(w1)
  142. w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
  143. for data, param in zip(w1s, w1):
  144. param.data = data
  145. self.w1 = self.w1.view(len(w1), *w1s[0].shape)
  146. self.w2 = torch._utils._flatten_dense_tensors(w2)
  147. w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
  148. for data, param in zip(w2s, w2):
  149. param.data = data
  150. self.w2 = self.w2.view(len(w2), *w2s[0].shape)
  151. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  152. batch_size, sequence_length, hidden_dim = hidden_states.shape
  153. hidden_states = hidden_states.view(-1, hidden_dim)
  154. if self.config.n_shared_experts is not None:
  155. shared_output = self.shared_experts(hidden_states)
  156. # router_logits: (batch * sequence_length, n_experts)
  157. router_logits, _ = self.gate(hidden_states)
  158. routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
  159. routing_weights, selected_experts = torch.topk(routing_weights,
  160. self.top_k,
  161. dim=-1)
  162. if self.config.norm_topk_prob:
  163. routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
  164. final_hidden_states = fused_moe(hidden_states,
  165. self.w1,
  166. self.w2,
  167. routing_weights,
  168. selected_experts,
  169. inplace=True)
  170. if self.config.n_shared_experts is not None:
  171. final_hidden_states = final_hidden_states + shared_output
  172. final_hidden_states = tensor_model_parallel_all_reduce(
  173. final_hidden_states)
  174. return final_hidden_states.view(batch_size, sequence_length,
  175. hidden_dim)
  176. class DeepseekAttention(nn.Module):
  177. def __init__(
  178. self,
  179. hidden_size: int,
  180. num_heads: int,
  181. num_kv_heads: int,
  182. rope_theta: float = 10000,
  183. rope_scaling: Optional[Dict[str, Any]] = None,
  184. max_position_embeddings: int = 8192,
  185. linear_method: Optional[LinearMethodBase] = None,
  186. ) -> None:
  187. super().__init__()
  188. self.hidden_size = hidden_size
  189. tp_size = get_tensor_model_parallel_world_size()
  190. self.total_num_heads = num_heads
  191. assert self.total_num_heads % tp_size == 0
  192. self.num_heads = self.total_num_heads // tp_size
  193. self.total_num_kv_heads = num_kv_heads
  194. if self.total_num_kv_heads >= tp_size:
  195. # Number of KV heads is greater than TP size, so we partition
  196. # the KV heads across multiple tensor parallel GPUs.
  197. assert self.total_num_kv_heads % tp_size == 0
  198. else:
  199. # Number of KV heads is less than TP size, so we replicate
  200. # the KV heads across multiple tensor parallel GPUs.
  201. assert tp_size % self.total_num_kv_heads == 0
  202. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  203. self.head_dim = hidden_size // self.total_num_heads
  204. self.q_size = self.num_heads * self.head_dim
  205. self.kv_size = self.num_kv_heads * self.head_dim
  206. self.scaling = self.head_dim**-0.5
  207. self.rope_theta = rope_theta
  208. self.max_position_embeddings = max_position_embeddings
  209. if linear_method is not None and not linear_method.quant_config.merge_weight(
  210. ):
  211. self.merge_weight = False
  212. self.q_proj = ColumnParallelLinear(hidden_size,
  213. self.q_size,
  214. bias=False,
  215. linear_method=linear_method)
  216. self.k_proj = ColumnParallelLinear(hidden_size,
  217. self.kv_size,
  218. bias=False,
  219. linear_method=linear_method)
  220. self.v_proj = ColumnParallelLinear(hidden_size,
  221. self.kv_size,
  222. bias=False,
  223. linear_method=linear_method)
  224. else:
  225. self.merge_weight = True
  226. self.qkv_proj = QKVParallelLinear(
  227. hidden_size,
  228. self.head_dim,
  229. self.total_num_heads,
  230. self.total_num_kv_heads,
  231. bias=False,
  232. linear_method=linear_method,
  233. )
  234. self.o_proj = RowParallelLinear(
  235. self.total_num_heads * self.head_dim,
  236. hidden_size,
  237. bias=False,
  238. linear_method=linear_method,
  239. )
  240. is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
  241. ) is None else linear_method.quant_config.rope_style()
  242. self.rotary_emb = get_rope(
  243. self.head_dim,
  244. rotary_dim=self.head_dim,
  245. max_position=max_position_embeddings,
  246. base=rope_theta,
  247. rope_scaling=rope_scaling,
  248. is_neox_style=is_neox_style,
  249. )
  250. self.attn = PagedAttention(self.num_heads,
  251. self.head_dim,
  252. self.scaling,
  253. num_kv_heads=self.num_kv_heads)
  254. def forward(
  255. self,
  256. positions: torch.Tensor,
  257. hidden_states: torch.Tensor,
  258. kv_cache: KVCache,
  259. input_metadata: InputMetadata,
  260. ) -> torch.Tensor:
  261. if self.merge_weight:
  262. qkv, _ = self.qkv_proj(hidden_states)
  263. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
  264. dim=-1)
  265. else:
  266. q, _ = self.q_proj(hidden_states)
  267. k, _ = self.k_proj(hidden_states)
  268. v, _ = self.v_proj(hidden_states)
  269. q, k = self.rotary_emb(positions, q, k)
  270. k_cache, v_cache = kv_cache
  271. attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
  272. output, _ = self.o_proj(attn_output)
  273. return output
  274. class DeepseekDecoderLayer(nn.Module):
  275. def __init__(
  276. self,
  277. config: PretrainedConfig,
  278. layer_idx: int,
  279. linear_method: Optional[LinearMethodBase] = None,
  280. ) -> None:
  281. super().__init__()
  282. self.hidden_size = config.hidden_size
  283. rope_theta = getattr(config, "rope_theta", 10000)
  284. rope_scaling = getattr(config, "rope_scaling", None)
  285. max_position_embeddings = getattr(config, "max_position_embeddings",
  286. 8192)
  287. self.self_attn = DeepseekAttention(
  288. hidden_size=self.hidden_size,
  289. num_heads=config.num_attention_heads,
  290. num_kv_heads=config.num_key_value_heads,
  291. rope_theta=rope_theta,
  292. rope_scaling=rope_scaling,
  293. max_position_embeddings=max_position_embeddings,
  294. linear_method=linear_method,
  295. )
  296. if (config.n_routed_experts is not None and \
  297. layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0):
  298. self.mlp = DeepseekMoE(config=config, linear_method=linear_method)
  299. else:
  300. self.mlp = DeepseekMLP(
  301. hidden_size=config.hidden_size,
  302. intermediate_size=config.intermediate_size,
  303. hidden_act=config.hidden_act,
  304. linear_method=linear_method,
  305. )
  306. self.input_layernorm = RMSNorm(config.hidden_size,
  307. eps=config.rms_norm_eps)
  308. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  309. eps=config.rms_norm_eps)
  310. def forward(
  311. self,
  312. positions: torch.Tensor,
  313. hidden_states: torch.Tensor,
  314. kv_cache: KVCache,
  315. input_metadata: InputMetadata,
  316. residual: Optional[torch.Tensor],
  317. ) -> torch.Tensor:
  318. # Self Attention
  319. if residual is None:
  320. residual = hidden_states
  321. hidden_states = self.input_layernorm(hidden_states)
  322. else:
  323. hidden_states, residual = self.input_layernorm(
  324. hidden_states, residual)
  325. hidden_states = self.self_attn(
  326. positions=positions,
  327. hidden_states=hidden_states,
  328. kv_cache=kv_cache,
  329. input_metadata=input_metadata,
  330. )
  331. # Fully Connected
  332. hidden_states, residual = self.post_attention_layernorm(
  333. hidden_states, residual)
  334. hidden_states = self.mlp(hidden_states)
  335. return hidden_states, residual
  336. class DeepseekModel(nn.Module):
  337. def __init__(
  338. self,
  339. config: PretrainedConfig,
  340. linear_method: Optional[LinearMethodBase] = None,
  341. ) -> None:
  342. super().__init__()
  343. self.padding_idx = config.pad_token_id
  344. self.vocab_size = config.vocab_size
  345. self.embed_tokens = VocabParallelEmbedding(
  346. config.vocab_size,
  347. config.hidden_size,
  348. linear_method=linear_method,
  349. )
  350. self.layers = nn.ModuleList([
  351. DeepseekDecoderLayer(config,
  352. layer_idx,
  353. linear_method=linear_method)
  354. for layer_idx in range(config.num_hidden_layers)
  355. ])
  356. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  357. def forward(
  358. self,
  359. input_ids: torch.Tensor,
  360. positions: torch.Tensor,
  361. kv_caches: List[KVCache],
  362. input_metadata: InputMetadata,
  363. ) -> torch.Tensor:
  364. hidden_states = self.embed_tokens(input_ids)
  365. residual = None
  366. for i in range(len(self.layers)):
  367. layer = self.layers[i]
  368. hidden_states, residual = layer(positions, hidden_states,
  369. kv_caches[i], input_metadata,
  370. residual)
  371. hidden_states, _ = self.norm(hidden_states, residual)
  372. return hidden_states
  373. class DeepseekForCausalLM(nn.Module):
  374. def __init__(
  375. self,
  376. config: PretrainedConfig,
  377. linear_method: Optional[LinearMethodBase] = None,
  378. ) -> None:
  379. super().__init__()
  380. self.config = config
  381. self.linear_method = linear_method
  382. self.model = DeepseekModel(config, linear_method)
  383. self.lm_head = ParallelLMHead(config.vocab_size,
  384. config.hidden_size,
  385. linear_method=linear_method)
  386. self.sampler = Sampler(config.vocab_size)
  387. def forward(
  388. self,
  389. input_ids: torch.Tensor,
  390. positions: torch.Tensor,
  391. kv_caches: List[KVCache],
  392. input_metadata: InputMetadata,
  393. ) -> torch.Tensor:
  394. hidden_states = self.model(input_ids, positions, kv_caches,
  395. input_metadata)
  396. return hidden_states
  397. def sample(
  398. self,
  399. hidden_states: Optional[torch.Tensor],
  400. sampling_metadata: SamplingMetadata,
  401. ) -> Optional[SamplerOutput]:
  402. next_tokens = self.sampler(self.lm_head(hidden_states),
  403. sampling_metadata)
  404. return next_tokens
  405. def load_weights(self,
  406. model_name_or_path: str,
  407. cache_dir: Optional[str] = None,
  408. load_format: str = "auto",
  409. revision: Optional[str] = None):
  410. stacked_params_mapping = [
  411. # (param_name, shard_name, shard_id)
  412. ("qkv_proj", "q_proj", "q"),
  413. ("qkv_proj", "k_proj", "k"),
  414. ("qkv_proj", "v_proj", "v"),
  415. ("gate_up_proj", "gate_proj", 0),
  416. ("gate_up_proj", "up_proj", 1),
  417. ]
  418. if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
  419. ):
  420. stacked_params_mapping = []
  421. params_dict = dict(self.named_parameters())
  422. for name, loaded_weight in hf_model_weights_iterator(
  423. model_name_or_path,
  424. cache_dir,
  425. load_format,
  426. revision,
  427. fall_back_to_pt=False):
  428. if "rotary_emb.inv_freq" in name:
  429. continue
  430. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  431. if weight_name not in name:
  432. continue
  433. name = name.replace(weight_name, param_name)
  434. # Skip loading extra bias for GPTQ models.
  435. if name.endswith(".bias") and name not in params_dict:
  436. continue
  437. # Skip experts that are not assigned to this worker.
  438. if (("mlp.experts." in name or "mlp.shared_experts." in name)
  439. and name not in params_dict):
  440. continue
  441. param = params_dict[name]
  442. weight_loader = param.weight_loader
  443. weight_loader(param, loaded_weight, shard_id)
  444. break
  445. else:
  446. # Skip loading extra bias for GPTQ models.
  447. if name.endswith(".bias") and name not in params_dict:
  448. continue
  449. # Skip experts that are not assigned to this worker.
  450. if (("mlp.experts." in name or "mlp.shared_experts." in name)
  451. and name not in params_dict):
  452. continue
  453. param = params_dict[name]
  454. weight_loader = getattr(param, "weight_loader",
  455. default_weight_loader)
  456. weight_loader(param, loaded_weight)