llama.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only LLaMA model compatible with HuggingFace weights."""
  24. from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
  25. import torch
  26. from torch import nn
  27. from transformers import LlamaConfig
  28. from aphrodite.attention import Attention, AttentionMetadata
  29. from aphrodite.common.config import CacheConfig, LoRAConfig
  30. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  31. from aphrodite.common.utils import is_hip, print_warning_once
  32. from aphrodite.distributed import (get_pp_group, get_pp_indices,
  33. get_tensor_model_parallel_rank,
  34. get_tensor_model_parallel_world_size)
  35. from aphrodite.modeling.layers.activation import SiluAndMul
  36. from aphrodite.modeling.layers.layernorm import RMSNorm
  37. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  38. QKVParallelLinear,
  39. RowParallelLinear)
  40. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  41. from aphrodite.modeling.layers.rotary_embedding import get_rope
  42. from aphrodite.modeling.layers.sampler import Sampler
  43. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  44. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  45. from aphrodite.modeling.model_loader.weight_utils import (
  46. default_weight_loader, kv_cache_scales_loader)
  47. from aphrodite.modeling.models.interfaces import SupportsLoRA
  48. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  49. from aphrodite.quantization.base_config import QuantizationConfig
  50. class LlamaMLP(nn.Module):
  51. def __init__(
  52. self,
  53. hidden_size: int,
  54. intermediate_size: int,
  55. hidden_act: str,
  56. quant_config: Optional[QuantizationConfig] = None,
  57. bias: bool = False,
  58. ) -> None:
  59. super().__init__()
  60. self.gate_up_proj = MergedColumnParallelLinear(
  61. input_size=hidden_size,
  62. output_sizes=[intermediate_size] * 2,
  63. bias=bias,
  64. quant_config=quant_config)
  65. self.down_proj = RowParallelLinear(input_size=intermediate_size,
  66. output_size=hidden_size,
  67. bias=bias,
  68. quant_config=quant_config)
  69. if hidden_act != "silu":
  70. raise ValueError(f"Unsupported activation: {hidden_act}. "
  71. "Only silu is supported for now.")
  72. self.act_fn = SiluAndMul()
  73. def forward(self, x):
  74. gate_up, _ = self.gate_up_proj(x)
  75. x = self.act_fn(gate_up)
  76. x, _ = self.down_proj(x)
  77. return x
  78. class LlamaAttention(nn.Module):
  79. def __init__(
  80. self,
  81. config: LlamaConfig,
  82. hidden_size: int,
  83. num_heads: int,
  84. num_kv_heads: int,
  85. rope_theta: float = 10000,
  86. rope_scaling: Optional[Dict[str, Any]] = None,
  87. max_position_embeddings: int = 8192,
  88. quant_config: Optional[QuantizationConfig] = None,
  89. bias: bool = False,
  90. cache_config: Optional[CacheConfig] = None,
  91. ) -> None:
  92. super().__init__()
  93. self.hidden_size = hidden_size
  94. tp_size = get_tensor_model_parallel_world_size()
  95. self.total_num_heads = num_heads
  96. assert self.total_num_heads % tp_size == 0
  97. self.num_heads = self.total_num_heads // tp_size
  98. self.total_num_kv_heads = num_kv_heads
  99. if self.total_num_kv_heads >= tp_size:
  100. # Number of KV heads is greater than TP size, so we partition
  101. # the KV heads across multiple tensor parallel GPUs.
  102. assert self.total_num_kv_heads % tp_size == 0
  103. else:
  104. # Number of KV heads is less than TP size, so we replicate
  105. # the KV heads across multiple tensor parallel GPUs.
  106. assert tp_size % self.total_num_kv_heads == 0
  107. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  108. # MistralConfig has an optional head_dim introduced by Mistral-Nemo
  109. self.head_dim = getattr(config, "head_dim",
  110. self.hidden_size // self.total_num_heads)
  111. self.q_size = self.num_heads * self.head_dim
  112. self.kv_size = self.num_kv_heads * self.head_dim
  113. self.scaling = self.head_dim**-0.5
  114. self.rope_theta = rope_theta
  115. self.max_position_embeddings = max_position_embeddings
  116. self.qkv_proj = QKVParallelLinear(
  117. hidden_size=hidden_size,
  118. head_size=self.head_dim,
  119. total_num_heads=self.total_num_heads,
  120. total_num_kv_heads=self.total_num_kv_heads,
  121. bias=bias,
  122. quant_config=quant_config,
  123. )
  124. self.o_proj = RowParallelLinear(
  125. input_size=self.total_num_heads * self.head_dim,
  126. output_size=hidden_size,
  127. bias=bias,
  128. quant_config=quant_config,
  129. )
  130. self.rotary_emb = get_rope(
  131. self.head_dim,
  132. rotary_dim=self.head_dim,
  133. max_position=max_position_embeddings,
  134. base=rope_theta,
  135. rope_scaling=rope_scaling,
  136. )
  137. self.attn = Attention(self.num_heads,
  138. self.head_dim,
  139. self.scaling,
  140. num_kv_heads=self.num_kv_heads,
  141. cache_config=cache_config,
  142. quant_config=quant_config)
  143. def forward(
  144. self,
  145. positions: torch.Tensor,
  146. hidden_states: torch.Tensor,
  147. kv_cache: torch.Tensor,
  148. attn_metadata: AttentionMetadata,
  149. ) -> torch.Tensor:
  150. qkv, _ = self.qkv_proj(hidden_states)
  151. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  152. q, k = self.rotary_emb(positions, q, k)
  153. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  154. output, _ = self.o_proj(attn_output)
  155. return output
  156. class LlamaDecoderLayer(nn.Module):
  157. def __init__(
  158. self,
  159. config: LlamaConfig,
  160. cache_config: Optional[CacheConfig] = None,
  161. quant_config: Optional[QuantizationConfig] = None,
  162. ) -> None:
  163. super().__init__()
  164. self.hidden_size = config.hidden_size
  165. rope_theta = getattr(config, "rope_theta", 10000)
  166. rope_scaling = getattr(config, "rope_scaling", None)
  167. if rope_scaling is not None and getattr(
  168. config, "original_max_position_embeddings", None):
  169. rope_scaling["original_max_position_embeddings"] = (
  170. config.original_max_position_embeddings)
  171. max_position_embeddings = getattr(config, "max_position_embeddings",
  172. 8192)
  173. # Support abacusai/Smaug-72B-v0.1 with attention_bias
  174. # Support internlm/internlm-7b with bias
  175. attention_bias = getattr(config, "attention_bias", False) or getattr(
  176. config, "bias", False)
  177. self.self_attn = LlamaAttention(
  178. config=config,
  179. hidden_size=self.hidden_size,
  180. num_heads=config.num_attention_heads,
  181. num_kv_heads=getattr(config, "num_key_value_heads",
  182. config.num_attention_heads),
  183. rope_theta=rope_theta,
  184. rope_scaling=rope_scaling,
  185. max_position_embeddings=max_position_embeddings,
  186. quant_config=quant_config,
  187. bias=attention_bias,
  188. cache_config=cache_config,
  189. )
  190. self.mlp = LlamaMLP(
  191. hidden_size=self.hidden_size,
  192. intermediate_size=config.intermediate_size,
  193. hidden_act=config.hidden_act,
  194. quant_config=quant_config,
  195. bias=getattr(config, "mlp_bias", False),
  196. )
  197. self.input_layernorm = RMSNorm(config.hidden_size,
  198. eps=config.rms_norm_eps)
  199. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  200. eps=config.rms_norm_eps)
  201. def forward(
  202. self,
  203. positions: torch.Tensor,
  204. hidden_states: torch.Tensor,
  205. kv_cache: torch.Tensor,
  206. attn_metadata: AttentionMetadata,
  207. residual: Optional[torch.Tensor],
  208. ) -> Tuple[torch.Tensor, torch.Tensor]:
  209. # Self Attention
  210. if residual is None:
  211. residual = hidden_states
  212. hidden_states = self.input_layernorm(hidden_states)
  213. else:
  214. hidden_states, residual = self.input_layernorm(
  215. hidden_states, residual)
  216. hidden_states = self.self_attn(
  217. positions=positions,
  218. hidden_states=hidden_states,
  219. kv_cache=kv_cache,
  220. attn_metadata=attn_metadata,
  221. )
  222. # Fully Connected
  223. hidden_states, residual = self.post_attention_layernorm(
  224. hidden_states, residual)
  225. hidden_states = self.mlp(hidden_states)
  226. return hidden_states, residual
  227. class LlamaModel(nn.Module):
  228. def __init__(
  229. self,
  230. config: LlamaConfig,
  231. cache_config: Optional[CacheConfig] = None,
  232. quant_config: Optional[QuantizationConfig] = None,
  233. lora_config: Optional[LoRAConfig] = None,
  234. ) -> None:
  235. super().__init__()
  236. self.config = config
  237. self.padding_idx = config.pad_token_id
  238. lora_vocab = (lora_config.lora_extra_vocab_size *
  239. (lora_config.max_loras or 1)) if lora_config else 0
  240. self.vocab_size = config.vocab_size + lora_vocab
  241. self.org_vocab_size = config.vocab_size
  242. self.embed_tokens = VocabParallelEmbedding(
  243. self.vocab_size,
  244. config.hidden_size,
  245. org_num_embeddings=config.vocab_size,
  246. )
  247. self.start_layer, self.end_layer = get_pp_indices(
  248. config.num_hidden_layers,
  249. get_pp_group().rank_in_group,
  250. get_pp_group().world_size)
  251. self.layers = nn.ModuleList(
  252. [nn.Identity() for _ in range(self.start_layer)] + [
  253. LlamaDecoderLayer(config=config,
  254. cache_config=cache_config,
  255. quant_config=quant_config)
  256. for _ in range(self.start_layer, self.end_layer)
  257. ] + [
  258. nn.Identity()
  259. for _ in range(self.end_layer, config.num_hidden_layers)
  260. ])
  261. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  262. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  263. return self.embed_tokens(input_ids)
  264. def forward(
  265. self,
  266. input_ids: Optional[torch.Tensor],
  267. positions: torch.Tensor,
  268. kv_caches: List[torch.Tensor],
  269. attn_metadata: AttentionMetadata,
  270. intermediate_tensors: Optional[IntermediateTensors],
  271. inputs_embeds: Optional[torch.Tensor] = None,
  272. ) -> Union[torch.Tensor, IntermediateTensors]:
  273. if get_pp_group().is_first_rank:
  274. if inputs_embeds is not None:
  275. hidden_states = inputs_embeds
  276. else:
  277. hidden_states = self.get_input_embeddings(input_ids)
  278. residual = None
  279. else:
  280. assert intermediate_tensors is not None
  281. hidden_states = intermediate_tensors["hidden_states"]
  282. residual = intermediate_tensors["residual"]
  283. for i in range(self.start_layer, self.end_layer):
  284. layer = self.layers[i]
  285. hidden_states, residual = layer(
  286. positions,
  287. hidden_states,
  288. kv_caches[i - self.start_layer],
  289. attn_metadata,
  290. residual,
  291. )
  292. if not get_pp_group().is_last_rank:
  293. return IntermediateTensors({
  294. "hidden_states": hidden_states,
  295. "residual": residual
  296. })
  297. hidden_states, _ = self.norm(hidden_states, residual)
  298. return hidden_states
  299. class LlamaForCausalLM(nn.Module, SupportsLoRA):
  300. packed_modules_mapping = {
  301. "qkv_proj": [
  302. "q_proj",
  303. "k_proj",
  304. "v_proj",
  305. ],
  306. "gate_up_proj": [
  307. "gate_proj",
  308. "up_proj",
  309. ],
  310. }
  311. # LoRA specific attributes
  312. supported_lora_modules = [
  313. "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
  314. "lm_head"
  315. ]
  316. embedding_modules = {
  317. "embed_tokens": "input_embeddings",
  318. "lm_head": "output_embeddings",
  319. }
  320. embedding_padding_modules = ["lm_head"]
  321. bitsandbytes_stacked_params_mapping = {
  322. # shard_name, weight_name, index
  323. "q_proj": ("qkv_proj", 0),
  324. "k_proj": ("qkv_proj", 1),
  325. "v_proj": ("qkv_proj", 2),
  326. "gate_proj": ("gate_up_proj", 0),
  327. "up_proj": ("gate_up_proj", 1),
  328. }
  329. def __init__(
  330. self,
  331. config: LlamaConfig,
  332. cache_config: Optional[CacheConfig] = None,
  333. quant_config: Optional[QuantizationConfig] = None,
  334. lora_config: Optional[LoRAConfig] = None,
  335. ) -> None:
  336. super().__init__()
  337. self.config = config
  338. self.lora_config = lora_config
  339. self.model = LlamaModel(config,
  340. cache_config,
  341. quant_config,
  342. lora_config=lora_config)
  343. self.unpadded_vocab_size = config.vocab_size
  344. if lora_config:
  345. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  346. self.lm_head = ParallelLMHead(
  347. self.unpadded_vocab_size,
  348. config.hidden_size,
  349. org_num_embeddings=config.vocab_size,
  350. padding_size=DEFAULT_VOCAB_PADDING_SIZE
  351. # We need bigger padding if using lora for kernel
  352. # compatibility
  353. if not lora_config else lora_config.lora_vocab_padding_size,
  354. quant_config=quant_config,
  355. )
  356. if config.tie_word_embeddings:
  357. self.lm_head.weight = self.model.embed_tokens.weight
  358. logit_scale = getattr(config, "logit_scale", 1.0)
  359. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  360. config.vocab_size, logit_scale)
  361. self.sampler = Sampler()
  362. def forward(
  363. self,
  364. input_ids: torch.Tensor,
  365. positions: torch.Tensor,
  366. kv_caches: List[torch.Tensor],
  367. attn_metadata: AttentionMetadata,
  368. intermediate_tensors: Optional[IntermediateTensors] = None,
  369. ) -> Union[torch.Tensor, IntermediateTensors]:
  370. model_output = self.model(input_ids, positions, kv_caches,
  371. attn_metadata, intermediate_tensors)
  372. return model_output
  373. def compute_logits(self, hidden_states: torch.Tensor,
  374. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  375. logits = self.logits_processor(self.lm_head, hidden_states,
  376. sampling_metadata)
  377. return logits
  378. def sample(
  379. self,
  380. logits: torch.Tensor,
  381. sampling_metadata: SamplingMetadata,
  382. ) -> Optional[SamplerOutput]:
  383. next_tokens = self.sampler(logits, sampling_metadata)
  384. return next_tokens
  385. def make_empty_intermediate_tensors(
  386. self, batch_size: int, dtype: torch.dtype,
  387. device: torch.device) -> IntermediateTensors:
  388. return IntermediateTensors({
  389. "hidden_states":
  390. torch.zeros((batch_size, self.config.hidden_size),
  391. dtype=dtype,
  392. device=device),
  393. "residual":
  394. torch.zeros((batch_size, self.config.hidden_size),
  395. dtype=dtype,
  396. device=device),
  397. })
  398. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  399. stacked_params_mapping = [
  400. # (param_name, shard_name, shard_id)
  401. (".qkv_proj", ".q_proj", "q"),
  402. (".qkv_proj", ".k_proj", "k"),
  403. (".qkv_proj", ".v_proj", "v"),
  404. (".gate_up_proj", ".gate_proj", 0),
  405. (".gate_up_proj", ".up_proj", 1),
  406. ]
  407. params_dict = dict(self.named_parameters())
  408. for name, loaded_weight in weights:
  409. if "rotary_emb.inv_freq" in name:
  410. continue
  411. if ("rotary_emb.cos_cached" in name
  412. or "rotary_emb.sin_cached" in name):
  413. # Models trained using ColossalAI may include these tensors in
  414. # the checkpoint. Skip them.
  415. continue
  416. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  417. if weight_name not in name:
  418. continue
  419. name = name.replace(weight_name, param_name)
  420. # Skip loading extra bias for GPTQ models.
  421. if name.endswith(".bias") and name not in params_dict:
  422. continue
  423. try:
  424. param = params_dict[name]
  425. weight_loader = param.weight_loader
  426. weight_loader(param, loaded_weight, shard_id)
  427. except KeyError:
  428. pass
  429. break
  430. else:
  431. # Skip loading extra bias for GPTQ models.
  432. if name.endswith(".bias") and name not in params_dict:
  433. continue
  434. # Remapping the name of FP8 kv-scale.
  435. if name.endswith("kv_scale"):
  436. remapped_kv_scale_name = name.replace(
  437. ".kv_scale", ".attn.kv_scale")
  438. if remapped_kv_scale_name not in params_dict:
  439. print_warning_once(
  440. f"Found kv scale in the checkpoint (e.g. {name}), "
  441. "but not found the expected name in the model "
  442. f"(e.g. {remapped_kv_scale_name}). kv-scale is "
  443. "not loaded.")
  444. continue
  445. else:
  446. name = remapped_kv_scale_name
  447. try:
  448. param = params_dict[name]
  449. weight_loader = getattr(param, "weight_loader",
  450. default_weight_loader)
  451. weight_loader(param, loaded_weight)
  452. except KeyError:
  453. pass
  454. # If this function is called, it should always initialize KV cache scale
  455. # factors (or else raise an exception). Thus, handled exceptions should
  456. # make sure to leave KV cache scale factors in a known good (dummy) state
  457. def load_kv_cache_scales(self, quantization_param_path: str) -> None:
  458. tp_size = get_tensor_model_parallel_world_size()
  459. tp_rank = get_tensor_model_parallel_rank()
  460. for layer_idx, scaling_factor in kv_cache_scales_loader(
  461. quantization_param_path, tp_rank, tp_size,
  462. self.config.num_hidden_layers,
  463. self.config.__class__.model_type):
  464. if not isinstance(self.model.layers[layer_idx], nn.Identity):
  465. layer_self_attn = self.model.layers[layer_idx].self_attn
  466. if is_hip():
  467. # The scaling factor convention we are assuming is
  468. # quantized_value * scaling_factor ~= true_value
  469. # which is consistent with the practice of setting
  470. # scaling_factor = tensor_amax / FPtype_max
  471. scaling_factor *= 2
  472. if hasattr(layer_self_attn, "kv_scale"):
  473. layer_self_attn.attn._kv_scale = scaling_factor
  474. else:
  475. raise RuntimeError("Self attention has no KV cache scaling "
  476. "factor attribute!")