llama.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only LLaMA model compatible with HuggingFace weights."""
  24. from typing import Any, Dict, List, Optional, Tuple
  25. import torch
  26. from torch import nn
  27. from transformers import LlamaConfig
  28. from aphrodite.modeling.metadata import InputMetadata
  29. from aphrodite.modeling.layers.activation import SiluAndMul
  30. from aphrodite.modeling.layers.attention import Attention
  31. from aphrodite.modeling.layers.layernorm import RMSNorm
  32. from aphrodite.modeling.layers.linear import (
  33. LinearMethodBase,
  34. MergedColumnParallelLinear,
  35. QKVParallelLinear,
  36. RowParallelLinear,
  37. ColumnParallelLinear,
  38. )
  39. from aphrodite.modeling.layers.rotary_embedding import get_rope
  40. from aphrodite.modeling.layers.sampler import Sampler, QuantSampler
  41. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  42. VocabParallelEmbedding,
  43. ParallelLMHead,
  44. DEFAULT_VOCAB_PADDING_SIZE,
  45. )
  46. from aphrodite.modeling.megatron.parallel_state import (
  47. get_tensor_model_parallel_world_size, )
  48. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  49. from aphrodite.modeling.hf_downloader import (
  50. default_weight_loader,
  51. hf_model_weights_iterator,
  52. )
  53. from aphrodite.common.sequence import SamplerOutput
  54. from aphrodite.common.config import LoRAConfig
  55. KVCache = Tuple[torch.Tensor, torch.Tensor]
  56. class LlamaMLP(nn.Module):
  57. def __init__(
  58. self,
  59. hidden_size: int,
  60. intermediate_size: int,
  61. hidden_act: str,
  62. linear_method: Optional[LinearMethodBase] = None,
  63. ) -> None:
  64. super().__init__()
  65. if (linear_method is not None
  66. and not linear_method.quant_config.merge_weight()):
  67. self.merge_weight = False
  68. self.gate_proj = ColumnParallelLinear(
  69. hidden_size,
  70. intermediate_size,
  71. bias=False,
  72. linear_method=linear_method,
  73. )
  74. self.up_proj = ColumnParallelLinear(
  75. hidden_size,
  76. intermediate_size,
  77. bias=False,
  78. linear_method=linear_method,
  79. )
  80. else:
  81. self.merge_weight = True
  82. self.gate_up_proj = MergedColumnParallelLinear(
  83. hidden_size,
  84. [intermediate_size] * 2,
  85. bias=False,
  86. linear_method=linear_method,
  87. )
  88. self.down_proj = RowParallelLinear(
  89. intermediate_size,
  90. hidden_size,
  91. bias=False,
  92. linear_method=linear_method,
  93. )
  94. if hidden_act != "silu":
  95. raise ValueError(f"Unsupported activation: {hidden_act}. "
  96. "Only silu is supported for now.")
  97. self.act_fn = SiluAndMul()
  98. def forward(self, x):
  99. if self.merge_weight:
  100. gate_up, _ = self.gate_up_proj(x)
  101. else:
  102. up, _ = self.up_proj(x)
  103. gate, _ = self.gate_proj(x)
  104. gate_up = torch.cat([gate, up], dim=-1)
  105. x = self.act_fn(gate_up)
  106. x, _ = self.down_proj(x)
  107. return x
  108. class LlamaAttention(nn.Module):
  109. def __init__(
  110. self,
  111. hidden_size: int,
  112. num_heads: int,
  113. num_kv_heads: int,
  114. rope_theta: float = 10000,
  115. rope_scaling: Optional[Dict[str, Any]] = None,
  116. max_position_embeddings: int = 8192,
  117. linear_method: Optional[LinearMethodBase] = None,
  118. bias: bool = False,
  119. sliding_window: Optional[int] = None,
  120. ) -> None:
  121. super().__init__()
  122. self.hidden_size = hidden_size
  123. tp_size = get_tensor_model_parallel_world_size()
  124. self.total_num_heads = num_heads
  125. assert self.total_num_heads % tp_size == 0
  126. self.num_heads = self.total_num_heads // tp_size
  127. self.total_num_kv_heads = num_kv_heads
  128. if self.total_num_kv_heads >= tp_size:
  129. # Number of KV heads is greater than TP size, so we partition
  130. # the KV heads across multiple tensor parallel GPUs.
  131. assert self.total_num_kv_heads % tp_size == 0
  132. else:
  133. # Number of KV heads is less than TP size, so we replicate
  134. # the KV heads across multiple tensor parallel GPUs.
  135. assert tp_size % self.total_num_kv_heads == 0
  136. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  137. self.head_dim = hidden_size // self.total_num_heads
  138. self.q_size = self.num_heads * self.head_dim
  139. self.kv_size = self.num_kv_heads * self.head_dim
  140. self.scaling = self.head_dim**-0.5
  141. self.rope_theta = rope_theta
  142. self.max_position_embeddings = max_position_embeddings
  143. if (linear_method is not None
  144. and not linear_method.quant_config.merge_weight()):
  145. self.merge_weight = False
  146. self.q_proj = ColumnParallelLinear(hidden_size,
  147. self.q_size,
  148. bias=bias,
  149. linear_method=linear_method)
  150. self.k_proj = ColumnParallelLinear(
  151. hidden_size,
  152. self.kv_size,
  153. bias=bias,
  154. linear_method=linear_method,
  155. )
  156. self.v_proj = ColumnParallelLinear(
  157. hidden_size,
  158. self.kv_size,
  159. bias=bias,
  160. linear_method=linear_method,
  161. )
  162. else:
  163. self.merge_weight = True
  164. self.qkv_proj = QKVParallelLinear(
  165. hidden_size,
  166. self.head_dim,
  167. self.total_num_heads,
  168. self.total_num_kv_heads,
  169. bias=bias,
  170. linear_method=linear_method,
  171. )
  172. self.o_proj = RowParallelLinear(
  173. self.total_num_heads * self.head_dim,
  174. hidden_size,
  175. bias=bias,
  176. linear_method=linear_method,
  177. )
  178. is_neox_style = (True if linear_method is None
  179. or linear_method.quant_config.rope_style() is None
  180. else linear_method.quant_config.rope_style())
  181. self.rotary_emb = get_rope(
  182. self.head_dim,
  183. rotary_dim=self.head_dim,
  184. max_position=max_position_embeddings,
  185. base=rope_theta,
  186. rope_scaling=rope_scaling,
  187. is_neox_style=is_neox_style,
  188. )
  189. self.attn = Attention(
  190. self.num_heads,
  191. self.head_dim,
  192. self.scaling,
  193. num_kv_heads=self.num_kv_heads,
  194. sliding_window=sliding_window,
  195. )
  196. def forward(
  197. self,
  198. positions: torch.Tensor,
  199. hidden_states: torch.Tensor,
  200. kv_cache: KVCache,
  201. input_metadata: InputMetadata,
  202. # kv_quant_param: List[float],
  203. ) -> torch.Tensor:
  204. if self.merge_weight:
  205. qkv, _ = self.qkv_proj(hidden_states)
  206. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
  207. dim=-1)
  208. else:
  209. q, _ = self.q_proj(hidden_states)
  210. k, _ = self.k_proj(hidden_states)
  211. v, _ = self.v_proj(hidden_states)
  212. q, k = self.rotary_emb(positions, q, k)
  213. k_cache, v_cache = kv_cache
  214. attn_output = self.attn(
  215. q,
  216. k,
  217. v,
  218. k_cache,
  219. v_cache,
  220. input_metadata,
  221. # kv_quant_param
  222. )
  223. output, _ = self.o_proj(attn_output)
  224. return output
  225. class LlamaDecoderLayer(nn.Module):
  226. def __init__(
  227. self,
  228. config: LlamaConfig,
  229. linear_method: Optional[LinearMethodBase] = None,
  230. ) -> None:
  231. super().__init__()
  232. self.hidden_size = config.hidden_size
  233. rope_theta = getattr(config, "rope_theta", 10000)
  234. rope_scaling = getattr(config, "rope_scaling", None)
  235. max_position_embeddings = getattr(config, "max_position_embeddings",
  236. 8192)
  237. sliding_window = getattr(config, "sliding_window", None)
  238. self.self_attn = LlamaAttention(
  239. hidden_size=self.hidden_size,
  240. num_heads=config.num_attention_heads,
  241. num_kv_heads=getattr(config, "num_key_value_heads",
  242. config.num_attention_heads),
  243. rope_theta=rope_theta,
  244. rope_scaling=rope_scaling,
  245. max_position_embeddings=max_position_embeddings,
  246. linear_method=linear_method,
  247. bias=getattr(config, "bias", False),
  248. sliding_window=sliding_window,
  249. )
  250. self.mlp = LlamaMLP(
  251. hidden_size=self.hidden_size,
  252. intermediate_size=config.intermediate_size,
  253. hidden_act=config.hidden_act,
  254. linear_method=linear_method,
  255. )
  256. self.input_layernorm = RMSNorm(config.hidden_size,
  257. eps=config.rms_norm_eps)
  258. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  259. eps=config.rms_norm_eps)
  260. if config.model_type == "Yi":
  261. # Some old Yi finetunes and quants have not been llama-fied
  262. self.ln1 = self.input_layernorm
  263. self.ln2 = self.post_attention_layernorm
  264. def forward(
  265. self,
  266. positions: torch.Tensor,
  267. hidden_states: torch.Tensor,
  268. kv_cache: KVCache,
  269. input_metadata: InputMetadata,
  270. residual: Optional[torch.Tensor],
  271. # kv_quant_param: List[float],
  272. ) -> Tuple[torch.Tensor, torch.Tensor]:
  273. # Self Attention
  274. if residual is None:
  275. residual = hidden_states
  276. hidden_states = self.input_layernorm(hidden_states)
  277. else:
  278. hidden_states, residual = self.input_layernorm(
  279. hidden_states, residual)
  280. hidden_states = self.self_attn(
  281. positions=positions,
  282. hidden_states=hidden_states,
  283. kv_cache=kv_cache,
  284. input_metadata=input_metadata,
  285. # kv_quant_param=kv_quant_param,
  286. )
  287. # Fully Connected
  288. hidden_states, residual = self.post_attention_layernorm(
  289. hidden_states, residual)
  290. hidden_states = self.mlp(hidden_states)
  291. return hidden_states, residual
  292. class LlamaModel(nn.Module):
  293. def __init__(
  294. self,
  295. config: LlamaConfig,
  296. linear_method: Optional[LinearMethodBase] = None,
  297. lora_config: Optional[LoRAConfig] = None,
  298. ) -> None:
  299. super().__init__()
  300. self.config = config
  301. self.padding_idx = config.pad_token_id
  302. lora_vocab = ((lora_config.lora_extra_vocab_size *
  303. (lora_config.max_loras or 1)) if lora_config else 0)
  304. self.vocab_size = config.vocab_size + lora_vocab
  305. self.org_vocab_size = config.vocab_size
  306. self.embed_tokens = VocabParallelEmbedding(
  307. self.vocab_size,
  308. config.hidden_size,
  309. linear_method=linear_method,
  310. org_num_embeddings=config.vocab_size,
  311. )
  312. self.layers = nn.ModuleList([
  313. LlamaDecoderLayer(config, linear_method)
  314. for _ in range(config.num_hidden_layers)
  315. ])
  316. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  317. def forward(
  318. self,
  319. input_ids: torch.Tensor,
  320. positions: torch.Tensor,
  321. kv_caches: List[KVCache],
  322. input_metadata: InputMetadata,
  323. ) -> torch.Tensor:
  324. hidden_states = self.embed_tokens(input_ids)
  325. residual = None
  326. for i in range(len(self.layers)):
  327. layer = self.layers[i]
  328. hidden_states, residual = layer(
  329. positions,
  330. hidden_states,
  331. kv_caches[i],
  332. input_metadata,
  333. residual,
  334. # input_metadata.kv_quant_params[i]
  335. # if input_metadata.kv_quant_params is not None else None,
  336. )
  337. hidden_states, _ = self.norm(hidden_states, residual)
  338. return hidden_states
  339. class LlamaForCausalLM(nn.Module):
  340. packed_modules_mapping = {
  341. "qkv_proj": [
  342. "q_proj",
  343. "k_proj",
  344. "v_proj",
  345. ],
  346. "gate_up_proj": [
  347. "gate_proj",
  348. "up_proj",
  349. ],
  350. }
  351. # LoRA specific attributes
  352. supported_lora_modules = [
  353. "qkv_proj",
  354. "o_proj",
  355. "gate_up_proj",
  356. "down_proj",
  357. "embed_tokens",
  358. "lm_head",
  359. ]
  360. embedding_modules = {
  361. "embed_tokens": "input_embeddings",
  362. "lm_head": "output_embeddings",
  363. }
  364. embedding_padding_modules = ["lm_head"]
  365. def __init__(
  366. self,
  367. config: LlamaConfig,
  368. linear_method: Optional[LinearMethodBase] = None,
  369. lora_config: Optional[LoRAConfig] = None,
  370. ) -> None:
  371. super().__init__()
  372. self.config = config
  373. self.linear_method = linear_method
  374. self.model = LlamaModel(config, linear_method, lora_config=lora_config)
  375. self.unpadded_vocab_size = config.vocab_size
  376. if lora_config:
  377. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  378. self.lm_head = ParallelLMHead(
  379. self.unpadded_vocab_size,
  380. config.hidden_size,
  381. org_num_embeddings=config.vocab_size,
  382. linear_method=linear_method,
  383. padding_size=DEFAULT_VOCAB_PADDING_SIZE
  384. # We need bigger padding if using lora for kernel
  385. # compatibility
  386. if not lora_config else lora_config.lora_vocab_padding_size,
  387. )
  388. self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
  389. self.quant_sampler = QuantSampler(self.unpadded_vocab_size,
  390. config.vocab_size)
  391. def forward(
  392. self,
  393. input_ids: torch.Tensor,
  394. positions: torch.Tensor,
  395. kv_caches: List[KVCache],
  396. input_metadata: InputMetadata,
  397. ) -> torch.Tensor:
  398. hidden_states = self.model(input_ids, positions, kv_caches,
  399. input_metadata)
  400. return hidden_states
  401. def sample(
  402. self,
  403. hidden_states: torch.Tensor,
  404. sampling_metadata: SamplingMetadata,
  405. ) -> Optional[SamplerOutput]:
  406. if (self.linear_method is not None
  407. and not self.linear_method.quant_config.merge_weight()):
  408. next_tokens = self.quant_sampler(self.lm_head(hidden_states),
  409. sampling_metadata)
  410. else:
  411. next_tokens = self.sampler(self.lm_head.weight, hidden_states,
  412. sampling_metadata)
  413. return next_tokens
  414. def load_weights(
  415. self,
  416. model_name_or_path: str,
  417. cache_dir: Optional[str] = None,
  418. load_format: str = "auto",
  419. revision: Optional[str] = None,
  420. ):
  421. stacked_params_mapping = [
  422. # (param_name, shard_name, shard_id)
  423. ("qkv_proj", "q_proj", "q"),
  424. ("qkv_proj", "k_proj", "k"),
  425. ("qkv_proj", "v_proj", "v"),
  426. ("gate_up_proj", "gate_proj", 0),
  427. ("gate_up_proj", "up_proj", 1),
  428. ]
  429. if (self.linear_method is not None
  430. and not self.linear_method.quant_config.merge_weight()):
  431. stacked_params_mapping = []
  432. params_dict = dict(self.named_parameters())
  433. for name, loaded_weight in hf_model_weights_iterator(
  434. model_name_or_path, cache_dir, load_format, revision,
  435. self.config):
  436. if "rotary_emb.inv_freq" in name:
  437. continue
  438. if ("rotary_emb.cos_cached" in name
  439. or "rotary_emb.sin_cached" in name):
  440. # Models trained using ColossalAI may include these tensors in
  441. # the checkpoint. Skip them.
  442. continue
  443. for param_name, weight_name, shard_id in stacked_params_mapping:
  444. if weight_name not in name:
  445. continue
  446. name = name.replace(weight_name, param_name)
  447. # Skip loading extra bias for GPTQ models.
  448. if name.endswith(".bias") and name not in params_dict:
  449. continue
  450. param = params_dict[name]
  451. weight_loader = param.weight_loader
  452. weight_loader(param, loaded_weight, shard_id)
  453. break
  454. else:
  455. # Skip loading extra bias for GPTQ models.
  456. if name.endswith(".bias") and name not in params_dict:
  457. continue
  458. param = params_dict[name]
  459. weight_loader = getattr(param, "weight_loader",
  460. default_weight_loader)
  461. weight_loader(param, loaded_weight)