llama.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only LLaMA model compatible with HuggingFace weights."""
  24. from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
  25. import torch
  26. from torch import nn
  27. from transformers import LlamaConfig
  28. from aphrodite.attention import Attention, AttentionMetadata
  29. from aphrodite.common.config import CacheConfig, LoRAConfig
  30. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  31. from aphrodite.common.utils import is_hip
  32. from aphrodite.distributed import (get_current_tp_rank_partition_size,
  33. get_pp_group,
  34. get_tensor_model_parallel_rank,
  35. get_tensor_model_parallel_world_size)
  36. from aphrodite.modeling.layers.activation import SiluAndMul
  37. from aphrodite.modeling.layers.layernorm import RMSNorm
  38. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  39. QKVParallelLinear,
  40. RowParallelLinear)
  41. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  42. from aphrodite.modeling.layers.rotary_embedding import get_rope
  43. from aphrodite.modeling.layers.sampler import Sampler
  44. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  45. ParallelLMHead, VocabParallelEmbedding)
  46. from aphrodite.modeling.model_loader.weight_utils import (
  47. default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
  48. from aphrodite.modeling.models.interfaces import SupportsLoRA
  49. from aphrodite.modeling.models.utils import (PPMissingLayer,
  50. is_pp_missing_parameter,
  51. make_layers)
  52. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  53. from aphrodite.quantization.base_config import QuantizationConfig
  54. class LlamaMLP(nn.Module):
  55. def __init__(
  56. self,
  57. hidden_size: int,
  58. intermediate_size: int,
  59. hidden_act: str,
  60. quant_config: Optional[QuantizationConfig] = None,
  61. bias: bool = False,
  62. prefix: str = "",
  63. ) -> None:
  64. super().__init__()
  65. self.gate_up_proj = MergedColumnParallelLinear(
  66. input_size=hidden_size,
  67. output_sizes=[intermediate_size] * 2,
  68. bias=bias,
  69. quant_config=quant_config,
  70. prefix=f"{prefix}.gate_up_proj")
  71. self.down_proj = RowParallelLinear(input_size=intermediate_size,
  72. output_size=hidden_size,
  73. bias=bias,
  74. quant_config=quant_config,
  75. prefix=f"{prefix}.down_proj")
  76. if hidden_act != "silu":
  77. raise ValueError(f"Unsupported activation: {hidden_act}. "
  78. "Only silu is supported for now.")
  79. self.act_fn = SiluAndMul()
  80. def forward(self, x):
  81. gate_up, _ = self.gate_up_proj(x)
  82. x = self.act_fn(gate_up)
  83. x, _ = self.down_proj(x)
  84. return x
  85. class LlamaAttention(nn.Module):
  86. def __init__(
  87. self,
  88. config: LlamaConfig,
  89. hidden_size: int,
  90. num_heads: int,
  91. num_kv_heads: int,
  92. rope_theta: float = 10000,
  93. rope_scaling: Optional[Dict[str, Any]] = None,
  94. max_position_embeddings: int = 8192,
  95. quant_config: Optional[QuantizationConfig] = None,
  96. bias: bool = False,
  97. cache_config: Optional[CacheConfig] = None,
  98. prefix: str = "",
  99. ) -> None:
  100. super().__init__()
  101. self.hidden_size = hidden_size
  102. tp_size = get_tensor_model_parallel_world_size()
  103. tp_rank = get_tensor_model_parallel_rank()
  104. self.total_num_heads = num_heads
  105. self.total_num_kv_heads = num_kv_heads
  106. self.num_kv_heads = max(
  107. 1,
  108. get_current_tp_rank_partition_size(self.total_num_kv_heads,
  109. tp_rank, tp_size))
  110. num_heads_per_kv_head = self.total_num_heads // self.total_num_kv_heads
  111. self.num_heads = self.num_kv_heads * num_heads_per_kv_head
  112. # MistralConfig has an optional head_dim introduced by Mistral-Nemo
  113. self.head_dim = getattr(config, "head_dim",
  114. self.hidden_size // self.total_num_heads)
  115. self.q_size = self.num_heads * self.head_dim
  116. self.kv_size = self.num_kv_heads * self.head_dim
  117. self.scaling = self.head_dim**-0.5
  118. self.rope_theta = rope_theta
  119. self.max_position_embeddings = max_position_embeddings
  120. self.qkv_proj = QKVParallelLinear(
  121. hidden_size=hidden_size,
  122. head_size=self.head_dim,
  123. total_num_heads=self.total_num_heads,
  124. total_num_kv_heads=self.total_num_kv_heads,
  125. bias=bias,
  126. quant_config=quant_config,
  127. prefix=f"{prefix}.qkv_proj",
  128. )
  129. self.o_proj = RowParallelLinear(
  130. input_size=self.total_num_heads * self.head_dim,
  131. output_size=hidden_size,
  132. bias=bias,
  133. quant_config=quant_config,
  134. partition_multiple_of=num_heads_per_kv_head * self.head_dim,
  135. prefix=f"{prefix}.o_proj",
  136. )
  137. self.rotary_emb = get_rope(
  138. self.head_dim,
  139. rotary_dim=self.head_dim,
  140. max_position=max_position_embeddings,
  141. base=rope_theta,
  142. rope_scaling=rope_scaling,
  143. )
  144. self.attn = Attention(self.num_heads,
  145. self.head_dim,
  146. self.scaling,
  147. num_kv_heads=self.num_kv_heads,
  148. cache_config=cache_config,
  149. quant_config=quant_config)
  150. def forward(
  151. self,
  152. positions: torch.Tensor,
  153. hidden_states: torch.Tensor,
  154. kv_cache: torch.Tensor,
  155. attn_metadata: AttentionMetadata,
  156. ) -> torch.Tensor:
  157. qkv, _ = self.qkv_proj(hidden_states)
  158. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  159. q, k = self.rotary_emb(positions, q, k)
  160. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  161. output, _ = self.o_proj(attn_output)
  162. return output
  163. class LlamaDecoderLayer(nn.Module):
  164. def __init__(
  165. self,
  166. config: LlamaConfig,
  167. cache_config: Optional[CacheConfig] = None,
  168. quant_config: Optional[QuantizationConfig] = None,
  169. prefix: str = "",
  170. ) -> None:
  171. super().__init__()
  172. self.hidden_size = config.hidden_size
  173. rope_theta = getattr(config, "rope_theta", 10000)
  174. rope_scaling = getattr(config, "rope_scaling", None)
  175. if rope_scaling is not None and getattr(
  176. config, "original_max_position_embeddings", None):
  177. rope_scaling["original_max_position_embeddings"] = (
  178. config.original_max_position_embeddings)
  179. max_position_embeddings = getattr(config, "max_position_embeddings",
  180. 8192)
  181. # Support abacusai/Smaug-72B-v0.1 with attention_bias
  182. # Support internlm/internlm-7b with bias
  183. attention_bias = getattr(config, "attention_bias", False) or getattr(
  184. config, "bias", False)
  185. self.self_attn = LlamaAttention(
  186. config=config,
  187. hidden_size=self.hidden_size,
  188. num_heads=config.num_attention_heads,
  189. num_kv_heads=getattr(config, "num_key_value_heads",
  190. config.num_attention_heads),
  191. rope_theta=rope_theta,
  192. rope_scaling=rope_scaling,
  193. max_position_embeddings=max_position_embeddings,
  194. quant_config=quant_config,
  195. bias=attention_bias,
  196. cache_config=cache_config,
  197. prefix=f"{prefix}.self_attn",
  198. )
  199. self.mlp = LlamaMLP(
  200. hidden_size=self.hidden_size,
  201. intermediate_size=config.intermediate_size,
  202. hidden_act=config.hidden_act,
  203. quant_config=quant_config,
  204. bias=getattr(config, "mlp_bias", False),
  205. prefix=f"{prefix}.mlp",
  206. )
  207. self.input_layernorm = RMSNorm(config.hidden_size,
  208. eps=config.rms_norm_eps)
  209. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  210. eps=config.rms_norm_eps)
  211. def forward(
  212. self,
  213. positions: torch.Tensor,
  214. hidden_states: torch.Tensor,
  215. kv_cache: torch.Tensor,
  216. attn_metadata: AttentionMetadata,
  217. residual: Optional[torch.Tensor],
  218. ) -> Tuple[torch.Tensor, torch.Tensor]:
  219. # Self Attention
  220. if residual is None:
  221. residual = hidden_states
  222. hidden_states = self.input_layernorm(hidden_states)
  223. else:
  224. hidden_states, residual = self.input_layernorm(
  225. hidden_states, residual)
  226. hidden_states = self.self_attn(
  227. positions=positions,
  228. hidden_states=hidden_states,
  229. kv_cache=kv_cache,
  230. attn_metadata=attn_metadata,
  231. )
  232. # Fully Connected
  233. hidden_states, residual = self.post_attention_layernorm(
  234. hidden_states, residual)
  235. hidden_states = self.mlp(hidden_states)
  236. return hidden_states, residual
  237. class LlamaModel(nn.Module):
  238. def __init__(
  239. self,
  240. config: LlamaConfig,
  241. cache_config: Optional[CacheConfig] = None,
  242. quant_config: Optional[QuantizationConfig] = None,
  243. lora_config: Optional[LoRAConfig] = None,
  244. prefix: str = "",
  245. ) -> None:
  246. super().__init__()
  247. self.config = config
  248. self.padding_idx = config.pad_token_id
  249. lora_vocab = (lora_config.lora_extra_vocab_size *
  250. (lora_config.max_loras or 1)) if lora_config else 0
  251. self.vocab_size = config.vocab_size + lora_vocab
  252. self.org_vocab_size = config.vocab_size
  253. if get_pp_group().is_first_rank or (config.tie_word_embeddings
  254. and get_pp_group().is_last_rank):
  255. self.embed_tokens = VocabParallelEmbedding(
  256. self.vocab_size,
  257. config.hidden_size,
  258. org_num_embeddings=config.vocab_size,
  259. )
  260. else:
  261. self.embed_tokens = PPMissingLayer()
  262. self.start_layer, self.end_layer, self.layers = make_layers(
  263. config.num_hidden_layers,
  264. lambda prefix: LlamaDecoderLayer(config=config,
  265. cache_config=cache_config,
  266. quant_config=quant_config,
  267. prefix=prefix),
  268. prefix=f"{prefix}.layers")
  269. if get_pp_group().is_last_rank:
  270. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  271. else:
  272. self.norm = PPMissingLayer()
  273. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  274. return self.embed_tokens(input_ids)
  275. def forward(
  276. self,
  277. input_ids: Optional[torch.Tensor],
  278. positions: torch.Tensor,
  279. kv_caches: List[torch.Tensor],
  280. attn_metadata: AttentionMetadata,
  281. intermediate_tensors: Optional[IntermediateTensors],
  282. inputs_embeds: Optional[torch.Tensor] = None,
  283. ) -> Union[torch.Tensor, IntermediateTensors]:
  284. if get_pp_group().is_first_rank:
  285. if inputs_embeds is not None:
  286. hidden_states = inputs_embeds
  287. else:
  288. hidden_states = self.get_input_embeddings(input_ids)
  289. residual = None
  290. else:
  291. assert intermediate_tensors is not None
  292. hidden_states = intermediate_tensors["hidden_states"]
  293. residual = intermediate_tensors["residual"]
  294. for i in range(self.start_layer, self.end_layer):
  295. layer = self.layers[i]
  296. hidden_states, residual = layer(
  297. positions,
  298. hidden_states,
  299. kv_caches[i - self.start_layer],
  300. attn_metadata,
  301. residual,
  302. )
  303. if not get_pp_group().is_last_rank:
  304. return IntermediateTensors({
  305. "hidden_states": hidden_states,
  306. "residual": residual
  307. })
  308. hidden_states, _ = self.norm(hidden_states, residual)
  309. return hidden_states
  310. class LlamaForCausalLM(nn.Module, SupportsLoRA):
  311. packed_modules_mapping = {
  312. "qkv_proj": [
  313. "q_proj",
  314. "k_proj",
  315. "v_proj",
  316. ],
  317. "gate_up_proj": [
  318. "gate_proj",
  319. "up_proj",
  320. ],
  321. }
  322. # LoRA specific attributes
  323. supported_lora_modules = [
  324. "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
  325. "lm_head"
  326. ]
  327. embedding_modules = {
  328. "embed_tokens": "input_embeddings",
  329. "lm_head": "output_embeddings",
  330. }
  331. embedding_padding_modules = ["lm_head"]
  332. bitsandbytes_stacked_params_mapping = {
  333. # shard_name, weight_name, index
  334. "q_proj": ("qkv_proj", 0),
  335. "k_proj": ("qkv_proj", 1),
  336. "v_proj": ("qkv_proj", 2),
  337. "gate_proj": ("gate_up_proj", 0),
  338. "up_proj": ("gate_up_proj", 1),
  339. }
  340. def __init__(
  341. self,
  342. config: LlamaConfig,
  343. cache_config: Optional[CacheConfig] = None,
  344. quant_config: Optional[QuantizationConfig] = None,
  345. lora_config: Optional[LoRAConfig] = None,
  346. ) -> None:
  347. super().__init__()
  348. self.config = config
  349. self.lora_config = lora_config
  350. self.model = LlamaModel(config,
  351. cache_config,
  352. quant_config,
  353. lora_config=lora_config,
  354. prefix="model")
  355. if get_pp_group().is_last_rank:
  356. self.unpadded_vocab_size = config.vocab_size
  357. if lora_config:
  358. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  359. self.lm_head = ParallelLMHead(
  360. self.unpadded_vocab_size,
  361. config.hidden_size,
  362. org_num_embeddings=config.vocab_size,
  363. padding_size=None
  364. # We need bigger padding if using lora for kernel
  365. # compatibility
  366. if not lora_config else lora_config.lora_vocab_padding_size,
  367. quant_config=quant_config,
  368. )
  369. if config.tie_word_embeddings:
  370. self.lm_head.weight = self.model.embed_tokens.weight
  371. logit_scale = getattr(config, "logit_scale", 1.0)
  372. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  373. config.vocab_size,
  374. logit_scale)
  375. self.sampler = Sampler()
  376. else:
  377. self.lm_head = PPMissingLayer()
  378. def forward(
  379. self,
  380. input_ids: torch.Tensor,
  381. positions: torch.Tensor,
  382. kv_caches: List[torch.Tensor],
  383. attn_metadata: AttentionMetadata,
  384. intermediate_tensors: Optional[IntermediateTensors] = None,
  385. ) -> Union[torch.Tensor, IntermediateTensors]:
  386. model_output = self.model(input_ids, positions, kv_caches,
  387. attn_metadata, intermediate_tensors)
  388. return model_output
  389. def compute_logits(self, hidden_states: torch.Tensor,
  390. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  391. logits = self.logits_processor(self.lm_head, hidden_states,
  392. sampling_metadata)
  393. return logits
  394. def sample(
  395. self,
  396. logits: torch.Tensor,
  397. sampling_metadata: SamplingMetadata,
  398. ) -> Optional[SamplerOutput]:
  399. next_tokens = self.sampler(logits, sampling_metadata)
  400. return next_tokens
  401. def make_empty_intermediate_tensors(
  402. self, batch_size: int, dtype: torch.dtype,
  403. device: torch.device) -> IntermediateTensors:
  404. return IntermediateTensors({
  405. "hidden_states":
  406. torch.zeros((batch_size, self.config.hidden_size),
  407. dtype=dtype,
  408. device=device),
  409. "residual":
  410. torch.zeros((batch_size, self.config.hidden_size),
  411. dtype=dtype,
  412. device=device),
  413. })
  414. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  415. stacked_params_mapping = [
  416. # (param_name, shard_name, shard_id)
  417. (".qkv_proj", ".q_proj", "q"),
  418. (".qkv_proj", ".k_proj", "k"),
  419. (".qkv_proj", ".v_proj", "v"),
  420. (".gate_up_proj", ".gate_proj", 0),
  421. (".gate_up_proj", ".up_proj", 1),
  422. ]
  423. params_dict = dict(self.named_parameters())
  424. for name, loaded_weight in weights:
  425. if "rotary_emb.inv_freq" in name:
  426. continue
  427. if ("rotary_emb.cos_cached" in name
  428. or "rotary_emb.sin_cached" in name):
  429. # Models trained using ColossalAI may include these tensors in
  430. # the checkpoint. Skip them.
  431. continue
  432. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  433. if weight_name not in name:
  434. continue
  435. name = name.replace(weight_name, param_name)
  436. # Skip loading extra bias for GPTQ models.
  437. if name.endswith(".bias") and name not in params_dict:
  438. continue
  439. if is_pp_missing_parameter(name, self):
  440. continue
  441. param = params_dict[name]
  442. weight_loader = param.weight_loader
  443. weight_loader(param, loaded_weight, shard_id)
  444. break
  445. else:
  446. # Skip loading extra bias for GPTQ models.
  447. if name.endswith(".bias") and name not in params_dict:
  448. continue
  449. # Remapping the name of FP8 kv-scale.
  450. name = maybe_remap_kv_scale_name(name, params_dict)
  451. if name is None:
  452. continue
  453. if is_pp_missing_parameter(name, self):
  454. continue
  455. param = params_dict[name]
  456. weight_loader = getattr(param, "weight_loader",
  457. default_weight_loader)
  458. weight_loader(param, loaded_weight)
  459. # If this function is called, it should always initialize KV cache scale
  460. # factors (or else raise an exception). Thus, handled exceptions should
  461. # make sure to leave KV cache scale factors in a known good (dummy) state
  462. def load_kv_cache_scales(self, quantization_param_path: str) -> None:
  463. tp_size = get_tensor_model_parallel_world_size()
  464. tp_rank = get_tensor_model_parallel_rank()
  465. for layer_idx, scaling_factor in kv_cache_scales_loader(
  466. quantization_param_path, tp_rank, tp_size,
  467. self.config.num_hidden_layers,
  468. self.config.__class__.model_type):
  469. if not isinstance(self.model.layers[layer_idx], nn.Identity):
  470. layer_self_attn = self.model.layers[layer_idx].self_attn
  471. if is_hip():
  472. # The scaling factor convention we are assuming is
  473. # quantized_value * scaling_factor ~= true_value
  474. # which is consistent with the practice of setting
  475. # scaling_factor = tensor_amax / FPtype_max
  476. scaling_factor *= 2
  477. if hasattr(layer_self_attn, "kv_scale"):
  478. layer_self_attn.attn._kv_scale = scaling_factor
  479. else:
  480. raise RuntimeError("Self attention has no KV cache scaling "
  481. "factor attribute!")