llama.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only LLaMA model compatible with HuggingFace weights."""
  24. from typing import Any, Dict, Iterable, List, Optional, Tuple
  25. import torch
  26. from torch import nn
  27. from transformers import LlamaConfig
  28. from aphrodite.attention import Attention, AttentionMetadata
  29. from aphrodite.common.config import CacheConfig, LoRAConfig
  30. from aphrodite.common.sequence import SamplerOutput
  31. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  32. get_tensor_model_parallel_world_size)
  33. from aphrodite.modeling.layers.activation import SiluAndMul
  34. from aphrodite.modeling.layers.layernorm import RMSNorm
  35. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  36. QKVParallelLinear,
  37. RowParallelLinear)
  38. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  39. from aphrodite.modeling.layers.rotary_embedding import get_rope
  40. from aphrodite.modeling.layers.sampler import Sampler
  41. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  42. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  43. from aphrodite.modeling.model_loader.weight_utils import (
  44. default_weight_loader, kv_cache_scales_loader)
  45. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  46. from aphrodite.quantization.base_config import QuantizationConfig
  47. from aphrodite.common.utils import is_hip, print_warning_once
  48. class LlamaMLP(nn.Module):
  49. def __init__(
  50. self,
  51. hidden_size: int,
  52. intermediate_size: int,
  53. hidden_act: str,
  54. quant_config: Optional[QuantizationConfig] = None,
  55. bias: bool = False,
  56. ) -> None:
  57. super().__init__()
  58. self.gate_up_proj = MergedColumnParallelLinear(
  59. input_size=hidden_size,
  60. output_sizes=[intermediate_size] * 2,
  61. bias=bias,
  62. quant_config=quant_config)
  63. self.down_proj = RowParallelLinear(input_size=intermediate_size,
  64. output_size=hidden_size,
  65. bias=bias,
  66. quant_config=quant_config)
  67. if hidden_act != "silu":
  68. raise ValueError(f"Unsupported activation: {hidden_act}. "
  69. "Only silu is supported for now.")
  70. self.act_fn = SiluAndMul()
  71. def forward(self, x):
  72. gate_up, _ = self.gate_up_proj(x)
  73. x = self.act_fn(gate_up)
  74. x, _ = self.down_proj(x)
  75. return x
  76. class LlamaAttention(nn.Module):
  77. def __init__(
  78. self,
  79. config: LlamaConfig,
  80. hidden_size: int,
  81. num_heads: int,
  82. num_kv_heads: int,
  83. rope_theta: float = 10000,
  84. rope_scaling: Optional[Dict[str, Any]] = None,
  85. max_position_embeddings: int = 8192,
  86. quant_config: Optional[QuantizationConfig] = None,
  87. bias: bool = False,
  88. cache_config: Optional[CacheConfig] = None,
  89. ) -> None:
  90. super().__init__()
  91. self.hidden_size = hidden_size
  92. tp_size = get_tensor_model_parallel_world_size()
  93. self.total_num_heads = num_heads
  94. assert self.total_num_heads % tp_size == 0
  95. self.num_heads = self.total_num_heads // tp_size
  96. self.total_num_kv_heads = num_kv_heads
  97. if self.total_num_kv_heads >= tp_size:
  98. # Number of KV heads is greater than TP size, so we partition
  99. # the KV heads across multiple tensor parallel GPUs.
  100. assert self.total_num_kv_heads % tp_size == 0
  101. else:
  102. # Number of KV heads is less than TP size, so we replicate
  103. # the KV heads across multiple tensor parallel GPUs.
  104. assert tp_size % self.total_num_kv_heads == 0
  105. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  106. # MistralConfig has an optional head_dim introduced by Mistral-Nemo
  107. self.head_dim = getattr(config, "head_dim",
  108. self.hidden_size // self.total_num_heads)
  109. self.q_size = self.num_heads * self.head_dim
  110. self.kv_size = self.num_kv_heads * self.head_dim
  111. self.scaling = self.head_dim**-0.5
  112. self.rope_theta = rope_theta
  113. self.max_position_embeddings = max_position_embeddings
  114. self.qkv_proj = QKVParallelLinear(
  115. hidden_size=hidden_size,
  116. head_size=self.head_dim,
  117. total_num_heads=self.total_num_heads,
  118. total_num_kv_heads=self.total_num_kv_heads,
  119. bias=bias,
  120. quant_config=quant_config,
  121. )
  122. self.o_proj = RowParallelLinear(
  123. input_size=self.total_num_heads * self.head_dim,
  124. output_size=hidden_size,
  125. bias=bias,
  126. quant_config=quant_config,
  127. )
  128. self.rotary_emb = get_rope(
  129. self.head_dim,
  130. rotary_dim=self.head_dim,
  131. max_position=max_position_embeddings,
  132. base=rope_theta,
  133. rope_scaling=rope_scaling,
  134. )
  135. self.attn = Attention(self.num_heads,
  136. self.head_dim,
  137. self.scaling,
  138. num_kv_heads=self.num_kv_heads,
  139. cache_config=cache_config,
  140. quant_config=quant_config)
  141. def forward(
  142. self,
  143. positions: torch.Tensor,
  144. hidden_states: torch.Tensor,
  145. kv_cache: torch.Tensor,
  146. attn_metadata: AttentionMetadata,
  147. ) -> torch.Tensor:
  148. qkv, _ = self.qkv_proj(hidden_states)
  149. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  150. q, k = self.rotary_emb(positions, q, k)
  151. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  152. output, _ = self.o_proj(attn_output)
  153. return output
  154. class LlamaDecoderLayer(nn.Module):
  155. def __init__(
  156. self,
  157. config: LlamaConfig,
  158. cache_config: Optional[CacheConfig] = None,
  159. quant_config: Optional[QuantizationConfig] = None,
  160. ) -> None:
  161. super().__init__()
  162. self.hidden_size = config.hidden_size
  163. rope_theta = getattr(config, "rope_theta", 10000)
  164. rope_scaling = getattr(config, "rope_scaling", None)
  165. if rope_scaling is not None and getattr(
  166. config, "original_max_position_embeddings", None):
  167. rope_scaling["original_max_position_embeddings"] = (
  168. config.original_max_position_embeddings)
  169. max_position_embeddings = getattr(config, "max_position_embeddings",
  170. 8192)
  171. # Support abacusai/Smaug-72B-v0.1 with attention_bias
  172. # Support internlm/internlm-7b with bias
  173. attention_bias = getattr(config, "attention_bias", False) or getattr(
  174. config, "bias", False)
  175. self.self_attn = LlamaAttention(
  176. config=config,
  177. hidden_size=self.hidden_size,
  178. num_heads=config.num_attention_heads,
  179. num_kv_heads=getattr(config, "num_key_value_heads",
  180. config.num_attention_heads),
  181. rope_theta=rope_theta,
  182. rope_scaling=rope_scaling,
  183. max_position_embeddings=max_position_embeddings,
  184. quant_config=quant_config,
  185. bias=attention_bias,
  186. cache_config=cache_config,
  187. )
  188. self.mlp = LlamaMLP(
  189. hidden_size=self.hidden_size,
  190. intermediate_size=config.intermediate_size,
  191. hidden_act=config.hidden_act,
  192. quant_config=quant_config,
  193. bias=getattr(config, "mlp_bias", False),
  194. )
  195. self.input_layernorm = RMSNorm(config.hidden_size,
  196. eps=config.rms_norm_eps)
  197. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  198. eps=config.rms_norm_eps)
  199. def forward(
  200. self,
  201. positions: torch.Tensor,
  202. hidden_states: torch.Tensor,
  203. kv_cache: torch.Tensor,
  204. attn_metadata: AttentionMetadata,
  205. residual: Optional[torch.Tensor],
  206. ) -> Tuple[torch.Tensor, torch.Tensor]:
  207. # Self Attention
  208. if residual is None:
  209. residual = hidden_states
  210. hidden_states = self.input_layernorm(hidden_states)
  211. else:
  212. hidden_states, residual = self.input_layernorm(
  213. hidden_states, residual)
  214. hidden_states = self.self_attn(
  215. positions=positions,
  216. hidden_states=hidden_states,
  217. kv_cache=kv_cache,
  218. attn_metadata=attn_metadata,
  219. )
  220. # Fully Connected
  221. hidden_states, residual = self.post_attention_layernorm(
  222. hidden_states, residual)
  223. hidden_states = self.mlp(hidden_states)
  224. return hidden_states, residual
  225. class LlamaModel(nn.Module):
  226. def __init__(
  227. self,
  228. config: LlamaConfig,
  229. cache_config: Optional[CacheConfig] = None,
  230. quant_config: Optional[QuantizationConfig] = None,
  231. lora_config: Optional[LoRAConfig] = None,
  232. ) -> None:
  233. super().__init__()
  234. self.config = config
  235. self.padding_idx = config.pad_token_id
  236. lora_vocab = (lora_config.lora_extra_vocab_size *
  237. (lora_config.max_loras or 1)) if lora_config else 0
  238. self.vocab_size = config.vocab_size + lora_vocab
  239. self.org_vocab_size = config.vocab_size
  240. self.embed_tokens = VocabParallelEmbedding(
  241. self.vocab_size,
  242. config.hidden_size,
  243. org_num_embeddings=config.vocab_size,
  244. )
  245. self.layers = nn.ModuleList([
  246. LlamaDecoderLayer(config=config,
  247. cache_config=cache_config,
  248. quant_config=quant_config)
  249. for idx in range(config.num_hidden_layers)
  250. ])
  251. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  252. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  253. return self.embed_tokens(input_ids)
  254. def forward(
  255. self,
  256. input_ids: Optional[torch.Tensor],
  257. positions: torch.Tensor,
  258. kv_caches: List[torch.Tensor],
  259. attn_metadata: AttentionMetadata,
  260. inputs_embeds: Optional[torch.Tensor] = None,
  261. ) -> torch.Tensor:
  262. if inputs_embeds is not None:
  263. hidden_states = inputs_embeds
  264. else:
  265. hidden_states = self.get_input_embeddings(input_ids)
  266. residual = None
  267. for i in range(len(self.layers)):
  268. layer = self.layers[i]
  269. hidden_states, residual = layer(
  270. positions,
  271. hidden_states,
  272. kv_caches[i],
  273. attn_metadata,
  274. residual,
  275. )
  276. hidden_states, _ = self.norm(hidden_states, residual)
  277. return hidden_states
  278. class LlamaForCausalLM(nn.Module):
  279. packed_modules_mapping = {
  280. "qkv_proj": [
  281. "q_proj",
  282. "k_proj",
  283. "v_proj",
  284. ],
  285. "gate_up_proj": [
  286. "gate_proj",
  287. "up_proj",
  288. ],
  289. }
  290. # LoRA specific attributes
  291. supported_lora_modules = [
  292. "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
  293. "lm_head"
  294. ]
  295. embedding_modules = {
  296. "embed_tokens": "input_embeddings",
  297. "lm_head": "output_embeddings",
  298. }
  299. embedding_padding_modules = ["lm_head"]
  300. def __init__(
  301. self,
  302. config: LlamaConfig,
  303. cache_config: Optional[CacheConfig] = None,
  304. quant_config: Optional[QuantizationConfig] = None,
  305. lora_config: Optional[LoRAConfig] = None,
  306. ) -> None:
  307. super().__init__()
  308. self.config = config
  309. self.model = LlamaModel(config,
  310. cache_config,
  311. quant_config,
  312. lora_config=lora_config)
  313. self.unpadded_vocab_size = config.vocab_size
  314. if lora_config:
  315. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  316. self.lm_head = ParallelLMHead(
  317. self.unpadded_vocab_size,
  318. config.hidden_size,
  319. org_num_embeddings=config.vocab_size,
  320. padding_size=DEFAULT_VOCAB_PADDING_SIZE
  321. # We need bigger padding if using lora for kernel
  322. # compatibility
  323. if not lora_config else lora_config.lora_vocab_padding_size,
  324. )
  325. if config.tie_word_embeddings:
  326. self.lm_head.weight = self.model.embed_tokens.weight
  327. logit_scale = getattr(config, "logit_scale", 1.0)
  328. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  329. config.vocab_size, logit_scale)
  330. self.sampler = Sampler()
  331. def forward(
  332. self,
  333. input_ids: torch.Tensor,
  334. positions: torch.Tensor,
  335. kv_caches: List[torch.Tensor],
  336. attn_metadata: AttentionMetadata,
  337. ) -> torch.Tensor:
  338. hidden_states = self.model(input_ids, positions, kv_caches,
  339. attn_metadata)
  340. return hidden_states
  341. def compute_logits(self, hidden_states: torch.Tensor,
  342. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  343. logits = self.logits_processor(self.lm_head.weight, hidden_states,
  344. sampling_metadata)
  345. return logits
  346. def sample(
  347. self,
  348. logits: torch.Tensor,
  349. sampling_metadata: SamplingMetadata,
  350. ) -> Optional[SamplerOutput]:
  351. next_tokens = self.sampler(logits, sampling_metadata)
  352. return next_tokens
  353. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  354. stacked_params_mapping = [
  355. # (param_name, shard_name, shard_id)
  356. (".qkv_proj", ".q_proj", "q"),
  357. (".qkv_proj", ".k_proj", "k"),
  358. (".qkv_proj", ".v_proj", "v"),
  359. (".gate_up_proj", ".gate_proj", 0),
  360. (".gate_up_proj", ".up_proj", 1),
  361. ]
  362. params_dict = dict(self.named_parameters())
  363. for name, loaded_weight in weights:
  364. if "rotary_emb.inv_freq" in name:
  365. continue
  366. if ("rotary_emb.cos_cached" in name
  367. or "rotary_emb.sin_cached" in name):
  368. # Models trained using ColossalAI may include these tensors in
  369. # the checkpoint. Skip them.
  370. continue
  371. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  372. if weight_name not in name:
  373. continue
  374. name = name.replace(weight_name, param_name)
  375. # Skip loading extra bias for GPTQ models.
  376. if name.endswith(".bias") and name not in params_dict:
  377. continue
  378. param = params_dict[name]
  379. weight_loader = param.weight_loader
  380. weight_loader(param, loaded_weight, shard_id)
  381. break
  382. else:
  383. # Skip loading extra bias for GPTQ models.
  384. if name.endswith(".bias") and name not in params_dict:
  385. continue
  386. # Remapping the name of FP8 kv-scale.
  387. if name.endswith("kv_scale"):
  388. remapped_kv_scale_name = name.replace(
  389. ".kv_scale", ".attn.kv_scale")
  390. if remapped_kv_scale_name not in params_dict:
  391. print_warning_once(
  392. f"Found kv scale in the checkpoint (e.g. {name}), "
  393. "but not found the expected name in the model "
  394. f"(e.g. {remapped_kv_scale_name}). kv-scale is "
  395. "not loaded.")
  396. continue
  397. else:
  398. name = remapped_kv_scale_name
  399. param = params_dict[name]
  400. weight_loader = getattr(param, "weight_loader",
  401. default_weight_loader)
  402. weight_loader(param, loaded_weight)
  403. # If this function is called, it should always initialize KV cache scale
  404. # factors (or else raise an exception). Thus, handled exceptions should
  405. # make sure to leave KV cache scale factors in a known good (dummy) state
  406. def load_kv_cache_scales(self, quantization_param_path: str) -> None:
  407. tp_size = get_tensor_model_parallel_world_size()
  408. tp_rank = get_tensor_model_parallel_rank()
  409. for layer_idx, scaling_factor in kv_cache_scales_loader(
  410. quantization_param_path, tp_rank, tp_size,
  411. self.config.num_hidden_layers,
  412. self.config.__class__.model_type):
  413. layer_self_attn = self.model.layers[layer_idx].self_attn
  414. if is_hip():
  415. # The scaling factor convention we are assuming is
  416. # quantized_value * scaling_factor ~= true_value
  417. # which is consistent with the practice of setting
  418. # scaling_factor = tensor_amax / FPtype_max
  419. scaling_factor *= 2
  420. if hasattr(layer_self_attn, "kv_scale"):
  421. layer_self_attn.attn._kv_scale = scaling_factor
  422. else:
  423. raise RuntimeError("Self attention has no KV cache scaling "
  424. "factor attribute!")