1
0

deepseek_v2.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only DeepseekV2 model."""
  25. from typing import Any, Dict, Iterable, List, Optional, Tuple
  26. import torch
  27. from torch import nn
  28. from transformers import PretrainedConfig
  29. from aphrodite.attention import Attention, AttentionMetadata
  30. from aphrodite.common.config import CacheConfig
  31. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  32. from aphrodite.common.utils import progress_bar
  33. from aphrodite.distributed import (get_tensor_model_parallel_world_size,
  34. tensor_model_parallel_all_reduce)
  35. from aphrodite.modeling.layers.activation import SiluAndMul
  36. from aphrodite.modeling.layers.fused_moe import FusedMoE
  37. from aphrodite.modeling.layers.layernorm import RMSNorm
  38. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  39. MergedColumnParallelLinear,
  40. ReplicatedLinear,
  41. RowParallelLinear)
  42. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  43. from aphrodite.modeling.layers.rotary_embedding import get_rope
  44. from aphrodite.modeling.layers.sampler import Sampler
  45. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  46. ParallelLMHead, VocabParallelEmbedding)
  47. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  48. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  49. from aphrodite.quantization.base_config import QuantizationConfig
  50. class DeepseekV2MLP(nn.Module):
  51. def __init__(
  52. self,
  53. hidden_size: int,
  54. intermediate_size: int,
  55. hidden_act: str,
  56. quant_config: Optional[QuantizationConfig] = None,
  57. reduce_results: bool = True,
  58. ) -> None:
  59. super().__init__()
  60. self.gate_up_proj = MergedColumnParallelLinear(
  61. hidden_size, [intermediate_size] * 2,
  62. bias=False,
  63. quant_config=quant_config)
  64. self.down_proj = RowParallelLinear(intermediate_size,
  65. hidden_size,
  66. bias=False,
  67. quant_config=quant_config,
  68. reduce_results=reduce_results)
  69. if hidden_act != "silu":
  70. raise ValueError(f"Unsupported activation: {hidden_act}. "
  71. "Only silu is supported for now.")
  72. self.act_fn = SiluAndMul()
  73. def forward(self, x):
  74. gate_up, _ = self.gate_up_proj(x)
  75. x = self.act_fn(gate_up)
  76. x, _ = self.down_proj(x)
  77. return x
  78. class DeepseekV2MoE(nn.Module):
  79. def __init__(
  80. self,
  81. config: PretrainedConfig,
  82. quant_config: Optional[QuantizationConfig] = None,
  83. ):
  84. super().__init__()
  85. self.tp_size = get_tensor_model_parallel_world_size()
  86. self.routed_scaling_factor = config.routed_scaling_factor
  87. self.n_shared_experts = config.n_shared_experts
  88. self.routed_scaling_factor = config.routed_scaling_factor
  89. if self.tp_size > config.n_routed_experts:
  90. raise ValueError(
  91. f"Tensor parallel size {self.tp_size} is greater than "
  92. f"the number of experts {config.n_routed_experts}.")
  93. if config.hidden_act != "silu":
  94. raise ValueError(f"Unsupported activation: {config.hidden_act}. "
  95. "Only silu is supported for now.")
  96. self.experts = FusedMoE(num_experts=config.n_routed_experts,
  97. top_k=config.num_experts_per_tok,
  98. hidden_size=config.hidden_size,
  99. intermediate_size=config.moe_intermediate_size,
  100. reduce_results=False,
  101. renormalize=config.norm_topk_prob,
  102. quant_config=quant_config,
  103. use_grouped_topk=True,
  104. num_expert_group=config.n_group,
  105. topk_group=config.topk_group)
  106. self.gate = ReplicatedLinear(config.hidden_size,
  107. config.n_routed_experts,
  108. bias=False,
  109. quant_config=None)
  110. if config.n_shared_experts is not None:
  111. intermediate_size = (config.moe_intermediate_size *
  112. config.n_shared_experts)
  113. self.shared_experts = DeepseekV2MLP(
  114. hidden_size=config.hidden_size,
  115. intermediate_size=intermediate_size,
  116. hidden_act=config.hidden_act,
  117. quant_config=quant_config,
  118. reduce_results=False,
  119. )
  120. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  121. num_tokens, hidden_dim = hidden_states.shape
  122. hidden_states = hidden_states.view(-1, hidden_dim)
  123. if self.n_shared_experts is not None:
  124. shared_output = self.shared_experts(hidden_states)
  125. # router_logits: (num_tokens, n_experts)
  126. router_logits, _ = self.gate(hidden_states)
  127. final_hidden_states = self.experts(
  128. hidden_states=hidden_states,
  129. router_logits=router_logits) * self.routed_scaling_factor
  130. if shared_output is not None:
  131. final_hidden_states = final_hidden_states + shared_output
  132. if self.tp_size > 1:
  133. final_hidden_states = tensor_model_parallel_all_reduce(
  134. final_hidden_states)
  135. return final_hidden_states.view(num_tokens, hidden_dim)
  136. def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
  137. import math
  138. if scale <= 1:
  139. return 1.0
  140. return 0.1 * mscale * math.log(scale) + 1.0
  141. class DeepseekV2Attention(nn.Module):
  142. def __init__(
  143. self,
  144. config: PretrainedConfig,
  145. hidden_size: int,
  146. num_heads: int,
  147. qk_nope_head_dim: int,
  148. qk_rope_head_dim: int,
  149. v_head_dim: int,
  150. q_lora_rank: int,
  151. kv_lora_rank: int,
  152. rope_theta: float = 10000,
  153. rope_scaling: Optional[Dict[str, Any]] = None,
  154. max_position_embeddings: int = 8192,
  155. cache_config: Optional[CacheConfig] = None,
  156. quant_config: Optional[QuantizationConfig] = None,
  157. layer_idx=None,
  158. ) -> None:
  159. super().__init__()
  160. self.layer_idx = layer_idx
  161. self.hidden_size = hidden_size
  162. self.qk_nope_head_dim = qk_nope_head_dim
  163. self.qk_rope_head_dim = qk_rope_head_dim
  164. self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
  165. self.v_head_dim = v_head_dim
  166. self.q_lora_rank = q_lora_rank
  167. self.kv_lora_rank = kv_lora_rank
  168. self.num_heads = num_heads
  169. tp_size = get_tensor_model_parallel_world_size()
  170. assert num_heads % tp_size == 0
  171. self.num_local_heads = num_heads // tp_size
  172. self.scaling = self.qk_head_dim**-0.5
  173. self.rope_theta = rope_theta
  174. self.max_position_embeddings = max_position_embeddings
  175. if self.q_lora_rank is not None:
  176. self.q_a_proj = ReplicatedLinear(self.hidden_size,
  177. self.q_lora_rank,
  178. bias=False,
  179. quant_config=quant_config)
  180. self.q_a_layernorm = RMSNorm(self.q_lora_rank,
  181. eps=config.rms_norm_eps)
  182. self.q_b_proj = ColumnParallelLinear(q_lora_rank,
  183. self.num_heads *
  184. self.qk_head_dim,
  185. bias=False,
  186. quant_config=quant_config)
  187. else:
  188. self.q_proj = ColumnParallelLinear(self.hidden_size,
  189. self.num_heads *
  190. self.qk_head_dim,
  191. bias=False,
  192. quant_config=quant_config)
  193. self.kv_a_proj_with_mqa = ReplicatedLinear(self.hidden_size,
  194. self.kv_lora_rank +
  195. self.qk_rope_head_dim,
  196. bias=False,
  197. quant_config=quant_config)
  198. self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
  199. eps=config.rms_norm_eps)
  200. self.kv_b_proj = ColumnParallelLinear(
  201. self.kv_lora_rank,
  202. self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
  203. bias=False,
  204. quant_config=quant_config)
  205. # O projection.
  206. self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
  207. self.hidden_size,
  208. bias=False,
  209. quant_config=quant_config)
  210. rope_scaling['type'] = 'deepseek_yarn'
  211. self.rotary_emb = get_rope(qk_rope_head_dim,
  212. rotary_dim=qk_rope_head_dim,
  213. max_position=max_position_embeddings,
  214. base=rope_theta,
  215. rope_scaling=rope_scaling,
  216. is_neox_style=False)
  217. if rope_scaling:
  218. mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
  219. scaling_factor = rope_scaling["factor"]
  220. mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
  221. self.scaling = self.scaling * mscale * mscale
  222. # self.attn = Attention(self.num_heads,
  223. # self.qk_head_dim,
  224. # self.scaling,
  225. # num_kv_heads=self.num_heads)
  226. # TODO, support head_size 192
  227. self.attn = Attention(self.num_local_heads,
  228. 256,
  229. self.scaling,
  230. num_kv_heads=self.num_local_heads,
  231. cache_config=cache_config,
  232. quant_config=quant_config)
  233. def forward(
  234. self,
  235. positions: torch.Tensor,
  236. hidden_states: torch.Tensor,
  237. kv_cache: torch.Tensor,
  238. attn_metadata: AttentionMetadata,
  239. ) -> torch.Tensor:
  240. if self.q_lora_rank is not None:
  241. q = self.q_a_proj(hidden_states)[0]
  242. q = self.q_a_layernorm(q)
  243. q = self.q_b_proj(q)[0].view(-1, self.num_local_heads,
  244. self.qk_head_dim)
  245. else:
  246. q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads,
  247. self.qk_head_dim)
  248. q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
  249. dim=-1)
  250. latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
  251. kv_a, _ = latent_cache.split(
  252. [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
  253. latent_cache = latent_cache.unsqueeze(1)
  254. kv_a = self.kv_a_layernorm(kv_a.contiguous())
  255. kv = self.kv_b_proj(kv_a)[0]
  256. kv = kv.view(-1, self.num_local_heads,
  257. self.qk_nope_head_dim + self.v_head_dim)
  258. k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
  259. k_pe = latent_cache[:, :, self.kv_lora_rank:]
  260. q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
  261. q[..., self.qk_nope_head_dim:] = q_pe
  262. k = torch.empty_like(q)
  263. k[..., :self.qk_nope_head_dim] = k_nope
  264. k[..., self.qk_nope_head_dim:] = k_pe
  265. q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim],
  266. value=0).view(-1,
  267. self.num_local_heads * 256)
  268. k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim],
  269. value=0).view(-1,
  270. self.num_local_heads * 256)
  271. v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim],
  272. value=0).view(-1,
  273. self.num_local_heads * 256)
  274. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  275. attn_output = attn_output.view(
  276. -1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape(
  277. -1, self.num_local_heads * self.v_head_dim)
  278. output, _ = self.o_proj(attn_output)
  279. return output
  280. class DeepseekV2DecoderLayer(nn.Module):
  281. def __init__(
  282. self,
  283. config: PretrainedConfig,
  284. layer_idx: int,
  285. cache_config: Optional[CacheConfig] = None,
  286. quant_config: Optional[QuantizationConfig] = None,
  287. ) -> None:
  288. super().__init__()
  289. self.hidden_size = config.hidden_size
  290. rope_theta = getattr(config, "rope_theta", 10000)
  291. rope_scaling = getattr(config, "rope_scaling", None)
  292. max_position_embeddings = getattr(config, "max_position_embeddings",
  293. 8192)
  294. self.self_attn = DeepseekV2Attention(
  295. config=config,
  296. hidden_size=self.hidden_size,
  297. num_heads=config.num_attention_heads,
  298. qk_nope_head_dim=config.qk_nope_head_dim,
  299. qk_rope_head_dim=config.qk_rope_head_dim,
  300. v_head_dim=config.v_head_dim,
  301. q_lora_rank=config.q_lora_rank
  302. if hasattr(config, "q_lora_rank") else None,
  303. kv_lora_rank=config.kv_lora_rank,
  304. rope_theta=rope_theta,
  305. rope_scaling=rope_scaling,
  306. max_position_embeddings=max_position_embeddings,
  307. cache_config=cache_config,
  308. quant_config=quant_config,
  309. layer_idx=layer_idx,
  310. )
  311. if (config.n_routed_experts is not None
  312. and layer_idx >= config.first_k_dense_replace
  313. and layer_idx % config.moe_layer_freq == 0):
  314. self.mlp = DeepseekV2MoE(config=config, quant_config=quant_config)
  315. else:
  316. self.mlp = DeepseekV2MLP(
  317. hidden_size=config.hidden_size,
  318. intermediate_size=config.intermediate_size,
  319. hidden_act=config.hidden_act,
  320. quant_config=quant_config,
  321. )
  322. self.input_layernorm = RMSNorm(config.hidden_size,
  323. eps=config.rms_norm_eps)
  324. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  325. eps=config.rms_norm_eps)
  326. def forward(
  327. self,
  328. positions: torch.Tensor,
  329. hidden_states: torch.Tensor,
  330. kv_cache: torch.Tensor,
  331. attn_metadata: AttentionMetadata,
  332. residual: Optional[torch.Tensor],
  333. ) -> torch.Tensor:
  334. # Self Attention
  335. if residual is None:
  336. residual = hidden_states
  337. hidden_states = self.input_layernorm(hidden_states)
  338. else:
  339. hidden_states, residual = self.input_layernorm(
  340. hidden_states, residual)
  341. hidden_states = self.self_attn(
  342. positions=positions,
  343. hidden_states=hidden_states,
  344. kv_cache=kv_cache,
  345. attn_metadata=attn_metadata,
  346. )
  347. # Fully Connected
  348. hidden_states, residual = self.post_attention_layernorm(
  349. hidden_states, residual)
  350. hidden_states = self.mlp(hidden_states)
  351. return hidden_states, residual
  352. class DeepseekV2Model(nn.Module):
  353. fall_back_to_pt_during_load = False
  354. def __init__(
  355. self,
  356. config: PretrainedConfig,
  357. cache_config: Optional[CacheConfig] = None,
  358. quant_config: Optional[QuantizationConfig] = None,
  359. ) -> None:
  360. super().__init__()
  361. self.padding_idx = config.pad_token_id
  362. self.vocab_size = config.vocab_size
  363. self.embed_tokens = VocabParallelEmbedding(
  364. config.vocab_size,
  365. config.hidden_size,
  366. )
  367. self.layers = nn.ModuleList([
  368. DeepseekV2DecoderLayer(config,
  369. layer_idx,
  370. cache_config=cache_config,
  371. quant_config=quant_config)
  372. for layer_idx in range(config.num_hidden_layers)
  373. ])
  374. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  375. def forward(
  376. self,
  377. input_ids: torch.Tensor,
  378. positions: torch.Tensor,
  379. kv_caches: List[torch.Tensor],
  380. attn_metadata: AttentionMetadata,
  381. ) -> torch.Tensor:
  382. hidden_states = self.embed_tokens(input_ids)
  383. residual = None
  384. for i in range(len(self.layers)):
  385. layer = self.layers[i]
  386. hidden_states, residual = layer(positions, hidden_states,
  387. kv_caches[i], attn_metadata,
  388. residual)
  389. hidden_states, _ = self.norm(hidden_states, residual)
  390. return hidden_states
  391. class DeepseekV2ForCausalLM(nn.Module):
  392. def __init__(
  393. self,
  394. config: PretrainedConfig,
  395. cache_config: Optional[CacheConfig] = None,
  396. quant_config: Optional[QuantizationConfig] = None,
  397. ) -> None:
  398. super().__init__()
  399. self.config = config
  400. self.quant_config = quant_config
  401. self.model = DeepseekV2Model(config, cache_config, quant_config)
  402. self.lm_head = ParallelLMHead(config.vocab_size,
  403. config.hidden_size,
  404. quant_config=quant_config)
  405. self.logits_processor = LogitsProcessor(config.vocab_size)
  406. self.sampler = Sampler()
  407. def forward(
  408. self,
  409. input_ids: torch.Tensor,
  410. positions: torch.Tensor,
  411. kv_caches: List[torch.Tensor],
  412. attn_metadata: AttentionMetadata,
  413. intermediate_tensors: Optional[IntermediateTensors] = None,
  414. ) -> torch.Tensor:
  415. hidden_states = self.model(input_ids, positions, kv_caches,
  416. attn_metadata)
  417. return hidden_states
  418. def compute_logits(
  419. self,
  420. hidden_states: torch.Tensor,
  421. sampling_metadata: SamplingMetadata,
  422. ) -> Optional[torch.Tensor]:
  423. logits = self.logits_processor(self.lm_head, hidden_states,
  424. sampling_metadata)
  425. return logits
  426. def sample(
  427. self,
  428. logits: Optional[torch.Tensor],
  429. sampling_metadata: SamplingMetadata,
  430. ) -> Optional[SamplerOutput]:
  431. next_tokens = self.sampler(logits, sampling_metadata)
  432. return next_tokens
  433. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  434. stacked_params_mapping = [
  435. # (param_name, shard_name, shard_id)
  436. ("gate_up_proj", "gate_proj", 0),
  437. ("gate_up_proj", "up_proj", 1),
  438. ]
  439. # Params for weights, fp8 weight scales, fp8 activation scales
  440. # (param_name, weight_name, expert_id, shard_id)
  441. expert_params_mapping = FusedMoE.make_expert_params_mapping(
  442. ckpt_gate_proj_name="gate_proj",
  443. ckpt_down_proj_name="down_proj",
  444. ckpt_up_proj_name="up_proj",
  445. num_experts=self.config.n_routed_experts)
  446. params_dict = dict(self.named_parameters())
  447. weights_list = list(weights)
  448. for name, loaded_weight in progress_bar(weights_list,
  449. desc="Loading modules..."):
  450. if "rotary_emb.inv_freq" in name:
  451. continue
  452. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  453. # Skip non-stacked layers and experts (experts handled below).
  454. if weight_name not in name:
  455. continue
  456. # We have mlp.experts[0].gate_proj in the checkpoint.
  457. # Since we handle the experts below in expert_params_mapping,
  458. # we need to skip here BEFORE we update the name, otherwise
  459. # name will be updated to mlp.experts[0].gate_up_proj, which
  460. # will then be updated below in expert_params_mapping
  461. # for mlp.experts[0].gate_gate_up_proj, which breaks load.
  462. if (("mlp.experts." in name) and name not in params_dict):
  463. continue
  464. name = name.replace(weight_name, param_name)
  465. # Skip loading extra bias for GPTQ models.
  466. if name.endswith(".bias") and name not in params_dict:
  467. continue
  468. param = params_dict[name]
  469. weight_loader = param.weight_loader
  470. weight_loader(param, loaded_weight, shard_id)
  471. break
  472. else:
  473. for mapping in expert_params_mapping:
  474. param_name, weight_name, expert_id, shard_id = mapping
  475. if weight_name not in name:
  476. continue
  477. name = name.replace(weight_name, param_name)
  478. param = params_dict[name]
  479. weight_loader = param.weight_loader
  480. weight_loader(param,
  481. loaded_weight,
  482. name,
  483. shard_id=shard_id,
  484. expert_id=expert_id)
  485. break
  486. else:
  487. # Skip loading extra bias for GPTQ models.
  488. if name.endswith(".bias") and name not in params_dict:
  489. continue
  490. param = params_dict[name]
  491. weight_loader = getattr(param, "weight_loader",
  492. default_weight_loader)
  493. weight_loader(param, loaded_weight)