bitnet.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/microsoft/BitBLAS/blob/main/integration/BitNet/modeling_bitnet.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only Bitnet model compatible with HuggingFace weights."""
  25. # ruff: noqa: E501
  26. from typing import Dict, Iterable, List, Optional, Tuple
  27. import torch
  28. from torch import nn
  29. from transformers.configuration_utils import PretrainedConfig
  30. from loguru import logger
  31. from aphrodite.attention import Attention, AttentionMetadata
  32. from aphrodite.common.config import CacheConfig
  33. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  34. get_tensor_model_parallel_world_size)
  35. from aphrodite.modeling.layers.activation import SiluAndMul
  36. from aphrodite.modeling.layers.layernorm import RMSNorm
  37. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  38. QKVParallelLinear,
  39. RowParallelLinear)
  40. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  41. from aphrodite.quantization.base_config import (
  42. QuantizationConfig)
  43. from aphrodite.modeling.layers.rotary_embedding import get_rope
  44. from aphrodite.modeling.layers.sampler import Sampler
  45. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  46. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  47. from aphrodite.modeling.model_loader.weight_utils import (
  48. default_weight_loader, kv_cache_scales_loader)
  49. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  50. from aphrodite.common.sequence import SamplerOutput
  51. from aphrodite.common.utils import is_hip, print_warning_once
  52. class BitnetConfig(PretrainedConfig):
  53. model_type = "llama"
  54. keys_to_ignore_at_inference = ["past_key_values"]
  55. def __init__(
  56. self,
  57. vocab_size=32000,
  58. hidden_size=4096,
  59. intermediate_size=11008,
  60. num_hidden_layers=32,
  61. num_attention_heads=32,
  62. num_key_value_heads=None,
  63. hidden_act="silu",
  64. max_position_embeddings=2048,
  65. initializer_range=0.02,
  66. rms_norm_eps=1e-6,
  67. use_cache=True,
  68. pad_token_id=None,
  69. bos_token_id=1,
  70. eos_token_id=2,
  71. pretraining_tp=1,
  72. tie_word_embeddings=False,
  73. rope_theta=10000.0,
  74. rope_scaling=None,
  75. attention_bias=False,
  76. attention_dropout=0.0,
  77. weight_bits=1,
  78. input_bits=8,
  79. **kwargs,
  80. ):
  81. self.vocab_size = vocab_size
  82. self.max_position_embeddings = max_position_embeddings
  83. self.hidden_size = hidden_size
  84. self.intermediate_size = intermediate_size
  85. self.num_hidden_layers = num_hidden_layers
  86. self.num_attention_heads = num_attention_heads
  87. # for backward compatibility
  88. if num_key_value_heads is None:
  89. num_key_value_heads = num_attention_heads
  90. self.num_key_value_heads = num_key_value_heads
  91. self.hidden_act = hidden_act
  92. self.initializer_range = initializer_range
  93. self.rms_norm_eps = rms_norm_eps
  94. self.pretraining_tp = pretraining_tp
  95. self.use_cache = use_cache
  96. self.rope_theta = rope_theta
  97. self.rope_scaling = rope_scaling
  98. self._rope_scaling_validation()
  99. self.attention_bias = attention_bias
  100. self.attention_dropout = attention_dropout
  101. self.weight_bits = weight_bits
  102. self.input_bits = input_bits
  103. super().__init__(
  104. pad_token_id=pad_token_id,
  105. bos_token_id=bos_token_id,
  106. eos_token_id=eos_token_id,
  107. tie_word_embeddings=tie_word_embeddings,
  108. **kwargs,
  109. )
  110. def _rope_scaling_validation(self):
  111. """
  112. Validate the `rope_scaling` configuration.
  113. """
  114. if self.rope_scaling is None:
  115. return
  116. if not isinstance(self.rope_scaling,
  117. dict) or len(self.rope_scaling) != 2:
  118. raise ValueError(
  119. "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
  120. f"got {self.rope_scaling}")
  121. rope_scaling_type = self.rope_scaling.get("type", None)
  122. rope_scaling_factor = self.rope_scaling.get("factor", None)
  123. if rope_scaling_type is None or rope_scaling_type not in [
  124. "linear", "dynamic"
  125. ]:
  126. raise ValueError(
  127. f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
  128. )
  129. if (rope_scaling_factor is None
  130. or not isinstance(rope_scaling_factor, float)
  131. or rope_scaling_factor <= 1.0):
  132. raise ValueError(
  133. f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}"
  134. )
  135. class BitnetMLP(nn.Module):
  136. def __init__(
  137. self,
  138. hidden_size: int,
  139. intermediate_size: int,
  140. hidden_act: str,
  141. quant_config: Optional[QuantizationConfig] = None,
  142. bias: bool = False,
  143. config: BitnetConfig = None,
  144. ) -> None:
  145. super().__init__()
  146. self.gate_up_proj = MergedColumnParallelLinear(
  147. input_size=hidden_size,
  148. output_sizes=[intermediate_size] * 2,
  149. bias=bias,
  150. quant_config=quant_config,
  151. )
  152. self.down_proj = RowParallelLinear(
  153. input_size=intermediate_size,
  154. output_size=hidden_size,
  155. bias=bias,
  156. quant_config=quant_config,
  157. )
  158. if hidden_act != "silu":
  159. raise ValueError(f"Unsupported activation: {hidden_act}. "
  160. "Only silu is supported for now.")
  161. self.act_fn = SiluAndMul()
  162. self.ffn_layernorm = RMSNorm(intermediate_size,
  163. eps=config.rms_norm_eps)
  164. def forward(self, x):
  165. gate_up, _ = self.gate_up_proj(x)
  166. x = self.act_fn(gate_up)
  167. x = self.ffn_layernorm(x)
  168. x, _ = self.down_proj(x)
  169. return x
  170. class BitnetAttention(nn.Module):
  171. def __init__(
  172. self,
  173. hidden_size: int,
  174. num_heads: int,
  175. num_kv_heads: int,
  176. rope_theta: float = 10000,
  177. rope_scaling: Optional[Dict[str, float]] = None,
  178. max_position_embeddings: int = 8192,
  179. quant_config: Optional[QuantizationConfig] = None,
  180. bias: bool = False,
  181. cache_config: Optional[CacheConfig] = None,
  182. config: BitnetConfig = None,
  183. ) -> None:
  184. super().__init__()
  185. self.hidden_size = hidden_size
  186. tp_size = get_tensor_model_parallel_world_size()
  187. self.total_num_heads = num_heads
  188. assert self.total_num_heads % tp_size == 0
  189. self.num_heads = self.total_num_heads // tp_size
  190. self.total_num_kv_heads = num_kv_heads
  191. if self.total_num_kv_heads >= tp_size:
  192. # Number of KV heads is greater than TP size, so we partition
  193. # the KV heads across multiple tensor parallel GPUs.
  194. assert self.total_num_kv_heads % tp_size == 0
  195. else:
  196. # Number of KV heads is less than TP size, so we replicate
  197. # the KV heads across multiple tensor parallel GPUs.
  198. assert tp_size % self.total_num_kv_heads == 0
  199. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  200. self.num_kv_groups = self.num_heads // self.num_kv_heads
  201. self.head_dim = hidden_size // self.total_num_heads
  202. self.padded_head_dim = self.find_flash_attn_supported_head_dims(
  203. self.head_dim)
  204. self.q_size = self.num_heads * self.head_dim
  205. self.kv_size = self.num_kv_heads * self.head_dim
  206. self.scaling = self.head_dim**-0.5
  207. self.rope_theta = rope_theta
  208. self.max_position_embeddings = max_position_embeddings
  209. self.attention_dropout = config.attention_dropout
  210. self.qkv_proj = QKVParallelLinear(
  211. hidden_size=hidden_size,
  212. head_size=self.head_dim,
  213. total_num_heads=self.total_num_heads,
  214. total_num_kv_heads=self.total_num_kv_heads,
  215. bias=bias,
  216. quant_config=quant_config,
  217. )
  218. self.o_proj = RowParallelLinear(
  219. input_size=self.total_num_heads * self.head_dim,
  220. output_size=hidden_size,
  221. bias=bias,
  222. quant_config=quant_config,
  223. )
  224. self.attn = Attention(
  225. self.num_heads,
  226. self.padded_head_dim,
  227. self.scaling,
  228. num_kv_heads=self.num_kv_heads,
  229. cache_config=cache_config,
  230. quant_config=quant_config,
  231. )
  232. self.rotary_emb = get_rope(
  233. self.head_dim,
  234. rotary_dim=self.head_dim,
  235. max_position=max_position_embeddings,
  236. base=rope_theta,
  237. rope_scaling=rope_scaling,
  238. )
  239. self.inner_attn_ln = RMSNorm(config.hidden_size,
  240. eps=config.rms_norm_eps)
  241. def find_flash_attn_supported_head_dims(self, head_dim: int) -> int:
  242. """
  243. Find the closest head dimension to the given head dimension that is supported by Flash Attention.
  244. """
  245. from aphrodite.attention.backends.flash_attn import FlashAttentionBackend
  246. FLASHATTN_SUPPORTED_HEAD_DIMS = (
  247. FlashAttentionBackend.get_supported_head_sizes())
  248. for supported_head_dim in FLASHATTN_SUPPORTED_HEAD_DIMS:
  249. if head_dim <= supported_head_dim:
  250. return supported_head_dim
  251. raise ValueError(
  252. f"Head dimension {head_dim} is not supported by Flash Attention. Supported head dimensions are "
  253. f"{FLASHATTN_SUPPORTED_HEAD_DIMS}.")
  254. def forward(
  255. self,
  256. positions: torch.Tensor,
  257. hidden_states: torch.Tensor,
  258. kv_cache: Optional[torch.Tensor],
  259. attn_metadata: AttentionMetadata,
  260. ) -> torch.Tensor:
  261. # QKV projection cannot be grouped as the they
  262. # do not share the same scaling factor
  263. qkv, _ = self.qkv_proj(hidden_states)
  264. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  265. q, k = self.rotary_emb(positions, q, k)
  266. # Padding as paged attention doesn't support head_dim == 100
  267. q = torch.nn.functional.pad(
  268. q.view(-1, self.total_num_heads, self.head_dim),
  269. (0, self.padded_head_dim - self.head_dim),
  270. ).view(-1, self.total_num_heads * self.padded_head_dim)
  271. k = torch.nn.functional.pad(
  272. k.view(-1, self.num_kv_heads, self.head_dim),
  273. (0, self.padded_head_dim - self.head_dim),
  274. ).view(-1, self.total_num_kv_heads * self.padded_head_dim)
  275. v = torch.nn.functional.pad(
  276. v.view(-1, self.num_kv_heads, self.head_dim),
  277. (0, self.padded_head_dim - self.head_dim),
  278. ).view(-1, self.total_num_kv_heads * self.padded_head_dim)
  279. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  280. attn_output = attn_output.view(
  281. -1, self.total_num_heads,
  282. self.padded_head_dim)[..., :self.head_dim].reshape(
  283. -1, self.total_num_heads * self.head_dim)
  284. attn_output = self.inner_attn_ln(attn_output)
  285. output, _ = self.o_proj(attn_output)
  286. return output
  287. class BitnetDecoderLayer(nn.Module):
  288. def __init__(
  289. self,
  290. config: BitnetConfig,
  291. cache_config: Optional[CacheConfig] = None,
  292. quant_config: Optional[QuantizationConfig] = None,
  293. ) -> None:
  294. super().__init__()
  295. self.hidden_size = config.hidden_size
  296. rope_theta = getattr(config, "rope_theta", 10000)
  297. rope_scaling = getattr(config, "rope_scaling", None)
  298. if rope_scaling is not None and getattr(
  299. config, "original_max_position_embeddings", None):
  300. rope_scaling["original_max_position_embeddings"] = (
  301. config.original_max_position_embeddings)
  302. max_position_embeddings = getattr(config, "max_position_embeddings",
  303. 8192)
  304. attention_bias = getattr(config, "attention_bias", False) or getattr(
  305. config, "bias", False)
  306. self.self_attn = BitnetAttention(
  307. hidden_size=self.hidden_size,
  308. num_heads=config.num_attention_heads,
  309. num_kv_heads=getattr(config, "num_key_value_heads",
  310. config.num_attention_heads),
  311. rope_theta=rope_theta,
  312. rope_scaling=rope_scaling,
  313. max_position_embeddings=max_position_embeddings,
  314. quant_config=quant_config,
  315. bias=attention_bias,
  316. cache_config=cache_config,
  317. config=config,
  318. )
  319. self.mlp = BitnetMLP(
  320. hidden_size=self.hidden_size,
  321. intermediate_size=config.intermediate_size,
  322. hidden_act=config.hidden_act,
  323. quant_config=quant_config,
  324. bias=getattr(config, "mlp_bias", False),
  325. config=config,
  326. )
  327. self.input_layernorm = RMSNorm(config.hidden_size,
  328. eps=config.rms_norm_eps)
  329. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  330. eps=config.rms_norm_eps)
  331. def forward(
  332. self,
  333. positions: torch.Tensor,
  334. hidden_states: torch.Tensor,
  335. kv_cache: torch.Tensor,
  336. attn_metadata: AttentionMetadata,
  337. residual: Optional[torch.Tensor],
  338. ) -> Tuple[torch.Tensor, torch.Tensor]:
  339. # Self Attention
  340. if residual is None:
  341. residual = hidden_states
  342. hidden_states = self.input_layernorm(hidden_states)
  343. else:
  344. hidden_states, residual = self.input_layernorm(
  345. hidden_states, residual)
  346. hidden_states = self.self_attn(
  347. positions=positions,
  348. hidden_states=hidden_states,
  349. kv_cache=kv_cache,
  350. attn_metadata=attn_metadata,
  351. )
  352. # Fully Connected
  353. hidden_states, residual = self.post_attention_layernorm(
  354. hidden_states, residual)
  355. hidden_states = self.mlp(hidden_states)
  356. return hidden_states, residual
  357. class BitnetModel(nn.Module):
  358. def __init__(
  359. self,
  360. config: BitnetConfig,
  361. cache_config: Optional[CacheConfig] = None,
  362. quant_config: Optional[QuantizationConfig] = None,
  363. ) -> None:
  364. super().__init__()
  365. self.config = config
  366. self.padding_idx = config.pad_token_id
  367. self.vocab_size = config.vocab_size
  368. self.org_vocab_size = config.vocab_size
  369. self.embed_tokens = VocabParallelEmbedding(
  370. self.vocab_size,
  371. config.hidden_size,
  372. org_num_embeddings=config.vocab_size,
  373. )
  374. self.layers = nn.ModuleList([
  375. BitnetDecoderLayer(config=config,
  376. cache_config=cache_config,
  377. quant_config=quant_config)
  378. for _ in range(config.num_hidden_layers)
  379. ])
  380. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  381. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  382. return self.embed_tokens(input_ids)
  383. def forward(
  384. self,
  385. input_ids: Optional[torch.Tensor],
  386. positions: torch.Tensor,
  387. kv_caches: List[torch.Tensor],
  388. attn_metadata: AttentionMetadata,
  389. inputs_embeds: Optional[torch.Tensor] = None,
  390. ) -> torch.Tensor:
  391. if inputs_embeds is not None:
  392. hidden_states = inputs_embeds
  393. else:
  394. hidden_states = self.get_input_embeddings(input_ids)
  395. residual = None
  396. for i in range(len(self.layers)):
  397. layer = self.layers[i]
  398. hidden_states, residual = layer(
  399. positions,
  400. hidden_states,
  401. kv_caches[i],
  402. attn_metadata,
  403. residual,
  404. )
  405. hidden_states, _ = self.norm(hidden_states, residual)
  406. return hidden_states
  407. class BitnetForCausalLM(nn.Module):
  408. packed_modules_mapping = {
  409. "qkv_proj": [
  410. "q_proj",
  411. "k_proj",
  412. "v_proj",
  413. ],
  414. "gate_up_proj": [
  415. "gate_proj",
  416. "up_proj",
  417. ],
  418. }
  419. embedding_modules = {
  420. "embed_tokens": "input_embeddings",
  421. "lm_head": "output_embeddings",
  422. }
  423. embedding_padding_modules = ["lm_head"]
  424. def __init__(
  425. self,
  426. config: BitnetConfig,
  427. cache_config: Optional[CacheConfig] = None,
  428. quant_config: Optional[QuantizationConfig] = None,
  429. ) -> None:
  430. super().__init__()
  431. self.config = config
  432. self.model = BitnetModel(config, cache_config, quant_config)
  433. self.unpadded_vocab_size = config.vocab_size
  434. self.lm_head = ParallelLMHead(
  435. self.unpadded_vocab_size,
  436. config.hidden_size,
  437. org_num_embeddings=config.vocab_size,
  438. padding_size=DEFAULT_VOCAB_PADDING_SIZE,
  439. )
  440. if config.tie_word_embeddings:
  441. self.lm_head.weight = self.model.embed_tokens.weight
  442. logit_scale = getattr(config, "logit_scale", 1.0)
  443. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  444. config.vocab_size,
  445. logit_scale)
  446. self.sampler = Sampler()
  447. def forward(
  448. self,
  449. input_ids: torch.Tensor,
  450. positions: torch.Tensor,
  451. kv_caches: List[torch.Tensor],
  452. attn_metadata: AttentionMetadata,
  453. ) -> torch.Tensor:
  454. hidden_states = self.model(input_ids, positions, kv_caches,
  455. attn_metadata)
  456. return hidden_states
  457. def compute_logits(self, hidden_states: torch.Tensor,
  458. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  459. logits = self.logits_processor(self.lm_head, hidden_states,
  460. sampling_metadata)
  461. return logits
  462. def sample(
  463. self,
  464. logits: torch.Tensor,
  465. sampling_metadata: SamplingMetadata,
  466. ) -> Optional[SamplerOutput]:
  467. next_tokens = self.sampler(logits, sampling_metadata)
  468. return next_tokens
  469. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  470. stacked_params_mapping = [
  471. # (param_name, shard_name, shard_id)
  472. (".qkv_proj", ".q_proj", "q"),
  473. (".qkv_proj", ".k_proj", "k"),
  474. (".qkv_proj", ".v_proj", "v"),
  475. (".gate_up_proj", ".gate_proj", 0),
  476. (".gate_up_proj", ".up_proj", 1),
  477. ]
  478. params_dict = dict(self.named_parameters())
  479. for name, loaded_weight in weights:
  480. if "rotary_emb.inv_freq" in name:
  481. continue
  482. if ("rotary_emb.cos_cached" in name
  483. or "rotary_emb.sin_cached" in name):
  484. # Models trained using ColossalAI may include these tensors in
  485. # the checkpoint. Skip them.
  486. continue
  487. for param_name, weight_name, shard_id in stacked_params_mapping:
  488. if weight_name not in name:
  489. continue
  490. name = name.replace(weight_name, param_name)
  491. # Skip loading extra bias for GPTQ models.
  492. if name.endswith(".bias") and name not in params_dict:
  493. continue
  494. param = params_dict[name]
  495. weight_loader = param.weight_loader
  496. # align scaling attr with param
  497. if sum(param.data.shape) == 0 or sum(loaded_weight.shape) == 0:
  498. loaded_weight = loaded_weight.view(param.data.shape)
  499. weight_loader(param, loaded_weight, shard_id)
  500. break
  501. else:
  502. # Skip loading extra bias for GPTQ models.
  503. if name.endswith(".bias") and name not in params_dict:
  504. continue
  505. # Remapping the name of FP8 kv-scale.
  506. if name.endswith("kv_scale"):
  507. remapped_kv_scale_name = name.replace(
  508. ".kv_scale", ".attn.kv_scale")
  509. if remapped_kv_scale_name not in params_dict:
  510. print_warning_once(
  511. f"Found kv scale in the checkpoint (e.g. {name}), "
  512. "but not found the expected name in the model "
  513. f"(e.g. {remapped_kv_scale_name}). kv-scale is "
  514. "not loaded.")
  515. continue
  516. else:
  517. name = remapped_kv_scale_name
  518. param = params_dict[name]
  519. weight_loader = getattr(param, "weight_loader",
  520. default_weight_loader)
  521. # align scaling attr with param
  522. if sum(param.data.shape) == 0 or sum(loaded_weight.shape) == 0:
  523. loaded_weight = loaded_weight.view(param.data.shape)
  524. weight_loader(param, loaded_weight)
  525. # If this function is called, it should always initialize KV cache scale
  526. # factors (or else raise an exception). Thus, handled exceptions should
  527. # make sure to leave KV cache scale factors in a known good (dummy) state
  528. def load_kv_cache_scales(self, quantization_param_path: str) -> None:
  529. tp_size = get_tensor_model_parallel_world_size()
  530. tp_rank = get_tensor_model_parallel_rank()
  531. for layer_idx, scaling_factor in kv_cache_scales_loader(
  532. quantization_param_path,
  533. tp_rank,
  534. tp_size,
  535. self.config.num_hidden_layers,
  536. self.config.__class__.model_type,
  537. ):
  538. layer_self_attn = self.model.layers[layer_idx].self_attn
  539. if is_hip():
  540. # The scaling factor convention we are assuming is
  541. # quantized_value * scaling_factor ~= true_value
  542. # which is consistent with the practice of setting
  543. # scaling_factor = tensor_amax / FPtype_max
  544. scaling_factor *= 2
  545. if hasattr(layer_self_attn, "kv_scale"):
  546. layer_self_attn.attn._kv_scale = scaling_factor
  547. else:
  548. raise RuntimeError("Self attention has no KV cache scaling "
  549. "factor attribute!")