1
0

llama.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only LLaMA model compatible with HuggingFace weights."""
  24. from typing import Any, Dict, Iterable, List, Optional, Tuple
  25. import torch
  26. from torch import nn
  27. from transformers import LlamaConfig
  28. from aphrodite.attention import Attention, AttentionMetadata
  29. from aphrodite.common.config import CacheConfig, LoRAConfig
  30. from aphrodite.common.sequence import SamplerOutput
  31. from aphrodite.common.utils import is_hip, print_warning_once
  32. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  33. get_tensor_model_parallel_world_size)
  34. from aphrodite.modeling.layers.activation import SiluAndMul
  35. from aphrodite.modeling.layers.layernorm import RMSNorm
  36. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  37. QKVParallelLinear,
  38. RowParallelLinear)
  39. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  40. from aphrodite.modeling.layers.rotary_embedding import get_rope
  41. from aphrodite.modeling.layers.sampler import Sampler
  42. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  43. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  44. from aphrodite.modeling.model_loader.weight_utils import (
  45. default_weight_loader, kv_cache_scales_loader)
  46. from aphrodite.modeling.models.interfaces import SupportsLoRA
  47. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  48. from aphrodite.quantization.base_config import QuantizationConfig
  49. class LlamaMLP(nn.Module):
  50. def __init__(
  51. self,
  52. hidden_size: int,
  53. intermediate_size: int,
  54. hidden_act: str,
  55. quant_config: Optional[QuantizationConfig] = None,
  56. bias: bool = False,
  57. ) -> None:
  58. super().__init__()
  59. self.gate_up_proj = MergedColumnParallelLinear(
  60. input_size=hidden_size,
  61. output_sizes=[intermediate_size] * 2,
  62. bias=bias,
  63. quant_config=quant_config)
  64. self.down_proj = RowParallelLinear(input_size=intermediate_size,
  65. output_size=hidden_size,
  66. bias=bias,
  67. quant_config=quant_config)
  68. if hidden_act != "silu":
  69. raise ValueError(f"Unsupported activation: {hidden_act}. "
  70. "Only silu is supported for now.")
  71. self.act_fn = SiluAndMul()
  72. def forward(self, x):
  73. gate_up, _ = self.gate_up_proj(x)
  74. x = self.act_fn(gate_up)
  75. x, _ = self.down_proj(x)
  76. return x
  77. class LlamaAttention(nn.Module):
  78. def __init__(
  79. self,
  80. config: LlamaConfig,
  81. hidden_size: int,
  82. num_heads: int,
  83. num_kv_heads: int,
  84. rope_theta: float = 10000,
  85. rope_scaling: Optional[Dict[str, Any]] = None,
  86. max_position_embeddings: int = 8192,
  87. quant_config: Optional[QuantizationConfig] = None,
  88. bias: bool = False,
  89. cache_config: Optional[CacheConfig] = None,
  90. ) -> None:
  91. super().__init__()
  92. self.hidden_size = hidden_size
  93. tp_size = get_tensor_model_parallel_world_size()
  94. self.total_num_heads = num_heads
  95. assert self.total_num_heads % tp_size == 0
  96. self.num_heads = self.total_num_heads // tp_size
  97. self.total_num_kv_heads = num_kv_heads
  98. if self.total_num_kv_heads >= tp_size:
  99. # Number of KV heads is greater than TP size, so we partition
  100. # the KV heads across multiple tensor parallel GPUs.
  101. assert self.total_num_kv_heads % tp_size == 0
  102. else:
  103. # Number of KV heads is less than TP size, so we replicate
  104. # the KV heads across multiple tensor parallel GPUs.
  105. assert tp_size % self.total_num_kv_heads == 0
  106. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  107. # MistralConfig has an optional head_dim introduced by Mistral-Nemo
  108. self.head_dim = getattr(config, "head_dim",
  109. self.hidden_size // self.total_num_heads)
  110. self.q_size = self.num_heads * self.head_dim
  111. self.kv_size = self.num_kv_heads * self.head_dim
  112. self.scaling = self.head_dim**-0.5
  113. self.rope_theta = rope_theta
  114. self.max_position_embeddings = max_position_embeddings
  115. self.qkv_proj = QKVParallelLinear(
  116. hidden_size=hidden_size,
  117. head_size=self.head_dim,
  118. total_num_heads=self.total_num_heads,
  119. total_num_kv_heads=self.total_num_kv_heads,
  120. bias=bias,
  121. quant_config=quant_config,
  122. )
  123. self.o_proj = RowParallelLinear(
  124. input_size=self.total_num_heads * self.head_dim,
  125. output_size=hidden_size,
  126. bias=bias,
  127. quant_config=quant_config,
  128. )
  129. self.rotary_emb = get_rope(
  130. self.head_dim,
  131. rotary_dim=self.head_dim,
  132. max_position=max_position_embeddings,
  133. base=rope_theta,
  134. rope_scaling=rope_scaling,
  135. )
  136. self.attn = Attention(self.num_heads,
  137. self.head_dim,
  138. self.scaling,
  139. num_kv_heads=self.num_kv_heads,
  140. cache_config=cache_config,
  141. quant_config=quant_config)
  142. def forward(
  143. self,
  144. positions: torch.Tensor,
  145. hidden_states: torch.Tensor,
  146. kv_cache: torch.Tensor,
  147. attn_metadata: AttentionMetadata,
  148. ) -> torch.Tensor:
  149. qkv, _ = self.qkv_proj(hidden_states)
  150. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  151. q, k = self.rotary_emb(positions, q, k)
  152. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  153. output, _ = self.o_proj(attn_output)
  154. return output
  155. class LlamaDecoderLayer(nn.Module):
  156. def __init__(
  157. self,
  158. config: LlamaConfig,
  159. cache_config: Optional[CacheConfig] = None,
  160. quant_config: Optional[QuantizationConfig] = None,
  161. ) -> None:
  162. super().__init__()
  163. self.hidden_size = config.hidden_size
  164. rope_theta = getattr(config, "rope_theta", 10000)
  165. rope_scaling = getattr(config, "rope_scaling", None)
  166. if rope_scaling is not None and getattr(
  167. config, "original_max_position_embeddings", None):
  168. rope_scaling["original_max_position_embeddings"] = (
  169. config.original_max_position_embeddings)
  170. max_position_embeddings = getattr(config, "max_position_embeddings",
  171. 8192)
  172. # Support abacusai/Smaug-72B-v0.1 with attention_bias
  173. # Support internlm/internlm-7b with bias
  174. attention_bias = getattr(config, "attention_bias", False) or getattr(
  175. config, "bias", False)
  176. self.self_attn = LlamaAttention(
  177. config=config,
  178. hidden_size=self.hidden_size,
  179. num_heads=config.num_attention_heads,
  180. num_kv_heads=getattr(config, "num_key_value_heads",
  181. config.num_attention_heads),
  182. rope_theta=rope_theta,
  183. rope_scaling=rope_scaling,
  184. max_position_embeddings=max_position_embeddings,
  185. quant_config=quant_config,
  186. bias=attention_bias,
  187. cache_config=cache_config,
  188. )
  189. self.mlp = LlamaMLP(
  190. hidden_size=self.hidden_size,
  191. intermediate_size=config.intermediate_size,
  192. hidden_act=config.hidden_act,
  193. quant_config=quant_config,
  194. bias=getattr(config, "mlp_bias", False),
  195. )
  196. self.input_layernorm = RMSNorm(config.hidden_size,
  197. eps=config.rms_norm_eps)
  198. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  199. eps=config.rms_norm_eps)
  200. def forward(
  201. self,
  202. positions: torch.Tensor,
  203. hidden_states: torch.Tensor,
  204. kv_cache: torch.Tensor,
  205. attn_metadata: AttentionMetadata,
  206. residual: Optional[torch.Tensor],
  207. ) -> Tuple[torch.Tensor, torch.Tensor]:
  208. # Self Attention
  209. if residual is None:
  210. residual = hidden_states
  211. hidden_states = self.input_layernorm(hidden_states)
  212. else:
  213. hidden_states, residual = self.input_layernorm(
  214. hidden_states, residual)
  215. hidden_states = self.self_attn(
  216. positions=positions,
  217. hidden_states=hidden_states,
  218. kv_cache=kv_cache,
  219. attn_metadata=attn_metadata,
  220. )
  221. # Fully Connected
  222. hidden_states, residual = self.post_attention_layernorm(
  223. hidden_states, residual)
  224. hidden_states = self.mlp(hidden_states)
  225. return hidden_states, residual
  226. class LlamaModel(nn.Module):
  227. def __init__(
  228. self,
  229. config: LlamaConfig,
  230. cache_config: Optional[CacheConfig] = None,
  231. quant_config: Optional[QuantizationConfig] = None,
  232. lora_config: Optional[LoRAConfig] = None,
  233. ) -> None:
  234. super().__init__()
  235. self.config = config
  236. self.padding_idx = config.pad_token_id
  237. lora_vocab = (lora_config.lora_extra_vocab_size *
  238. (lora_config.max_loras or 1)) if lora_config else 0
  239. self.vocab_size = config.vocab_size + lora_vocab
  240. self.org_vocab_size = config.vocab_size
  241. self.embed_tokens = VocabParallelEmbedding(
  242. self.vocab_size,
  243. config.hidden_size,
  244. org_num_embeddings=config.vocab_size,
  245. )
  246. self.layers = nn.ModuleList([
  247. LlamaDecoderLayer(config=config,
  248. cache_config=cache_config,
  249. quant_config=quant_config)
  250. for idx in range(config.num_hidden_layers)
  251. ])
  252. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  253. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  254. return self.embed_tokens(input_ids)
  255. def forward(
  256. self,
  257. input_ids: Optional[torch.Tensor],
  258. positions: torch.Tensor,
  259. kv_caches: List[torch.Tensor],
  260. attn_metadata: AttentionMetadata,
  261. inputs_embeds: Optional[torch.Tensor] = None,
  262. ) -> torch.Tensor:
  263. if inputs_embeds is not None:
  264. hidden_states = inputs_embeds
  265. else:
  266. hidden_states = self.get_input_embeddings(input_ids)
  267. residual = None
  268. for i in range(len(self.layers)):
  269. layer = self.layers[i]
  270. hidden_states, residual = layer(
  271. positions,
  272. hidden_states,
  273. kv_caches[i],
  274. attn_metadata,
  275. residual,
  276. )
  277. hidden_states, _ = self.norm(hidden_states, residual)
  278. return hidden_states
  279. class LlamaForCausalLM(nn.Module, SupportsLoRA):
  280. packed_modules_mapping = {
  281. "qkv_proj": [
  282. "q_proj",
  283. "k_proj",
  284. "v_proj",
  285. ],
  286. "gate_up_proj": [
  287. "gate_proj",
  288. "up_proj",
  289. ],
  290. }
  291. # LoRA specific attributes
  292. supported_lora_modules = [
  293. "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
  294. "lm_head"
  295. ]
  296. embedding_modules = {
  297. "embed_tokens": "input_embeddings",
  298. "lm_head": "output_embeddings",
  299. }
  300. embedding_padding_modules = ["lm_head"]
  301. bitsandbytes_stacked_params_mapping = {
  302. # shard_name, weight_name, index
  303. "q_proj": ("qkv_proj", 0),
  304. "k_proj": ("qkv_proj", 1),
  305. "v_proj": ("qkv_proj", 2),
  306. "gate_proj": ("gate_up_proj", 0),
  307. "up_proj": ("gate_up_proj", 1),
  308. }
  309. def __init__(
  310. self,
  311. config: LlamaConfig,
  312. cache_config: Optional[CacheConfig] = None,
  313. quant_config: Optional[QuantizationConfig] = None,
  314. lora_config: Optional[LoRAConfig] = None,
  315. ) -> None:
  316. super().__init__()
  317. self.config = config
  318. self.lora_config = lora_config
  319. self.model = LlamaModel(config,
  320. cache_config,
  321. quant_config,
  322. lora_config=lora_config)
  323. self.unpadded_vocab_size = config.vocab_size
  324. if lora_config:
  325. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  326. self.lm_head = ParallelLMHead(
  327. self.unpadded_vocab_size,
  328. config.hidden_size,
  329. org_num_embeddings=config.vocab_size,
  330. padding_size=DEFAULT_VOCAB_PADDING_SIZE
  331. # We need bigger padding if using lora for kernel
  332. # compatibility
  333. if not lora_config else lora_config.lora_vocab_padding_size,
  334. )
  335. if config.tie_word_embeddings:
  336. self.lm_head.weight = self.model.embed_tokens.weight
  337. logit_scale = getattr(config, "logit_scale", 1.0)
  338. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  339. config.vocab_size, logit_scale)
  340. self.sampler = Sampler()
  341. def forward(
  342. self,
  343. input_ids: torch.Tensor,
  344. positions: torch.Tensor,
  345. kv_caches: List[torch.Tensor],
  346. attn_metadata: AttentionMetadata,
  347. ) -> torch.Tensor:
  348. hidden_states = self.model(input_ids, positions, kv_caches,
  349. attn_metadata)
  350. return hidden_states
  351. def compute_logits(self, hidden_states: torch.Tensor,
  352. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  353. logits = self.logits_processor(self.lm_head.weight, hidden_states,
  354. sampling_metadata)
  355. return logits
  356. def sample(
  357. self,
  358. logits: torch.Tensor,
  359. sampling_metadata: SamplingMetadata,
  360. ) -> Optional[SamplerOutput]:
  361. next_tokens = self.sampler(logits, sampling_metadata)
  362. return next_tokens
  363. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  364. stacked_params_mapping = [
  365. # (param_name, shard_name, shard_id)
  366. (".qkv_proj", ".q_proj", "q"),
  367. (".qkv_proj", ".k_proj", "k"),
  368. (".qkv_proj", ".v_proj", "v"),
  369. (".gate_up_proj", ".gate_proj", 0),
  370. (".gate_up_proj", ".up_proj", 1),
  371. ]
  372. params_dict = dict(self.named_parameters())
  373. for name, loaded_weight in weights:
  374. if "rotary_emb.inv_freq" in name:
  375. continue
  376. if ("rotary_emb.cos_cached" in name
  377. or "rotary_emb.sin_cached" in name):
  378. # Models trained using ColossalAI may include these tensors in
  379. # the checkpoint. Skip them.
  380. continue
  381. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  382. if weight_name not in name:
  383. continue
  384. name = name.replace(weight_name, param_name)
  385. # Skip loading extra bias for GPTQ models.
  386. if name.endswith(".bias") and name not in params_dict:
  387. continue
  388. param = params_dict[name]
  389. weight_loader = param.weight_loader
  390. weight_loader(param, loaded_weight, shard_id)
  391. break
  392. else:
  393. # Skip loading extra bias for GPTQ models.
  394. if name.endswith(".bias") and name not in params_dict:
  395. continue
  396. # Remapping the name of FP8 kv-scale.
  397. if name.endswith("kv_scale"):
  398. remapped_kv_scale_name = name.replace(
  399. ".kv_scale", ".attn.kv_scale")
  400. if remapped_kv_scale_name not in params_dict:
  401. print_warning_once(
  402. f"Found kv scale in the checkpoint (e.g. {name}), "
  403. "but not found the expected name in the model "
  404. f"(e.g. {remapped_kv_scale_name}). kv-scale is "
  405. "not loaded.")
  406. continue
  407. else:
  408. name = remapped_kv_scale_name
  409. param = params_dict[name]
  410. weight_loader = getattr(param, "weight_loader",
  411. default_weight_loader)
  412. weight_loader(param, loaded_weight)
  413. # If this function is called, it should always initialize KV cache scale
  414. # factors (or else raise an exception). Thus, handled exceptions should
  415. # make sure to leave KV cache scale factors in a known good (dummy) state
  416. def load_kv_cache_scales(self, quantization_param_path: str) -> None:
  417. tp_size = get_tensor_model_parallel_world_size()
  418. tp_rank = get_tensor_model_parallel_rank()
  419. for layer_idx, scaling_factor in kv_cache_scales_loader(
  420. quantization_param_path, tp_rank, tp_size,
  421. self.config.num_hidden_layers,
  422. self.config.__class__.model_type):
  423. layer_self_attn = self.model.layers[layer_idx].self_attn
  424. if is_hip():
  425. # The scaling factor convention we are assuming is
  426. # quantized_value * scaling_factor ~= true_value
  427. # which is consistent with the practice of setting
  428. # scaling_factor = tensor_amax / FPtype_max
  429. scaling_factor *= 2
  430. if hasattr(layer_self_attn, "kv_scale"):
  431. layer_self_attn.attn._kv_scale = scaling_factor
  432. else:
  433. raise RuntimeError("Self attention has no KV cache scaling "
  434. "factor attribute!")