1
0

llama.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only LLaMA model compatible with HuggingFace weights."""
  25. from typing import Any, Dict, List, Optional, Tuple
  26. import torch
  27. from torch import nn
  28. from transformers import LlamaConfig
  29. from aphrodite.modeling.metadata import InputMetadata
  30. from aphrodite.modeling.layers.activation import SiluAndMul
  31. from aphrodite.modeling.layers.attention import PagedAttention
  32. from aphrodite.modeling.layers.layernorm import RMSNorm
  33. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  34. MergedColumnParallelLinear,
  35. QKVParallelLinear,
  36. RowParallelLinear,
  37. ColumnParallelLinear)
  38. from aphrodite.modeling.layers.rotary_embedding import get_rope
  39. from aphrodite.modeling.layers.sampler import Sampler
  40. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  41. VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
  42. from aphrodite.modeling.megatron.parallel_state import (
  43. get_tensor_model_parallel_world_size)
  44. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  45. from aphrodite.modeling.hf_downloader import (default_weight_loader,
  46. hf_model_weights_iterator)
  47. from aphrodite.common.sequence import SamplerOutput
  48. from aphrodite.common.config import LoRAConfig
  49. KVCache = Tuple[torch.Tensor, torch.Tensor]
  50. class LlamaMLP(nn.Module):
  51. def __init__(
  52. self,
  53. hidden_size: int,
  54. intermediate_size: int,
  55. hidden_act: str,
  56. linear_method: Optional[LinearMethodBase] = None,
  57. ) -> None:
  58. super().__init__()
  59. if linear_method is not None and not linear_method.quant_config.merge_weight(
  60. ):
  61. self.merge_weight = False
  62. self.gate_proj = ColumnParallelLinear(hidden_size,
  63. intermediate_size,
  64. bias=False,
  65. linear_method=linear_method)
  66. self.up_proj = ColumnParallelLinear(hidden_size,
  67. intermediate_size,
  68. bias=False,
  69. linear_method=linear_method)
  70. else:
  71. self.merge_weight = True
  72. self.gate_up_proj = MergedColumnParallelLinear(
  73. hidden_size, [intermediate_size] * 2,
  74. bias=False,
  75. linear_method=linear_method)
  76. self.down_proj = RowParallelLinear(intermediate_size,
  77. hidden_size,
  78. bias=False,
  79. linear_method=linear_method)
  80. if hidden_act != "silu":
  81. raise ValueError(f"Unsupported activation: {hidden_act}. "
  82. "Only silu is supported for now.")
  83. self.act_fn = SiluAndMul()
  84. def forward(self, x):
  85. if self.merge_weight:
  86. gate_up, _ = self.gate_up_proj(x)
  87. else:
  88. up, _ = self.up_proj(x)
  89. gate, _ = self.gate_proj(x)
  90. gate_up = torch.cat([gate, up], dim=-1)
  91. x = self.act_fn(gate_up)
  92. x, _ = self.down_proj(x)
  93. return x
  94. class LlamaAttention(nn.Module):
  95. def __init__(
  96. self,
  97. hidden_size: int,
  98. num_heads: int,
  99. num_kv_heads: int,
  100. rope_theta: float = 10000,
  101. rope_scaling: Optional[Dict[str, Any]] = None,
  102. max_position_embeddings: int = 8192,
  103. linear_method: Optional[LinearMethodBase] = None,
  104. ) -> None:
  105. super().__init__()
  106. self.hidden_size = hidden_size
  107. tp_size = get_tensor_model_parallel_world_size()
  108. self.total_num_heads = num_heads
  109. assert self.total_num_heads % tp_size == 0
  110. self.num_heads = self.total_num_heads // tp_size
  111. self.total_num_kv_heads = num_kv_heads
  112. if self.total_num_kv_heads >= tp_size:
  113. # Number of KV heads is greater than TP size, so we partition
  114. # the KV heads across multiple tensor parallel GPUs.
  115. assert self.total_num_kv_heads % tp_size == 0
  116. else:
  117. # Number of KV heads is less than TP size, so we replicate
  118. # the KV heads across multiple tensor parallel GPUs.
  119. assert tp_size % self.total_num_kv_heads == 0
  120. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  121. self.head_dim = hidden_size // self.total_num_heads
  122. self.q_size = self.num_heads * self.head_dim
  123. self.kv_size = self.num_kv_heads * self.head_dim
  124. self.scaling = self.head_dim**-0.5
  125. self.rope_theta = rope_theta
  126. self.max_position_embeddings = max_position_embeddings
  127. if linear_method is not None and not linear_method.quant_config.merge_weight(
  128. ):
  129. self.merge_weight = False
  130. self.q_proj = ColumnParallelLinear(hidden_size,
  131. self.q_size,
  132. bias=False,
  133. linear_method=linear_method)
  134. self.k_proj = ColumnParallelLinear(hidden_size,
  135. self.kv_size,
  136. bias=False,
  137. linear_method=linear_method)
  138. self.v_proj = ColumnParallelLinear(hidden_size,
  139. self.kv_size,
  140. bias=False,
  141. linear_method=linear_method)
  142. else:
  143. self.merge_weight = True
  144. self.qkv_proj = QKVParallelLinear(
  145. hidden_size,
  146. self.head_dim,
  147. self.total_num_heads,
  148. self.total_num_kv_heads,
  149. bias=False,
  150. linear_method=linear_method,
  151. )
  152. self.o_proj = RowParallelLinear(
  153. self.total_num_heads * self.head_dim,
  154. hidden_size,
  155. bias=False,
  156. linear_method=linear_method,
  157. )
  158. is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
  159. ) is None else linear_method.quant_config.rope_style()
  160. self.rotary_emb = get_rope(
  161. self.head_dim,
  162. rotary_dim=self.head_dim,
  163. max_position=max_position_embeddings,
  164. base=rope_theta,
  165. rope_scaling=rope_scaling,
  166. is_neox_style=is_neox_style,
  167. )
  168. self.attn = PagedAttention(self.num_heads,
  169. self.head_dim,
  170. self.scaling,
  171. num_kv_heads=self.num_kv_heads)
  172. def forward(
  173. self,
  174. positions: torch.Tensor,
  175. hidden_states: torch.Tensor,
  176. kv_cache: KVCache,
  177. input_metadata: InputMetadata,
  178. ) -> torch.Tensor:
  179. if self.merge_weight:
  180. qkv, _ = self.qkv_proj(hidden_states)
  181. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
  182. dim=-1)
  183. else:
  184. q, _ = self.q_proj(hidden_states)
  185. k, _ = self.k_proj(hidden_states)
  186. v, _ = self.v_proj(hidden_states)
  187. q, k = self.rotary_emb(positions, q, k)
  188. k_cache, v_cache = kv_cache
  189. attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
  190. output, _ = self.o_proj(attn_output)
  191. return output
  192. class LlamaDecoderLayer(nn.Module):
  193. def __init__(
  194. self,
  195. config: LlamaConfig,
  196. linear_method: Optional[LinearMethodBase] = None,
  197. ) -> None:
  198. super().__init__()
  199. self.hidden_size = config.hidden_size
  200. rope_theta = getattr(config, "rope_theta", 10000)
  201. rope_scaling = getattr(config, "rope_scaling", None)
  202. max_position_embeddings = getattr(config, "max_position_embeddings",
  203. 8192)
  204. self.self_attn = LlamaAttention(
  205. hidden_size=self.hidden_size,
  206. num_heads=config.num_attention_heads,
  207. num_kv_heads=config.num_key_value_heads,
  208. rope_theta=rope_theta,
  209. rope_scaling=rope_scaling,
  210. max_position_embeddings=max_position_embeddings,
  211. linear_method=linear_method,
  212. )
  213. self.mlp = LlamaMLP(
  214. hidden_size=self.hidden_size,
  215. intermediate_size=config.intermediate_size,
  216. hidden_act=config.hidden_act,
  217. linear_method=linear_method,
  218. )
  219. self.input_layernorm = RMSNorm(config.hidden_size,
  220. eps=config.rms_norm_eps)
  221. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  222. eps=config.rms_norm_eps)
  223. def forward(
  224. self,
  225. positions: torch.Tensor,
  226. hidden_states: torch.Tensor,
  227. kv_cache: KVCache,
  228. input_metadata: InputMetadata,
  229. residual: Optional[torch.Tensor],
  230. ) -> Tuple[torch.Tensor, torch.Tensor]:
  231. # Self Attention
  232. if residual is None:
  233. residual = hidden_states
  234. hidden_states = self.input_layernorm(hidden_states)
  235. else:
  236. hidden_states, residual = self.input_layernorm(
  237. hidden_states, residual)
  238. hidden_states = self.self_attn(
  239. positions=positions,
  240. hidden_states=hidden_states,
  241. kv_cache=kv_cache,
  242. input_metadata=input_metadata,
  243. )
  244. # Fully Connected
  245. hidden_states, residual = self.post_attention_layernorm(
  246. hidden_states, residual)
  247. hidden_states = self.mlp(hidden_states)
  248. return hidden_states, residual
  249. class LlamaModel(nn.Module):
  250. def __init__(
  251. self,
  252. config: LlamaConfig,
  253. linear_method: Optional[LinearMethodBase] = None,
  254. lora_config: Optional[LoRAConfig] = None,
  255. ) -> None:
  256. super().__init__()
  257. self.config = config
  258. self.padding_idx = config.pad_token_id
  259. lora_vocab = (lora_config.lora_extra_vocab_size *
  260. (lora_config.max_loras or 1)) if lora_config else 0
  261. self.vocab_size = config.vocab_size + lora_vocab
  262. self.org_vocab_size = config.vocab_size
  263. self.embed_tokens = VocabParallelEmbedding(
  264. self.vocab_size,
  265. config.hidden_size,
  266. linear_method=linear_method,
  267. org_num_embeddings=config.vocab_size,
  268. )
  269. self.layers = nn.ModuleList([
  270. LlamaDecoderLayer(config, linear_method)
  271. for _ in range(config.num_hidden_layers)
  272. ])
  273. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  274. def forward(
  275. self,
  276. input_ids: torch.Tensor,
  277. positions: torch.Tensor,
  278. kv_caches: List[KVCache],
  279. input_metadata: InputMetadata,
  280. ) -> torch.Tensor:
  281. hidden_states = self.embed_tokens(input_ids)
  282. residual = None
  283. for i in range(len(self.layers)):
  284. layer = self.layers[i]
  285. hidden_states, residual = layer(
  286. positions,
  287. hidden_states,
  288. kv_caches[i],
  289. input_metadata,
  290. residual,
  291. )
  292. hidden_states, _ = self.norm(hidden_states, residual)
  293. return hidden_states
  294. class LlamaForCausalLM(nn.Module):
  295. supports_lora = True
  296. def __init__(
  297. self,
  298. config: LlamaConfig,
  299. linear_method: Optional[LinearMethodBase] = None,
  300. lora_config: Optional[LoRAConfig] = None,
  301. ) -> None:
  302. super().__init__()
  303. self.config = config
  304. self.linear_method = linear_method
  305. self.model = LlamaModel(config, linear_method, lora_config=lora_config)
  306. unpadded_vocab_size = config.vocab_size
  307. if lora_config:
  308. unpadded_vocab_size += lora_config.lora_extra_vocab_size
  309. self.lm_head = ParallelLMHead(
  310. unpadded_vocab_size,
  311. config.hidden_size,
  312. linear_method=linear_method,
  313. org_num_embeddings=config.vocab_size,
  314. padding_size=DEFAULT_VOCAB_PADDING_SIZE
  315. # We need bigger padding if using lora for kernel
  316. # compatibility
  317. if not lora_config else lora_config.lora_vocab_padding_size,
  318. )
  319. self.sampler = Sampler(unpadded_vocab_size, config.vocab_size)
  320. def forward(
  321. self,
  322. input_ids: torch.Tensor,
  323. positions: torch.Tensor,
  324. kv_caches: List[KVCache],
  325. input_metadata: InputMetadata,
  326. ) -> torch.Tensor:
  327. hidden_states = self.model(input_ids, positions, kv_caches,
  328. input_metadata)
  329. return hidden_states
  330. def sample(
  331. self,
  332. hidden_states: torch.Tensor,
  333. sampling_metadata: SamplingMetadata,
  334. ) -> Optional[SamplerOutput]:
  335. next_tokens = self.sampler(self.lm_head(hidden_states),
  336. sampling_metadata)
  337. return next_tokens
  338. def load_weights(self,
  339. model_name_or_path: str,
  340. cache_dir: Optional[str] = None,
  341. load_format: str = "auto",
  342. revision: Optional[str] = None):
  343. stacked_params_mapping = [
  344. # (param_name, shard_name, shard_id)
  345. ("qkv_proj", "q_proj", "q"),
  346. ("qkv_proj", "k_proj", "k"),
  347. ("qkv_proj", "v_proj", "v"),
  348. ("gate_up_proj", "gate_proj", 0),
  349. ("gate_up_proj", "up_proj", 1),
  350. ]
  351. if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
  352. ):
  353. stacked_params_mapping = []
  354. params_dict = dict(self.named_parameters())
  355. for name, loaded_weight in hf_model_weights_iterator(
  356. model_name_or_path, cache_dir, load_format, revision):
  357. if "rotary_emb.inv_freq" in name:
  358. continue
  359. if ("rotary_emb.cos_cached" in name
  360. or "rotary_emb.sin_cached" in name):
  361. # Models trained using ColossalAI may include these tensors in
  362. # the checkpoint. Skip them.
  363. continue
  364. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  365. if weight_name not in name:
  366. continue
  367. name = name.replace(weight_name, param_name)
  368. # Skip loading extra bias for GPTQ models.
  369. if name.endswith(".bias") and name not in params_dict:
  370. continue
  371. param = params_dict[name]
  372. weight_loader = param.weight_loader
  373. weight_loader(param, loaded_weight, shard_id)
  374. break
  375. else:
  376. # Skip loading extra bias for GPTQ models.
  377. if name.endswith(".bias") and name not in params_dict:
  378. continue
  379. param = params_dict[name]
  380. weight_loader = getattr(param, "weight_loader",
  381. default_weight_loader)
  382. weight_loader(param, loaded_weight)