solar.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only Solar model compatible with HuggingFace weights."""
  25. from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
  26. import torch
  27. from torch import nn
  28. from aphrodite.attention import Attention, AttentionMetadata
  29. from aphrodite.common.config import CacheConfig, LoRAConfig
  30. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  31. from aphrodite.common.utils import is_hip
  32. from aphrodite.distributed import (get_pp_group,
  33. get_tensor_model_parallel_rank,
  34. get_tensor_model_parallel_world_size)
  35. from aphrodite.modeling.layers.activation import SiluAndMul
  36. from aphrodite.modeling.layers.layernorm import RMSNorm
  37. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  38. QKVParallelLinear,
  39. RowParallelLinear)
  40. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  41. from aphrodite.modeling.layers.rotary_embedding import get_rope
  42. from aphrodite.modeling.layers.sampler import Sampler
  43. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  44. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  45. from aphrodite.modeling.model_loader.weight_utils import (
  46. default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
  47. from aphrodite.modeling.models.interfaces import SupportsLoRA
  48. from aphrodite.modeling.models.utils import (PPMissingLayer,
  49. is_pp_missing_parameter,
  50. make_layers)
  51. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  52. from aphrodite.quantization.base_config import QuantizationConfig
  53. from aphrodite.quantization.compressed_tensors.utils import (
  54. get_compressed_tensors_cache_scale)
  55. class SolarMLP(nn.Module):
  56. def __init__(
  57. self,
  58. hidden_size: int,
  59. intermediate_size: int,
  60. hidden_act: str,
  61. quant_config: Optional[QuantizationConfig] = None,
  62. bias: bool = False,
  63. prefix: str = "",
  64. ) -> None:
  65. super().__init__()
  66. self.gate_up_proj = MergedColumnParallelLinear(
  67. input_size=hidden_size,
  68. output_sizes=[intermediate_size] * 2,
  69. bias=bias,
  70. quant_config=quant_config,
  71. prefix=f"{prefix}.gate_up_proj")
  72. self.down_proj = RowParallelLinear(input_size=intermediate_size,
  73. output_size=hidden_size,
  74. bias=bias,
  75. quant_config=quant_config,
  76. prefix=f"{prefix}.down_proj")
  77. if hidden_act != "silu":
  78. raise ValueError(f"Unsupported activation: {hidden_act}. "
  79. "Only silu is supported for now.")
  80. self.act_fn = SiluAndMul()
  81. def forward(self, x):
  82. gate_up, _ = self.gate_up_proj(x)
  83. x = self.act_fn(gate_up)
  84. x, _ = self.down_proj(x)
  85. return x
  86. class SolarAttention(nn.Module):
  87. def __init__(
  88. self,
  89. config,
  90. hidden_size: int,
  91. num_heads: int,
  92. num_kv_heads: int,
  93. rope_theta: float = 10000,
  94. rope_scaling: Optional[Dict[str, Any]] = None,
  95. max_position_embeddings: int = 8192,
  96. quant_config: Optional[QuantizationConfig] = None,
  97. bias: bool = False,
  98. cache_config: Optional[CacheConfig] = None,
  99. prefix: str = "",
  100. ) -> None:
  101. super().__init__()
  102. self.hidden_size = hidden_size
  103. tp_size = get_tensor_model_parallel_world_size()
  104. self.total_num_heads = num_heads
  105. assert self.total_num_heads % tp_size == 0
  106. self.num_heads = self.total_num_heads // tp_size
  107. self.total_num_kv_heads = num_kv_heads
  108. if self.total_num_kv_heads >= tp_size:
  109. # Number of KV heads is greater than TP size, so we partition
  110. # the KV heads across multiple tensor parallel GPUs.
  111. assert self.total_num_kv_heads % tp_size == 0
  112. else:
  113. # Number of KV heads is less than TP size, so we replicate
  114. # the KV heads across multiple tensor parallel GPUs.
  115. assert tp_size % self.total_num_kv_heads == 0
  116. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  117. # MistralConfig has an optional head_dim introduced by Mistral-Nemo
  118. self.head_dim = getattr(config, "head_dim",
  119. self.hidden_size // self.total_num_heads)
  120. self.q_size = self.num_heads * self.head_dim
  121. self.kv_size = self.num_kv_heads * self.head_dim
  122. self.scaling = self.head_dim**-0.5
  123. self.rope_theta = rope_theta
  124. self.max_position_embeddings = max_position_embeddings
  125. self.qkv_proj = QKVParallelLinear(
  126. hidden_size=hidden_size,
  127. head_size=self.head_dim,
  128. total_num_heads=self.total_num_heads,
  129. total_num_kv_heads=self.total_num_kv_heads,
  130. bias=bias,
  131. quant_config=quant_config,
  132. prefix=f"{prefix}.qkv_proj",
  133. )
  134. self.o_proj = RowParallelLinear(
  135. input_size=self.total_num_heads * self.head_dim,
  136. output_size=hidden_size,
  137. bias=bias,
  138. quant_config=quant_config,
  139. prefix=f"{prefix}.o_proj",
  140. )
  141. self.rotary_emb = get_rope(
  142. self.head_dim,
  143. rotary_dim=self.head_dim,
  144. max_position=max_position_embeddings,
  145. base=rope_theta,
  146. rope_scaling=rope_scaling,
  147. )
  148. self.attn = Attention(self.num_heads,
  149. self.head_dim,
  150. self.scaling,
  151. num_kv_heads=self.num_kv_heads,
  152. cache_config=cache_config,
  153. quant_config=quant_config)
  154. def forward(
  155. self,
  156. positions: torch.Tensor,
  157. hidden_states: torch.Tensor,
  158. kv_cache: torch.Tensor,
  159. attn_metadata: AttentionMetadata,
  160. ) -> torch.Tensor:
  161. qkv, _ = self.qkv_proj(hidden_states)
  162. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  163. q, k = self.rotary_emb(positions, q, k)
  164. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  165. output, _ = self.o_proj(attn_output)
  166. return output
  167. class SolarDecoderLayer(nn.Module):
  168. def __init__(
  169. self,
  170. config,
  171. cache_config: Optional[CacheConfig] = None,
  172. quant_config: Optional[QuantizationConfig] = None,
  173. prefix: str = "",
  174. ) -> None:
  175. super().__init__()
  176. self.hidden_size = config.hidden_size
  177. rope_theta = getattr(config, "rope_theta", 10000)
  178. rope_scaling = getattr(config, "rope_scaling", None)
  179. if rope_scaling is not None and getattr(
  180. config, "original_max_position_embeddings", None):
  181. rope_scaling["original_max_position_embeddings"] = (
  182. config.original_max_position_embeddings)
  183. max_position_embeddings = getattr(config, "max_position_embeddings",
  184. 8192)
  185. # Support abacusai/Smaug-72B-v0.1 with attention_bias
  186. # Support internlm/internlm-7b with bias
  187. attention_bias = getattr(config, "attention_bias", False) or getattr(
  188. config, "bias", False)
  189. self.self_attn = SolarAttention(
  190. config=config,
  191. hidden_size=self.hidden_size,
  192. num_heads=config.num_attention_heads,
  193. num_kv_heads=getattr(config, "num_key_value_heads",
  194. config.num_attention_heads),
  195. rope_theta=rope_theta,
  196. rope_scaling=rope_scaling,
  197. max_position_embeddings=max_position_embeddings,
  198. quant_config=quant_config,
  199. bias=attention_bias,
  200. cache_config=cache_config,
  201. prefix=f"{prefix}.self_attn",
  202. )
  203. self.mlp = SolarMLP(
  204. hidden_size=self.hidden_size,
  205. intermediate_size=config.intermediate_size,
  206. hidden_act=config.hidden_act,
  207. quant_config=quant_config,
  208. bias=getattr(config, "mlp_bias", False),
  209. prefix=f"{prefix}.mlp",
  210. )
  211. self.input_layernorm = RMSNorm(config.hidden_size,
  212. eps=config.rms_norm_eps)
  213. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  214. eps=config.rms_norm_eps)
  215. def forward(
  216. self,
  217. positions: torch.Tensor,
  218. hidden_states: torch.Tensor,
  219. kv_cache: torch.Tensor,
  220. attn_metadata: AttentionMetadata,
  221. residual: Optional[torch.Tensor],
  222. ) -> Tuple[torch.Tensor, torch.Tensor]:
  223. # Self Attention
  224. if residual is None:
  225. residual = hidden_states
  226. hidden_states = self.input_layernorm(hidden_states)
  227. else:
  228. hidden_states, residual = self.input_layernorm(
  229. hidden_states, residual)
  230. hidden_states = self.self_attn(
  231. positions=positions,
  232. hidden_states=hidden_states,
  233. kv_cache=kv_cache,
  234. attn_metadata=attn_metadata,
  235. )
  236. # Fully Connected
  237. hidden_states, residual = self.post_attention_layernorm(
  238. hidden_states, residual)
  239. hidden_states = self.mlp(hidden_states)
  240. return hidden_states, residual
  241. class SolarModel(nn.Module):
  242. def __init__(
  243. self,
  244. config,
  245. cache_config: Optional[CacheConfig] = None,
  246. quant_config: Optional[QuantizationConfig] = None,
  247. lora_config: Optional[LoRAConfig] = None,
  248. prefix: str = "",
  249. ) -> None:
  250. super().__init__()
  251. self.config = config
  252. self.padding_idx = config.pad_token_id
  253. lora_vocab = (lora_config.lora_extra_vocab_size *
  254. (lora_config.max_loras or 1)) if lora_config else 0
  255. self.vocab_size = config.vocab_size + lora_vocab
  256. self.org_vocab_size = config.vocab_size
  257. if get_pp_group().is_first_rank or (config.tie_word_embeddings
  258. and get_pp_group().is_last_rank):
  259. self.embed_tokens = VocabParallelEmbedding(
  260. self.vocab_size,
  261. config.hidden_size,
  262. org_num_embeddings=config.vocab_size,
  263. )
  264. else:
  265. self.embed_tokens = PPMissingLayer()
  266. self.start_layer, self.end_layer, self.layers = make_layers(
  267. config.num_hidden_layers,
  268. lambda prefix: SolarDecoderLayer(config=config,
  269. cache_config=cache_config,
  270. quant_config=quant_config,
  271. prefix=prefix),
  272. prefix=f"{prefix}.layers")
  273. if get_pp_group().is_last_rank:
  274. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  275. else:
  276. self.norm = PPMissingLayer()
  277. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  278. return self.embed_tokens(input_ids)
  279. def forward(
  280. self,
  281. input_ids: Optional[torch.Tensor],
  282. positions: torch.Tensor,
  283. kv_caches: List[torch.Tensor],
  284. attn_metadata: AttentionMetadata,
  285. intermediate_tensors: Optional[IntermediateTensors],
  286. inputs_embeds: Optional[torch.Tensor] = None,
  287. ) -> Union[torch.Tensor, IntermediateTensors]:
  288. if get_pp_group().is_first_rank:
  289. if inputs_embeds is not None:
  290. hidden_states = inputs_embeds
  291. else:
  292. hidden_states = self.get_input_embeddings(input_ids)
  293. residual = None
  294. else:
  295. assert intermediate_tensors is not None
  296. hidden_states = intermediate_tensors["hidden_states"]
  297. residual = intermediate_tensors["residual"]
  298. bskcn_h_1 = None
  299. bskcn_h_2 = None
  300. bskcn_r_1 = None
  301. bskcn_r_2 = None
  302. bskcn_tv = (self.config.bskcn_tv[0] \
  303. if self.training else self.config.bskcn_tv[1])
  304. for i in range(self.start_layer, self.end_layer):
  305. if i in self.config.bskcn_1:
  306. bskcn_h_1 = hidden_states.clone()
  307. bskcn_r_1 = residual.clone()
  308. if i in self.config.bskcn_2:
  309. bskcn_h_2 = hidden_states.clone()
  310. bskcn_r_2 = residual.clone()
  311. if i in self.config.bskcn_3:
  312. hidden_states = bskcn_h_1*bskcn_tv + hidden_states*(1-bskcn_tv)
  313. residual = bskcn_r_1*bskcn_tv + residual*(1-bskcn_tv)
  314. if i in self.config.bskcn_4:
  315. hidden_states = bskcn_h_2*bskcn_tv + hidden_states*(1-bskcn_tv)
  316. residual = bskcn_r_2*bskcn_tv + residual*(1-bskcn_tv)
  317. layer = self.layers[i]
  318. hidden_states, residual = layer(
  319. positions,
  320. hidden_states,
  321. kv_caches[i - self.start_layer],
  322. attn_metadata,
  323. residual,
  324. )
  325. if not get_pp_group().is_last_rank:
  326. return IntermediateTensors({
  327. "hidden_states": hidden_states,
  328. "residual": residual
  329. })
  330. hidden_states, _ = self.norm(hidden_states, residual)
  331. return hidden_states
  332. class SolarForCausalLM(nn.Module, SupportsLoRA):
  333. packed_modules_mapping = {
  334. "qkv_proj": [
  335. "q_proj",
  336. "k_proj",
  337. "v_proj",
  338. ],
  339. "gate_up_proj": [
  340. "gate_proj",
  341. "up_proj",
  342. ],
  343. }
  344. # LoRA specific attributes
  345. supported_lora_modules = [
  346. "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
  347. "lm_head"
  348. ]
  349. embedding_modules = {
  350. "embed_tokens": "input_embeddings",
  351. "lm_head": "output_embeddings",
  352. }
  353. embedding_padding_modules = ["lm_head"]
  354. bitsandbytes_stacked_params_mapping = {
  355. # shard_name, weight_name, index
  356. "q_proj": ("qkv_proj", 0),
  357. "k_proj": ("qkv_proj", 1),
  358. "v_proj": ("qkv_proj", 2),
  359. "gate_proj": ("gate_up_proj", 0),
  360. "up_proj": ("gate_up_proj", 1),
  361. }
  362. def __init__(
  363. self,
  364. config,
  365. cache_config: Optional[CacheConfig] = None,
  366. quant_config: Optional[QuantizationConfig] = None,
  367. lora_config: Optional[LoRAConfig] = None,
  368. ) -> None:
  369. super().__init__()
  370. self.config = config
  371. self.lora_config = lora_config
  372. self.model = SolarModel(config,
  373. cache_config,
  374. quant_config,
  375. lora_config=lora_config,
  376. prefix="model")
  377. if get_pp_group().is_last_rank:
  378. self.unpadded_vocab_size = config.vocab_size
  379. if lora_config:
  380. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  381. self.lm_head = ParallelLMHead(
  382. self.unpadded_vocab_size,
  383. config.hidden_size,
  384. org_num_embeddings=config.vocab_size,
  385. padding_size=DEFAULT_VOCAB_PADDING_SIZE
  386. # We need bigger padding if using lora for kernel
  387. # compatibility
  388. if not lora_config else lora_config.lora_vocab_padding_size,
  389. quant_config=quant_config,
  390. )
  391. if config.tie_word_embeddings:
  392. self.lm_head.weight = self.model.embed_tokens.weight
  393. logit_scale = getattr(config, "logit_scale", 1.0)
  394. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  395. config.vocab_size,
  396. logit_scale)
  397. self.sampler = Sampler()
  398. else:
  399. self.lm_head = PPMissingLayer()
  400. def forward(
  401. self,
  402. input_ids: torch.Tensor,
  403. positions: torch.Tensor,
  404. kv_caches: List[torch.Tensor],
  405. attn_metadata: AttentionMetadata,
  406. intermediate_tensors: Optional[IntermediateTensors] = None,
  407. ) -> Union[torch.Tensor, IntermediateTensors]:
  408. model_output = self.model(input_ids, positions, kv_caches,
  409. attn_metadata, intermediate_tensors)
  410. return model_output
  411. def compute_logits(self, hidden_states: torch.Tensor,
  412. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  413. logits = self.logits_processor(self.lm_head, hidden_states,
  414. sampling_metadata)
  415. return logits
  416. def sample(
  417. self,
  418. logits: torch.Tensor,
  419. sampling_metadata: SamplingMetadata,
  420. ) -> Optional[SamplerOutput]:
  421. next_tokens = self.sampler(logits, sampling_metadata)
  422. return next_tokens
  423. def make_empty_intermediate_tensors(
  424. self, batch_size: int, dtype: torch.dtype,
  425. device: torch.device) -> IntermediateTensors:
  426. return IntermediateTensors({
  427. "hidden_states":
  428. torch.zeros((batch_size, self.config.hidden_size),
  429. dtype=dtype,
  430. device=device),
  431. "residual":
  432. torch.zeros((batch_size, self.config.hidden_size),
  433. dtype=dtype,
  434. device=device),
  435. })
  436. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  437. stacked_params_mapping = [
  438. # (param_name, shard_name, shard_id)
  439. (".qkv_proj", ".q_proj", "q"),
  440. (".qkv_proj", ".k_proj", "k"),
  441. (".qkv_proj", ".v_proj", "v"),
  442. (".gate_up_proj", ".gate_proj", 0),
  443. (".gate_up_proj", ".up_proj", 1),
  444. ]
  445. params_dict = dict(self.named_parameters())
  446. for name, loaded_weight in weights:
  447. if "rotary_emb.inv_freq" in name:
  448. continue
  449. if ("rotary_emb.cos_cached" in name
  450. or "rotary_emb.sin_cached" in name):
  451. # Models trained using ColossalAI may include these tensors in
  452. # the checkpoint. Skip them.
  453. continue
  454. if scale_name := get_compressed_tensors_cache_scale(name):
  455. # Loading kv cache scales for compressed-tensors quantization
  456. param = params_dict[scale_name]
  457. weight_loader = getattr(param, "weight_loader",
  458. default_weight_loader)
  459. loaded_weight = loaded_weight[0]
  460. weight_loader(param, loaded_weight)
  461. continue
  462. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  463. if weight_name not in name:
  464. continue
  465. name = name.replace(weight_name, param_name)
  466. # Skip loading extra bias for GPTQ models.
  467. if name.endswith(".bias") and name not in params_dict:
  468. continue
  469. if is_pp_missing_parameter(name, self):
  470. continue
  471. param = params_dict[name]
  472. weight_loader = param.weight_loader
  473. weight_loader(param, loaded_weight, shard_id)
  474. break
  475. else:
  476. # Skip loading extra bias for GPTQ models.
  477. if name.endswith(".bias") and name not in params_dict:
  478. continue
  479. # Remapping the name of FP8 kv-scale.
  480. name = maybe_remap_kv_scale_name(name, params_dict)
  481. if name is None:
  482. continue
  483. if is_pp_missing_parameter(name, self):
  484. continue
  485. param = params_dict[name]
  486. weight_loader = getattr(param, "weight_loader",
  487. default_weight_loader)
  488. weight_loader(param, loaded_weight)
  489. # If this function is called, it should always initialize KV cache scale
  490. # factors (or else raise an exception). Thus, handled exceptions should
  491. # make sure to leave KV cache scale factors in a known good (dummy) state
  492. def load_kv_cache_scales(self, quantization_param_path: str) -> None:
  493. tp_size = get_tensor_model_parallel_world_size()
  494. tp_rank = get_tensor_model_parallel_rank()
  495. for layer_idx, scaling_factor in kv_cache_scales_loader(
  496. quantization_param_path, tp_rank, tp_size,
  497. self.config.num_hidden_layers,
  498. self.config.__class__.model_type):
  499. if not isinstance(self.model.layers[layer_idx], nn.Identity):
  500. layer_self_attn = self.model.layers[layer_idx].self_attn
  501. if is_hip():
  502. # The scaling factor convention we are assuming is
  503. # quantized_value * scaling_factor ~= true_value
  504. # which is consistent with the practice of setting
  505. # scaling_factor = tensor_amax / FPtype_max
  506. scaling_factor *= 2
  507. if hasattr(layer_self_attn, "kv_scale"):
  508. layer_self_attn.attn._kv_scale = scaling_factor
  509. else:
  510. raise RuntimeError("Self attention has no KV cache scaling "
  511. "factor attribute!")