mixtral.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only Mixtral model."""
  24. from typing import Iterable, List, Optional, Tuple
  25. import torch
  26. from torch import nn
  27. from transformers import MixtralConfig
  28. from aphrodite import _custom_ops as ops
  29. from aphrodite.attention import Attention, AttentionMetadata
  30. from aphrodite.common.config import CacheConfig, LoRAConfig
  31. from aphrodite.common.sequence import SamplerOutput
  32. from aphrodite.common.utils import print_warning_once
  33. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  34. get_tensor_model_parallel_world_size,
  35. tensor_model_parallel_all_reduce)
  36. from aphrodite.modeling.layers.fused_moe import fused_moe
  37. from aphrodite.modeling.layers.layernorm import RMSNorm
  38. from aphrodite.modeling.layers.linear import (QKVParallelLinear,
  39. ReplicatedLinear,
  40. RowParallelLinear)
  41. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  42. from aphrodite.modeling.layers.rotary_embedding import get_rope
  43. from aphrodite.modeling.layers.sampler import Sampler
  44. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  45. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  46. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  47. from aphrodite.modeling.models.interfaces import SupportsLoRA
  48. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  49. from aphrodite.modeling.utils import set_weight_attrs
  50. from aphrodite.quantization.base_config import QuantizationConfig
  51. from aphrodite.quantization.fp8 import (Fp8Config, per_tensor_dequantize,
  52. per_tensor_quantize)
  53. class MixtralMoE(nn.Module):
  54. """A tensor-parallel MoE implementation for Mixtral that shards each expert
  55. across all ranks.
  56. Each expert's weights are sharded across all ranks and a fused MoE
  57. kernel is used for the forward pass, and finally we reduce the outputs
  58. across ranks.
  59. """
  60. def __init__(
  61. self,
  62. num_experts: int,
  63. top_k: int,
  64. hidden_size: int,
  65. intermediate_size: int,
  66. params_dtype: Optional[torch.dtype] = None,
  67. tp_size: Optional[int] = None,
  68. quant_config: Optional[QuantizationConfig] = None,
  69. ):
  70. super().__init__()
  71. self.tp_size = tp_size or get_tensor_model_parallel_world_size()
  72. self.num_total_experts = num_experts
  73. self.top_k = top_k
  74. self.hidden_size = hidden_size
  75. self.intermediate_size = intermediate_size // self.tp_size
  76. self.quant_config = quant_config
  77. # FIXME(pcmoritz): Make this more general to support different
  78. # quantization schemes
  79. self.use_fp8 = isinstance(quant_config, Fp8Config)
  80. if params_dtype is None:
  81. params_dtype = torch.get_default_dtype()
  82. self.params_dtype = params_dtype
  83. # Gate always runs at half / full precision for now.
  84. self.gate = ReplicatedLinear(self.hidden_size,
  85. self.num_total_experts,
  86. bias=False,
  87. params_dtype=self.params_dtype,
  88. quant_config=None)
  89. if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
  90. params_dtype = torch.float8_e4m3fn
  91. self.w13_weight = nn.Parameter(torch.empty(self.num_total_experts,
  92. 2 * self.intermediate_size,
  93. self.hidden_size,
  94. dtype=params_dtype),
  95. requires_grad=False)
  96. self.w2_weight = nn.Parameter(torch.empty(self.num_total_experts,
  97. self.hidden_size,
  98. self.intermediate_size,
  99. dtype=params_dtype),
  100. requires_grad=False)
  101. set_weight_attrs(self.w13_weight, {
  102. "weight_loader": self.weight_loader,
  103. })
  104. set_weight_attrs(self.w2_weight, {
  105. "weight_loader": self.weight_loader,
  106. })
  107. # Used for fp8.
  108. self.w13_scale = None
  109. self.w2_scale = None
  110. self.a13_scale = None
  111. self.a2_scale = None
  112. if self.use_fp8:
  113. # WEIGHT_SCALE (for fp8)
  114. # Allocate 2 scales for w1 and w3 respectively.
  115. # They will be combined to a single scale after weight loading.
  116. self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
  117. 2,
  118. dtype=torch.float32),
  119. requires_grad=False)
  120. self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
  121. dtype=torch.float32),
  122. requires_grad=False)
  123. # If loading fp8 checkpoint, pass the weight loaders.
  124. # If loading an fp16 checkpoint, do not (we will quantize in
  125. # process_weights_after_loading()
  126. if quant_config.is_checkpoint_fp8_serialized:
  127. set_weight_attrs(self.w13_scale, {
  128. "weight_loader": self.weight_loader,
  129. })
  130. set_weight_attrs(self.w2_scale, {
  131. "weight_loader": self.weight_loader,
  132. })
  133. # ACT_SCALE (for fp8)
  134. if quant_config.activation_scheme == "static":
  135. if not quant_config.is_checkpoint_fp8_serialized:
  136. raise ValueError(
  137. "Found static activation scheme for checkpoint that "
  138. "was not serialized fp8.")
  139. self.a13_scale = nn.Parameter(torch.ones(
  140. self.num_total_experts, dtype=torch.float32),
  141. requires_grad=False)
  142. self.a2_scale = nn.Parameter(torch.ones(self.num_total_experts,
  143. dtype=torch.float32),
  144. requires_grad=False)
  145. set_weight_attrs(self.a13_scale, {
  146. "weight_loader": self.weight_loader,
  147. })
  148. set_weight_attrs(self.a2_scale, {
  149. "weight_loader": self.weight_loader,
  150. })
  151. def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
  152. weight_name: str, expert_id: int):
  153. tp_rank = get_tensor_model_parallel_rank()
  154. param_data = param.data
  155. shard_size = self.intermediate_size
  156. shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
  157. if weight_name.endswith("w1.weight"):
  158. param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
  159. if weight_name.endswith("w3.weight"):
  160. param_data[expert_id,
  161. shard_size:2 * shard_size, :] = loaded_weight[shard, :]
  162. if weight_name.endswith("w2.weight"):
  163. param_data[expert_id, :, :] = loaded_weight[:, shard]
  164. # Loading scales
  165. if "input_scale" in weight_name or "w2.weight_scale" in weight_name:
  166. if param_data[expert_id] != 1 and (param_data[expert_id] -
  167. loaded_weight).abs() > 1e-5:
  168. raise ValueError(
  169. "act_scales of w1 and w3 of a layer "
  170. f"must be equal. But got {param_data[expert_id]} "
  171. f"vs. {loaded_weight}")
  172. param_data[expert_id] = loaded_weight
  173. elif "weight_scale" in weight_name:
  174. # We have to keep the weight scales of w1 and w3 because
  175. # we need to re-quantize w1/w3 weights after weight loading.
  176. assert "w1" in weight_name or "w3" in weight_name
  177. shard_id = 0 if "w1" in weight_name else 1
  178. param_data[expert_id][shard_id] = loaded_weight
  179. def process_weights_after_loading(self):
  180. # Fp8 is the only case where we need to process after loading.
  181. if not self.use_fp8:
  182. return
  183. # If checkpoint is fp16, quantize here.
  184. if not self.quant_config.is_checkpoint_fp8_serialized:
  185. w13_weight = torch.empty_like(self.w13_weight.data,
  186. dtype=torch.float8_e4m3fn)
  187. w2_weight = torch.empty_like(self.w2_weight.data,
  188. dtype=torch.float8_e4m3fn)
  189. # Re-initialize w13_scale because we directly quantize
  190. # merged w13 weights and generate a single scaling factor.
  191. self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
  192. dtype=torch.float32),
  193. requires_grad=False)
  194. for expert in range(self.num_total_experts):
  195. w13_weight[expert, :, :], self.w13_scale[
  196. expert] = ops.scaled_fp8_quant(
  197. self.w13_weight.data[expert, :, :])
  198. w2_weight[expert, :, :], self.w2_scale[
  199. expert] = ops.scaled_fp8_quant(
  200. self.w2_weight.data[expert, :, :])
  201. self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
  202. self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
  203. else:
  204. # If checkpoint is fp8 + static, cleanup act_scales.
  205. # Since state_dict has an input_scale per expert but our kernels
  206. # are passed one input_scale shared across all experts.
  207. if self.quant_config.activation_scheme == "static":
  208. if self.a13_scale is None or self.a2_scale is None:
  209. raise ValueError(
  210. "QuantConfig has static quantization, but found "
  211. "activation scales are None.")
  212. if (not all_close_1d(self.a13_scale)
  213. or not all_close_1d(self.a2_scale)):
  214. print_warning_once(
  215. "Found act_scales that are not equal for "
  216. "fp8 MoE layer. Using the maximum across experts "
  217. "for each layer. ")
  218. self.a13_scale = nn.Parameter(self.a13_scale.max(),
  219. requires_grad=False)
  220. self.a2_scale = nn.Parameter(self.a2_scale.max(),
  221. requires_grad=False)
  222. assert self.w13_scale is not None
  223. shard_size = self.intermediate_size
  224. max_w13_scales = self.w13_scale.max(dim=1).values
  225. for expert_id in range(self.num_total_experts):
  226. start = 0
  227. for shard_id in range(2):
  228. dq_weight = per_tensor_dequantize(
  229. self.w13_weight[expert_id][start:start +
  230. shard_size, :],
  231. self.w13_scale[expert_id][shard_id])
  232. self.w13_weight[expert_id][
  233. start:start + shard_size, :] = per_tensor_quantize(
  234. dq_weight, max_w13_scales[expert_id])
  235. start += shard_size
  236. self.w13_scale = nn.Parameter(max_w13_scales, requires_grad=False)
  237. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  238. num_tokens, hidden_size = hidden_states.shape
  239. hidden_states = hidden_states.view(-1, self.hidden_size)
  240. # router_logits: (num_tokens, n_experts)
  241. router_logits, _ = self.gate(hidden_states)
  242. final_hidden_states = fused_moe(hidden_states,
  243. self.w13_weight,
  244. self.w2_weight,
  245. router_logits,
  246. self.top_k,
  247. renormalize=True,
  248. inplace=True,
  249. use_fp8=self.use_fp8,
  250. w1_scale=self.w13_scale,
  251. w2_scale=self.w2_scale,
  252. a1_scale=self.a13_scale,
  253. a2_scale=self.a2_scale)
  254. if self.tp_size > 1:
  255. final_hidden_states = tensor_model_parallel_all_reduce(
  256. final_hidden_states)
  257. return final_hidden_states.view(num_tokens, hidden_size)
  258. class MixtralAttention(nn.Module):
  259. def __init__(self,
  260. hidden_size: int,
  261. num_heads: int,
  262. num_kv_heads: int,
  263. max_position: int = 4096 * 32,
  264. rope_theta: float = 10000,
  265. cache_config: Optional[CacheConfig] = None,
  266. quant_config: Optional[QuantizationConfig] = None) -> None:
  267. super().__init__()
  268. self.hidden_size = hidden_size
  269. tp_size = get_tensor_model_parallel_world_size()
  270. self.total_num_heads = num_heads
  271. assert self.total_num_heads % tp_size == 0
  272. self.num_heads = self.total_num_heads // tp_size
  273. self.total_num_kv_heads = num_kv_heads
  274. if self.total_num_kv_heads >= tp_size:
  275. # Number of KV heads is greater than TP size, so we partition
  276. # the KV heads across multiple tensor parallel GPUs.
  277. assert self.total_num_kv_heads % tp_size == 0
  278. else:
  279. # Number of KV heads is less than TP size, so we replicate
  280. # the KV heads across multiple tensor parallel GPUs.
  281. assert tp_size % self.total_num_kv_heads == 0
  282. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  283. self.head_dim = hidden_size // self.total_num_heads
  284. self.q_size = self.num_heads * self.head_dim
  285. self.kv_size = self.num_kv_heads * self.head_dim
  286. self.scaling = self.head_dim**-0.5
  287. self.rope_theta = rope_theta
  288. self.qkv_proj = QKVParallelLinear(
  289. hidden_size,
  290. self.head_dim,
  291. self.total_num_heads,
  292. self.total_num_kv_heads,
  293. bias=False,
  294. quant_config=quant_config,
  295. )
  296. self.o_proj = RowParallelLinear(
  297. self.total_num_heads * self.head_dim,
  298. hidden_size,
  299. bias=False,
  300. quant_config=quant_config,
  301. )
  302. self.rotary_emb = get_rope(
  303. self.head_dim,
  304. rotary_dim=self.head_dim,
  305. max_position=max_position,
  306. base=int(self.rope_theta),
  307. is_neox_style=True,
  308. )
  309. self.attn = Attention(self.num_heads,
  310. self.head_dim,
  311. self.scaling,
  312. num_kv_heads=self.num_kv_heads,
  313. cache_config=cache_config,
  314. quant_config=quant_config)
  315. def forward(
  316. self,
  317. positions: torch.Tensor,
  318. hidden_states: torch.Tensor,
  319. kv_cache: torch.Tensor,
  320. attn_metadata: AttentionMetadata,
  321. ) -> torch.Tensor:
  322. qkv, _ = self.qkv_proj(hidden_states)
  323. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  324. q, k = self.rotary_emb(positions, q, k)
  325. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  326. output, _ = self.o_proj(attn_output)
  327. return output
  328. class MixtralDecoderLayer(nn.Module):
  329. def __init__(
  330. self,
  331. config: MixtralConfig,
  332. cache_config: Optional[CacheConfig] = None,
  333. quant_config: Optional[QuantizationConfig] = None,
  334. ) -> None:
  335. super().__init__()
  336. self.hidden_size = config.hidden_size
  337. # Requires transformers > 4.32.0
  338. rope_theta = getattr(config, "rope_theta", 10000)
  339. self.self_attn = MixtralAttention(
  340. hidden_size=self.hidden_size,
  341. num_heads=config.num_attention_heads,
  342. max_position=config.max_position_embeddings,
  343. num_kv_heads=config.num_key_value_heads,
  344. rope_theta=rope_theta,
  345. cache_config=cache_config,
  346. quant_config=quant_config)
  347. self.block_sparse_moe = MixtralMoE(
  348. num_experts=config.num_local_experts,
  349. top_k=config.num_experts_per_tok,
  350. hidden_size=config.hidden_size,
  351. intermediate_size=config.intermediate_size,
  352. quant_config=quant_config)
  353. self.input_layernorm = RMSNorm(config.hidden_size,
  354. eps=config.rms_norm_eps)
  355. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  356. eps=config.rms_norm_eps)
  357. def forward(
  358. self,
  359. positions: torch.Tensor,
  360. hidden_states: torch.Tensor,
  361. kv_cache: torch.Tensor,
  362. attn_metadata: AttentionMetadata,
  363. residual: Optional[torch.Tensor],
  364. ) -> torch.Tensor:
  365. # Self Attention
  366. if residual is None:
  367. residual = hidden_states
  368. hidden_states = self.input_layernorm(hidden_states)
  369. else:
  370. hidden_states, residual = self.input_layernorm(
  371. hidden_states, residual)
  372. hidden_states = self.self_attn(
  373. positions=positions,
  374. hidden_states=hidden_states,
  375. kv_cache=kv_cache,
  376. attn_metadata=attn_metadata,
  377. )
  378. # Fully Connected
  379. hidden_states, residual = self.post_attention_layernorm(
  380. hidden_states, residual)
  381. hidden_states = self.block_sparse_moe(hidden_states)
  382. return hidden_states, residual
  383. class MixtralModel(nn.Module):
  384. def __init__(
  385. self,
  386. config: MixtralConfig,
  387. cache_config: Optional[CacheConfig] = None,
  388. quant_config: Optional[QuantizationConfig] = None,
  389. lora_config: Optional[LoRAConfig] = None,
  390. ) -> None:
  391. super().__init__()
  392. self.padding_idx = config.pad_token_id
  393. lora_vocab = (lora_config.lora_extra_vocab_size *
  394. (lora_config.max_loras or 1)) if lora_config else 0
  395. self.vocab_size = config.vocab_size + lora_vocab
  396. self.org_vocab_size = config.vocab_size
  397. self.embed_tokens = VocabParallelEmbedding(
  398. self.vocab_size,
  399. config.hidden_size,
  400. org_num_embeddings=config.vocab_size,
  401. )
  402. self.layers = nn.ModuleList([
  403. MixtralDecoderLayer(config,
  404. cache_config,
  405. quant_config=quant_config)
  406. for _ in range(config.num_hidden_layers)
  407. ])
  408. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  409. def forward(
  410. self,
  411. input_ids: torch.Tensor,
  412. positions: torch.Tensor,
  413. kv_caches: List[torch.Tensor],
  414. attn_metadata: AttentionMetadata,
  415. ) -> torch.Tensor:
  416. hidden_states = self.embed_tokens(input_ids)
  417. residual = None
  418. for i in range(len(self.layers)):
  419. layer = self.layers[i]
  420. hidden_states, residual = layer(positions, hidden_states,
  421. kv_caches[i], attn_metadata,
  422. residual)
  423. hidden_states, _ = self.norm(hidden_states, residual)
  424. return hidden_states
  425. class MixtralForCausalLM(nn.Module, SupportsLoRA):
  426. fall_back_to_pt_during_load = False
  427. packed_modules_mapping = {
  428. "qkv_proj": [
  429. "q_proj",
  430. "k_proj",
  431. "v_proj",
  432. ],
  433. }
  434. # LoRA specific attributes
  435. supported_lora_modules = [
  436. "qkv_proj",
  437. "o_proj",
  438. "embed_tokens",
  439. "lm_head",
  440. ]
  441. embedding_modules = {
  442. "embed_tokens": "input_embeddings",
  443. "lm_head": "output_embeddings",
  444. }
  445. embedding_padding_modules = ["lm_head"]
  446. def __init__(
  447. self,
  448. config: MixtralConfig,
  449. cache_config: Optional[CacheConfig] = None,
  450. quant_config: Optional[QuantizationConfig] = None,
  451. lora_config: Optional[LoRAConfig] = None,
  452. ) -> None:
  453. super().__init__()
  454. self.config = config
  455. self.lora_config = lora_config
  456. self.model = MixtralModel(config,
  457. cache_config,
  458. quant_config,
  459. lora_config=lora_config)
  460. self.unpadded_vocab_size = config.vocab_size
  461. if lora_config:
  462. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  463. self.lm_head = ParallelLMHead(
  464. self.unpadded_vocab_size,
  465. config.hidden_size,
  466. org_num_embeddings=config.vocab_size,
  467. padding_size=DEFAULT_VOCAB_PADDING_SIZE
  468. # We need bigger padding if using lora for kernel
  469. # compatibility
  470. if not lora_config else lora_config.lora_vocab_padding_size,
  471. )
  472. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  473. config.vocab_size)
  474. self.sampler = Sampler()
  475. def forward(
  476. self,
  477. input_ids: torch.Tensor,
  478. positions: torch.Tensor,
  479. kv_caches: List[torch.Tensor],
  480. attn_metadata: AttentionMetadata,
  481. ) -> torch.Tensor:
  482. hidden_states = self.model(input_ids, positions, kv_caches,
  483. attn_metadata)
  484. return hidden_states
  485. def compute_logits(self, hidden_states: torch.Tensor,
  486. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  487. logits = self.logits_processor(self.lm_head.weight, hidden_states,
  488. sampling_metadata)
  489. return logits
  490. def sample(
  491. self,
  492. logits: Optional[torch.Tensor],
  493. sampling_metadata: SamplingMetadata,
  494. ) -> Optional[SamplerOutput]:
  495. next_tokens = self.sampler(logits, sampling_metadata)
  496. return next_tokens
  497. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  498. stacked_params_mapping = [
  499. # (param_name, shard_name, shard_id)
  500. ("qkv_proj", "q_proj", "q"),
  501. ("qkv_proj", "k_proj", "k"),
  502. ("qkv_proj", "v_proj", "v"),
  503. ]
  504. expert_params_mapping = [
  505. # These are the weight scales for the experts
  506. # (param_name, weight_name, expert_id)
  507. ("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
  508. f"experts.{expert_id}.{weight_name}.weight_scale", expert_id)
  509. for expert_id in range(self.config.num_local_experts)
  510. for weight_name in ["w1", "w2", "w3"]
  511. ] + [
  512. # These are the weights for the experts
  513. # (param_name, weight_name, expert_id)
  514. ("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
  515. f"experts.{expert_id}.{weight_name}.weight", expert_id)
  516. for expert_id in range(self.config.num_local_experts)
  517. for weight_name in ["w1", "w2", "w3"]
  518. ] + [
  519. # These are the activation scales for the experts
  520. # (param_name, weight_name, expert_id)
  521. ("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
  522. f"experts.{expert_id}.{weight_name}.input_scale", expert_id)
  523. for expert_id in range(self.config.num_local_experts)
  524. for weight_name in ["w1", "w2", "w3"]
  525. ]
  526. params_dict = dict(self.named_parameters())
  527. for name, loaded_weight in weights:
  528. if "rotary_emb.inv_freq" in name:
  529. continue
  530. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  531. if weight_name not in name:
  532. continue
  533. name = name.replace(weight_name, param_name)
  534. # Skip loading extra bias for GPTQ models.
  535. if name.endswith(".bias") and name not in params_dict:
  536. continue
  537. param = params_dict[name]
  538. weight_loader = param.weight_loader
  539. weight_loader(param, loaded_weight, shard_id)
  540. break
  541. else:
  542. for param_name, weight_name, expert_id in expert_params_mapping:
  543. if weight_name not in name:
  544. continue
  545. name = name.replace(weight_name, param_name)
  546. param = params_dict[name]
  547. weight_loader = param.weight_loader
  548. weight_loader(param,
  549. loaded_weight,
  550. weight_name,
  551. expert_id=expert_id)
  552. break
  553. else:
  554. # Skip loading extra bias for GPTQ models.
  555. if name.endswith(".bias") and name not in params_dict:
  556. continue
  557. # Remapping the name of FP8 kv-scale.
  558. if name.endswith("kv_scale"):
  559. remapped_kv_scale_name = name.replace(
  560. ".kv_scale", ".attn.kv_scale")
  561. if remapped_kv_scale_name not in params_dict:
  562. print_warning_once(
  563. "Found kv scale in the checkpoint "
  564. f"(e.g. {name}), but not found the expected "
  565. f"name in the model "
  566. f"(e.g. {remapped_kv_scale_name}). "
  567. "kv-scale is not loaded.")
  568. continue
  569. else:
  570. name = remapped_kv_scale_name
  571. param = params_dict[name]
  572. weight_loader = getattr(param, "weight_loader",
  573. default_weight_loader)
  574. weight_loader(param, loaded_weight)
  575. def all_close_1d(x: torch.Tensor) -> bool:
  576. assert len(x.shape) == 1
  577. return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))