mixtral.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only Mixtral model."""
  24. from typing import Iterable, List, Optional, Tuple
  25. import torch
  26. from torch import nn
  27. from transformers import MixtralConfig
  28. from aphrodite.attention import Attention, AttentionMetadata
  29. from aphrodite.common.config import CacheConfig, LoRAConfig
  30. from aphrodite.common.sequence import SamplerOutput
  31. from aphrodite.common.utils import print_warning_once
  32. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  33. get_tensor_model_parallel_world_size,
  34. tensor_model_parallel_all_reduce)
  35. from aphrodite.modeling.layers.fused_moe import fused_moe
  36. from aphrodite.modeling.layers.layernorm import RMSNorm
  37. from aphrodite.modeling.layers.linear import (QKVParallelLinear,
  38. ReplicatedLinear,
  39. RowParallelLinear)
  40. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  41. from aphrodite.modeling.layers.rotary_embedding import get_rope
  42. from aphrodite.modeling.layers.sampler import Sampler
  43. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  44. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  45. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  46. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  47. from aphrodite.modeling.utils import set_weight_attrs
  48. from aphrodite.quantization.base_config import QuantizationConfig
  49. from aphrodite.quantization.fp8 import Fp8Config, scaled_fp8_quant
  50. class MixtralMoE(nn.Module):
  51. """A tensor-parallel MoE implementation for Mixtral that shards each expert
  52. across all ranks.
  53. Each expert's weights are sharded across all ranks and a fused MoE
  54. kernel is used for the forward pass, and finally we reduce the outputs
  55. across ranks.
  56. """
  57. def __init__(
  58. self,
  59. num_experts: int,
  60. top_k: int,
  61. hidden_size: int,
  62. intermediate_size: int,
  63. params_dtype: Optional[torch.dtype] = None,
  64. tp_size: Optional[int] = None,
  65. quant_config: Optional[QuantizationConfig] = None,
  66. ):
  67. super().__init__()
  68. self.tp_size = tp_size or get_tensor_model_parallel_world_size()
  69. self.num_total_experts = num_experts
  70. self.top_k = top_k
  71. self.hidden_size = hidden_size
  72. self.intermediate_size = intermediate_size // self.tp_size
  73. self.quant_config = quant_config
  74. # FIXME(pcmoritz): Make this more general to support different
  75. # quantization schemes
  76. self.use_fp8 = isinstance(quant_config, Fp8Config)
  77. if params_dtype is None:
  78. params_dtype = torch.get_default_dtype()
  79. self.params_dtype = params_dtype
  80. # Gate always runs at half / full precision for now.
  81. self.gate = ReplicatedLinear(self.hidden_size,
  82. self.num_total_experts,
  83. bias=False,
  84. params_dtype=self.params_dtype,
  85. quant_config=None)
  86. if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
  87. params_dtype = torch.float8_e4m3fn
  88. self.w13_weight = nn.Parameter(
  89. torch.empty(self.num_total_experts,
  90. 2 * self.intermediate_size,
  91. self.hidden_size,
  92. dtype=params_dtype))
  93. self.w2_weight = nn.Parameter(
  94. torch.empty(self.num_total_experts,
  95. self.hidden_size,
  96. self.intermediate_size,
  97. dtype=params_dtype))
  98. set_weight_attrs(self.w13_weight, {
  99. "weight_loader": self.weight_loader,
  100. })
  101. set_weight_attrs(self.w2_weight, {
  102. "weight_loader": self.weight_loader,
  103. })
  104. # Used for fp8.
  105. self.w13_scale = None
  106. self.w2_scale = None
  107. self.a13_scale = None
  108. self.a2_scale = None
  109. if self.use_fp8:
  110. # WEIGHT_SCALE (for fp8)
  111. self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
  112. dtype=torch.float32),
  113. requires_grad=False)
  114. self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
  115. dtype=torch.float32),
  116. requires_grad=False)
  117. # If loading fp8 checkpoint, pass the weight loaders.
  118. # If loading an fp16 checkpoint, do not (we will quantize in
  119. # process_weights_after_loading()
  120. if quant_config.is_checkpoint_fp8_serialized:
  121. set_weight_attrs(self.w13_scale, {
  122. "weight_loader": self.weight_loader,
  123. })
  124. set_weight_attrs(self.w2_scale, {
  125. "weight_loader": self.weight_loader,
  126. })
  127. # ACT_SCALE (for fp8)
  128. if quant_config.activation_scheme == "static":
  129. if not quant_config.is_checkpoint_fp8_serialized:
  130. raise ValueError(
  131. "Found static activation scheme for checkpoint that "
  132. "was not serialized fp8.")
  133. self.a13_scale = nn.Parameter(torch.zeros(
  134. self.num_total_experts, dtype=torch.float32),
  135. requires_grad=False)
  136. self.a2_scale = nn.Parameter(torch.zeros(
  137. self.num_total_experts, dtype=torch.float32),
  138. requires_grad=False)
  139. set_weight_attrs(self.a13_scale, {
  140. "weight_loader": self.weight_loader,
  141. })
  142. set_weight_attrs(self.a2_scale, {
  143. "weight_loader": self.weight_loader,
  144. })
  145. def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
  146. weight_name: str, expert_id: int):
  147. tp_rank = get_tensor_model_parallel_rank()
  148. param_data = param.data
  149. shard_size = self.intermediate_size
  150. shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
  151. if weight_name.endswith("w1.weight"):
  152. param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
  153. if weight_name.endswith("w3.weight"):
  154. param_data[expert_id,
  155. shard_size:2 * shard_size, :] = loaded_weight[shard, :]
  156. if weight_name.endswith("w2.weight"):
  157. param_data[expert_id, :, :] = loaded_weight[:, shard]
  158. if "act_scale" in weight_name or "weight_scale" in weight_name:
  159. param_data[expert_id] = loaded_weight
  160. def process_weights_after_loading(self):
  161. # Fp8 is the only case where we need to process after loading.
  162. if not self.use_fp8:
  163. return
  164. # If checkpoint is fp16, quantize here.
  165. if not self.quant_config.is_checkpoint_fp8_serialized:
  166. w13_weight = torch.empty_like(self.w13_weight.data,
  167. dtype=torch.float8_e4m3fn)
  168. w2_weight = torch.empty_like(self.w2_weight.data,
  169. dtype=torch.float8_e4m3fn)
  170. for expert in range(self.num_total_experts):
  171. w13_weight[
  172. expert, :, :], self.w13_scale[expert] = scaled_fp8_quant(
  173. self.w13_weight.data[expert, :, :])
  174. w2_weight[
  175. expert, :, :], self.w2_scale[expert] = scaled_fp8_quant(
  176. self.w2_weight.data[expert, :, :])
  177. self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
  178. self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
  179. # If checkpoint is fp8 + static, cleanup act_scales.
  180. # Since state_dict has an act_scale per expert but our kernels
  181. # are passed one act_scale shared across all experts.
  182. elif self.quant_config.activation_scheme == "static":
  183. if self.a13_scale is None or self.a2_scale is None:
  184. raise ValueError(
  185. "QuantConfig has static quantization, but found "
  186. "activation scales are None.")
  187. if (not all_close_1d(self.a13_scale)
  188. or not all_close_1d(self.a2_scale)):
  189. print_warning_once(
  190. "Found act_scales that are not equal for fp8 MoE layer. "
  191. "Using the maximum across experts for each layer. ")
  192. self.a13_scale = nn.Parameter(self.a13_scale.max(),
  193. requires_grad=False)
  194. self.a2_scale = nn.Parameter(self.a2_scale.max(),
  195. requires_grad=False)
  196. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  197. num_tokens, hidden_size = hidden_states.shape
  198. hidden_states = hidden_states.view(-1, self.hidden_size)
  199. # router_logits: (num_tokens, n_experts)
  200. router_logits, _ = self.gate(hidden_states)
  201. final_hidden_states = fused_moe(hidden_states,
  202. self.w13_weight,
  203. self.w2_weight,
  204. router_logits,
  205. self.top_k,
  206. renormalize=True,
  207. inplace=True,
  208. use_fp8=self.use_fp8,
  209. w1_scale=self.w13_scale,
  210. w2_scale=self.w2_scale,
  211. a1_scale=self.a13_scale,
  212. a2_scale=self.a2_scale)
  213. if self.tp_size > 1:
  214. final_hidden_states = tensor_model_parallel_all_reduce(
  215. final_hidden_states)
  216. return final_hidden_states.view(num_tokens, hidden_size)
  217. class MixtralAttention(nn.Module):
  218. def __init__(self,
  219. hidden_size: int,
  220. num_heads: int,
  221. num_kv_heads: int,
  222. max_position: int = 4096 * 32,
  223. rope_theta: float = 10000,
  224. cache_config: Optional[CacheConfig] = None,
  225. quant_config: Optional[QuantizationConfig] = None) -> None:
  226. super().__init__()
  227. self.hidden_size = hidden_size
  228. tp_size = get_tensor_model_parallel_world_size()
  229. self.total_num_heads = num_heads
  230. assert self.total_num_heads % tp_size == 0
  231. self.num_heads = self.total_num_heads // tp_size
  232. self.total_num_kv_heads = num_kv_heads
  233. if self.total_num_kv_heads >= tp_size:
  234. # Number of KV heads is greater than TP size, so we partition
  235. # the KV heads across multiple tensor parallel GPUs.
  236. assert self.total_num_kv_heads % tp_size == 0
  237. else:
  238. # Number of KV heads is less than TP size, so we replicate
  239. # the KV heads across multiple tensor parallel GPUs.
  240. assert tp_size % self.total_num_kv_heads == 0
  241. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  242. self.head_dim = hidden_size // self.total_num_heads
  243. self.q_size = self.num_heads * self.head_dim
  244. self.kv_size = self.num_kv_heads * self.head_dim
  245. self.scaling = self.head_dim**-0.5
  246. self.rope_theta = rope_theta
  247. if isinstance(
  248. quant_config,
  249. Fp8Config) and not quant_config.is_checkpoint_fp8_serialized:
  250. print_warning_once(
  251. "For Mixtral FP8 quantization, we currently do not quantize "
  252. "the attention layers until their FP8 performance is improved."
  253. )
  254. quant_config = None
  255. self.qkv_proj = QKVParallelLinear(
  256. hidden_size,
  257. self.head_dim,
  258. self.total_num_heads,
  259. self.total_num_kv_heads,
  260. bias=False,
  261. quant_config=quant_config,
  262. )
  263. self.o_proj = RowParallelLinear(
  264. self.total_num_heads * self.head_dim,
  265. hidden_size,
  266. bias=False,
  267. quant_config=quant_config,
  268. )
  269. self.rotary_emb = get_rope(
  270. self.head_dim,
  271. rotary_dim=self.head_dim,
  272. max_position=max_position,
  273. base=int(self.rope_theta),
  274. is_neox_style=True,
  275. )
  276. self.attn = Attention(self.num_heads,
  277. self.head_dim,
  278. self.scaling,
  279. num_kv_heads=self.num_kv_heads,
  280. cache_config=cache_config,
  281. quant_config=quant_config)
  282. def forward(
  283. self,
  284. positions: torch.Tensor,
  285. hidden_states: torch.Tensor,
  286. kv_cache: torch.Tensor,
  287. attn_metadata: AttentionMetadata,
  288. ) -> torch.Tensor:
  289. qkv, _ = self.qkv_proj(hidden_states)
  290. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  291. q, k = self.rotary_emb(positions, q, k)
  292. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  293. output, _ = self.o_proj(attn_output)
  294. return output
  295. class MixtralDecoderLayer(nn.Module):
  296. def __init__(
  297. self,
  298. config: MixtralConfig,
  299. cache_config: Optional[CacheConfig] = None,
  300. quant_config: Optional[QuantizationConfig] = None,
  301. ) -> None:
  302. super().__init__()
  303. self.hidden_size = config.hidden_size
  304. # Requires transformers > 4.32.0
  305. rope_theta = getattr(config, "rope_theta", 10000)
  306. self.self_attn = MixtralAttention(
  307. hidden_size=self.hidden_size,
  308. num_heads=config.num_attention_heads,
  309. max_position=config.max_position_embeddings,
  310. num_kv_heads=config.num_key_value_heads,
  311. rope_theta=rope_theta,
  312. cache_config=cache_config,
  313. quant_config=quant_config)
  314. self.block_sparse_moe = MixtralMoE(
  315. num_experts=config.num_local_experts,
  316. top_k=config.num_experts_per_tok,
  317. hidden_size=config.hidden_size,
  318. intermediate_size=config.intermediate_size,
  319. quant_config=quant_config)
  320. self.input_layernorm = RMSNorm(config.hidden_size,
  321. eps=config.rms_norm_eps)
  322. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  323. eps=config.rms_norm_eps)
  324. def forward(
  325. self,
  326. positions: torch.Tensor,
  327. hidden_states: torch.Tensor,
  328. kv_cache: torch.Tensor,
  329. attn_metadata: AttentionMetadata,
  330. residual: Optional[torch.Tensor],
  331. ) -> torch.Tensor:
  332. # Self Attention
  333. if residual is None:
  334. residual = hidden_states
  335. hidden_states = self.input_layernorm(hidden_states)
  336. else:
  337. hidden_states, residual = self.input_layernorm(
  338. hidden_states, residual)
  339. hidden_states = self.self_attn(
  340. positions=positions,
  341. hidden_states=hidden_states,
  342. kv_cache=kv_cache,
  343. attn_metadata=attn_metadata,
  344. )
  345. # Fully Connected
  346. hidden_states, residual = self.post_attention_layernorm(
  347. hidden_states, residual)
  348. hidden_states = self.block_sparse_moe(hidden_states)
  349. return hidden_states, residual
  350. class MixtralModel(nn.Module):
  351. def __init__(
  352. self,
  353. config: MixtralConfig,
  354. cache_config: Optional[CacheConfig] = None,
  355. quant_config: Optional[QuantizationConfig] = None,
  356. lora_config: Optional[LoRAConfig] = None,
  357. ) -> None:
  358. super().__init__()
  359. self.padding_idx = config.pad_token_id
  360. lora_vocab = (lora_config.lora_extra_vocab_size *
  361. (lora_config.max_loras or 1)) if lora_config else 0
  362. self.vocab_size = config.vocab_size + lora_vocab
  363. self.org_vocab_size = config.vocab_size
  364. self.embed_tokens = VocabParallelEmbedding(
  365. self.vocab_size,
  366. config.hidden_size,
  367. org_num_embeddings=config.vocab_size,
  368. )
  369. self.layers = nn.ModuleList([
  370. MixtralDecoderLayer(config,
  371. cache_config,
  372. quant_config=quant_config)
  373. for _ in range(config.num_hidden_layers)
  374. ])
  375. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  376. def forward(
  377. self,
  378. input_ids: torch.Tensor,
  379. positions: torch.Tensor,
  380. kv_caches: List[torch.Tensor],
  381. attn_metadata: AttentionMetadata,
  382. ) -> torch.Tensor:
  383. hidden_states = self.embed_tokens(input_ids)
  384. residual = None
  385. for i in range(len(self.layers)):
  386. layer = self.layers[i]
  387. hidden_states, residual = layer(positions, hidden_states,
  388. kv_caches[i], attn_metadata,
  389. residual)
  390. hidden_states, _ = self.norm(hidden_states, residual)
  391. return hidden_states
  392. class MixtralForCausalLM(nn.Module):
  393. fall_back_to_pt_during_load = False
  394. packed_modules_mapping = {
  395. "qkv_proj": [
  396. "q_proj",
  397. "k_proj",
  398. "v_proj",
  399. ],
  400. }
  401. # LoRA specific attributes
  402. supported_lora_modules = [
  403. "qkv_proj",
  404. "o_proj",
  405. "embed_tokens",
  406. "lm_head",
  407. ]
  408. embedding_modules = {
  409. "embed_tokens": "input_embeddings",
  410. "lm_head": "output_embeddings",
  411. }
  412. embedding_padding_modules = ["lm_head"]
  413. def __init__(
  414. self,
  415. config: MixtralConfig,
  416. cache_config: Optional[CacheConfig] = None,
  417. quant_config: Optional[QuantizationConfig] = None,
  418. lora_config: Optional[LoRAConfig] = None,
  419. ) -> None:
  420. super().__init__()
  421. self.config = config
  422. self.model = MixtralModel(config,
  423. cache_config,
  424. quant_config,
  425. lora_config=lora_config)
  426. self.unpadded_vocab_size = config.vocab_size
  427. if lora_config:
  428. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  429. self.lm_head = ParallelLMHead(
  430. self.unpadded_vocab_size,
  431. config.hidden_size,
  432. org_num_embeddings=config.vocab_size,
  433. padding_size=DEFAULT_VOCAB_PADDING_SIZE
  434. # We need bigger padding if using lora for kernel
  435. # compatibility
  436. if not lora_config else lora_config.lora_vocab_padding_size,
  437. )
  438. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  439. config.vocab_size)
  440. self.sampler = Sampler()
  441. def forward(
  442. self,
  443. input_ids: torch.Tensor,
  444. positions: torch.Tensor,
  445. kv_caches: List[torch.Tensor],
  446. attn_metadata: AttentionMetadata,
  447. ) -> torch.Tensor:
  448. hidden_states = self.model(input_ids, positions, kv_caches,
  449. attn_metadata)
  450. return hidden_states
  451. def compute_logits(self, hidden_states: torch.Tensor,
  452. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  453. logits = self.logits_processor(self.lm_head.weight, hidden_states,
  454. sampling_metadata)
  455. return logits
  456. def sample(
  457. self,
  458. logits: Optional[torch.Tensor],
  459. sampling_metadata: SamplingMetadata,
  460. ) -> Optional[SamplerOutput]:
  461. next_tokens = self.sampler(logits, sampling_metadata)
  462. return next_tokens
  463. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  464. stacked_params_mapping = [
  465. # (param_name, shard_name, shard_id)
  466. ("qkv_proj", "q_proj", "q"),
  467. ("qkv_proj", "k_proj", "k"),
  468. ("qkv_proj", "v_proj", "v"),
  469. ]
  470. expert_params_mapping = [
  471. # These are the weight scales for the experts
  472. # (param_name, weight_name, expert_id)
  473. ("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
  474. f"experts.{expert_id}.{weight_name}.weight_scale", expert_id)
  475. for expert_id in range(self.config.num_local_experts)
  476. for weight_name in ["w1", "w2", "w3"]
  477. ] + [
  478. # These are the weights for the experts
  479. # (param_name, weight_name, expert_id)
  480. ("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
  481. f"experts.{expert_id}.{weight_name}.weight", expert_id)
  482. for expert_id in range(self.config.num_local_experts)
  483. for weight_name in ["w1", "w2", "w3"]
  484. ] + [
  485. # These are the activation scales for the experts
  486. # (param_name, weight_name, expert_id)
  487. ("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
  488. f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
  489. for expert_id in range(self.config.num_local_experts)
  490. for weight_name in ["w1", "w2", "w3"]
  491. ]
  492. params_dict = dict(self.named_parameters())
  493. for name, loaded_weight in weights:
  494. if "rotary_emb.inv_freq" in name:
  495. continue
  496. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  497. if weight_name not in name:
  498. continue
  499. name = name.replace(weight_name, param_name)
  500. # Skip loading extra bias for GPTQ models.
  501. if name.endswith(".bias") and name not in params_dict:
  502. continue
  503. param = params_dict[name]
  504. weight_loader = param.weight_loader
  505. weight_loader(param, loaded_weight, shard_id)
  506. break
  507. else:
  508. for param_name, weight_name, expert_id in expert_params_mapping:
  509. if weight_name not in name:
  510. continue
  511. name = name.replace(weight_name, param_name)
  512. param = params_dict[name]
  513. weight_loader = param.weight_loader
  514. weight_loader(param,
  515. loaded_weight,
  516. weight_name,
  517. expert_id=expert_id)
  518. break
  519. else:
  520. # Skip loading extra bias for GPTQ models.
  521. if name.endswith(".bias") and name not in params_dict:
  522. continue
  523. # Remapping the name of FP8 kv-scale.
  524. if name.endswith("kv_scale"):
  525. remapped_kv_scale_name = name.replace(
  526. ".kv_scale", ".attn.kv_scale")
  527. if remapped_kv_scale_name not in params_dict:
  528. print_warning_once(
  529. "Found kv scale in the checkpoint "
  530. f"(e.g. {name}), but not found the expected "
  531. f"name in the model "
  532. f"(e.g. {remapped_kv_scale_name}). "
  533. "kv-scale is not loaded.")
  534. continue
  535. else:
  536. name = remapped_kv_scale_name
  537. param = params_dict[name]
  538. weight_loader = getattr(param, "weight_loader",
  539. default_weight_loader)
  540. weight_loader(param, loaded_weight)
  541. def all_close_1d(x: torch.Tensor) -> bool:
  542. assert len(x.shape) == 1
  543. return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))