llama.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only LLaMA model compatible with HuggingFace weights."""
  24. from typing import Any, Dict, Iterable, List, Optional, Tuple
  25. import torch
  26. from torch import nn
  27. from transformers import LlamaConfig
  28. from aphrodite.attention import Attention, AttentionMetadata
  29. from aphrodite.common.config import LoRAConfig
  30. from aphrodite.common.sequence import SamplerOutput
  31. from aphrodite.common.utils import is_hip
  32. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  33. get_tensor_model_parallel_world_size)
  34. from aphrodite.modeling.layers.activation import SiluAndMul
  35. from aphrodite.modeling.layers.layernorm import RMSNorm
  36. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  37. QKVParallelLinear,
  38. RowParallelLinear)
  39. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  40. from aphrodite.modeling.layers.rotary_embedding import get_rope
  41. from aphrodite.modeling.layers.sampler import Sampler
  42. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  43. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  44. from aphrodite.modeling.model_loader.weight_utils import (
  45. default_weight_loader, kv_cache_scales_loader)
  46. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  47. from aphrodite.quantization.base_config import QuantizationConfig
  48. class LlamaMLP(nn.Module):
  49. def __init__(
  50. self,
  51. hidden_size: int,
  52. intermediate_size: int,
  53. hidden_act: str,
  54. quant_config: Optional[QKVParallelLinear] = None,
  55. ) -> None:
  56. super().__init__()
  57. self.gate_up_proj = MergedColumnParallelLinear(
  58. hidden_size, [intermediate_size] * 2,
  59. bias=False,
  60. quant_config=quant_config)
  61. self.down_proj = RowParallelLinear(intermediate_size,
  62. hidden_size,
  63. bias=False,
  64. quant_config=quant_config)
  65. if hidden_act != "silu":
  66. raise ValueError(f"Unsupported activation: {hidden_act}. "
  67. "Only silu is supported for now.")
  68. self.act_fn = SiluAndMul()
  69. def forward(self, x):
  70. gate_up, _ = self.gate_up_proj(x)
  71. x = self.act_fn(gate_up)
  72. x, _ = self.down_proj(x)
  73. return x
  74. class LlamaAttention(nn.Module):
  75. def __init__(
  76. self,
  77. hidden_size: int,
  78. num_heads: int,
  79. num_kv_heads: int,
  80. rope_theta: float = 10000,
  81. rope_scaling: Optional[Dict[str, Any]] = None,
  82. max_position_embeddings: int = 8192,
  83. quant_config: Optional[QuantizationConfig] = None,
  84. bias: bool = False,
  85. sliding_window: Optional[int] = None,
  86. ) -> None:
  87. super().__init__()
  88. self.hidden_size = hidden_size
  89. tp_size = get_tensor_model_parallel_world_size()
  90. self.total_num_heads = num_heads
  91. assert self.total_num_heads % tp_size == 0
  92. self.num_heads = self.total_num_heads // tp_size
  93. self.total_num_kv_heads = num_kv_heads
  94. if self.total_num_kv_heads >= tp_size:
  95. # Number of KV heads is greater than TP size, so we partition
  96. # the KV heads across multiple tensor parallel GPUs.
  97. assert self.total_num_kv_heads % tp_size == 0
  98. else:
  99. # Number of KV heads is less than TP size, so we replicate
  100. # the KV heads across multiple tensor parallel GPUs.
  101. assert tp_size % self.total_num_kv_heads == 0
  102. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  103. self.head_dim = hidden_size // self.total_num_heads
  104. self.q_size = self.num_heads * self.head_dim
  105. self.kv_size = self.num_kv_heads * self.head_dim
  106. self.scaling = self.head_dim**-0.5
  107. self.rope_theta = rope_theta
  108. self.max_position_embeddings = max_position_embeddings
  109. # This will be overwritten by model initialization if we are using it.
  110. # N.B. currently we only support per tensor scalar scaling factors
  111. # & only applicable to ROCm (AMD GPU).
  112. # The scaling factor convention we are assuming is
  113. # quantized_value * scaling_factor ~= true_value
  114. # which is consistent with the practice of setting
  115. # scaling_factor = tensor_amax / FPtype_max
  116. self.kv_scale = 1.0
  117. self.qkv_proj = QKVParallelLinear(
  118. hidden_size,
  119. self.head_dim,
  120. self.total_num_heads,
  121. self.total_num_kv_heads,
  122. bias=bias,
  123. quant_config=quant_config,
  124. )
  125. self.o_proj = RowParallelLinear(
  126. self.total_num_heads * self.head_dim,
  127. hidden_size,
  128. bias=bias,
  129. quant_config=quant_config,
  130. )
  131. self.rotary_emb = get_rope(
  132. self.head_dim,
  133. rotary_dim=self.head_dim,
  134. max_position=max_position_embeddings,
  135. base=rope_theta,
  136. rope_scaling=rope_scaling,
  137. )
  138. self.attn = Attention(self.num_heads,
  139. self.head_dim,
  140. self.scaling,
  141. num_kv_heads=self.num_kv_heads,
  142. sliding_window=sliding_window)
  143. def forward(
  144. self,
  145. positions: torch.Tensor,
  146. hidden_states: torch.Tensor,
  147. kv_cache: torch.Tensor,
  148. attn_metadata: AttentionMetadata,
  149. ) -> torch.Tensor:
  150. qkv, _ = self.qkv_proj(hidden_states)
  151. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  152. q, k = self.rotary_emb(positions, q, k)
  153. attn_output = self.attn(q, k, v, kv_cache, attn_metadata,
  154. self.kv_scale)
  155. output, _ = self.o_proj(attn_output)
  156. return output
  157. class LlamaDecoderLayer(nn.Module):
  158. def __init__(
  159. self,
  160. config: LlamaConfig,
  161. quant_config: Optional[QuantizationConfig] = None,
  162. ) -> None:
  163. super().__init__()
  164. self.hidden_size = config.hidden_size
  165. rope_theta = getattr(config, "rope_theta", 10000)
  166. rope_scaling = getattr(config, "rope_scaling", None)
  167. if rope_scaling is not None and getattr(
  168. config, "original_max_position_embeddings", None):
  169. rope_scaling["original_max_position_embeddings"] = (
  170. config.original_max_position_embeddings)
  171. max_position_embeddings = getattr(config, "max_position_embeddings",
  172. 8192)
  173. sliding_window = getattr(config, "sliding_window", None)
  174. # Support abacusai/Smaug-72B-v0.1 with attention_bias
  175. # Support internlm/internlm-7b with bias
  176. attention_bias = getattr(config, "attention_bias", False) or getattr(
  177. config, "bias", False)
  178. self.self_attn = LlamaAttention(
  179. hidden_size=self.hidden_size,
  180. num_heads=config.num_attention_heads,
  181. num_kv_heads=getattr(config, "num_key_value_heads",
  182. config.num_attention_heads),
  183. rope_theta=rope_theta,
  184. rope_scaling=rope_scaling,
  185. max_position_embeddings=max_position_embeddings,
  186. quant_config=quant_config,
  187. bias=attention_bias,
  188. sliding_window=sliding_window,
  189. )
  190. self.mlp = LlamaMLP(
  191. hidden_size=self.hidden_size,
  192. intermediate_size=config.intermediate_size,
  193. hidden_act=config.hidden_act,
  194. quant_config=quant_config,
  195. )
  196. self.input_layernorm = RMSNorm(config.hidden_size,
  197. eps=config.rms_norm_eps)
  198. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  199. eps=config.rms_norm_eps)
  200. def forward(
  201. self,
  202. positions: torch.Tensor,
  203. hidden_states: torch.Tensor,
  204. kv_cache: torch.Tensor,
  205. attn_metadata: AttentionMetadata,
  206. residual: Optional[torch.Tensor],
  207. ) -> Tuple[torch.Tensor, torch.Tensor]:
  208. # Self Attention
  209. if residual is None:
  210. residual = hidden_states
  211. hidden_states = self.input_layernorm(hidden_states)
  212. else:
  213. hidden_states, residual = self.input_layernorm(
  214. hidden_states, residual)
  215. hidden_states = self.self_attn(
  216. positions=positions,
  217. hidden_states=hidden_states,
  218. kv_cache=kv_cache,
  219. attn_metadata=attn_metadata,
  220. )
  221. # Fully Connected
  222. hidden_states, residual = self.post_attention_layernorm(
  223. hidden_states, residual)
  224. hidden_states = self.mlp(hidden_states)
  225. return hidden_states, residual
  226. class LlamaModel(nn.Module):
  227. def __init__(
  228. self,
  229. config: LlamaConfig,
  230. quant_config: Optional[QuantizationConfig] = None,
  231. lora_config: Optional[LoRAConfig] = None,
  232. ) -> None:
  233. super().__init__()
  234. self.config = config
  235. self.padding_idx = config.pad_token_id
  236. lora_vocab = (lora_config.lora_extra_vocab_size *
  237. (lora_config.max_loras or 1)) if lora_config else 0
  238. self.vocab_size = config.vocab_size + lora_vocab
  239. self.org_vocab_size = config.vocab_size
  240. self.embed_tokens = VocabParallelEmbedding(
  241. self.vocab_size,
  242. config.hidden_size,
  243. org_num_embeddings=config.vocab_size,
  244. )
  245. self.layers = nn.ModuleList([
  246. LlamaDecoderLayer(config, quant_config)
  247. for _ in range(config.num_hidden_layers)
  248. ])
  249. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  250. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  251. return self.embed_tokens(input_ids)
  252. def forward(
  253. self,
  254. input_ids: Optional[torch.Tensor],
  255. positions: torch.Tensor,
  256. kv_caches: List[torch.Tensor],
  257. attn_metadata: AttentionMetadata,
  258. inputs_embeds: Optional[torch.Tensor] = None,
  259. ) -> torch.Tensor:
  260. if inputs_embeds is not None:
  261. hidden_states = inputs_embeds
  262. else:
  263. hidden_states = self.get_input_embeddings(input_ids)
  264. residual = None
  265. for i in range(len(self.layers)):
  266. layer = self.layers[i]
  267. hidden_states, residual = layer(
  268. positions,
  269. hidden_states,
  270. kv_caches[i],
  271. attn_metadata,
  272. residual,
  273. )
  274. hidden_states, _ = self.norm(hidden_states, residual)
  275. return hidden_states
  276. class LlamaForCausalLM(nn.Module):
  277. packed_modules_mapping = {
  278. "qkv_proj": [
  279. "q_proj",
  280. "k_proj",
  281. "v_proj",
  282. ],
  283. "gate_up_proj": [
  284. "gate_proj",
  285. "up_proj",
  286. ],
  287. }
  288. # LoRA specific attributes
  289. supported_lora_modules = [
  290. "qkv_proj",
  291. "o_proj",
  292. "gate_up_proj",
  293. "down_proj",
  294. "embed_tokens",
  295. "lm_head",
  296. ]
  297. embedding_modules = {
  298. "embed_tokens": "input_embeddings",
  299. "lm_head": "output_embeddings",
  300. }
  301. embedding_padding_modules = ["lm_head"]
  302. def __init__(
  303. self,
  304. config: LlamaConfig,
  305. quant_config: Optional[QuantizationConfig] = None,
  306. lora_config: Optional[LoRAConfig] = None,
  307. ) -> None:
  308. super().__init__()
  309. self.config = config
  310. self.model = LlamaModel(config, quant_config, lora_config=lora_config)
  311. self.unpadded_vocab_size = config.vocab_size
  312. if lora_config:
  313. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  314. self.lm_head = ParallelLMHead(
  315. self.unpadded_vocab_size,
  316. config.hidden_size,
  317. org_num_embeddings=config.vocab_size,
  318. padding_size=DEFAULT_VOCAB_PADDING_SIZE
  319. # We need bigger padding if using lora for kernel
  320. # compatibility
  321. if not lora_config else lora_config.lora_vocab_padding_size,
  322. )
  323. logit_scale = getattr(config, "logit_scale", 1.0)
  324. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  325. config.vocab_size, logit_scale)
  326. self.sampler = Sampler()
  327. def forward(
  328. self,
  329. input_ids: torch.Tensor,
  330. positions: torch.Tensor,
  331. kv_caches: List[torch.Tensor],
  332. attn_metadata: AttentionMetadata,
  333. ) -> torch.Tensor:
  334. hidden_states = self.model(input_ids, positions, kv_caches,
  335. attn_metadata)
  336. return hidden_states
  337. def compute_logits(self, hidden_states: torch.Tensor,
  338. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  339. logits = self.logits_processor(self.lm_head.weight, hidden_states,
  340. sampling_metadata)
  341. return logits
  342. def sample(
  343. self,
  344. logits: torch.Tensor,
  345. sampling_metadata: SamplingMetadata,
  346. ) -> Optional[SamplerOutput]:
  347. next_tokens = self.sampler(logits, sampling_metadata)
  348. return next_tokens
  349. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  350. stacked_params_mapping = [
  351. # (param_name, shard_name, shard_id)
  352. (".qkv_proj", ".q_proj", "q"),
  353. (".qkv_proj", ".k_proj", "k"),
  354. (".qkv_proj", ".v_proj", "v"),
  355. (".gate_up_proj", ".gate_proj", 0),
  356. (".gate_up_proj", ".up_proj", 1),
  357. ]
  358. params_dict = dict(self.named_parameters())
  359. for name, loaded_weight in weights:
  360. if "rotary_emb.inv_freq" in name:
  361. continue
  362. if ("rotary_emb.cos_cached" in name
  363. or "rotary_emb.sin_cached" in name):
  364. # Models trained using ColossalAI may include these tensors in
  365. # the checkpoint. Skip them.
  366. continue
  367. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  368. if weight_name not in name:
  369. continue
  370. name = name.replace(weight_name, param_name)
  371. # Skip loading extra bias for GPTQ models.
  372. if name.endswith(".bias") and name not in params_dict:
  373. continue
  374. param = params_dict[name]
  375. weight_loader = param.weight_loader
  376. weight_loader(param, loaded_weight, shard_id)
  377. break
  378. else:
  379. # Skip loading extra bias for GPTQ models.
  380. if name.endswith(".bias") and name not in params_dict:
  381. continue
  382. param = params_dict[name]
  383. weight_loader = getattr(param, "weight_loader",
  384. default_weight_loader)
  385. weight_loader(param, loaded_weight)
  386. # If this function is called, it should always initialize KV cache scale
  387. # factors (or else raise an exception). Thus, handled exceptions should
  388. # make sure to leave KV cache scale factors in a known good (dummy) state
  389. def load_kv_cache_scales(self, quantization_param_path: str) -> None:
  390. tp_size = get_tensor_model_parallel_world_size()
  391. tp_rank = get_tensor_model_parallel_rank()
  392. for layer_idx, scaling_factor in kv_cache_scales_loader(
  393. quantization_param_path, tp_rank, tp_size,
  394. self.config.num_hidden_layers,
  395. self.config.__class__.model_type):
  396. layer_self_attn = self.model.layers[layer_idx].self_attn
  397. if is_hip():
  398. # The scaling factor convention we are assuming is
  399. # quantized_value * scaling_factor ~= true_value
  400. # which is consistent with the practice of setting
  401. # scaling_factor = tensor_amax / FPtype_max
  402. scaling_factor *= 2
  403. if hasattr(layer_self_attn, "kv_scale"):
  404. layer_self_attn.kv_scale = scaling_factor
  405. else:
  406. raise RuntimeError("Self attention has no KV cache scaling "
  407. "factor attribute!")