llama.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only LLaMA model compatible with HuggingFace weights."""
  24. from typing import Any, Dict, List, Optional, Tuple
  25. import torch
  26. from torch import nn
  27. from transformers import LlamaConfig
  28. from aphrodite.attention import Attention, AttentionMetadata
  29. from aphrodite.common.config import LoRAConfig
  30. from aphrodite.common.sequence import SamplerOutput
  31. from aphrodite.common.utils import is_hip
  32. from aphrodite.modeling.hf_downloader import (default_weight_loader,
  33. hf_model_weights_iterator,
  34. kv_cache_scales_loader)
  35. from aphrodite.modeling.layers.activation import SiluAndMul
  36. from aphrodite.modeling.layers.layernorm import RMSNorm
  37. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  38. LinearMethodBase,
  39. MergedColumnParallelLinear,
  40. QKVParallelLinear,
  41. RowParallelLinear)
  42. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  43. from aphrodite.modeling.layers.rotary_embedding import get_rope
  44. from aphrodite.modeling.layers.sampler import Sampler
  45. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  46. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  47. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  48. get_tensor_model_parallel_world_size)
  49. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  50. class LlamaMLP(nn.Module):
  51. def __init__(
  52. self,
  53. hidden_size: int,
  54. intermediate_size: int,
  55. hidden_act: str,
  56. linear_method: Optional[LinearMethodBase] = None,
  57. ) -> None:
  58. super().__init__()
  59. if (linear_method is not None
  60. and not linear_method.quant_config.merge_weight()):
  61. self.merge_weight = False
  62. self.gate_proj = ColumnParallelLinear(
  63. hidden_size,
  64. intermediate_size,
  65. bias=False,
  66. linear_method=linear_method,
  67. )
  68. self.up_proj = ColumnParallelLinear(
  69. hidden_size,
  70. intermediate_size,
  71. bias=False,
  72. linear_method=linear_method,
  73. )
  74. else:
  75. self.merge_weight = True
  76. self.gate_up_proj = MergedColumnParallelLinear(
  77. hidden_size,
  78. [intermediate_size] * 2,
  79. bias=False,
  80. linear_method=linear_method,
  81. )
  82. self.down_proj = RowParallelLinear(
  83. intermediate_size,
  84. hidden_size,
  85. bias=False,
  86. linear_method=linear_method,
  87. )
  88. if hidden_act != "silu":
  89. raise ValueError(f"Unsupported activation: {hidden_act}. "
  90. "Only silu is supported for now.")
  91. self.act_fn = SiluAndMul()
  92. def forward(self, x):
  93. if self.merge_weight:
  94. gate_up, _ = self.gate_up_proj(x)
  95. else:
  96. up, _ = self.up_proj(x)
  97. gate, _ = self.gate_proj(x)
  98. gate_up = torch.cat([gate, up], dim=-1)
  99. x = self.act_fn(gate_up)
  100. x, _ = self.down_proj(x)
  101. return x
  102. class LlamaAttention(nn.Module):
  103. def __init__(
  104. self,
  105. hidden_size: int,
  106. num_heads: int,
  107. num_kv_heads: int,
  108. rope_theta: float = 10000,
  109. rope_scaling: Optional[Dict[str, Any]] = None,
  110. max_position_embeddings: int = 8192,
  111. linear_method: Optional[LinearMethodBase] = None,
  112. bias: bool = False,
  113. sliding_window: Optional[int] = None,
  114. ) -> None:
  115. super().__init__()
  116. self.hidden_size = hidden_size
  117. tp_size = get_tensor_model_parallel_world_size()
  118. self.total_num_heads = num_heads
  119. assert self.total_num_heads % tp_size == 0
  120. self.num_heads = self.total_num_heads // tp_size
  121. self.total_num_kv_heads = num_kv_heads
  122. if self.total_num_kv_heads >= tp_size:
  123. # Number of KV heads is greater than TP size, so we partition
  124. # the KV heads across multiple tensor parallel GPUs.
  125. assert self.total_num_kv_heads % tp_size == 0
  126. else:
  127. # Number of KV heads is less than TP size, so we replicate
  128. # the KV heads across multiple tensor parallel GPUs.
  129. assert tp_size % self.total_num_kv_heads == 0
  130. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  131. self.head_dim = hidden_size // self.total_num_heads
  132. self.q_size = self.num_heads * self.head_dim
  133. self.kv_size = self.num_kv_heads * self.head_dim
  134. self.scaling = self.head_dim**-0.5
  135. self.rope_theta = rope_theta
  136. self.max_position_embeddings = max_position_embeddings
  137. # This will be overwritten by model initialization if we are using it.
  138. # N.B. currently we only support per tensor scalar scaling factors
  139. # & only applicable to ROCm (AMD GPU).
  140. # The scaling factor convention we are assuming is
  141. # quantized_value * scaling_factor ~= true_value
  142. # which is consistent with the practice of setting
  143. # scaling_factor = tensor_amax / FPtype_max
  144. self.kv_scale = 1.0
  145. if (linear_method is not None
  146. and not linear_method.quant_config.merge_weight()):
  147. self.merge_weight = False
  148. self.q_proj = ColumnParallelLinear(hidden_size,
  149. self.q_size,
  150. bias=bias,
  151. linear_method=linear_method)
  152. self.k_proj = ColumnParallelLinear(
  153. hidden_size,
  154. self.kv_size,
  155. bias=bias,
  156. linear_method=linear_method,
  157. )
  158. self.v_proj = ColumnParallelLinear(
  159. hidden_size,
  160. self.kv_size,
  161. bias=bias,
  162. linear_method=linear_method,
  163. )
  164. else:
  165. self.merge_weight = True
  166. self.qkv_proj = QKVParallelLinear(
  167. hidden_size,
  168. self.head_dim,
  169. self.total_num_heads,
  170. self.total_num_kv_heads,
  171. bias=bias,
  172. linear_method=linear_method,
  173. )
  174. self.o_proj = RowParallelLinear(
  175. self.total_num_heads * self.head_dim,
  176. hidden_size,
  177. bias=bias,
  178. linear_method=linear_method,
  179. )
  180. self.rotary_emb = get_rope(
  181. self.head_dim,
  182. rotary_dim=self.head_dim,
  183. max_position=max_position_embeddings,
  184. base=rope_theta,
  185. rope_scaling=rope_scaling,
  186. is_neox_style=True,
  187. )
  188. self.attn = Attention(
  189. self.num_heads,
  190. self.head_dim,
  191. self.scaling,
  192. num_kv_heads=self.num_kv_heads,
  193. sliding_window=sliding_window,
  194. )
  195. def forward(
  196. self,
  197. positions: torch.Tensor,
  198. hidden_states: torch.Tensor,
  199. kv_cache: torch.Tensor,
  200. attn_metadata: AttentionMetadata,
  201. # kv_quant_param: List[float],
  202. ) -> torch.Tensor:
  203. if self.merge_weight:
  204. qkv, _ = self.qkv_proj(hidden_states)
  205. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
  206. dim=-1)
  207. else:
  208. q, _ = self.q_proj(hidden_states)
  209. k, _ = self.k_proj(hidden_states)
  210. v, _ = self.v_proj(hidden_states)
  211. q, k = self.rotary_emb(positions, q, k)
  212. attn_output = self.attn(q, k, v, kv_cache, attn_metadata,
  213. self.kv_scale)
  214. output, _ = self.o_proj(attn_output)
  215. return output
  216. class LlamaDecoderLayer(nn.Module):
  217. def __init__(
  218. self,
  219. config: LlamaConfig,
  220. linear_method: Optional[LinearMethodBase] = None,
  221. ) -> None:
  222. super().__init__()
  223. self.hidden_size = config.hidden_size
  224. rope_theta = getattr(config, "rope_theta", 10000)
  225. rope_scaling = getattr(config, "rope_scaling", None)
  226. max_position_embeddings = getattr(config, "max_position_embeddings",
  227. 8192)
  228. sliding_window = getattr(config, "sliding_window", None)
  229. # Support abacusai/Smaug-72B-v0.1 with attention_bias
  230. # Support internlm/internlm-7b with bias
  231. attention_bias = getattr(config, "attention_bias", False) or getattr(
  232. config, "bias", False)
  233. self.self_attn = LlamaAttention(
  234. hidden_size=self.hidden_size,
  235. num_heads=config.num_attention_heads,
  236. num_kv_heads=getattr(config, "num_key_value_heads",
  237. config.num_attention_heads),
  238. rope_theta=rope_theta,
  239. rope_scaling=rope_scaling,
  240. max_position_embeddings=max_position_embeddings,
  241. linear_method=linear_method,
  242. bias=attention_bias,
  243. sliding_window=sliding_window,
  244. )
  245. self.mlp = LlamaMLP(
  246. hidden_size=self.hidden_size,
  247. intermediate_size=config.intermediate_size,
  248. hidden_act=config.hidden_act,
  249. linear_method=linear_method,
  250. )
  251. self.input_layernorm = RMSNorm(config.hidden_size,
  252. eps=config.rms_norm_eps)
  253. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  254. eps=config.rms_norm_eps)
  255. if config.model_type == "Yi":
  256. # Some old Yi finetunes and quants have not been llama-fied
  257. self.ln1 = self.input_layernorm
  258. self.ln2 = self.post_attention_layernorm
  259. def forward(
  260. self,
  261. positions: torch.Tensor,
  262. hidden_states: torch.Tensor,
  263. kv_cache: torch.Tensor,
  264. attn_metadata: AttentionMetadata,
  265. residual: Optional[torch.Tensor],
  266. # kv_quant_param: List[float],
  267. ) -> Tuple[torch.Tensor, torch.Tensor]:
  268. # Self Attention
  269. if residual is None:
  270. residual = hidden_states
  271. hidden_states = self.input_layernorm(hidden_states)
  272. else:
  273. hidden_states, residual = self.input_layernorm(
  274. hidden_states, residual)
  275. hidden_states = self.self_attn(
  276. positions=positions,
  277. hidden_states=hidden_states,
  278. kv_cache=kv_cache,
  279. attn_metadata=attn_metadata,
  280. # kv_quant_param=kv_quant_param,
  281. )
  282. # Fully Connected
  283. hidden_states, residual = self.post_attention_layernorm(
  284. hidden_states, residual)
  285. hidden_states = self.mlp(hidden_states)
  286. return hidden_states, residual
  287. class LlamaModel(nn.Module):
  288. def __init__(
  289. self,
  290. config: LlamaConfig,
  291. linear_method: Optional[LinearMethodBase] = None,
  292. lora_config: Optional[LoRAConfig] = None,
  293. ) -> None:
  294. super().__init__()
  295. self.config = config
  296. self.padding_idx = config.pad_token_id
  297. lora_vocab = ((lora_config.lora_extra_vocab_size *
  298. (lora_config.max_loras or 1)) if lora_config else 0)
  299. self.vocab_size = config.vocab_size + lora_vocab
  300. self.org_vocab_size = config.vocab_size
  301. self.embed_tokens = VocabParallelEmbedding(
  302. self.vocab_size,
  303. config.hidden_size,
  304. linear_method=linear_method,
  305. org_num_embeddings=config.vocab_size,
  306. )
  307. self.layers = nn.ModuleList([
  308. LlamaDecoderLayer(config, linear_method)
  309. for _ in range(config.num_hidden_layers)
  310. ])
  311. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  312. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  313. return self.embed_tokens(input_ids)
  314. def forward(
  315. self,
  316. input_ids: Optional[torch.Tensor],
  317. positions: torch.Tensor,
  318. kv_caches: List[torch.Tensor],
  319. attn_metadata: AttentionMetadata,
  320. inputs_embeds: Optional[torch.Tensor] = None,
  321. ) -> torch.Tensor:
  322. if inputs_embeds is not None:
  323. hidden_states = inputs_embeds
  324. else:
  325. hidden_states = self.get_input_embeddings(input_ids)
  326. residual = None
  327. for i in range(len(self.layers)):
  328. layer = self.layers[i]
  329. hidden_states, residual = layer(
  330. positions,
  331. hidden_states,
  332. kv_caches[i],
  333. attn_metadata,
  334. residual,
  335. # attn_metadata.kv_quant_params[i]
  336. # if attn_metadata.kv_quant_params is not None else None,
  337. )
  338. hidden_states, _ = self.norm(hidden_states, residual)
  339. return hidden_states
  340. class LlamaForCausalLM(nn.Module):
  341. packed_modules_mapping = {
  342. "qkv_proj": [
  343. "q_proj",
  344. "k_proj",
  345. "v_proj",
  346. ],
  347. "gate_up_proj": [
  348. "gate_proj",
  349. "up_proj",
  350. ],
  351. }
  352. # LoRA specific attributes
  353. supported_lora_modules = [
  354. "qkv_proj",
  355. "o_proj",
  356. "gate_up_proj",
  357. "down_proj",
  358. "embed_tokens",
  359. "lm_head",
  360. ]
  361. embedding_modules = {
  362. "embed_tokens": "input_embeddings",
  363. "lm_head": "output_embeddings",
  364. }
  365. embedding_padding_modules = ["lm_head"]
  366. def __init__(
  367. self,
  368. config: LlamaConfig,
  369. linear_method: Optional[LinearMethodBase] = None,
  370. lora_config: Optional[LoRAConfig] = None,
  371. ) -> None:
  372. super().__init__()
  373. self.config = config
  374. self.linear_method = linear_method
  375. self.model = LlamaModel(config, linear_method, lora_config=lora_config)
  376. self.unpadded_vocab_size = config.vocab_size
  377. if lora_config:
  378. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  379. self.lm_head = ParallelLMHead(
  380. self.unpadded_vocab_size,
  381. config.hidden_size,
  382. org_num_embeddings=config.vocab_size,
  383. linear_method=linear_method,
  384. padding_size=DEFAULT_VOCAB_PADDING_SIZE
  385. # We need bigger padding if using lora for kernel
  386. # compatibility
  387. if not lora_config else lora_config.lora_vocab_padding_size,
  388. )
  389. logit_scale = getattr(config, "logit_scale", 1.0)
  390. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  391. config.vocab_size, logit_scale)
  392. self.sampler = Sampler()
  393. def forward(
  394. self,
  395. input_ids: torch.Tensor,
  396. positions: torch.Tensor,
  397. kv_caches: List[torch.Tensor],
  398. attn_metadata: AttentionMetadata,
  399. ) -> torch.Tensor:
  400. hidden_states = self.model(input_ids, positions, kv_caches,
  401. attn_metadata)
  402. return hidden_states
  403. def compute_logits(self, hidden_states: torch.Tensor,
  404. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  405. logits = self.logits_processor(self.lm_head, hidden_states,
  406. sampling_metadata)
  407. return logits
  408. def sample(
  409. self,
  410. logits: torch.Tensor,
  411. sampling_metadata: SamplingMetadata,
  412. ) -> Optional[SamplerOutput]:
  413. next_tokens = self.sampler(logits, sampling_metadata)
  414. return next_tokens
  415. def load_weights(
  416. self,
  417. model_name_or_path: str,
  418. cache_dir: Optional[str] = None,
  419. load_format: str = "auto",
  420. revision: Optional[str] = None,
  421. ):
  422. stacked_params_mapping = [
  423. # (param_name, shard_name, shard_id)
  424. ("qkv_proj", "q_proj", "q"),
  425. ("qkv_proj", "k_proj", "k"),
  426. ("qkv_proj", "v_proj", "v"),
  427. ("gate_up_proj", "gate_proj", 0),
  428. ("gate_up_proj", "up_proj", 1),
  429. ]
  430. if (self.linear_method is not None
  431. and not self.linear_method.quant_config.merge_weight()):
  432. stacked_params_mapping = []
  433. params_dict = dict(self.named_parameters())
  434. for name, loaded_weight in hf_model_weights_iterator(
  435. model_name_or_path, cache_dir, load_format, revision,
  436. self.config):
  437. if "rotary_emb.inv_freq" in name:
  438. continue
  439. if ("rotary_emb.cos_cached" in name
  440. or "rotary_emb.sin_cached" in name):
  441. # Models trained using ColossalAI may include these tensors in
  442. # the checkpoint. Skip them.
  443. continue
  444. for param_name, weight_name, shard_id in stacked_params_mapping:
  445. if weight_name not in name:
  446. continue
  447. name = name.replace(weight_name, param_name)
  448. # Skip loading extra bias for GPTQ models.
  449. if name.endswith(".bias") and name not in params_dict:
  450. continue
  451. param = params_dict[name]
  452. weight_loader = param.weight_loader
  453. weight_loader(param, loaded_weight, shard_id)
  454. break
  455. else:
  456. # Skip loading extra bias for GPTQ models.
  457. if name.endswith(".bias") and name not in params_dict:
  458. continue
  459. param = params_dict[name]
  460. weight_loader = getattr(param, "weight_loader",
  461. default_weight_loader)
  462. weight_loader(param, loaded_weight)
  463. # If this function is called, it should always initialize KV cache scale
  464. # factors (or else raise an exception). Thus, handled exceptions should
  465. # make sure to leave KV cache scale factors in a known good (dummy) state
  466. def load_kv_cache_scales(self, quantization_param_path: str) -> None:
  467. tp_size = get_tensor_model_parallel_world_size()
  468. tp_rank = get_tensor_model_parallel_rank()
  469. for layer_idx, scaling_factor in kv_cache_scales_loader(
  470. quantization_param_path, tp_rank, tp_size,
  471. self.config.num_hidden_layers,
  472. self.config.__class__.model_type):
  473. layer_self_attn = self.model.layers[layer_idx].self_attn
  474. if is_hip():
  475. # The scaling factor convention we are assuming is
  476. # quantized_value * scaling_factor ~= true_value
  477. # which is consistent with the practice of setting
  478. # scaling_factor = tensor_amax / FPtype_max
  479. scaling_factor *= 2
  480. if hasattr(layer_self_attn, "kv_scale"):
  481. layer_self_attn.kv_scale = scaling_factor
  482. else:
  483. raise RuntimeError("Self attention has no KV cache scaling "
  484. "factor attribute!")