1
0

exaone.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617
  1. # coding=utf-8
  2. # Adapted from
  3. # https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/blob/main/modeling_exaone.py
  4. # Copyright 2024 The LG U+ CTO AI Tech Lab.
  5. # Copyright 2021 The LG AI Research EXAONE Lab
  6. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only Exaone model compatible with HuggingFace weights."""
  25. from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
  26. import torch
  27. from torch import nn
  28. from aphrodite.attention import Attention, AttentionMetadata
  29. from aphrodite.common.config import CacheConfig, LoRAConfig
  30. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  31. from aphrodite.common.utils import is_hip
  32. from aphrodite.distributed import (get_pp_group,
  33. get_tensor_model_parallel_rank,
  34. get_tensor_model_parallel_world_size)
  35. from aphrodite.modeling.layers.activation import SiluAndMul
  36. from aphrodite.modeling.layers.layernorm import RMSNorm
  37. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  38. QKVParallelLinear,
  39. RowParallelLinear)
  40. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  41. from aphrodite.modeling.layers.rotary_embedding import get_rope
  42. from aphrodite.modeling.layers.sampler import Sampler
  43. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  44. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  45. from aphrodite.modeling.model_loader.weight_utils import (
  46. default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
  47. from aphrodite.modeling.models.interfaces import SupportsLoRA
  48. from aphrodite.modeling.models.utils import (PPMissingLayer,
  49. is_pp_missing_parameter,
  50. make_layers)
  51. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  52. from aphrodite.quantization.base_config import QuantizationConfig
  53. from aphrodite.quantization.compressed_tensors.utils import (
  54. get_compressed_tensors_cache_scale)
  55. class ExaoneGatedMLP(nn.Module):
  56. def __init__(
  57. self,
  58. hidden_size: int,
  59. intermediate_size: int,
  60. hidden_act: str,
  61. quant_config: Optional[QuantizationConfig] = None,
  62. bias: bool = False,
  63. prefix: str = "",
  64. ) -> None:
  65. super().__init__()
  66. self.gate_up_proj = MergedColumnParallelLinear(
  67. input_size=hidden_size,
  68. output_sizes=[intermediate_size] * 2,
  69. bias=bias,
  70. quant_config=quant_config,
  71. prefix=f"{prefix}.gate_up_proj",
  72. )
  73. self.c_proj = RowParallelLinear(
  74. input_size=intermediate_size,
  75. output_size=hidden_size,
  76. bias=bias,
  77. quant_config=quant_config,
  78. prefix=f"{prefix}.c_proj",
  79. )
  80. if hidden_act != "silu":
  81. raise ValueError(f"Unsupported activation: {hidden_act}. "
  82. "Only silu is supported for now.")
  83. self.act_fn = SiluAndMul()
  84. def forward(self, x):
  85. gate_up, _ = self.gate_up_proj(x)
  86. x = self.act_fn(gate_up)
  87. x, _ = self.c_proj(x)
  88. return x
  89. class ExaoneAttention(nn.Module):
  90. def __init__(
  91. self,
  92. config,
  93. hidden_size: int,
  94. num_heads: int,
  95. num_kv_heads: int,
  96. rope_theta: float = 10000,
  97. rope_scaling: Optional[Dict[str, Any]] = None,
  98. max_position_embeddings: int = 8192,
  99. quant_config: Optional[QuantizationConfig] = None,
  100. bias: bool = False,
  101. cache_config: Optional[CacheConfig] = None,
  102. prefix: str = "",
  103. ) -> None:
  104. super().__init__()
  105. self.hidden_size = hidden_size
  106. tp_size = get_tensor_model_parallel_world_size()
  107. self.total_num_heads = num_heads
  108. assert self.total_num_heads % tp_size == 0
  109. self.num_heads = self.total_num_heads // tp_size
  110. self.total_num_kv_heads = num_kv_heads
  111. if self.total_num_kv_heads >= tp_size:
  112. # Number of KV heads is greater than TP size, so we partition
  113. # the KV heads across multiple tensor parallel GPUs.
  114. assert self.total_num_kv_heads % tp_size == 0
  115. else:
  116. # Number of KV heads is less than TP size, so we replicate
  117. # the KV heads across multiple tensor parallel GPUs.
  118. assert tp_size % self.total_num_kv_heads == 0
  119. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  120. # MistralConfig has an optional head_dim introduced by Mistral-Nemo
  121. self.head_dim = getattr(config, "head_dim",
  122. self.hidden_size // self.total_num_heads)
  123. self.q_size = self.num_heads * self.head_dim
  124. self.kv_size = self.num_kv_heads * self.head_dim
  125. self.scaling = self.head_dim**-0.5
  126. self.rope_theta = rope_theta
  127. self.max_position_embeddings = max_position_embeddings
  128. self.qkv_proj = QKVParallelLinear(
  129. hidden_size=hidden_size,
  130. head_size=self.head_dim,
  131. total_num_heads=self.total_num_heads,
  132. total_num_kv_heads=self.total_num_kv_heads,
  133. bias=bias,
  134. quant_config=quant_config,
  135. prefix=f"{prefix}.qkv_proj",
  136. )
  137. self.out_proj = RowParallelLinear(
  138. input_size=self.total_num_heads * self.head_dim,
  139. output_size=hidden_size,
  140. bias=bias,
  141. quant_config=quant_config,
  142. prefix=f"{prefix}.out_proj",
  143. )
  144. is_neox_style = True
  145. if quant_config is not None and quant_config.get_name() == "gguf":
  146. is_neox_style = False
  147. self.rotary_emb = get_rope(
  148. self.head_dim,
  149. rotary_dim=self.head_dim,
  150. max_position=max_position_embeddings,
  151. base=rope_theta,
  152. rope_scaling=rope_scaling,
  153. is_neox_style=is_neox_style,
  154. )
  155. self.attn = Attention(
  156. self.num_heads,
  157. self.head_dim,
  158. self.scaling,
  159. num_kv_heads=self.num_kv_heads,
  160. cache_config=cache_config,
  161. quant_config=quant_config,
  162. )
  163. def forward(
  164. self,
  165. positions: torch.Tensor,
  166. hidden_states: torch.Tensor,
  167. kv_cache: torch.Tensor,
  168. attn_metadata: AttentionMetadata,
  169. ) -> torch.Tensor:
  170. qkv, _ = self.qkv_proj(hidden_states)
  171. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  172. q, k = self.rotary_emb(positions, q, k)
  173. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  174. output, _ = self.out_proj(attn_output)
  175. return output
  176. class ExaoneBlockAttention(nn.Module):
  177. def __init__(
  178. self,
  179. config,
  180. hidden_size: int,
  181. num_heads: int,
  182. num_kv_heads: int,
  183. rope_theta: float = 10000,
  184. rope_scaling: Optional[Dict[str, Any]] = None,
  185. max_position_embeddings: int = 8192,
  186. quant_config: Optional[QuantizationConfig] = None,
  187. bias: bool = False,
  188. cache_config: Optional[CacheConfig] = None,
  189. prefix: str = "",
  190. ) -> None:
  191. super().__init__()
  192. self.attention = ExaoneAttention(
  193. config=config,
  194. hidden_size=hidden_size,
  195. num_heads=num_heads,
  196. num_kv_heads=num_kv_heads,
  197. rope_theta=rope_theta,
  198. rope_scaling=rope_scaling,
  199. max_position_embeddings=max_position_embeddings,
  200. quant_config=quant_config,
  201. bias=bias,
  202. cache_config=cache_config,
  203. prefix=prefix,
  204. )
  205. def forward(
  206. self,
  207. positions: torch.Tensor,
  208. hidden_states: torch.Tensor,
  209. kv_cache: torch.Tensor,
  210. attn_metadata: AttentionMetadata,
  211. ) -> torch.Tensor:
  212. return self.attention(
  213. positions=positions,
  214. hidden_states=hidden_states,
  215. kv_cache=kv_cache,
  216. attn_metadata=attn_metadata,
  217. )
  218. class ExaoneDecoderLayer(nn.Module):
  219. def __init__(
  220. self,
  221. config,
  222. cache_config: Optional[CacheConfig] = None,
  223. quant_config: Optional[QuantizationConfig] = None,
  224. prefix: str = "",
  225. ) -> None:
  226. super().__init__()
  227. self.hidden_size = config.hidden_size
  228. rope_theta = getattr(config, "rope_theta", 10000)
  229. rope_scaling = getattr(config, "rope_scaling", None)
  230. if rope_scaling is not None and getattr(
  231. config, "original_max_position_embeddings", None):
  232. rope_scaling["original_max_position_embeddings"] = (
  233. config.original_max_position_embeddings)
  234. max_position_embeddings = getattr(config, "max_position_embeddings",
  235. 8192)
  236. # Support abacusai/Smaug-72B-v0.1 with attention_bias
  237. # Support internlm/internlm-7b with bias
  238. attention_bias = getattr(config, "attention_bias", False) or getattr(
  239. config, "bias", False)
  240. self.attn = ExaoneBlockAttention(
  241. config=config,
  242. hidden_size=self.hidden_size,
  243. num_heads=config.num_attention_heads,
  244. num_kv_heads=getattr(config, "num_key_value_heads",
  245. config.num_attention_heads),
  246. rope_theta=rope_theta,
  247. rope_scaling=rope_scaling,
  248. max_position_embeddings=max_position_embeddings,
  249. quant_config=quant_config,
  250. bias=attention_bias,
  251. cache_config=cache_config,
  252. prefix=f"{prefix}.attn",
  253. )
  254. self.mlp = ExaoneGatedMLP(
  255. hidden_size=self.hidden_size,
  256. intermediate_size=config.intermediate_size,
  257. hidden_act=config.activation_function,
  258. quant_config=quant_config,
  259. bias=getattr(config, "mlp_bias", False),
  260. prefix=f"{prefix}.mlp",
  261. )
  262. self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  263. self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  264. def forward(
  265. self,
  266. positions: torch.Tensor,
  267. hidden_states: torch.Tensor,
  268. kv_cache: torch.Tensor,
  269. attn_metadata: AttentionMetadata,
  270. residual: Optional[torch.Tensor],
  271. ) -> Tuple[torch.Tensor, torch.Tensor]:
  272. # Self Attention
  273. if residual is None:
  274. residual = hidden_states
  275. hidden_states = self.ln_1(hidden_states)
  276. else:
  277. hidden_states, residual = self.ln_1(hidden_states, residual)
  278. hidden_states = self.attn(
  279. positions=positions,
  280. hidden_states=hidden_states,
  281. kv_cache=kv_cache,
  282. attn_metadata=attn_metadata,
  283. )
  284. # Fully Connected
  285. hidden_states, residual = self.ln_2(hidden_states, residual)
  286. hidden_states = self.mlp(hidden_states)
  287. return hidden_states, residual
  288. class ExaoneModel(nn.Module):
  289. def __init__(
  290. self,
  291. config,
  292. cache_config: Optional[CacheConfig] = None,
  293. quant_config: Optional[QuantizationConfig] = None,
  294. lora_config: Optional[LoRAConfig] = None,
  295. prefix: str = "",
  296. ) -> None:
  297. super().__init__()
  298. self.config = config
  299. self.padding_idx = config.pad_token_id
  300. lora_vocab = ((lora_config.lora_extra_vocab_size *
  301. (lora_config.max_loras or 1)) if lora_config else 0)
  302. self.vocab_size = config.vocab_size + lora_vocab
  303. self.wte = config.vocab_size
  304. if get_pp_group().is_first_rank or (config.tie_word_embeddings
  305. and get_pp_group().is_last_rank):
  306. self.wte = VocabParallelEmbedding(
  307. self.vocab_size,
  308. config.hidden_size,
  309. org_num_embeddings=config.vocab_size,
  310. quant_config=quant_config,
  311. )
  312. else:
  313. self.wte = PPMissingLayer()
  314. self.start_layer, self.end_layer, self.h = make_layers(
  315. config.num_hidden_layers,
  316. lambda prefix: ExaoneDecoderLayer(
  317. config=config,
  318. cache_config=cache_config,
  319. quant_config=quant_config,
  320. prefix=prefix,
  321. ),
  322. prefix=f"{prefix}.h",
  323. )
  324. if get_pp_group().is_last_rank:
  325. self.ln_f = RMSNorm(config.hidden_size,
  326. eps=config.layer_norm_epsilon)
  327. else:
  328. self.ln_f = PPMissingLayer()
  329. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  330. return self.wte(input_ids)
  331. def forward(
  332. self,
  333. input_ids: Optional[torch.Tensor],
  334. positions: torch.Tensor,
  335. kv_caches: List[torch.Tensor],
  336. attn_metadata: AttentionMetadata,
  337. intermediate_tensors: Optional[IntermediateTensors],
  338. inputs_embeds: Optional[torch.Tensor] = None,
  339. ) -> Union[torch.Tensor, IntermediateTensors]:
  340. if get_pp_group().is_first_rank:
  341. if inputs_embeds is not None:
  342. hidden_states = inputs_embeds
  343. else:
  344. hidden_states = self.get_input_embeddings(input_ids)
  345. residual = None
  346. else:
  347. assert intermediate_tensors is not None
  348. hidden_states = intermediate_tensors["hidden_states"]
  349. residual = intermediate_tensors["residual"]
  350. for i in range(self.start_layer, self.end_layer):
  351. layer = self.h[i]
  352. hidden_states, residual = layer(
  353. positions,
  354. hidden_states,
  355. kv_caches[i - self.start_layer],
  356. attn_metadata,
  357. residual,
  358. )
  359. if not get_pp_group().is_last_rank:
  360. return IntermediateTensors({
  361. "hidden_states": hidden_states,
  362. "residual": residual
  363. })
  364. hidden_states, _ = self.ln_f(hidden_states, residual)
  365. return hidden_states
  366. class ExaoneForCausalLM(nn.Module, SupportsLoRA):
  367. packed_modules_mapping = {
  368. "qkv_proj": [
  369. "q_proj",
  370. "k_proj",
  371. "v_proj",
  372. ],
  373. "gate_up_proj": [
  374. "c_fc_0",
  375. "c_fc_1",
  376. ],
  377. }
  378. # LoRA specific attributes
  379. supported_lora_modules = [
  380. "qkv_proj",
  381. "out_proj",
  382. "gate_up_proj",
  383. "c_proj",
  384. "wte",
  385. "lm_head",
  386. ]
  387. embedding_modules = {
  388. "wte": "input_embeddings",
  389. "lm_head": "output_embeddings",
  390. }
  391. embedding_padding_modules = ["lm_head"]
  392. bitsandbytes_stacked_params_mapping = {
  393. # shard_name, weight_name, index
  394. "q_proj": ("qkv_proj", 0),
  395. "k_proj": ("qkv_proj", 1),
  396. "v_proj": ("qkv_proj", 2),
  397. "c_fc_0": ("gate_up_proj", 0),
  398. "c_fc_1": ("gate_up_proj", 1),
  399. }
  400. def __init__(
  401. self,
  402. config,
  403. cache_config: Optional[CacheConfig] = None,
  404. quant_config: Optional[QuantizationConfig] = None,
  405. lora_config: Optional[LoRAConfig] = None,
  406. ) -> None:
  407. super().__init__()
  408. self.config = config
  409. self.lora_config = lora_config
  410. self.transformer = ExaoneModel(
  411. config,
  412. cache_config,
  413. quant_config,
  414. lora_config=lora_config,
  415. prefix="model",
  416. )
  417. if get_pp_group().is_last_rank:
  418. self.unpadded_vocab_size = config.vocab_size
  419. if lora_config:
  420. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  421. self.lm_head = ParallelLMHead(
  422. self.unpadded_vocab_size,
  423. config.hidden_size,
  424. org_num_embeddings=config.vocab_size,
  425. padding_size=DEFAULT_VOCAB_PADDING_SIZE
  426. # We need bigger padding if using lora for kernel
  427. # compatibility
  428. if not lora_config else lora_config.lora_vocab_padding_size,
  429. quant_config=quant_config,
  430. )
  431. if config.tie_word_embeddings:
  432. self.lm_head.weight = self.transformer.wte.weight
  433. logit_scale = getattr(config, "logit_scale", 1.0)
  434. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  435. config.vocab_size,
  436. logit_scale)
  437. self.sampler = Sampler()
  438. else:
  439. self.lm_head = PPMissingLayer()
  440. def forward(
  441. self,
  442. input_ids: torch.Tensor,
  443. positions: torch.Tensor,
  444. kv_caches: List[torch.Tensor],
  445. attn_metadata: AttentionMetadata,
  446. intermediate_tensors: Optional[IntermediateTensors] = None,
  447. ) -> Union[torch.Tensor, IntermediateTensors]:
  448. model_output = self.transformer(input_ids, positions, kv_caches,
  449. attn_metadata, intermediate_tensors)
  450. return model_output
  451. def compute_logits(
  452. self,
  453. hidden_states: torch.Tensor,
  454. sampling_metadata: SamplingMetadata,
  455. ) -> Optional[torch.Tensor]:
  456. logits = self.logits_processor(self.lm_head, hidden_states,
  457. sampling_metadata)
  458. return logits
  459. def sample(
  460. self,
  461. logits: torch.Tensor,
  462. sampling_metadata: SamplingMetadata,
  463. ) -> Optional[SamplerOutput]:
  464. next_tokens = self.sampler(logits, sampling_metadata)
  465. return next_tokens
  466. def make_empty_intermediate_tensors(
  467. self, batch_size: int, dtype: torch.dtype,
  468. device: torch.device) -> IntermediateTensors:
  469. return IntermediateTensors({
  470. "hidden_states":
  471. torch.zeros(
  472. (batch_size, self.config.hidden_size),
  473. dtype=dtype,
  474. device=device,
  475. ),
  476. "residual":
  477. torch.zeros(
  478. (batch_size, self.config.hidden_size),
  479. dtype=dtype,
  480. device=device,
  481. ),
  482. })
  483. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  484. stacked_params_mapping = [
  485. # (param_name, shard_name, shard_id)
  486. (".qkv_proj", ".q_proj", "q"),
  487. (".qkv_proj", ".k_proj", "k"),
  488. (".qkv_proj", ".v_proj", "v"),
  489. (".gate_up_proj", ".c_fc_0", 0),
  490. (".gate_up_proj", ".c_fc_1", 1),
  491. ]
  492. params_dict = dict(self.named_parameters())
  493. for name, loaded_weight in weights:
  494. if "rotary_emb.inv_freq" in name:
  495. continue
  496. if ("rotary_emb.cos_cached" in name
  497. or "rotary_emb.sin_cached" in name):
  498. # Models trained using ColossalAI may include these tensors in
  499. # the checkpoint. Skip them.
  500. continue
  501. # With tie_word_embeddings, we can skip lm_head.weight
  502. # The weight might appear unnecessarily in the files if the model is
  503. # processed with quantization, LoRA, fine-tuning, etc.
  504. if self.config.tie_word_embeddings and "lm_head.weight" in name:
  505. continue
  506. if scale_name := get_compressed_tensors_cache_scale(name):
  507. # Loading kv cache scales for compressed-tensors quantization
  508. param = params_dict[scale_name]
  509. weight_loader = getattr(param, "weight_loader",
  510. default_weight_loader)
  511. loaded_weight = loaded_weight[0]
  512. weight_loader(param, loaded_weight)
  513. continue
  514. for param_name, weight_name, shard_id in stacked_params_mapping:
  515. if weight_name not in name:
  516. continue
  517. name = name.replace(weight_name, param_name)
  518. # Skip loading extra bias for GPTQ models.
  519. if name.endswith(".bias") and name not in params_dict:
  520. continue
  521. if is_pp_missing_parameter(name, self):
  522. continue
  523. param = params_dict[name]
  524. weight_loader = param.weight_loader
  525. weight_loader(param, loaded_weight, shard_id)
  526. break
  527. else:
  528. # Skip loading extra bias for GPTQ models.
  529. if name.endswith(".bias") and name not in params_dict:
  530. continue
  531. # Remapping the name of FP8 kv-scale.
  532. name = maybe_remap_kv_scale_name(name, params_dict)
  533. if name is None:
  534. continue
  535. if is_pp_missing_parameter(name, self):
  536. continue
  537. param = params_dict[name]
  538. weight_loader = getattr(param, "weight_loader",
  539. default_weight_loader)
  540. weight_loader(param, loaded_weight)
  541. # If this function is called, it should always initialize KV cache scale
  542. # factors (or else raise an exception). Thus, handled exceptions should
  543. # make sure to leave KV cache scale factors in a known good (dummy) state
  544. def load_kv_cache_scales(self, quantization_param_path: str) -> None:
  545. tp_size = get_tensor_model_parallel_world_size()
  546. tp_rank = get_tensor_model_parallel_rank()
  547. for layer_idx, scaling_factor in kv_cache_scales_loader(
  548. quantization_param_path,
  549. tp_rank,
  550. tp_size,
  551. self.config.num_hidden_layers,
  552. self.config.__class__.model_type,
  553. ):
  554. if not isinstance(self.transformer.h[layer_idx], nn.Identity):
  555. layer_self_attn = self.transformer.h[layer_idx].attn
  556. if is_hip():
  557. # The scaling factor convention we are assuming is
  558. # quantized_value * scaling_factor ~= true_value
  559. # which is consistent with the practice of setting
  560. # scaling_factor = tensor_amax / FPtype_max
  561. scaling_factor *= 2
  562. if hasattr(layer_self_attn, "kv_scale"):
  563. layer_self_attn.attn._kv_scale = scaling_factor
  564. else:
  565. raise RuntimeError("Self attention has no KV cache scaling "
  566. "factor attribute!")