llama.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only LLaMA model compatible with HuggingFace weights."""
  24. from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
  25. import torch
  26. from torch import nn
  27. from transformers import LlamaConfig
  28. from aphrodite.attention import Attention, AttentionMetadata
  29. from aphrodite.common.config import CacheConfig, LoRAConfig
  30. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  31. from aphrodite.common.utils import is_hip
  32. from aphrodite.distributed import (get_current_tp_rank_partition_size,
  33. get_pp_group,
  34. get_tensor_model_parallel_rank,
  35. get_tensor_model_parallel_world_size,
  36. tensor_model_parallel_gather)
  37. from aphrodite.modeling.layers.activation import SiluAndMul
  38. from aphrodite.modeling.layers.layernorm import RMSNorm
  39. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  40. QKVParallelLinear,
  41. RowParallelLinear)
  42. from aphrodite.modeling.layers.logits_processor import (LogitsProcessor,
  43. _prune_hidden_states)
  44. from aphrodite.modeling.layers.rotary_embedding import get_rope
  45. from aphrodite.modeling.layers.sampler import Sampler
  46. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  47. ParallelLMHead, VocabParallelEmbedding)
  48. from aphrodite.modeling.model_loader.weight_utils import (
  49. default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
  50. from aphrodite.modeling.models.interfaces import SupportsLoRA
  51. from aphrodite.modeling.models.utils import (PPMissingLayer,
  52. is_pp_missing_parameter,
  53. make_layers)
  54. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  55. from aphrodite.quantization.base_config import QuantizationConfig
  56. from aphrodite.quantization.compressed_tensors.utils import (
  57. get_compressed_tensors_cache_scale)
  58. class LlamaMLP(nn.Module):
  59. def __init__(
  60. self,
  61. hidden_size: int,
  62. intermediate_size: int,
  63. hidden_act: str,
  64. quant_config: Optional[QuantizationConfig] = None,
  65. bias: bool = False,
  66. prefix: str = "",
  67. ) -> None:
  68. super().__init__()
  69. self.gate_up_proj = MergedColumnParallelLinear(
  70. input_size=hidden_size,
  71. output_sizes=[intermediate_size] * 2,
  72. bias=bias,
  73. quant_config=quant_config,
  74. prefix=f"{prefix}.gate_up_proj")
  75. self.down_proj = RowParallelLinear(input_size=intermediate_size,
  76. output_size=hidden_size,
  77. bias=bias,
  78. quant_config=quant_config,
  79. prefix=f"{prefix}.down_proj")
  80. if hidden_act != "silu":
  81. raise ValueError(f"Unsupported activation: {hidden_act}. "
  82. "Only silu is supported for now.")
  83. self.act_fn = SiluAndMul()
  84. def forward(self, x):
  85. gate_up, _ = self.gate_up_proj(x)
  86. x = self.act_fn(gate_up)
  87. x, _ = self.down_proj(x)
  88. return x
  89. class LlamaAttention(nn.Module):
  90. def __init__(
  91. self,
  92. config: LlamaConfig,
  93. hidden_size: int,
  94. num_heads: int,
  95. num_kv_heads: int,
  96. rope_theta: float = 10000,
  97. rope_scaling: Optional[Dict[str, Any]] = None,
  98. max_position_embeddings: int = 8192,
  99. quant_config: Optional[QuantizationConfig] = None,
  100. bias: bool = False,
  101. cache_config: Optional[CacheConfig] = None,
  102. prefix: str = "",
  103. ) -> None:
  104. super().__init__()
  105. self.hidden_size = hidden_size
  106. tp_size = get_tensor_model_parallel_world_size()
  107. tp_rank = get_tensor_model_parallel_rank()
  108. self.total_num_heads = num_heads
  109. self.total_num_kv_heads = num_kv_heads
  110. self.num_kv_heads = max(
  111. 1,
  112. get_current_tp_rank_partition_size(self.total_num_kv_heads,
  113. tp_rank, tp_size))
  114. num_heads_per_kv_head = self.total_num_heads // self.total_num_kv_heads
  115. self.num_heads = self.num_kv_heads * num_heads_per_kv_head
  116. # MistralConfig has an optional head_dim introduced by Mistral-Nemo
  117. self.head_dim = getattr(config, "head_dim",
  118. self.hidden_size // self.total_num_heads)
  119. self.q_size = self.num_heads * self.head_dim
  120. self.kv_size = self.num_kv_heads * self.head_dim
  121. self.scaling = self.head_dim**-0.5
  122. self.rope_theta = rope_theta
  123. self.max_position_embeddings = max_position_embeddings
  124. self.qkv_proj = QKVParallelLinear(
  125. hidden_size=hidden_size,
  126. head_size=self.head_dim,
  127. total_num_heads=self.total_num_heads,
  128. total_num_kv_heads=self.total_num_kv_heads,
  129. bias=bias,
  130. quant_config=quant_config,
  131. prefix=f"{prefix}.qkv_proj",
  132. )
  133. self.o_proj = RowParallelLinear(
  134. input_size=self.total_num_heads * self.head_dim,
  135. output_size=hidden_size,
  136. bias=bias,
  137. quant_config=quant_config,
  138. partition_multiple_of=num_heads_per_kv_head * self.head_dim,
  139. prefix=f"{prefix}.o_proj",
  140. )
  141. is_neox_style = True
  142. if quant_config is not None and quant_config.get_name() == "gguf":
  143. is_neox_style = False
  144. self.rotary_emb = get_rope(
  145. self.head_dim,
  146. rotary_dim=self.head_dim,
  147. max_position=max_position_embeddings,
  148. base=rope_theta,
  149. rope_scaling=rope_scaling,
  150. is_neox_style=is_neox_style,
  151. )
  152. self.attn = Attention(self.num_heads,
  153. self.head_dim,
  154. self.scaling,
  155. num_kv_heads=self.num_kv_heads,
  156. cache_config=cache_config,
  157. quant_config=quant_config)
  158. def forward(
  159. self,
  160. positions: torch.Tensor,
  161. hidden_states: torch.Tensor,
  162. kv_cache: torch.Tensor,
  163. attn_metadata: AttentionMetadata,
  164. ) -> torch.Tensor:
  165. qkv, _ = self.qkv_proj(hidden_states)
  166. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  167. q, k = self.rotary_emb(positions, q, k)
  168. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  169. output, _ = self.o_proj(attn_output)
  170. return output
  171. class LlamaDecoderLayer(nn.Module):
  172. def __init__(
  173. self,
  174. config: LlamaConfig,
  175. cache_config: Optional[CacheConfig] = None,
  176. quant_config: Optional[QuantizationConfig] = None,
  177. prefix: str = "",
  178. ) -> None:
  179. super().__init__()
  180. self.hidden_size = config.hidden_size
  181. rope_theta = getattr(config, "rope_theta", 10000)
  182. rope_scaling = getattr(config, "rope_scaling", None)
  183. if rope_scaling is not None and getattr(
  184. config, "original_max_position_embeddings", None):
  185. rope_scaling["original_max_position_embeddings"] = (
  186. config.original_max_position_embeddings)
  187. max_position_embeddings = getattr(config, "max_position_embeddings",
  188. 8192)
  189. # Support abacusai/Smaug-72B-v0.1 with attention_bias
  190. # Support internlm/internlm-7b with bias
  191. attention_bias = getattr(config, "attention_bias", False) or getattr(
  192. config, "bias", False)
  193. self.self_attn = LlamaAttention(
  194. config=config,
  195. hidden_size=self.hidden_size,
  196. num_heads=config.num_attention_heads,
  197. num_kv_heads=getattr(config, "num_key_value_heads",
  198. config.num_attention_heads),
  199. rope_theta=rope_theta,
  200. rope_scaling=rope_scaling,
  201. max_position_embeddings=max_position_embeddings,
  202. quant_config=quant_config,
  203. bias=attention_bias,
  204. cache_config=cache_config,
  205. prefix=f"{prefix}.self_attn",
  206. )
  207. self.mlp = LlamaMLP(
  208. hidden_size=self.hidden_size,
  209. intermediate_size=config.intermediate_size,
  210. hidden_act=config.hidden_act,
  211. quant_config=quant_config,
  212. bias=getattr(config, "mlp_bias", False),
  213. prefix=f"{prefix}.mlp",
  214. )
  215. self.input_layernorm = RMSNorm(config.hidden_size,
  216. eps=config.rms_norm_eps)
  217. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  218. eps=config.rms_norm_eps)
  219. def forward(
  220. self,
  221. positions: torch.Tensor,
  222. hidden_states: torch.Tensor,
  223. kv_cache: torch.Tensor,
  224. attn_metadata: AttentionMetadata,
  225. residual: Optional[torch.Tensor],
  226. ) -> Tuple[torch.Tensor, torch.Tensor]:
  227. # Self Attention
  228. if residual is None:
  229. residual = hidden_states
  230. hidden_states = self.input_layernorm(hidden_states)
  231. else:
  232. hidden_states, residual = self.input_layernorm(
  233. hidden_states, residual)
  234. hidden_states = self.self_attn(
  235. positions=positions,
  236. hidden_states=hidden_states,
  237. kv_cache=kv_cache,
  238. attn_metadata=attn_metadata,
  239. )
  240. # Fully Connected
  241. hidden_states, residual = self.post_attention_layernorm(
  242. hidden_states, residual)
  243. hidden_states = self.mlp(hidden_states)
  244. return hidden_states, residual
  245. class LlamaModel(nn.Module):
  246. def __init__(
  247. self,
  248. config: LlamaConfig,
  249. cache_config: Optional[CacheConfig] = None,
  250. quant_config: Optional[QuantizationConfig] = None,
  251. lora_config: Optional[LoRAConfig] = None,
  252. prefix: str = "",
  253. ) -> None:
  254. super().__init__()
  255. self.config = config
  256. self.padding_idx = config.pad_token_id
  257. lora_vocab = (lora_config.lora_extra_vocab_size *
  258. (lora_config.max_loras or 1)) if lora_config else 0
  259. self.vocab_size = config.vocab_size + lora_vocab
  260. self.org_vocab_size = config.vocab_size
  261. if get_pp_group().is_first_rank or (config.tie_word_embeddings
  262. and get_pp_group().is_last_rank):
  263. self.embed_tokens = VocabParallelEmbedding(
  264. self.vocab_size,
  265. config.hidden_size,
  266. org_num_embeddings=config.vocab_size,
  267. quant_config=quant_config,
  268. )
  269. else:
  270. self.embed_tokens = PPMissingLayer()
  271. self.start_layer, self.end_layer, self.layers = make_layers(
  272. config.num_hidden_layers,
  273. lambda prefix: LlamaDecoderLayer(config=config,
  274. cache_config=cache_config,
  275. quant_config=quant_config,
  276. prefix=prefix),
  277. prefix=f"{prefix}.layers")
  278. if get_pp_group().is_last_rank:
  279. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  280. else:
  281. self.norm = PPMissingLayer()
  282. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  283. return self.embed_tokens(input_ids)
  284. def forward(
  285. self,
  286. input_ids: Optional[torch.Tensor],
  287. positions: torch.Tensor,
  288. kv_caches: List[torch.Tensor],
  289. attn_metadata: AttentionMetadata,
  290. intermediate_tensors: Optional[IntermediateTensors],
  291. inputs_embeds: Optional[torch.Tensor] = None,
  292. ) -> Union[torch.Tensor, IntermediateTensors]:
  293. if get_pp_group().is_first_rank:
  294. if inputs_embeds is not None:
  295. hidden_states = inputs_embeds
  296. else:
  297. hidden_states = self.get_input_embeddings(input_ids)
  298. residual = None
  299. else:
  300. assert intermediate_tensors is not None
  301. hidden_states = intermediate_tensors["hidden_states"]
  302. residual = intermediate_tensors["residual"]
  303. for i in range(self.start_layer, self.end_layer):
  304. layer = self.layers[i]
  305. hidden_states, residual = layer(
  306. positions,
  307. hidden_states,
  308. kv_caches[i - self.start_layer],
  309. attn_metadata,
  310. residual,
  311. )
  312. if not get_pp_group().is_last_rank:
  313. return IntermediateTensors({
  314. "hidden_states": hidden_states,
  315. "residual": residual
  316. })
  317. hidden_states, _ = self.norm(hidden_states, residual)
  318. return hidden_states
  319. class LlamaForCausalLM(nn.Module, SupportsLoRA):
  320. packed_modules_mapping = {
  321. "qkv_proj": [
  322. "q_proj",
  323. "k_proj",
  324. "v_proj",
  325. ],
  326. "gate_up_proj": [
  327. "gate_proj",
  328. "up_proj",
  329. ],
  330. }
  331. # LoRA specific attributes
  332. supported_lora_modules = [
  333. "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
  334. "lm_head"
  335. ]
  336. embedding_modules = {
  337. "embed_tokens": "input_embeddings",
  338. "lm_head": "output_embeddings",
  339. }
  340. embedding_padding_modules = ["lm_head"]
  341. bitsandbytes_stacked_params_mapping = {
  342. # shard_name, weight_name, index
  343. "q_proj": ("qkv_proj", 0),
  344. "k_proj": ("qkv_proj", 1),
  345. "v_proj": ("qkv_proj", 2),
  346. "gate_proj": ("gate_up_proj", 0),
  347. "up_proj": ("gate_up_proj", 1),
  348. }
  349. # Mistral/Llama models can also be loaded with --load-format mistral
  350. # from consolidated.safetensors checkpoints
  351. mistral_mapping = {
  352. "layers": "model.layers",
  353. "attention": "self_attn",
  354. "wq": "q_proj",
  355. "wk": "k_proj",
  356. "wv": "v_proj",
  357. "wo": "o_proj",
  358. "attention_norm": "input_layernorm",
  359. "feed_forward": "mlp",
  360. "w1": "gate_proj",
  361. "w2": "down_proj",
  362. "w3": "up_proj",
  363. "ffn_norm": "post_attention_layernorm",
  364. "tok_embeddings": "model.embed_tokens",
  365. "output": "lm_head",
  366. "norm": "model.norm"
  367. }
  368. def __init__(
  369. self,
  370. config: LlamaConfig,
  371. cache_config: Optional[CacheConfig] = None,
  372. quant_config: Optional[QuantizationConfig] = None,
  373. lora_config: Optional[LoRAConfig] = None,
  374. ) -> None:
  375. super().__init__()
  376. self.config = config
  377. self.lora_config = lora_config
  378. self.model = LlamaModel(config,
  379. cache_config,
  380. quant_config,
  381. lora_config=lora_config,
  382. prefix="model")
  383. if get_pp_group().is_last_rank:
  384. self.unpadded_vocab_size = config.vocab_size
  385. if lora_config:
  386. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  387. self.lm_head = ParallelLMHead(
  388. self.unpadded_vocab_size,
  389. config.hidden_size,
  390. org_num_embeddings=config.vocab_size,
  391. padding_size=None
  392. # We need bigger padding if using lora for kernel
  393. # compatibility
  394. if not lora_config else lora_config.lora_vocab_padding_size,
  395. quant_config=quant_config,
  396. )
  397. if config.tie_word_embeddings:
  398. self.lm_head.weight = self.model.embed_tokens.weight
  399. logit_scale = getattr(config, "logit_scale", 1.0)
  400. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  401. config.vocab_size,
  402. logit_scale,
  403. logits_as_input=True)
  404. self.org_vocab_size = config.vocab_size
  405. self.sampler = Sampler()
  406. else:
  407. self.lm_head = PPMissingLayer()
  408. def forward(
  409. self,
  410. input_ids: torch.Tensor,
  411. positions: torch.Tensor,
  412. kv_caches: List[torch.Tensor],
  413. attn_metadata: AttentionMetadata,
  414. intermediate_tensors: Optional[IntermediateTensors] = None,
  415. ) -> Union[torch.Tensor, IntermediateTensors]:
  416. model_output = self.model(input_ids, positions, kv_caches,
  417. attn_metadata, intermediate_tensors)
  418. return model_output
  419. def _get_logits(self,
  420. hidden_states: torch.Tensor,
  421. sampling_metadata: SamplingMetadata,
  422. ) -> torch.Tensor:
  423. hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
  424. # Get the logits for the next tokens.
  425. logits = self.lm_head.linear_method.apply(
  426. self.lm_head,
  427. hidden_states,
  428. bias=None,
  429. )
  430. logits = tensor_model_parallel_gather(logits)
  431. # Remove paddings in vocab (if any).
  432. if logits is not None:
  433. logits = logits[:, :self.org_vocab_size]
  434. return logits
  435. def compute_logits(
  436. self,
  437. hidden_states: torch.Tensor,
  438. sampling_metadata: SamplingMetadata,
  439. ) -> Optional[torch.Tensor]:
  440. logits = self.logits_processor(self.lm_head, hidden_states,
  441. sampling_metadata)
  442. return logits
  443. def sample(
  444. self,
  445. logits: torch.Tensor,
  446. sampling_metadata: SamplingMetadata,
  447. ) -> Optional[SamplerOutput]:
  448. next_tokens = self.sampler(logits, sampling_metadata)
  449. return next_tokens
  450. def make_empty_intermediate_tensors(
  451. self, batch_size: int, dtype: torch.dtype,
  452. device: torch.device) -> IntermediateTensors:
  453. return IntermediateTensors({
  454. "hidden_states":
  455. torch.zeros((batch_size, self.config.hidden_size),
  456. dtype=dtype,
  457. device=device),
  458. "residual":
  459. torch.zeros((batch_size, self.config.hidden_size),
  460. dtype=dtype,
  461. device=device),
  462. })
  463. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  464. stacked_params_mapping = [
  465. # (param_name, shard_name, shard_id)
  466. (".qkv_proj", ".q_proj", "q"),
  467. (".qkv_proj", ".k_proj", "k"),
  468. (".qkv_proj", ".v_proj", "v"),
  469. (".gate_up_proj", ".gate_proj", 0),
  470. (".gate_up_proj", ".up_proj", 1),
  471. ]
  472. params_dict = dict(self.named_parameters())
  473. for name, loaded_weight in weights:
  474. name, loaded_weight = self.maybe_remap_mistral(name, loaded_weight)
  475. if "rotary_emb.inv_freq" in name:
  476. continue
  477. if ("rotary_emb.cos_cached" in name
  478. or "rotary_emb.sin_cached" in name):
  479. # Models trained using ColossalAI may include these tensors in
  480. # the checkpoint. Skip them.
  481. continue
  482. # With tie_word_embeddings, we can skip lm_head.weight
  483. # The weight might appear unnecessarily in the files if the model is
  484. # processed with quantization, LoRA, fine-tuning, etc.
  485. if self.config.tie_word_embeddings and "lm_head.weight" in name:
  486. continue
  487. if scale_name := get_compressed_tensors_cache_scale(name):
  488. # Loading kv cache scales for compressed-tensors quantization
  489. param = params_dict[scale_name]
  490. weight_loader = getattr(param, "weight_loader",
  491. default_weight_loader)
  492. loaded_weight = loaded_weight[0]
  493. weight_loader(param, loaded_weight)
  494. continue
  495. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  496. if weight_name not in name:
  497. continue
  498. name = name.replace(weight_name, param_name)
  499. # Skip loading extra bias for GPTQ models.
  500. if name.endswith(".bias") and name not in params_dict:
  501. continue
  502. if is_pp_missing_parameter(name, self):
  503. continue
  504. param = params_dict[name]
  505. weight_loader = param.weight_loader
  506. weight_loader(param, loaded_weight, shard_id)
  507. break
  508. else:
  509. # Skip loading extra bias for GPTQ models.
  510. if name.endswith(".bias") and name not in params_dict:
  511. continue
  512. # Remapping the name of FP8 kv-scale.
  513. name = maybe_remap_kv_scale_name(name, params_dict)
  514. if name is None:
  515. continue
  516. if is_pp_missing_parameter(name, self):
  517. continue
  518. param = params_dict[name]
  519. weight_loader = getattr(param, "weight_loader",
  520. default_weight_loader)
  521. weight_loader(param, loaded_weight)
  522. # If this function is called, it should always initialize KV cache scale
  523. # factors (or else raise an exception). Thus, handled exceptions should
  524. # make sure to leave KV cache scale factors in a known good (dummy) state
  525. def load_kv_cache_scales(self, quantization_param_path: str) -> None:
  526. tp_size = get_tensor_model_parallel_world_size()
  527. tp_rank = get_tensor_model_parallel_rank()
  528. for layer_idx, scaling_factor in kv_cache_scales_loader(
  529. quantization_param_path, tp_rank, tp_size,
  530. self.config.num_hidden_layers,
  531. self.config.__class__.model_type):
  532. if not isinstance(self.model.layers[layer_idx], nn.Identity):
  533. layer_self_attn = self.model.layers[layer_idx].self_attn
  534. if is_hip():
  535. # The scaling factor convention we are assuming is
  536. # quantized_value * scaling_factor ~= true_value
  537. # which is consistent with the practice of setting
  538. # scaling_factor = tensor_amax / FPtype_max
  539. scaling_factor *= 2
  540. if hasattr(layer_self_attn, "kv_scale"):
  541. layer_self_attn.attn._kv_scale = scaling_factor
  542. else:
  543. raise RuntimeError("Self attention has no KV cache scaling "
  544. "factor attribute!")
  545. # This function is used to remap the mistral format as
  546. # used by Mistral and Llama <=2
  547. def maybe_remap_mistral(
  548. self, name: str,
  549. loaded_weight: torch.Tensor) -> Tuple[str, torch.Tensor]:
  550. def permute(w, n_heads):
  551. attn_in = self.config.head_dim * n_heads
  552. attn_out = self.config.hidden_size
  553. return w.view(n_heads, attn_in // n_heads // 2, 2,
  554. attn_out).transpose(1, 2).reshape(attn_in, attn_out)
  555. mapping = self.mistral_mapping
  556. modules = name.split(".")
  557. # rotary embeds should be sliced
  558. if "wk" in modules:
  559. loaded_weight = permute(loaded_weight,
  560. self.config.num_key_value_heads)
  561. elif "wq" in modules:
  562. loaded_weight = permute(loaded_weight,
  563. self.config.num_attention_heads)
  564. for item in modules:
  565. if item in mapping and mapping[item] not in name:
  566. name = name.replace(item, mapping[item])
  567. return name, loaded_weight