deepseek_v2.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only DeepseekV2 model."""
  25. from typing import Any, Dict, Iterable, List, Optional, Tuple
  26. import torch
  27. from torch import nn
  28. from transformers import PretrainedConfig
  29. from aphrodite.attention import Attention, AttentionMetadata
  30. from aphrodite.common.config import CacheConfig
  31. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  32. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  33. get_tensor_model_parallel_world_size,
  34. tensor_model_parallel_all_reduce)
  35. from aphrodite.modeling.layers.activation import SiluAndMul
  36. from aphrodite.modeling.layers.fused_moe import fused_experts, grouped_topk
  37. from aphrodite.modeling.layers.layernorm import RMSNorm
  38. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  39. MergedColumnParallelLinear,
  40. ReplicatedLinear,
  41. RowParallelLinear)
  42. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  43. from aphrodite.modeling.layers.rotary_embedding import get_rope
  44. from aphrodite.modeling.layers.sampler import Sampler
  45. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  46. ParallelLMHead, VocabParallelEmbedding)
  47. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  48. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  49. from aphrodite.quantization.base_config import QuantizationConfig
  50. class DeepseekV2MLP(nn.Module):
  51. def __init__(
  52. self,
  53. hidden_size: int,
  54. intermediate_size: int,
  55. hidden_act: str,
  56. quant_config: Optional[QuantizationConfig] = None,
  57. reduce_results: bool = True,
  58. ) -> None:
  59. super().__init__()
  60. self.gate_up_proj = MergedColumnParallelLinear(
  61. hidden_size, [intermediate_size] * 2,
  62. bias=False,
  63. quant_config=quant_config)
  64. self.down_proj = RowParallelLinear(intermediate_size,
  65. hidden_size,
  66. bias=False,
  67. quant_config=quant_config,
  68. reduce_results=reduce_results)
  69. if hidden_act != "silu":
  70. raise ValueError(f"Unsupported activation: {hidden_act}. "
  71. "Only silu is supported for now.")
  72. self.act_fn = SiluAndMul()
  73. def forward(self, x):
  74. gate_up, _ = self.gate_up_proj(x)
  75. x = self.act_fn(gate_up)
  76. x, _ = self.down_proj(x)
  77. return x
  78. class DeepseekV2MoE(nn.Module):
  79. def __init__(
  80. self,
  81. config: PretrainedConfig,
  82. quant_config: Optional[QuantizationConfig] = None,
  83. ):
  84. super().__init__()
  85. self.config = config
  86. self.rank = get_tensor_model_parallel_rank()
  87. self.tp_size = get_tensor_model_parallel_world_size()
  88. self.n_routed_experts = config.n_routed_experts
  89. self.top_k = config.num_experts_per_tok
  90. self.routed_scaling_factor = config.routed_scaling_factor
  91. if self.tp_size > self.n_routed_experts:
  92. raise ValueError(
  93. f"Tensor parallel size {self.tp_size} is greater than "
  94. f"the number of experts {self.n_routed_experts}.")
  95. self.experts = nn.ModuleList([
  96. DeepseekV2MLP(hidden_size=config.hidden_size,
  97. intermediate_size=config.moe_intermediate_size,
  98. hidden_act=config.hidden_act,
  99. quant_config=quant_config,
  100. reduce_results=False)
  101. for idx in range(self.n_routed_experts)
  102. ])
  103. self.pack_params()
  104. self.gate = ReplicatedLinear(config.hidden_size,
  105. self.n_routed_experts,
  106. bias=False,
  107. quant_config=None)
  108. if config.n_shared_experts is not None:
  109. intermediate_size = (config.moe_intermediate_size *
  110. config.n_shared_experts)
  111. self.shared_experts = DeepseekV2MLP(
  112. hidden_size=config.hidden_size,
  113. intermediate_size=intermediate_size,
  114. hidden_act=config.hidden_act,
  115. quant_config=quant_config,
  116. reduce_results=False,
  117. )
  118. def pack_params(self):
  119. w1 = []
  120. w2 = []
  121. for expert in self.experts:
  122. w1.append(expert.gate_up_proj.weight)
  123. w2.append(expert.down_proj.weight)
  124. self.w1 = torch._utils._flatten_dense_tensors(w1)
  125. w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
  126. for data, param in zip(w1s, w1):
  127. param.data = data
  128. self.w1 = self.w1.view(len(w1), *w1s[0].shape)
  129. self.w2 = torch._utils._flatten_dense_tensors(w2)
  130. w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
  131. for data, param in zip(w2s, w2):
  132. param.data = data
  133. self.w2 = self.w2.view(len(w2), *w2s[0].shape)
  134. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  135. num_tokens, hidden_dim = hidden_states.shape
  136. hidden_states = hidden_states.view(-1, hidden_dim)
  137. if self.config.n_shared_experts is not None:
  138. shared_output = self.shared_experts(hidden_states)
  139. # router_logits: (num_tokens, n_experts)
  140. router_logits, _ = self.gate(hidden_states)
  141. topk_weights, topk_ids = grouped_topk(
  142. hidden_states,
  143. router_logits,
  144. self.top_k,
  145. renormalize=self.config.norm_topk_prob,
  146. num_expert_group=self.config.n_group,
  147. topk_group=self.config.topk_group)
  148. final_hidden_states = fused_experts(
  149. hidden_states,
  150. self.w1,
  151. self.w2,
  152. topk_weights,
  153. topk_ids,
  154. inplace=True) * self.routed_scaling_factor
  155. if self.config.n_shared_experts is not None:
  156. final_hidden_states = final_hidden_states + shared_output
  157. final_hidden_states = tensor_model_parallel_all_reduce(
  158. final_hidden_states)
  159. return final_hidden_states.view(num_tokens, hidden_dim)
  160. def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
  161. import math
  162. if scale <= 1:
  163. return 1.0
  164. return 0.1 * mscale * math.log(scale) + 1.0
  165. class DeepseekV2Attention(nn.Module):
  166. def __init__(
  167. self,
  168. config: PretrainedConfig,
  169. hidden_size: int,
  170. num_heads: int,
  171. qk_nope_head_dim: int,
  172. qk_rope_head_dim: int,
  173. v_head_dim: int,
  174. q_lora_rank: int,
  175. kv_lora_rank: int,
  176. rope_theta: float = 10000,
  177. rope_scaling: Optional[Dict[str, Any]] = None,
  178. max_position_embeddings: int = 8192,
  179. cache_config: Optional[CacheConfig] = None,
  180. quant_config: Optional[QuantizationConfig] = None,
  181. layer_idx=None,
  182. ) -> None:
  183. super().__init__()
  184. self.layer_idx = layer_idx
  185. self.hidden_size = hidden_size
  186. self.qk_nope_head_dim = qk_nope_head_dim
  187. self.qk_rope_head_dim = qk_rope_head_dim
  188. self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
  189. self.v_head_dim = v_head_dim
  190. self.q_lora_rank = q_lora_rank
  191. self.kv_lora_rank = kv_lora_rank
  192. self.num_heads = num_heads
  193. tp_size = get_tensor_model_parallel_world_size()
  194. assert num_heads % tp_size == 0
  195. self.num_local_heads = num_heads // tp_size
  196. self.scaling = self.qk_head_dim**-0.5
  197. self.rope_theta = rope_theta
  198. self.max_position_embeddings = max_position_embeddings
  199. if self.q_lora_rank is not None:
  200. self.q_a_proj = ReplicatedLinear(self.hidden_size,
  201. self.q_lora_rank,
  202. bias=False,
  203. quant_config=quant_config)
  204. self.q_a_layernorm = RMSNorm(self.q_lora_rank,
  205. eps=config.rms_norm_eps)
  206. self.q_b_proj = ColumnParallelLinear(q_lora_rank,
  207. self.num_heads *
  208. self.qk_head_dim,
  209. bias=False,
  210. quant_config=quant_config)
  211. else:
  212. self.q_proj = ColumnParallelLinear(self.hidden_size,
  213. self.num_heads *
  214. self.qk_head_dim,
  215. bias=False,
  216. quant_config=quant_config)
  217. self.kv_a_proj_with_mqa = ReplicatedLinear(self.hidden_size,
  218. self.kv_lora_rank +
  219. self.qk_rope_head_dim,
  220. bias=False,
  221. quant_config=quant_config)
  222. self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
  223. eps=config.rms_norm_eps)
  224. self.kv_b_proj = ColumnParallelLinear(
  225. self.kv_lora_rank,
  226. self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
  227. bias=False,
  228. quant_config=quant_config)
  229. # O projection.
  230. self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
  231. self.hidden_size,
  232. bias=False,
  233. quant_config=quant_config)
  234. rope_scaling['type'] = 'deepseek_yarn'
  235. self.rotary_emb = get_rope(qk_rope_head_dim,
  236. rotary_dim=qk_rope_head_dim,
  237. max_position=max_position_embeddings,
  238. base=rope_theta,
  239. rope_scaling=rope_scaling,
  240. is_neox_style=False)
  241. if rope_scaling:
  242. mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
  243. scaling_factor = rope_scaling["factor"]
  244. mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
  245. self.scaling = self.scaling * mscale * mscale
  246. # self.attn = Attention(self.num_heads,
  247. # self.qk_head_dim,
  248. # self.scaling,
  249. # num_kv_heads=self.num_heads)
  250. # TODO, support head_size 192
  251. self.attn = Attention(self.num_local_heads,
  252. 256,
  253. self.scaling,
  254. num_kv_heads=self.num_local_heads,
  255. cache_config=cache_config,
  256. quant_config=quant_config)
  257. def forward(
  258. self,
  259. positions: torch.Tensor,
  260. hidden_states: torch.Tensor,
  261. kv_cache: torch.Tensor,
  262. attn_metadata: AttentionMetadata,
  263. ) -> torch.Tensor:
  264. if self.q_lora_rank is not None:
  265. q = self.q_a_proj(hidden_states)[0]
  266. q = self.q_a_layernorm(q)
  267. q = self.q_b_proj(q)[0].view(-1, self.num_local_heads,
  268. self.qk_head_dim)
  269. else:
  270. q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads,
  271. self.qk_head_dim)
  272. q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
  273. dim=-1)
  274. latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
  275. kv_a, _ = latent_cache.split(
  276. [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
  277. latent_cache = latent_cache.unsqueeze(1)
  278. kv_a = self.kv_a_layernorm(kv_a.contiguous())
  279. kv = self.kv_b_proj(kv_a)[0]
  280. kv = kv.view(-1, self.num_local_heads,
  281. self.qk_nope_head_dim + self.v_head_dim)
  282. k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
  283. k_pe = latent_cache[:, :, self.kv_lora_rank:]
  284. q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
  285. q[..., self.qk_nope_head_dim:] = q_pe
  286. k = torch.empty_like(q)
  287. k[..., :self.qk_nope_head_dim] = k_nope
  288. k[..., self.qk_nope_head_dim:] = k_pe
  289. q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim],
  290. value=0).view(-1,
  291. self.num_local_heads * 256)
  292. k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim],
  293. value=0).view(-1,
  294. self.num_local_heads * 256)
  295. v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim],
  296. value=0).view(-1,
  297. self.num_local_heads * 256)
  298. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  299. attn_output = attn_output.view(
  300. -1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape(
  301. -1, self.num_local_heads * self.v_head_dim)
  302. output, _ = self.o_proj(attn_output)
  303. return output
  304. class DeepseekV2DecoderLayer(nn.Module):
  305. def __init__(
  306. self,
  307. config: PretrainedConfig,
  308. layer_idx: int,
  309. cache_config: Optional[CacheConfig] = None,
  310. quant_config: Optional[QuantizationConfig] = None,
  311. ) -> None:
  312. super().__init__()
  313. self.hidden_size = config.hidden_size
  314. rope_theta = getattr(config, "rope_theta", 10000)
  315. rope_scaling = getattr(config, "rope_scaling", None)
  316. max_position_embeddings = getattr(config, "max_position_embeddings",
  317. 8192)
  318. self.self_attn = DeepseekV2Attention(
  319. config=config,
  320. hidden_size=self.hidden_size,
  321. num_heads=config.num_attention_heads,
  322. qk_nope_head_dim=config.qk_nope_head_dim,
  323. qk_rope_head_dim=config.qk_rope_head_dim,
  324. v_head_dim=config.v_head_dim,
  325. q_lora_rank=config.q_lora_rank
  326. if hasattr(config, "q_lora_rank") else None,
  327. kv_lora_rank=config.kv_lora_rank,
  328. rope_theta=rope_theta,
  329. rope_scaling=rope_scaling,
  330. max_position_embeddings=max_position_embeddings,
  331. cache_config=cache_config,
  332. quant_config=quant_config,
  333. layer_idx=layer_idx,
  334. )
  335. if (config.n_routed_experts is not None
  336. and layer_idx >= config.first_k_dense_replace
  337. and layer_idx % config.moe_layer_freq == 0):
  338. self.mlp = DeepseekV2MoE(config=config, quant_config=quant_config)
  339. else:
  340. self.mlp = DeepseekV2MLP(
  341. hidden_size=config.hidden_size,
  342. intermediate_size=config.intermediate_size,
  343. hidden_act=config.hidden_act,
  344. quant_config=quant_config,
  345. )
  346. self.input_layernorm = RMSNorm(config.hidden_size,
  347. eps=config.rms_norm_eps)
  348. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  349. eps=config.rms_norm_eps)
  350. def forward(
  351. self,
  352. positions: torch.Tensor,
  353. hidden_states: torch.Tensor,
  354. kv_cache: torch.Tensor,
  355. attn_metadata: AttentionMetadata,
  356. residual: Optional[torch.Tensor],
  357. ) -> torch.Tensor:
  358. # Self Attention
  359. if residual is None:
  360. residual = hidden_states
  361. hidden_states = self.input_layernorm(hidden_states)
  362. else:
  363. hidden_states, residual = self.input_layernorm(
  364. hidden_states, residual)
  365. hidden_states = self.self_attn(
  366. positions=positions,
  367. hidden_states=hidden_states,
  368. kv_cache=kv_cache,
  369. attn_metadata=attn_metadata,
  370. )
  371. # Fully Connected
  372. hidden_states, residual = self.post_attention_layernorm(
  373. hidden_states, residual)
  374. hidden_states = self.mlp(hidden_states)
  375. return hidden_states, residual
  376. class DeepseekV2Model(nn.Module):
  377. fall_back_to_pt_during_load = False
  378. def __init__(
  379. self,
  380. config: PretrainedConfig,
  381. cache_config: Optional[CacheConfig] = None,
  382. quant_config: Optional[QuantizationConfig] = None,
  383. ) -> None:
  384. super().__init__()
  385. self.padding_idx = config.pad_token_id
  386. self.vocab_size = config.vocab_size
  387. self.embed_tokens = VocabParallelEmbedding(
  388. config.vocab_size,
  389. config.hidden_size,
  390. )
  391. self.layers = nn.ModuleList([
  392. DeepseekV2DecoderLayer(config,
  393. layer_idx,
  394. cache_config=cache_config,
  395. quant_config=quant_config)
  396. for layer_idx in range(config.num_hidden_layers)
  397. ])
  398. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  399. def forward(
  400. self,
  401. input_ids: torch.Tensor,
  402. positions: torch.Tensor,
  403. kv_caches: List[torch.Tensor],
  404. attn_metadata: AttentionMetadata,
  405. ) -> torch.Tensor:
  406. hidden_states = self.embed_tokens(input_ids)
  407. residual = None
  408. for i in range(len(self.layers)):
  409. layer = self.layers[i]
  410. hidden_states, residual = layer(positions, hidden_states,
  411. kv_caches[i], attn_metadata,
  412. residual)
  413. hidden_states, _ = self.norm(hidden_states, residual)
  414. return hidden_states
  415. class DeepseekV2ForCausalLM(nn.Module):
  416. def __init__(
  417. self,
  418. config: PretrainedConfig,
  419. cache_config: Optional[CacheConfig] = None,
  420. quant_config: Optional[QuantizationConfig] = None,
  421. ) -> None:
  422. super().__init__()
  423. self.config = config
  424. self.quant_config = quant_config
  425. self.model = DeepseekV2Model(config, cache_config, quant_config)
  426. self.lm_head = ParallelLMHead(config.vocab_size,
  427. config.hidden_size,
  428. quant_config=quant_config)
  429. self.logits_processor = LogitsProcessor(config.vocab_size)
  430. self.sampler = Sampler()
  431. def forward(
  432. self,
  433. input_ids: torch.Tensor,
  434. positions: torch.Tensor,
  435. kv_caches: List[torch.Tensor],
  436. attn_metadata: AttentionMetadata,
  437. intermediate_tensors: Optional[IntermediateTensors] = None,
  438. ) -> torch.Tensor:
  439. hidden_states = self.model(input_ids, positions, kv_caches,
  440. attn_metadata)
  441. return hidden_states
  442. def compute_logits(self, hidden_states: torch.Tensor,
  443. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  444. logits = self.logits_processor(self.lm_head, hidden_states,
  445. sampling_metadata)
  446. return logits
  447. def sample(
  448. self,
  449. logits: Optional[torch.Tensor],
  450. sampling_metadata: SamplingMetadata,
  451. ) -> Optional[SamplerOutput]:
  452. next_tokens = self.sampler(logits, sampling_metadata)
  453. return next_tokens
  454. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  455. stacked_params_mapping = [
  456. # (param_name, shard_name, shard_id)
  457. ("gate_up_proj", "gate_proj", 0),
  458. ("gate_up_proj", "up_proj", 1),
  459. ]
  460. params_dict = dict(self.named_parameters())
  461. for name, loaded_weight in weights:
  462. if "rotary_emb.inv_freq" in name:
  463. continue
  464. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  465. if weight_name not in name:
  466. continue
  467. name = name.replace(weight_name, param_name)
  468. # Skip loading extra bias for GPTQ models.
  469. if name.endswith(".bias") and name not in params_dict:
  470. continue
  471. # Skip experts that are not assigned to this worker.
  472. if (("mlp.experts." in name or "mlp.shared_experts." in name)
  473. and name not in params_dict):
  474. continue
  475. param = params_dict[name]
  476. weight_loader = param.weight_loader
  477. weight_loader(param, loaded_weight, shard_id)
  478. break
  479. else:
  480. # Skip loading extra bias for GPTQ models.
  481. if name.endswith(".bias") and name not in params_dict:
  482. continue
  483. # Skip experts that are not assigned to this worker.
  484. if (("mlp.experts." in name or "mlp.shared_experts." in name)
  485. and name not in params_dict):
  486. continue
  487. param = params_dict[name]
  488. weight_loader = getattr(param, "weight_loader",
  489. default_weight_loader)
  490. weight_loader(param, loaded_weight)