deepseek_v3.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657
  1. # Adapted from
  2. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  3. # Copyright 2023 The PygmalionAI team.
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only DeepseekV3 model."""
  24. from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
  25. import torch
  26. from torch import nn
  27. from transformers import PretrainedConfig
  28. from aphrodite.attention import Attention, AttentionMetadata
  29. from aphrodite.common.config import CacheConfig
  30. from aphrodite.common.sequence import IntermediateTensors
  31. from aphrodite.distributed import (get_pp_group,
  32. get_tensor_model_parallel_world_size,
  33. tensor_model_parallel_all_reduce)
  34. from aphrodite.modeling.layers.activation import SiluAndMul
  35. from aphrodite.modeling.layers.fused_moe import FusedMoE
  36. from aphrodite.modeling.layers.layernorm import RMSNorm
  37. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  38. MergedColumnParallelLinear,
  39. ReplicatedLinear,
  40. RowParallelLinear)
  41. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  42. from aphrodite.modeling.layers.rotary_embedding import get_rope
  43. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  44. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  45. ParallelLMHead, VocabParallelEmbedding)
  46. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  47. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  48. from aphrodite.quantization import QuantizationConfig
  49. from .utils import (PPMissingLayer, is_pp_missing_parameter,
  50. make_empty_intermediate_tensors_factory, make_layers)
  51. class DeepseekV3MLP(nn.Module):
  52. def __init__(
  53. self,
  54. hidden_size: int,
  55. intermediate_size: int,
  56. hidden_act: str,
  57. quant_config: Optional[QuantizationConfig] = None,
  58. reduce_results: bool = True,
  59. prefix: str = "",
  60. ) -> None:
  61. super().__init__()
  62. self.gate_up_proj = MergedColumnParallelLinear(
  63. hidden_size, [intermediate_size] * 2,
  64. bias=False,
  65. quant_config=quant_config,
  66. prefix=f"{prefix}.gate_up_proj")
  67. self.down_proj = RowParallelLinear(intermediate_size,
  68. hidden_size,
  69. bias=False,
  70. quant_config=quant_config,
  71. reduce_results=reduce_results,
  72. prefix=f"{prefix}.down_proj")
  73. if hidden_act != "silu":
  74. raise ValueError(f"Unsupported activation: {hidden_act}. "
  75. "Only silu is supported for now.")
  76. self.act_fn = SiluAndMul()
  77. def forward(self, x):
  78. gate_up, _ = self.gate_up_proj(x)
  79. x = self.act_fn(gate_up)
  80. x, _ = self.down_proj(x)
  81. return x
  82. class DeepseekV3MoE(nn.Module):
  83. def __init__(
  84. self,
  85. config: PretrainedConfig,
  86. quant_config: Optional[QuantizationConfig] = None,
  87. prefix: str = "",
  88. ):
  89. super().__init__()
  90. self.tp_size = get_tensor_model_parallel_world_size()
  91. self.routed_scaling_factor = config.routed_scaling_factor
  92. self.n_shared_experts = config.n_shared_experts
  93. self.routed_scaling_factor = config.routed_scaling_factor
  94. if self.tp_size > config.n_routed_experts:
  95. raise ValueError(
  96. f"Tensor parallel size {self.tp_size} is greater than "
  97. f"the number of experts {config.n_routed_experts}.")
  98. if config.hidden_act != "silu":
  99. raise ValueError(f"Unsupported activation: {config.hidden_act}. "
  100. "Only silu is supported for now.")
  101. self.gate = ReplicatedLinear(config.hidden_size,
  102. config.n_routed_experts,
  103. bias=False,
  104. quant_config=None,
  105. prefix=f"{prefix}.gate")
  106. if config.topk_method == "noaux_tc":
  107. self.gate.e_score_correction_bias = nn.Parameter(
  108. torch.empty(config.n_routed_experts))
  109. else:
  110. self.gate.e_score_correction_bias = None
  111. self.experts = FusedMoE(
  112. num_experts=config.n_routed_experts,
  113. top_k=config.num_experts_per_tok,
  114. hidden_size=config.hidden_size,
  115. intermediate_size=config.moe_intermediate_size,
  116. reduce_results=False,
  117. renormalize=config.norm_topk_prob,
  118. quant_config=quant_config,
  119. use_grouped_topk=True,
  120. num_expert_group=config.n_group,
  121. topk_group=config.topk_group,
  122. prefix=f"{prefix}.experts",
  123. scoring_func=config.scoring_func,
  124. e_score_correction_bias=self.gate.e_score_correction_bias)
  125. if config.n_shared_experts is not None:
  126. intermediate_size = (config.moe_intermediate_size *
  127. config.n_shared_experts)
  128. self.shared_experts = DeepseekV3MLP(
  129. hidden_size=config.hidden_size,
  130. intermediate_size=intermediate_size,
  131. hidden_act=config.hidden_act,
  132. quant_config=quant_config,
  133. reduce_results=False,
  134. )
  135. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  136. num_tokens, hidden_dim = hidden_states.shape
  137. hidden_states = hidden_states.view(-1, hidden_dim)
  138. if self.n_shared_experts is not None:
  139. shared_output = self.shared_experts(hidden_states)
  140. # router_logits: (num_tokens, n_experts)
  141. router_logits, _ = self.gate(hidden_states)
  142. final_hidden_states = self.experts(
  143. hidden_states=hidden_states,
  144. router_logits=router_logits) * self.routed_scaling_factor
  145. if shared_output is not None:
  146. final_hidden_states = final_hidden_states + shared_output
  147. if self.tp_size > 1:
  148. final_hidden_states = tensor_model_parallel_all_reduce(
  149. final_hidden_states)
  150. return final_hidden_states.view(num_tokens, hidden_dim)
  151. def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
  152. import math
  153. if scale <= 1:
  154. return 1.0
  155. return 0.1 * mscale * math.log(scale) + 1.0
  156. class DeepseekV3Attention(nn.Module):
  157. def __init__(
  158. self,
  159. config: PretrainedConfig,
  160. hidden_size: int,
  161. num_heads: int,
  162. qk_nope_head_dim: int,
  163. qk_rope_head_dim: int,
  164. v_head_dim: int,
  165. q_lora_rank: int,
  166. kv_lora_rank: int,
  167. rope_theta: float = 10000,
  168. rope_scaling: Optional[Dict[str, Any]] = None,
  169. max_position_embeddings: int = 8192,
  170. cache_config: Optional[CacheConfig] = None,
  171. quant_config: Optional[QuantizationConfig] = None,
  172. prefix: str = "",
  173. ) -> None:
  174. super().__init__()
  175. self.hidden_size = hidden_size
  176. self.qk_nope_head_dim = qk_nope_head_dim
  177. self.qk_rope_head_dim = qk_rope_head_dim
  178. self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
  179. self.v_head_dim = v_head_dim
  180. self.q_lora_rank = q_lora_rank
  181. self.kv_lora_rank = kv_lora_rank
  182. self.num_heads = num_heads
  183. tp_size = get_tensor_model_parallel_world_size()
  184. assert num_heads % tp_size == 0
  185. self.num_local_heads = num_heads // tp_size
  186. self.scaling = self.qk_head_dim**-0.5
  187. self.rope_theta = rope_theta
  188. self.max_position_embeddings = max_position_embeddings
  189. if self.q_lora_rank is not None:
  190. self.q_a_proj = ReplicatedLinear(self.hidden_size,
  191. self.q_lora_rank,
  192. bias=False,
  193. quant_config=quant_config,
  194. prefix=f"{prefix}.q_a_proj")
  195. self.q_a_layernorm = RMSNorm(self.q_lora_rank,
  196. eps=config.rms_norm_eps)
  197. self.q_b_proj = ColumnParallelLinear(q_lora_rank,
  198. self.num_heads *
  199. self.qk_head_dim,
  200. bias=False,
  201. quant_config=quant_config,
  202. prefix=f"{prefix}.q_b_proj")
  203. else:
  204. self.q_proj = ColumnParallelLinear(self.hidden_size,
  205. self.num_heads *
  206. self.qk_head_dim,
  207. bias=False,
  208. quant_config=quant_config,
  209. prefix=f"{prefix}.q_proj")
  210. self.kv_a_proj_with_mqa = ReplicatedLinear(
  211. self.hidden_size,
  212. self.kv_lora_rank + self.qk_rope_head_dim,
  213. bias=False,
  214. quant_config=quant_config,
  215. prefix=f"{prefix}.kv_a_proj_with_mqa")
  216. self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
  217. eps=config.rms_norm_eps)
  218. self.kv_b_proj = ColumnParallelLinear(
  219. self.kv_lora_rank,
  220. self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
  221. bias=False,
  222. quant_config=quant_config,
  223. prefix=f"{prefix}.kv_b_proj")
  224. # O projection.
  225. self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
  226. self.hidden_size,
  227. bias=False,
  228. quant_config=quant_config,
  229. prefix=f"{prefix}.o_proj")
  230. rope_scaling["rope_type"] = 'deepseek_yarn'
  231. self.rotary_emb = get_rope(qk_rope_head_dim,
  232. rotary_dim=qk_rope_head_dim,
  233. max_position=max_position_embeddings,
  234. base=rope_theta,
  235. rope_scaling=rope_scaling,
  236. is_neox_style=False)
  237. if rope_scaling:
  238. mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
  239. scaling_factor = rope_scaling["factor"]
  240. mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
  241. self.scaling = self.scaling * mscale * mscale
  242. # self.attn = Attention(self.num_heads,
  243. # self.qk_head_dim,
  244. # self.scaling,
  245. # num_kv_heads=self.num_heads)
  246. # TODO, support head_size 192
  247. self.attn = Attention(self.num_local_heads,
  248. 256,
  249. self.scaling,
  250. num_kv_heads=self.num_local_heads,
  251. cache_config=cache_config,
  252. quant_config=quant_config,
  253. prefix=f"{prefix}.attn")
  254. def forward(
  255. self,
  256. positions: torch.Tensor,
  257. hidden_states: torch.Tensor,
  258. kv_cache: torch.Tensor,
  259. attn_metadata: AttentionMetadata,
  260. ) -> torch.Tensor:
  261. if self.q_lora_rank is not None:
  262. q = self.q_a_proj(hidden_states)[0]
  263. q = self.q_a_layernorm(q)
  264. q = self.q_b_proj(q)[0].view(-1, self.num_local_heads,
  265. self.qk_head_dim)
  266. else:
  267. q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads,
  268. self.qk_head_dim)
  269. q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
  270. dim=-1)
  271. latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
  272. kv_a, _ = latent_cache.split(
  273. [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
  274. latent_cache = latent_cache.unsqueeze(1)
  275. kv_a = self.kv_a_layernorm(kv_a.contiguous())
  276. kv = self.kv_b_proj(kv_a)[0]
  277. kv = kv.view(-1, self.num_local_heads,
  278. self.qk_nope_head_dim + self.v_head_dim)
  279. k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
  280. k_pe = latent_cache[:, :, self.kv_lora_rank:]
  281. q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
  282. q[..., self.qk_nope_head_dim:] = q_pe
  283. k = torch.empty_like(q)
  284. k[..., :self.qk_nope_head_dim] = k_nope
  285. k[..., self.qk_nope_head_dim:] = k_pe
  286. q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim],
  287. value=0).view(-1,
  288. self.num_local_heads * 256)
  289. k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim],
  290. value=0).view(-1,
  291. self.num_local_heads * 256)
  292. v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim],
  293. value=0).view(-1,
  294. self.num_local_heads * 256)
  295. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  296. attn_output = attn_output.view(
  297. -1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape(
  298. -1, self.num_local_heads * self.v_head_dim)
  299. output, _ = self.o_proj(attn_output)
  300. return output
  301. class DeepseekV3DecoderLayer(nn.Module):
  302. def __init__(
  303. self,
  304. config: PretrainedConfig,
  305. prefix: str,
  306. cache_config: Optional[CacheConfig] = None,
  307. quant_config: Optional[QuantizationConfig] = None,
  308. ) -> None:
  309. super().__init__()
  310. self.hidden_size = config.hidden_size
  311. rope_theta = getattr(config, "rope_theta", 10000)
  312. rope_scaling = getattr(config, "rope_scaling", None)
  313. max_position_embeddings = getattr(config, "max_position_embeddings",
  314. 8192)
  315. # DecoderLayers are created with `make_layers` which passes the prefix
  316. # with the layer's index.
  317. layer_idx = int(prefix.split(sep='.')[-1])
  318. self.self_attn = DeepseekV3Attention(
  319. config=config,
  320. hidden_size=self.hidden_size,
  321. num_heads=config.num_attention_heads,
  322. qk_nope_head_dim=config.qk_nope_head_dim,
  323. qk_rope_head_dim=config.qk_rope_head_dim,
  324. v_head_dim=config.v_head_dim,
  325. q_lora_rank=config.q_lora_rank
  326. if hasattr(config, "q_lora_rank") else None,
  327. kv_lora_rank=config.kv_lora_rank,
  328. rope_theta=rope_theta,
  329. rope_scaling=rope_scaling,
  330. max_position_embeddings=max_position_embeddings,
  331. cache_config=cache_config,
  332. quant_config=quant_config,
  333. prefix=f"{prefix}.self_attn",
  334. )
  335. if (config.n_routed_experts is not None
  336. and layer_idx >= config.first_k_dense_replace
  337. and layer_idx % config.moe_layer_freq == 0):
  338. self.mlp = DeepseekV3MoE(
  339. config=config,
  340. quant_config=quant_config,
  341. prefix=f"{prefix}.mlp",
  342. )
  343. else:
  344. self.mlp = DeepseekV3MLP(
  345. hidden_size=config.hidden_size,
  346. intermediate_size=config.intermediate_size,
  347. hidden_act=config.hidden_act,
  348. quant_config=quant_config,
  349. prefix=f"{prefix}.mlp",
  350. )
  351. self.input_layernorm = RMSNorm(config.hidden_size,
  352. eps=config.rms_norm_eps)
  353. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  354. eps=config.rms_norm_eps)
  355. def forward(
  356. self,
  357. positions: torch.Tensor,
  358. hidden_states: torch.Tensor,
  359. kv_cache: torch.Tensor,
  360. attn_metadata: AttentionMetadata,
  361. residual: Optional[torch.Tensor],
  362. ) -> torch.Tensor:
  363. # Self Attention
  364. if residual is None:
  365. residual = hidden_states
  366. hidden_states = self.input_layernorm(hidden_states)
  367. else:
  368. hidden_states, residual = self.input_layernorm(
  369. hidden_states, residual)
  370. hidden_states = self.self_attn(
  371. positions=positions,
  372. hidden_states=hidden_states,
  373. kv_cache=kv_cache,
  374. attn_metadata=attn_metadata,
  375. )
  376. # Fully Connected
  377. hidden_states, residual = self.post_attention_layernorm(
  378. hidden_states, residual)
  379. hidden_states = self.mlp(hidden_states)
  380. return hidden_states, residual
  381. # TODO(simon): check whether we support torch compile for Deepseek V3
  382. # @support_torch_compile
  383. class DeepseekV3Model(nn.Module):
  384. fall_back_to_pt_during_load = False
  385. def __init__(
  386. self, *,
  387. config: PretrainedConfig,
  388. cache_config: Optional[CacheConfig] = None,
  389. quant_config: Optional[QuantizationConfig] = None,
  390. prefix: str = ""):
  391. super().__init__()
  392. self.padding_idx = config.pad_token_id
  393. self.vocab_size = config.vocab_size
  394. if get_pp_group().is_first_rank:
  395. self.embed_tokens = VocabParallelEmbedding(
  396. config.vocab_size,
  397. config.hidden_size,
  398. )
  399. else:
  400. self.embed_tokens = PPMissingLayer()
  401. self.start_layer, self.end_layer, self.layers = make_layers(
  402. config.num_hidden_layers,
  403. lambda prefix: DeepseekV3DecoderLayer(
  404. config,
  405. prefix,
  406. cache_config=cache_config,
  407. quant_config=quant_config,
  408. ),
  409. prefix=f"{prefix}.layers")
  410. if get_pp_group().is_last_rank:
  411. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  412. else:
  413. self.norm = PPMissingLayer()
  414. self.make_empty_intermediate_tensors = (
  415. make_empty_intermediate_tensors_factory(
  416. ["hidden_states", "residual"], config.hidden_size))
  417. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  418. return self.embed_tokens(input_ids)
  419. def forward(
  420. self,
  421. input_ids: torch.Tensor,
  422. positions: torch.Tensor,
  423. kv_caches: List[torch.Tensor],
  424. attn_metadata: AttentionMetadata,
  425. intermediate_tensors: Optional[IntermediateTensors],
  426. inputs_embeds: Optional[torch.Tensor] = None,
  427. ) -> Union[torch.Tensor, IntermediateTensors]:
  428. if get_pp_group().is_first_rank:
  429. if inputs_embeds is not None:
  430. hidden_states = inputs_embeds
  431. else:
  432. hidden_states = self.get_input_embeddings(input_ids)
  433. residual = None
  434. else:
  435. assert intermediate_tensors is not None
  436. hidden_states = intermediate_tensors["hidden_states"]
  437. residual = intermediate_tensors["residual"]
  438. for i in range(self.start_layer, self.end_layer):
  439. layer = self.layers[i]
  440. hidden_states, residual = layer(positions, hidden_states,
  441. kv_caches[i - self.start_layer],
  442. attn_metadata, residual)
  443. if not get_pp_group().is_last_rank:
  444. return IntermediateTensors({
  445. "hidden_states": hidden_states,
  446. "residual": residual
  447. })
  448. hidden_states, _ = self.norm(hidden_states, residual)
  449. return hidden_states
  450. class DeepseekV3ForCausalLM(nn.Module):
  451. def __init__(
  452. self,
  453. *,
  454. config: PretrainedConfig,
  455. cache_config: Optional[CacheConfig] = None,
  456. quant_config: Optional[QuantizationConfig] = None,
  457. prefix: str = ""
  458. ):
  459. super().__init__()
  460. self.config = config
  461. self.quant_config = quant_config
  462. self.model = DeepseekV3Model(config=config,
  463. cache_config=cache_config,
  464. quant_config=quant_config,
  465. prefix="model")
  466. self.lm_head = ParallelLMHead(config.vocab_size,
  467. config.hidden_size,
  468. quant_config=quant_config)
  469. self.logits_processor = LogitsProcessor(config.vocab_size)
  470. self.sampler = Sampler()
  471. self.make_empty_intermediate_tensors = (
  472. self.model.make_empty_intermediate_tensors)
  473. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  474. return self.model.get_input_embeddings(input_ids)
  475. def forward(
  476. self,
  477. input_ids: torch.Tensor,
  478. positions: torch.Tensor,
  479. kv_caches: List[torch.Tensor],
  480. attn_metadata: AttentionMetadata,
  481. intermediate_tensors: Optional[IntermediateTensors] = None,
  482. inputs_embeds: Optional[torch.Tensor] = None,
  483. ) -> Union[torch.Tensor, IntermediateTensors]:
  484. hidden_states = self.model(input_ids, positions, kv_caches,
  485. attn_metadata, intermediate_tensors,
  486. inputs_embeds)
  487. return hidden_states
  488. def compute_logits(
  489. self,
  490. hidden_states: torch.Tensor,
  491. sampling_metadata: SamplingMetadata,
  492. ) -> Optional[torch.Tensor]:
  493. logits = self.logits_processor(self.lm_head, hidden_states,
  494. sampling_metadata)
  495. return logits
  496. def sample(
  497. self,
  498. logits: Optional[torch.Tensor],
  499. sampling_metadata: SamplingMetadata,
  500. ) -> Optional[SamplerOutput]:
  501. next_tokens = self.sampler(logits, sampling_metadata)
  502. return next_tokens
  503. def make_empty_intermediate_tensors(
  504. self, batch_size: int, dtype: torch.dtype,
  505. device: torch.device) -> IntermediateTensors:
  506. return IntermediateTensors({
  507. "hidden_states":
  508. torch.zeros((batch_size, self.config.hidden_size),
  509. dtype=dtype,
  510. device=device),
  511. "residual":
  512. torch.zeros((batch_size, self.config.hidden_size),
  513. dtype=dtype,
  514. device=device),
  515. })
  516. def load_weights(self, weights: Iterable[Tuple[str,
  517. torch.Tensor]]) -> Set[str]:
  518. stacked_params_mapping = [
  519. # (param_name, shard_name, shard_id)
  520. ("gate_up_proj", "gate_proj", 0),
  521. ("gate_up_proj", "up_proj", 1),
  522. ]
  523. # Params for weights, fp8 weight scales, fp8 activation scales
  524. # (param_name, weight_name, expert_id, shard_id)
  525. expert_params_mapping = FusedMoE.make_expert_params_mapping(
  526. ckpt_gate_proj_name="gate_proj",
  527. ckpt_down_proj_name="down_proj",
  528. ckpt_up_proj_name="up_proj",
  529. num_experts=self.config.n_routed_experts)
  530. params_dict = dict(self.named_parameters())
  531. loaded_params: Set[str] = set()
  532. for name, loaded_weight in weights:
  533. if "rotary_emb.inv_freq" in name:
  534. continue
  535. # TODO: support nextn predict layers
  536. if self.config.num_nextn_predict_layers > 0:
  537. assert self.config.num_nextn_predict_layers == 1
  538. layer_idx = self.config.num_hidden_layers
  539. if name.startswith(f"model.layers.{layer_idx}"):
  540. continue
  541. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  542. # Skip non-stacked layers and experts (experts handled below).
  543. if weight_name not in name:
  544. continue
  545. # We have mlp.experts[0].gate_proj in the checkpoint.
  546. # Since we handle the experts below in expert_params_mapping,
  547. # we need to skip here BEFORE we update the name, otherwise
  548. # name will be updated to mlp.experts[0].gate_up_proj, which
  549. # will then be updated below in expert_params_mapping
  550. # for mlp.experts[0].gate_gate_up_proj, which breaks load.
  551. if (("mlp.experts." in name) and name not in params_dict):
  552. continue
  553. name = name.replace(weight_name, param_name)
  554. # Skip loading extra bias for GPTQ models.
  555. if name.endswith(".bias") and name not in params_dict:
  556. continue
  557. if is_pp_missing_parameter(name, self):
  558. continue
  559. param = params_dict[name]
  560. weight_loader = param.weight_loader
  561. weight_loader(param, loaded_weight, shard_id)
  562. break
  563. else:
  564. for mapping in expert_params_mapping:
  565. param_name, weight_name, expert_id, shard_id = mapping
  566. if weight_name not in name:
  567. continue
  568. name = name.replace(weight_name, param_name)
  569. if is_pp_missing_parameter(name, self):
  570. continue
  571. param = params_dict[name]
  572. weight_loader = param.weight_loader
  573. weight_loader(param,
  574. loaded_weight,
  575. name,
  576. shard_id=shard_id,
  577. expert_id=expert_id)
  578. break
  579. else:
  580. # Skip loading extra bias for GPTQ models.
  581. if name.endswith(".bias") and name not in params_dict:
  582. continue
  583. if is_pp_missing_parameter(name, self):
  584. continue
  585. if name not in params_dict:
  586. for key in params_dict:
  587. print(key)
  588. param = params_dict[name]
  589. weight_loader = getattr(param, "weight_loader",
  590. default_weight_loader)
  591. weight_loader(param, loaded_weight)
  592. loaded_params.add(name)
  593. return loaded_params