hunyuan.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673
  1. # coding=utf-8
  2. # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
  3. #
  4. # Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT
  5. # (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # https://github.com/Tencent/Tencent-Hunyuan-Large/blob/main/License.docx
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """Inference-only HunYuan model compatible with HuggingFace weights."""
  17. import re
  18. from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
  19. import torch
  20. from torch import nn
  21. from transformers import PretrainedConfig
  22. from aphrodite.attention import Attention, AttentionMetadata
  23. from aphrodite.common.config import CacheConfig, LoRAConfig
  24. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  25. from aphrodite.common.utils import is_hip
  26. from aphrodite.distributed import (get_pp_group,
  27. get_tensor_model_parallel_rank,
  28. get_tensor_model_parallel_world_size,
  29. tensor_model_parallel_all_reduce)
  30. from aphrodite.modeling.layers.activation import SiluAndMul
  31. from aphrodite.modeling.layers.fused_moe import FusedMoE
  32. from aphrodite.modeling.layers.layernorm import RMSNorm
  33. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  34. MergedColumnParallelLinear,
  35. QKVParallelLinear,
  36. ReplicatedLinear,
  37. RowParallelLinear)
  38. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  39. from aphrodite.modeling.layers.rotary_embedding import get_rope
  40. from aphrodite.modeling.layers.sampler import Sampler
  41. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  42. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  43. from aphrodite.modeling.model_loader.weight_utils import (
  44. default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
  45. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  46. from aphrodite.quantization.base_config import QuantizationConfig
  47. from aphrodite.quantization.compressed_tensors.utils import (
  48. get_compressed_tensors_cache_scale)
  49. from .interfaces import SupportsLoRA
  50. from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
  51. class HunYuanMLP(nn.Module):
  52. def __init__(
  53. self,
  54. hidden_size: int,
  55. intermediate_size: int,
  56. hidden_act: str,
  57. quant_config: Optional[QuantizationConfig] = None,
  58. bias: bool = False,
  59. prefix: str = "",
  60. reduce_results: bool = True,
  61. ) -> None:
  62. super().__init__()
  63. self.gate_up_proj = MergedColumnParallelLinear(
  64. input_size=hidden_size,
  65. output_sizes=[intermediate_size] * 2,
  66. bias=bias,
  67. quant_config=quant_config,
  68. prefix=f"{prefix}.gate_up_proj")
  69. self.down_proj = RowParallelLinear(input_size=intermediate_size,
  70. output_size=hidden_size,
  71. bias=bias,
  72. quant_config=quant_config,
  73. prefix=f"{prefix}.down_proj",
  74. reduce_results=reduce_results)
  75. if hidden_act != "silu":
  76. raise ValueError(f"Unsupported activation: {hidden_act}. "
  77. "Only silu is supported for now.")
  78. self.act_fn = SiluAndMul()
  79. def forward(self, x):
  80. gate_up, _ = self.gate_up_proj(x)
  81. x = self.act_fn(gate_up)
  82. x, _ = self.down_proj(x)
  83. return x
  84. class HunYuanSparseMoeBlock(nn.Module):
  85. def __init__(
  86. self,
  87. config: PretrainedConfig,
  88. quant_config: Optional[QuantizationConfig] = None,
  89. ):
  90. super().__init__()
  91. self.tp_size = get_tensor_model_parallel_world_size()
  92. if self.tp_size > config.num_experts:
  93. raise ValueError(
  94. f"Tensor parallel size {self.tp_size} is greater than "
  95. f"the number of experts {config.num_experts}.")
  96. self.experts = FusedMoE(num_experts=config.num_experts,
  97. top_k=config.moe_topk,
  98. hidden_size=config.hidden_size,
  99. intermediate_size=config.intermediate_size,
  100. reduce_results=False,
  101. renormalize=True if config.moe_topk>1 else False, # noqa: SIM210, E501
  102. quant_config=quant_config)
  103. self.gate = ReplicatedLinear(config.hidden_size,
  104. config.num_experts,
  105. bias=False,
  106. quant_config=None)
  107. if config.use_mixed_mlp_moe > 0:
  108. self.shared_mlp = HunYuanMLP(
  109. hidden_size=config.hidden_size,
  110. intermediate_size=config.intermediate_size *
  111. config.num_shared_expert,
  112. hidden_act=config.hidden_act,
  113. quant_config=quant_config,
  114. reduce_results=False,
  115. )
  116. else:
  117. self.shared_mlp = None
  118. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  119. # NOTE: hidden_states can have either 1D or 2D shape.
  120. orig_shape = hidden_states.shape
  121. hidden_dim = hidden_states.shape[-1]
  122. hidden_states = hidden_states.view(-1, hidden_dim)
  123. shared_output = None
  124. if self.shared_mlp is not None:
  125. shared_output = self.shared_mlp(hidden_states)
  126. # router_logits: (num_tokens, n_experts)
  127. router_logits, _ = self.gate(hidden_states)
  128. final_hidden_states = self.experts(hidden_states=hidden_states,
  129. router_logits=router_logits)
  130. if shared_output is not None:
  131. final_hidden_states = final_hidden_states + shared_output
  132. if self.tp_size > 1:
  133. final_hidden_states = tensor_model_parallel_all_reduce(
  134. final_hidden_states)
  135. return final_hidden_states.view(orig_shape)
  136. class HunYuanAttention(nn.Module):
  137. def __init__(
  138. self,
  139. config: PretrainedConfig,
  140. hidden_size: int,
  141. num_heads: int,
  142. num_kv_heads: int,
  143. rope_theta: float = 10000,
  144. rope_scaling: Optional[Dict[str, Any]] = None,
  145. max_position_embeddings: int = 8192,
  146. quant_config: Optional[QuantizationConfig] = None,
  147. bias: bool = False,
  148. cache_config: Optional[CacheConfig] = None,
  149. prefix: str = "",
  150. attention_type: str = "self",
  151. ) -> None:
  152. super().__init__()
  153. self.hidden_size = hidden_size
  154. tp_size = get_tensor_model_parallel_world_size()
  155. self.total_num_heads = num_heads
  156. assert self.total_num_heads % tp_size == 0
  157. self.num_heads = self.total_num_heads // tp_size
  158. self.total_num_kv_heads = num_kv_heads
  159. if self.total_num_kv_heads >= tp_size:
  160. # Number of KV heads is greater than TP size, so we partition
  161. # the KV heads across multiple tensor parallel GPUs.
  162. assert self.total_num_kv_heads % tp_size == 0
  163. else:
  164. # Number of KV heads is less than TP size, so we replicate
  165. # the KV heads across multiple tensor parallel GPUs.
  166. assert tp_size % self.total_num_kv_heads == 0
  167. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  168. # MistralConfig has an optional head_dim introduced by Mistral-Nemo
  169. self.head_dim = getattr(config, "head_dim",
  170. self.hidden_size // self.total_num_heads)
  171. self.q_size = self.num_heads * self.head_dim
  172. self.kv_size = self.num_kv_heads * self.head_dim
  173. self.scaling = self.head_dim**-0.5
  174. self.rope_theta = rope_theta
  175. self.max_position_embeddings = max_position_embeddings
  176. self.use_qk_norm = config.use_qk_norm
  177. self.attention_type = attention_type
  178. if attention_type == "self":
  179. self.qkv_proj = QKVParallelLinear(
  180. hidden_size=hidden_size,
  181. head_size=self.head_dim,
  182. total_num_heads=self.total_num_heads,
  183. total_num_kv_heads=self.total_num_kv_heads,
  184. bias=bias,
  185. quant_config=quant_config,
  186. prefix=f"{prefix}.qkv_proj",
  187. )
  188. elif attention_type == "cross":
  189. self.q_proj = ColumnParallelLinear(
  190. hidden_size,
  191. hidden_size,
  192. bias=bias,
  193. quant_config=quant_config,
  194. prefix=f"{prefix}.q_proj",
  195. )
  196. else:
  197. raise RuntimeError("Not support attnention type")
  198. self.o_proj = RowParallelLinear(
  199. input_size=self.total_num_heads * self.head_dim,
  200. output_size=hidden_size,
  201. bias=bias,
  202. quant_config=quant_config,
  203. prefix=f"{prefix}.o_proj",
  204. )
  205. is_neox_style = True
  206. if quant_config is not None and quant_config.get_name() == "gguf":
  207. is_neox_style = False
  208. self.rotary_emb = get_rope(
  209. self.head_dim,
  210. rotary_dim=self.head_dim,
  211. max_position=max_position_embeddings,
  212. base=rope_theta,
  213. rope_scaling=rope_scaling,
  214. is_neox_style=is_neox_style,
  215. )
  216. self.attn = Attention(self.num_heads,
  217. self.head_dim,
  218. self.scaling,
  219. num_kv_heads=self.num_kv_heads,
  220. cache_config=cache_config,
  221. quant_config=quant_config)
  222. if self.use_qk_norm:
  223. self.query_layernorm = RMSNorm(self.head_dim,
  224. eps=config.rms_norm_eps)
  225. self.key_layernorm = RMSNorm(self.head_dim,
  226. eps=config.rms_norm_eps)
  227. def forward(
  228. self,
  229. positions: torch.Tensor,
  230. hidden_states: torch.Tensor,
  231. kv_cache: torch.Tensor,
  232. attn_metadata: AttentionMetadata,
  233. kv_states: Optional[Tuple[torch.Tensor]] = None,
  234. ) -> torch.Tensor:
  235. if self.attention_type == "self":
  236. qkv, _ = self.qkv_proj(hidden_states)
  237. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
  238. dim=-1)
  239. q, k = self.rotary_emb(positions, q, k)
  240. ori_k = k
  241. if self.use_qk_norm:
  242. q = self.query_layernorm(q.view(-1, self.num_heads,
  243. self.head_dim).contiguous())
  244. k = self.key_layernorm(k.view(-1, self.num_kv_heads,
  245. self.head_dim).contiguous())
  246. elif self.attention_type == "cross":
  247. assert kv_states is not None
  248. ori_k, v = kv_states # use last layer kv,
  249. k = ori_k
  250. q, _ = self.q_proj(hidden_states)
  251. k_tmp = torch.empty_like(k) # Todo: reduant rotary embedding
  252. q, _ = self.rotary_emb(positions, q, k_tmp)
  253. if self.use_qk_norm:
  254. q = self.query_layernorm(q.view(-1, self.num_heads,
  255. self.head_dim).contiguous())
  256. k = self.key_layernorm(k.view(-1, self.num_kv_heads,
  257. self.head_dim).contiguous())
  258. else:
  259. raise RuntimeError("Not support attnention type")
  260. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  261. output, _ = self.o_proj(attn_output)
  262. return output, (ori_k, v)
  263. class HunYuanDecoderLayer(nn.Module):
  264. def __init__(
  265. self,
  266. config: PretrainedConfig,
  267. cache_config: Optional[CacheConfig] = None,
  268. quant_config: Optional[QuantizationConfig] = None,
  269. prefix: str = "",
  270. layer_id: int = -1,
  271. ) -> None:
  272. super().__init__()
  273. self.hidden_size = config.hidden_size
  274. rope_theta = getattr(config, "rope_theta", 10000)
  275. rope_scaling = getattr(config, "rope_scaling", None)
  276. if rope_scaling is not None and getattr(
  277. config, "original_max_position_embeddings", None):
  278. rope_scaling["original_max_position_embeddings"] = (
  279. config.original_max_position_embeddings)
  280. max_position_embeddings = getattr(config, "max_position_embeddings",
  281. 8192)
  282. # Support abacusai/Smaug-72B-v0.1 with attention_bias
  283. # Support internlm/internlm-7b with bias
  284. attention_bias = getattr(config, "attention_bias", False) or getattr(
  285. config, "bias", False)
  286. cla_factor = getattr(config, "cla_share_factor", 1)
  287. attention_type = "cross" \
  288. if layer_id >= 0 and layer_id % cla_factor != 0 else "self"
  289. self.self_attn = HunYuanAttention(
  290. config=config,
  291. hidden_size=self.hidden_size,
  292. num_heads=config.num_attention_heads,
  293. num_kv_heads=getattr(config, "num_key_value_heads",
  294. config.num_attention_heads),
  295. rope_theta=rope_theta,
  296. rope_scaling=rope_scaling,
  297. max_position_embeddings=max_position_embeddings,
  298. quant_config=quant_config,
  299. bias=attention_bias,
  300. cache_config=cache_config,
  301. prefix=f"{prefix}.self_attn",
  302. attention_type=attention_type,
  303. )
  304. if getattr(config, "num_experts", None):
  305. self.mlp = HunYuanSparseMoeBlock(config=config,
  306. quant_config=quant_config)
  307. else:
  308. self.mlp = HunYuanMLP(
  309. hidden_size=self.hidden_size,
  310. intermediate_size=config.intermediate_size,
  311. hidden_act=config.hidden_act,
  312. quant_config=quant_config,
  313. bias=getattr(config, "mlp_bias", False),
  314. prefix=f"{prefix}.mlp",
  315. )
  316. self.input_layernorm = RMSNorm(config.hidden_size,
  317. eps=config.rms_norm_eps)
  318. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  319. eps=config.rms_norm_eps)
  320. def forward(
  321. self,
  322. positions: torch.Tensor,
  323. hidden_states: torch.Tensor,
  324. kv_cache: torch.Tensor,
  325. attn_metadata: AttentionMetadata,
  326. residual: Optional[torch.Tensor],
  327. kv_states: Optional[Tuple[torch.Tensor]] = None,
  328. ) -> Tuple[torch.Tensor, torch.Tensor]:
  329. # Self Attention
  330. if residual is None:
  331. residual = hidden_states
  332. hidden_states = self.input_layernorm(hidden_states)
  333. else:
  334. hidden_states, residual = self.input_layernorm(
  335. hidden_states, residual)
  336. hidden_states, ori_kv_states = self.self_attn(
  337. positions=positions,
  338. hidden_states=hidden_states,
  339. kv_cache=kv_cache,
  340. attn_metadata=attn_metadata,
  341. kv_states=kv_states,
  342. )
  343. # Fully Connected
  344. hidden_states, residual = self.post_attention_layernorm(
  345. hidden_states, residual)
  346. hidden_states = self.mlp(hidden_states)
  347. return hidden_states, residual, ori_kv_states
  348. class HunYuanModel(nn.Module):
  349. def __init__(
  350. self,
  351. config: PretrainedConfig,
  352. cache_config: Optional[CacheConfig] = None,
  353. quant_config: Optional[QuantizationConfig] = None,
  354. lora_config: Optional[LoRAConfig] = None,
  355. prefix: str = "",
  356. ) -> None:
  357. super().__init__()
  358. self.config = config
  359. self.padding_idx = config.pad_token_id
  360. lora_vocab = (lora_config.lora_extra_vocab_size *
  361. (lora_config.max_loras or 1)) if lora_config else 0
  362. self.vocab_size = config.vocab_size + lora_vocab
  363. self.org_vocab_size = config.vocab_size
  364. if get_pp_group().is_first_rank or (config.tie_word_embeddings
  365. and get_pp_group().is_last_rank):
  366. self.embed_tokens = VocabParallelEmbedding(
  367. self.vocab_size,
  368. config.hidden_size,
  369. org_num_embeddings=config.vocab_size,
  370. quant_config=quant_config,
  371. )
  372. else:
  373. self.embed_tokens = PPMissingLayer()
  374. self.start_layer, self.end_layer, self.layers = make_layers(
  375. config.num_hidden_layers,
  376. lambda prefix: HunYuanDecoderLayer(config=config,
  377. layer_id=int(
  378. prefix.split(".")[-1]),
  379. cache_config=cache_config,
  380. quant_config=quant_config,
  381. prefix=prefix),
  382. prefix=f"{prefix}.layers")
  383. if get_pp_group().is_last_rank:
  384. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  385. else:
  386. self.norm = PPMissingLayer()
  387. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  388. return self.embed_tokens(input_ids)
  389. def forward(
  390. self,
  391. input_ids: Optional[torch.Tensor],
  392. positions: torch.Tensor,
  393. kv_caches: List[torch.Tensor],
  394. attn_metadata: AttentionMetadata,
  395. intermediate_tensors: Optional[IntermediateTensors],
  396. inputs_embeds: Optional[torch.Tensor] = None,
  397. ) -> Union[torch.Tensor, IntermediateTensors]:
  398. if get_pp_group().is_first_rank:
  399. if inputs_embeds is not None:
  400. hidden_states = inputs_embeds
  401. else:
  402. hidden_states = self.get_input_embeddings(input_ids)
  403. residual = None
  404. else:
  405. assert intermediate_tensors is not None
  406. hidden_states = intermediate_tensors["hidden_states"]
  407. residual = intermediate_tensors["residual"]
  408. cla_factor = getattr(self.config, "cla_share_factor", 1)
  409. prev_kv_states = None
  410. for i in range(self.start_layer, self.end_layer):
  411. layer = self.layers[i]
  412. hidden_states, residual, kv_states = layer(
  413. positions,
  414. hidden_states,
  415. kv_caches[i - self.start_layer],
  416. # kv_caches[(i - self.start_layer) // cla_factor],
  417. attn_metadata,
  418. residual,
  419. prev_kv_states,
  420. )
  421. if (i - self.start_layer) % cla_factor == 0:
  422. prev_kv_states = kv_states
  423. else:
  424. prev_kv_states = None
  425. if not get_pp_group().is_last_rank:
  426. return IntermediateTensors({
  427. "hidden_states": hidden_states,
  428. "residual": residual
  429. })
  430. hidden_states, _ = self.norm(hidden_states, residual)
  431. return hidden_states
  432. class HunYuanForCausalLM(nn.Module, SupportsLoRA):
  433. packed_modules_mapping = {
  434. "qkv_proj": [
  435. "q_proj",
  436. "k_proj",
  437. "v_proj",
  438. ],
  439. "gate_up_proj": [
  440. "gate_proj",
  441. "up_proj",
  442. ],
  443. }
  444. # LoRA specific attributes
  445. supported_lora_modules = [
  446. "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
  447. "lm_head"
  448. ]
  449. embedding_modules = {
  450. "embed_tokens": "input_embeddings",
  451. "lm_head": "output_embeddings",
  452. }
  453. embedding_padding_modules = ["lm_head"]
  454. bitsandbytes_stacked_params_mapping = {
  455. # shard_name, weight_name, index
  456. "q_proj": ("qkv_proj", 0),
  457. "k_proj": ("qkv_proj", 1),
  458. "v_proj": ("qkv_proj", 2),
  459. "gate_proj": ("gate_up_proj", 0),
  460. "up_proj": ("gate_up_proj", 1),
  461. }
  462. def __init__(
  463. self,
  464. config: PretrainedConfig,
  465. cache_config: Optional[CacheConfig] = None,
  466. quant_config: Optional[QuantizationConfig] = None,
  467. lora_config: Optional[LoRAConfig] = None,
  468. ) -> None:
  469. super().__init__()
  470. self.config = config
  471. self.lora_config = lora_config
  472. self.model = HunYuanModel(config,
  473. cache_config,
  474. quant_config,
  475. lora_config=lora_config,
  476. prefix="model")
  477. if get_pp_group().is_last_rank:
  478. self.unpadded_vocab_size = config.vocab_size
  479. if lora_config:
  480. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  481. self.lm_head = ParallelLMHead(
  482. self.unpadded_vocab_size,
  483. config.hidden_size,
  484. org_num_embeddings=config.vocab_size,
  485. padding_size=DEFAULT_VOCAB_PADDING_SIZE
  486. # We need bigger padding if using lora for kernel
  487. # compatibility
  488. if not lora_config else lora_config.lora_vocab_padding_size,
  489. quant_config=quant_config,
  490. )
  491. if config.tie_word_embeddings:
  492. self.lm_head.weight = self.model.embed_tokens.weight
  493. logit_scale = getattr(config, "logit_scale", 1.0)
  494. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  495. config.vocab_size,
  496. logit_scale)
  497. self.sampler = Sampler()
  498. else:
  499. self.lm_head = PPMissingLayer()
  500. def forward(
  501. self,
  502. input_ids: torch.Tensor,
  503. positions: torch.Tensor,
  504. kv_caches: List[torch.Tensor],
  505. attn_metadata: AttentionMetadata,
  506. intermediate_tensors: Optional[IntermediateTensors] = None,
  507. ) -> Union[torch.Tensor, IntermediateTensors]:
  508. model_output = self.model(input_ids, positions, kv_caches,
  509. attn_metadata, intermediate_tensors)
  510. return model_output
  511. def compute_logits(
  512. self,
  513. hidden_states: torch.Tensor,
  514. sampling_metadata: SamplingMetadata,
  515. ) -> Optional[torch.Tensor]:
  516. logits = self.logits_processor(self.lm_head, hidden_states,
  517. sampling_metadata)
  518. return logits
  519. def sample(
  520. self,
  521. logits: torch.Tensor,
  522. sampling_metadata: SamplingMetadata,
  523. ) -> Optional[SamplerOutput]:
  524. next_tokens = self.sampler(logits, sampling_metadata)
  525. return next_tokens
  526. def make_empty_intermediate_tensors(
  527. self, batch_size: int, dtype: torch.dtype,
  528. device: torch.device) -> IntermediateTensors:
  529. return IntermediateTensors({
  530. "hidden_states":
  531. torch.zeros((batch_size, self.config.hidden_size),
  532. dtype=dtype,
  533. device=device),
  534. "residual":
  535. torch.zeros((batch_size, self.config.hidden_size),
  536. dtype=dtype,
  537. device=device),
  538. })
  539. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  540. cla_factor = getattr(self.config, "cla_share_factor", 1)
  541. stacked_params_mapping = [
  542. # (param_name, shard_name, shard_id)
  543. (".qkv_proj", ".q_proj", "q"),
  544. (".qkv_proj", ".k_proj", "k"),
  545. (".qkv_proj", ".v_proj", "v"),
  546. (".gate_up_proj", ".gate_proj", 0),
  547. (".gate_up_proj", ".up_proj", 1),
  548. ]
  549. if getattr(self.config, "num_experts", None):
  550. # Params for weights, fp8 weight scales, fp8 activation scales
  551. # (param_name, weight_name, expert_id, shard_id)
  552. expert_params_mapping = FusedMoE.make_expert_params_mapping(
  553. ckpt_gate_proj_name="gate_proj",
  554. ckpt_down_proj_name="down_proj",
  555. ckpt_up_proj_name="up_proj",
  556. num_experts=self.config.num_experts)
  557. else:
  558. expert_params_mapping = {}
  559. params_dict = dict(self.named_parameters())
  560. for name, loaded_weight in weights:
  561. if "rotary_emb.inv_freq" in name:
  562. continue
  563. if ("rotary_emb.cos_cached" in name
  564. or "rotary_emb.sin_cached" in name):
  565. # Models trained using ColossalAI may include these tensors in
  566. # the checkpoint. Skip them.
  567. continue
  568. # With tie_word_embeddings, we can skip lm_head.weight
  569. # The weight might appear unnecessarily in the files if the model is
  570. # processed with quantization, LoRA, fine-tuning, etc.
  571. if self.config.tie_word_embeddings and "lm_head.weight" in name:
  572. continue
  573. if scale_name := get_compressed_tensors_cache_scale(name):
  574. # Loading kv cache scales for compressed-tensors quantization
  575. param = params_dict[scale_name]
  576. weight_loader = getattr(param, "weight_loader",
  577. default_weight_loader)
  578. loaded_weight = loaded_weight[0]
  579. weight_loader(param, loaded_weight)
  580. continue
  581. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  582. if weight_name not in name:
  583. continue
  584. if "mlp.experts" in name:
  585. continue
  586. # cross layer only have q_proj, skip qkv pack
  587. if weight_name == ".q_proj":
  588. match = re.search(r'layers\.\d+', name)
  589. if match:
  590. layer_id = int(match.group(0).split('.')[-1])
  591. if cla_factor > 1 and layer_id % cla_factor != 0:
  592. continue
  593. name = name.replace(weight_name, param_name)
  594. # Skip loading extra bias for GPTQ models.
  595. if name.endswith(".bias") and name not in params_dict:
  596. continue
  597. if is_pp_missing_parameter(name, self):
  598. continue
  599. param = params_dict[name]
  600. weight_loader = param.weight_loader
  601. weight_loader(param, loaded_weight, shard_id)
  602. break
  603. else:
  604. # Skip loading extra bias for GPTQ models.
  605. if name.endswith(".bias") and name not in params_dict:
  606. continue
  607. for mapping in expert_params_mapping:
  608. param_name, weight_name, expert_id, shard_id = mapping
  609. if weight_name not in name:
  610. continue
  611. name = name.replace(weight_name, param_name)
  612. # Skip layers on other devices.
  613. if is_pp_missing_parameter(name, self):
  614. continue
  615. param = params_dict[name]
  616. weight_loader = param.weight_loader
  617. weight_loader(param,
  618. loaded_weight,
  619. name,
  620. shard_id=shard_id,
  621. expert_id=expert_id)
  622. break
  623. else:
  624. # Remapping the name of FP8 kv-scale.
  625. name = maybe_remap_kv_scale_name(name, params_dict)
  626. if name is None:
  627. continue
  628. if is_pp_missing_parameter(name, self):
  629. continue
  630. if "mlp.gate.wg." in name:
  631. name = name.replace("wg.", "")
  632. param = params_dict[name]
  633. weight_loader = getattr(param, "weight_loader",
  634. default_weight_loader)
  635. weight_loader(param, loaded_weight)
  636. # If this function is called, it should always initialize KV cache scale
  637. # factors (or else raise an exception). Thus, handled exceptions should
  638. # make sure to leave KV cache scale factors in a known good (dummy) state
  639. def load_kv_cache_scales(self, quantization_param_path: str) -> None:
  640. tp_size = get_tensor_model_parallel_world_size()
  641. tp_rank = get_tensor_model_parallel_rank()
  642. for layer_idx, scaling_factor in kv_cache_scales_loader(
  643. quantization_param_path, tp_rank, tp_size,
  644. self.config.num_hidden_layers,
  645. self.config.__class__.model_type):
  646. if not isinstance(self.model.layers[layer_idx], nn.Identity):
  647. layer_self_attn = self.model.layers[layer_idx].self_attn
  648. if is_hip():
  649. # The scaling factor convention we are assuming is
  650. # quantized_value * scaling_factor ~= true_value
  651. # which is consistent with the practice of setting
  652. # scaling_factor = tensor_amax / FPtype_max
  653. scaling_factor *= 2
  654. if hasattr(layer_self_attn, "kv_scale"):
  655. layer_self_attn.attn._kv_scale = scaling_factor
  656. else:
  657. raise RuntimeError("Self attention has no KV cache scaling "
  658. "factor attribute!")