llama.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only LLaMA model compatible with HuggingFace weights."""
  24. from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
  25. import torch
  26. from torch import nn
  27. from transformers import LlamaConfig
  28. from aphrodite.attention import Attention, AttentionMetadata
  29. from aphrodite.common.config import CacheConfig, LoRAConfig
  30. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  31. from aphrodite.common.utils import is_hip
  32. from aphrodite.distributed import (get_current_tp_rank_partition_size,
  33. get_pp_group,
  34. get_tensor_model_parallel_rank,
  35. get_tensor_model_parallel_world_size)
  36. from aphrodite.modeling.layers.activation import SiluAndMul
  37. from aphrodite.modeling.layers.layernorm import RMSNorm
  38. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  39. QKVParallelLinear,
  40. RowParallelLinear)
  41. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  42. from aphrodite.modeling.layers.rotary_embedding import get_rope
  43. from aphrodite.modeling.layers.sampler import Sampler
  44. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  45. ParallelLMHead, VocabParallelEmbedding)
  46. from aphrodite.modeling.model_loader.weight_utils import (
  47. default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
  48. from aphrodite.modeling.models.interfaces import SupportsLoRA
  49. from aphrodite.modeling.models.utils import (PPMissingLayer,
  50. is_pp_missing_parameter,
  51. make_layers)
  52. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  53. from aphrodite.quantization.base_config import QuantizationConfig
  54. from aphrodite.quantization.compressed_tensors.utils import (
  55. get_compressed_tensors_cache_scale)
  56. class LlamaMLP(nn.Module):
  57. def __init__(
  58. self,
  59. hidden_size: int,
  60. intermediate_size: int,
  61. hidden_act: str,
  62. quant_config: Optional[QuantizationConfig] = None,
  63. bias: bool = False,
  64. prefix: str = "",
  65. ) -> None:
  66. super().__init__()
  67. self.gate_up_proj = MergedColumnParallelLinear(
  68. input_size=hidden_size,
  69. output_sizes=[intermediate_size] * 2,
  70. bias=bias,
  71. quant_config=quant_config,
  72. prefix=f"{prefix}.gate_up_proj")
  73. self.down_proj = RowParallelLinear(input_size=intermediate_size,
  74. output_size=hidden_size,
  75. bias=bias,
  76. quant_config=quant_config,
  77. prefix=f"{prefix}.down_proj")
  78. if hidden_act != "silu":
  79. raise ValueError(f"Unsupported activation: {hidden_act}. "
  80. "Only silu is supported for now.")
  81. self.act_fn = SiluAndMul()
  82. def forward(self, x):
  83. gate_up, _ = self.gate_up_proj(x)
  84. x = self.act_fn(gate_up)
  85. x, _ = self.down_proj(x)
  86. return x
  87. class LlamaAttention(nn.Module):
  88. def __init__(
  89. self,
  90. config: LlamaConfig,
  91. hidden_size: int,
  92. num_heads: int,
  93. num_kv_heads: int,
  94. rope_theta: float = 10000,
  95. rope_scaling: Optional[Dict[str, Any]] = None,
  96. max_position_embeddings: int = 8192,
  97. quant_config: Optional[QuantizationConfig] = None,
  98. bias: bool = False,
  99. cache_config: Optional[CacheConfig] = None,
  100. prefix: str = "",
  101. ) -> None:
  102. super().__init__()
  103. self.hidden_size = hidden_size
  104. tp_size = get_tensor_model_parallel_world_size()
  105. tp_rank = get_tensor_model_parallel_rank()
  106. self.total_num_heads = num_heads
  107. self.total_num_kv_heads = num_kv_heads
  108. self.num_kv_heads = max(
  109. 1,
  110. get_current_tp_rank_partition_size(self.total_num_kv_heads,
  111. tp_rank, tp_size))
  112. num_heads_per_kv_head = self.total_num_heads // self.total_num_kv_heads
  113. self.num_heads = self.num_kv_heads * num_heads_per_kv_head
  114. # MistralConfig has an optional head_dim introduced by Mistral-Nemo
  115. self.head_dim = getattr(config, "head_dim",
  116. self.hidden_size // self.total_num_heads)
  117. self.q_size = self.num_heads * self.head_dim
  118. self.kv_size = self.num_kv_heads * self.head_dim
  119. self.scaling = self.head_dim**-0.5
  120. self.rope_theta = rope_theta
  121. self.max_position_embeddings = max_position_embeddings
  122. self.qkv_proj = QKVParallelLinear(
  123. hidden_size=hidden_size,
  124. head_size=self.head_dim,
  125. total_num_heads=self.total_num_heads,
  126. total_num_kv_heads=self.total_num_kv_heads,
  127. bias=bias,
  128. quant_config=quant_config,
  129. prefix=f"{prefix}.qkv_proj",
  130. )
  131. self.o_proj = RowParallelLinear(
  132. input_size=self.total_num_heads * self.head_dim,
  133. output_size=hidden_size,
  134. bias=bias,
  135. quant_config=quant_config,
  136. partition_multiple_of=num_heads_per_kv_head * self.head_dim,
  137. prefix=f"{prefix}.o_proj",
  138. )
  139. is_neox_style = True
  140. if quant_config is not None and quant_config.get_name() == "gguf":
  141. is_neox_style = False
  142. self.rotary_emb = get_rope(
  143. self.head_dim,
  144. rotary_dim=self.head_dim,
  145. max_position=max_position_embeddings,
  146. base=rope_theta,
  147. rope_scaling=rope_scaling,
  148. is_neox_style=is_neox_style,
  149. )
  150. self.attn = Attention(self.num_heads,
  151. self.head_dim,
  152. self.scaling,
  153. num_kv_heads=self.num_kv_heads,
  154. cache_config=cache_config,
  155. quant_config=quant_config)
  156. def forward(
  157. self,
  158. positions: torch.Tensor,
  159. hidden_states: torch.Tensor,
  160. kv_cache: torch.Tensor,
  161. attn_metadata: AttentionMetadata,
  162. ) -> torch.Tensor:
  163. qkv, _ = self.qkv_proj(hidden_states)
  164. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  165. q, k = self.rotary_emb(positions, q, k)
  166. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  167. output, _ = self.o_proj(attn_output)
  168. return output
  169. class LlamaDecoderLayer(nn.Module):
  170. def __init__(
  171. self,
  172. config: LlamaConfig,
  173. cache_config: Optional[CacheConfig] = None,
  174. quant_config: Optional[QuantizationConfig] = None,
  175. prefix: str = "",
  176. ) -> None:
  177. super().__init__()
  178. self.hidden_size = config.hidden_size
  179. rope_theta = getattr(config, "rope_theta", 10000)
  180. rope_scaling = getattr(config, "rope_scaling", None)
  181. if rope_scaling is not None and getattr(
  182. config, "original_max_position_embeddings", None):
  183. rope_scaling["original_max_position_embeddings"] = (
  184. config.original_max_position_embeddings)
  185. max_position_embeddings = getattr(config, "max_position_embeddings",
  186. 8192)
  187. # Support abacusai/Smaug-72B-v0.1 with attention_bias
  188. # Support internlm/internlm-7b with bias
  189. attention_bias = getattr(config, "attention_bias", False) or getattr(
  190. config, "bias", False)
  191. self.self_attn = LlamaAttention(
  192. config=config,
  193. hidden_size=self.hidden_size,
  194. num_heads=config.num_attention_heads,
  195. num_kv_heads=getattr(config, "num_key_value_heads",
  196. config.num_attention_heads),
  197. rope_theta=rope_theta,
  198. rope_scaling=rope_scaling,
  199. max_position_embeddings=max_position_embeddings,
  200. quant_config=quant_config,
  201. bias=attention_bias,
  202. cache_config=cache_config,
  203. prefix=f"{prefix}.self_attn",
  204. )
  205. self.mlp = LlamaMLP(
  206. hidden_size=self.hidden_size,
  207. intermediate_size=config.intermediate_size,
  208. hidden_act=config.hidden_act,
  209. quant_config=quant_config,
  210. bias=getattr(config, "mlp_bias", False),
  211. prefix=f"{prefix}.mlp",
  212. )
  213. self.input_layernorm = RMSNorm(config.hidden_size,
  214. eps=config.rms_norm_eps)
  215. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  216. eps=config.rms_norm_eps)
  217. def forward(
  218. self,
  219. positions: torch.Tensor,
  220. hidden_states: torch.Tensor,
  221. kv_cache: torch.Tensor,
  222. attn_metadata: AttentionMetadata,
  223. residual: Optional[torch.Tensor],
  224. ) -> Tuple[torch.Tensor, torch.Tensor]:
  225. # Self Attention
  226. if residual is None:
  227. residual = hidden_states
  228. hidden_states = self.input_layernorm(hidden_states)
  229. else:
  230. hidden_states, residual = self.input_layernorm(
  231. hidden_states, residual)
  232. hidden_states = self.self_attn(
  233. positions=positions,
  234. hidden_states=hidden_states,
  235. kv_cache=kv_cache,
  236. attn_metadata=attn_metadata,
  237. )
  238. # Fully Connected
  239. hidden_states, residual = self.post_attention_layernorm(
  240. hidden_states, residual)
  241. hidden_states = self.mlp(hidden_states)
  242. return hidden_states, residual
  243. class LlamaModel(nn.Module):
  244. def __init__(
  245. self,
  246. config: LlamaConfig,
  247. cache_config: Optional[CacheConfig] = None,
  248. quant_config: Optional[QuantizationConfig] = None,
  249. lora_config: Optional[LoRAConfig] = None,
  250. prefix: str = "",
  251. ) -> None:
  252. super().__init__()
  253. self.config = config
  254. self.padding_idx = config.pad_token_id
  255. lora_vocab = (lora_config.lora_extra_vocab_size *
  256. (lora_config.max_loras or 1)) if lora_config else 0
  257. self.vocab_size = config.vocab_size + lora_vocab
  258. self.org_vocab_size = config.vocab_size
  259. if get_pp_group().is_first_rank or (config.tie_word_embeddings
  260. and get_pp_group().is_last_rank):
  261. self.embed_tokens = VocabParallelEmbedding(
  262. self.vocab_size,
  263. config.hidden_size,
  264. org_num_embeddings=config.vocab_size,
  265. quant_config=quant_config,
  266. )
  267. else:
  268. self.embed_tokens = PPMissingLayer()
  269. self.start_layer, self.end_layer, self.layers = make_layers(
  270. config.num_hidden_layers,
  271. lambda prefix: LlamaDecoderLayer(config=config,
  272. cache_config=cache_config,
  273. quant_config=quant_config,
  274. prefix=prefix),
  275. prefix=f"{prefix}.layers")
  276. if get_pp_group().is_last_rank:
  277. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  278. else:
  279. self.norm = PPMissingLayer()
  280. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  281. return self.embed_tokens(input_ids)
  282. def forward(
  283. self,
  284. input_ids: Optional[torch.Tensor],
  285. positions: torch.Tensor,
  286. kv_caches: List[torch.Tensor],
  287. attn_metadata: AttentionMetadata,
  288. intermediate_tensors: Optional[IntermediateTensors],
  289. inputs_embeds: Optional[torch.Tensor] = None,
  290. ) -> Union[torch.Tensor, IntermediateTensors]:
  291. if get_pp_group().is_first_rank:
  292. if inputs_embeds is not None:
  293. hidden_states = inputs_embeds
  294. else:
  295. hidden_states = self.get_input_embeddings(input_ids)
  296. residual = None
  297. else:
  298. assert intermediate_tensors is not None
  299. hidden_states = intermediate_tensors["hidden_states"]
  300. residual = intermediate_tensors["residual"]
  301. for i in range(self.start_layer, self.end_layer):
  302. layer = self.layers[i]
  303. hidden_states, residual = layer(
  304. positions,
  305. hidden_states,
  306. kv_caches[i - self.start_layer],
  307. attn_metadata,
  308. residual,
  309. )
  310. if not get_pp_group().is_last_rank:
  311. return IntermediateTensors({
  312. "hidden_states": hidden_states,
  313. "residual": residual
  314. })
  315. hidden_states, _ = self.norm(hidden_states, residual)
  316. return hidden_states
  317. class LlamaForCausalLM(nn.Module, SupportsLoRA):
  318. packed_modules_mapping = {
  319. "qkv_proj": [
  320. "q_proj",
  321. "k_proj",
  322. "v_proj",
  323. ],
  324. "gate_up_proj": [
  325. "gate_proj",
  326. "up_proj",
  327. ],
  328. }
  329. # LoRA specific attributes
  330. supported_lora_modules = [
  331. "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
  332. "lm_head"
  333. ]
  334. embedding_modules = {
  335. "embed_tokens": "input_embeddings",
  336. "lm_head": "output_embeddings",
  337. }
  338. embedding_padding_modules = ["lm_head"]
  339. bitsandbytes_stacked_params_mapping = {
  340. # shard_name, weight_name, index
  341. "q_proj": ("qkv_proj", 0),
  342. "k_proj": ("qkv_proj", 1),
  343. "v_proj": ("qkv_proj", 2),
  344. "gate_proj": ("gate_up_proj", 0),
  345. "up_proj": ("gate_up_proj", 1),
  346. }
  347. # Mistral/Llama models can also be loaded with --load-format mistral
  348. # from consolidated.safetensors checkpoints
  349. mistral_mapping = {
  350. "layers": "model.layers",
  351. "attention": "self_attn",
  352. "wq": "q_proj",
  353. "wk": "k_proj",
  354. "wv": "v_proj",
  355. "wo": "o_proj",
  356. "attention_norm": "input_layernorm",
  357. "feed_forward": "mlp",
  358. "w1": "gate_proj",
  359. "w2": "down_proj",
  360. "w3": "up_proj",
  361. "ffn_norm": "post_attention_layernorm",
  362. "tok_embeddings": "model.embed_tokens",
  363. "output": "lm_head",
  364. "norm": "model.norm"
  365. }
  366. def __init__(
  367. self,
  368. config: LlamaConfig,
  369. cache_config: Optional[CacheConfig] = None,
  370. quant_config: Optional[QuantizationConfig] = None,
  371. lora_config: Optional[LoRAConfig] = None,
  372. ) -> None:
  373. super().__init__()
  374. self.config = config
  375. self.lora_config = lora_config
  376. self.model = LlamaModel(config,
  377. cache_config,
  378. quant_config,
  379. lora_config=lora_config,
  380. prefix="model")
  381. if get_pp_group().is_last_rank:
  382. self.unpadded_vocab_size = config.vocab_size
  383. if lora_config:
  384. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  385. self.lm_head = ParallelLMHead(
  386. self.unpadded_vocab_size,
  387. config.hidden_size,
  388. org_num_embeddings=config.vocab_size,
  389. padding_size=None
  390. # We need bigger padding if using lora for kernel
  391. # compatibility
  392. if not lora_config else lora_config.lora_vocab_padding_size,
  393. quant_config=quant_config,
  394. )
  395. if config.tie_word_embeddings:
  396. self.lm_head.weight = self.model.embed_tokens.weight
  397. logit_scale = getattr(config, "logit_scale", 1.0)
  398. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  399. config.vocab_size,
  400. logit_scale)
  401. self.sampler = Sampler()
  402. else:
  403. self.lm_head = PPMissingLayer()
  404. def forward(
  405. self,
  406. input_ids: torch.Tensor,
  407. positions: torch.Tensor,
  408. kv_caches: List[torch.Tensor],
  409. attn_metadata: AttentionMetadata,
  410. intermediate_tensors: Optional[IntermediateTensors] = None,
  411. ) -> Union[torch.Tensor, IntermediateTensors]:
  412. model_output = self.model(input_ids, positions, kv_caches,
  413. attn_metadata, intermediate_tensors)
  414. return model_output
  415. def compute_logits(
  416. self,
  417. hidden_states: torch.Tensor,
  418. sampling_metadata: SamplingMetadata,
  419. ) -> Optional[torch.Tensor]:
  420. logits = self.logits_processor(self.lm_head, hidden_states,
  421. sampling_metadata)
  422. return logits
  423. def sample(
  424. self,
  425. logits: torch.Tensor,
  426. sampling_metadata: SamplingMetadata,
  427. ) -> Optional[SamplerOutput]:
  428. next_tokens = self.sampler(logits, sampling_metadata)
  429. return next_tokens
  430. def make_empty_intermediate_tensors(
  431. self, batch_size: int, dtype: torch.dtype,
  432. device: torch.device) -> IntermediateTensors:
  433. return IntermediateTensors({
  434. "hidden_states":
  435. torch.zeros((batch_size, self.config.hidden_size),
  436. dtype=dtype,
  437. device=device),
  438. "residual":
  439. torch.zeros((batch_size, self.config.hidden_size),
  440. dtype=dtype,
  441. device=device),
  442. })
  443. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  444. stacked_params_mapping = [
  445. # (param_name, shard_name, shard_id)
  446. (".qkv_proj", ".q_proj", "q"),
  447. (".qkv_proj", ".k_proj", "k"),
  448. (".qkv_proj", ".v_proj", "v"),
  449. (".gate_up_proj", ".gate_proj", 0),
  450. (".gate_up_proj", ".up_proj", 1),
  451. ]
  452. params_dict = dict(self.named_parameters())
  453. for name, loaded_weight in weights:
  454. name, loaded_weight = self.maybe_remap_mistral(name, loaded_weight)
  455. if "rotary_emb.inv_freq" in name:
  456. continue
  457. if ("rotary_emb.cos_cached" in name
  458. or "rotary_emb.sin_cached" in name):
  459. # Models trained using ColossalAI may include these tensors in
  460. # the checkpoint. Skip them.
  461. continue
  462. # With tie_word_embeddings, we can skip lm_head.weight
  463. # The weight might appear unnecessarily in the files if the model is
  464. # processed with quantization, LoRA, fine-tuning, etc.
  465. if self.config.tie_word_embeddings and "lm_head.weight" in name:
  466. continue
  467. if scale_name := get_compressed_tensors_cache_scale(name):
  468. # Loading kv cache scales for compressed-tensors quantization
  469. param = params_dict[scale_name]
  470. weight_loader = getattr(param, "weight_loader",
  471. default_weight_loader)
  472. loaded_weight = loaded_weight[0]
  473. weight_loader(param, loaded_weight)
  474. continue
  475. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  476. if weight_name not in name:
  477. continue
  478. name = name.replace(weight_name, param_name)
  479. # Skip loading extra bias for GPTQ models.
  480. if name.endswith(".bias") and name not in params_dict:
  481. continue
  482. if is_pp_missing_parameter(name, self):
  483. continue
  484. param = params_dict[name]
  485. weight_loader = param.weight_loader
  486. weight_loader(param, loaded_weight, shard_id)
  487. break
  488. else:
  489. # Skip loading extra bias for GPTQ models.
  490. if name.endswith(".bias") and name not in params_dict:
  491. continue
  492. # Remapping the name of FP8 kv-scale.
  493. name = maybe_remap_kv_scale_name(name, params_dict)
  494. if name is None:
  495. continue
  496. if is_pp_missing_parameter(name, self):
  497. continue
  498. param = params_dict[name]
  499. weight_loader = getattr(param, "weight_loader",
  500. default_weight_loader)
  501. weight_loader(param, loaded_weight)
  502. # If this function is called, it should always initialize KV cache scale
  503. # factors (or else raise an exception). Thus, handled exceptions should
  504. # make sure to leave KV cache scale factors in a known good (dummy) state
  505. def load_kv_cache_scales(self, quantization_param_path: str) -> None:
  506. tp_size = get_tensor_model_parallel_world_size()
  507. tp_rank = get_tensor_model_parallel_rank()
  508. for layer_idx, scaling_factor in kv_cache_scales_loader(
  509. quantization_param_path, tp_rank, tp_size,
  510. self.config.num_hidden_layers,
  511. self.config.__class__.model_type):
  512. if not isinstance(self.model.layers[layer_idx], nn.Identity):
  513. layer_self_attn = self.model.layers[layer_idx].self_attn
  514. if is_hip():
  515. # The scaling factor convention we are assuming is
  516. # quantized_value * scaling_factor ~= true_value
  517. # which is consistent with the practice of setting
  518. # scaling_factor = tensor_amax / FPtype_max
  519. scaling_factor *= 2
  520. if hasattr(layer_self_attn, "kv_scale"):
  521. layer_self_attn.attn._kv_scale = scaling_factor
  522. else:
  523. raise RuntimeError("Self attention has no KV cache scaling "
  524. "factor attribute!")
  525. # This function is used to remap the mistral format as
  526. # used by Mistral and Llama <=2
  527. def maybe_remap_mistral(
  528. self, name: str,
  529. loaded_weight: torch.Tensor) -> Tuple[str, torch.Tensor]:
  530. def permute(w, n_heads):
  531. attn_in = self.config.head_dim * n_heads
  532. attn_out = self.config.hidden_size
  533. return w.view(n_heads, attn_in // n_heads // 2, 2,
  534. attn_out).transpose(1, 2).reshape(attn_in, attn_out)
  535. mapping = self.mistral_mapping
  536. modules = name.split(".")
  537. # rotary embeds should be sliced
  538. if "wk" in modules:
  539. loaded_weight = permute(loaded_weight,
  540. self.config.num_key_value_heads)
  541. elif "wq" in modules:
  542. loaded_weight = permute(loaded_weight,
  543. self.config.num_attention_heads)
  544. for item in modules:
  545. if item in mapping and mapping[item] not in name:
  546. name = name.replace(item, mapping[item])
  547. return name, loaded_weight