llama.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only LLaMA model compatible with HuggingFace weights.
  24. The input of the model is flattened to a 1D tensor of tokens. The model uses
  25. InputMetadata to extract the original 2D shape of the input.
  26. """
  27. from typing import Any, Dict, List, Optional, Tuple
  28. import torch
  29. from torch import nn
  30. from transformers import LlamaConfig
  31. from aphrodite.modeling.metadata import InputMetadata
  32. from aphrodite.modeling.layers.activation import SiluAndMul
  33. from aphrodite.modeling.layers.layernorm import RMSNorm
  34. from aphrodite.modeling.layers.attention import PagedAttentionWithRoPE
  35. from aphrodite.modeling.layers.sampler import Sampler
  36. from aphrodite.modeling.layers.quantized_linear import ParallelLinear
  37. from aphrodite.modeling.megatron.parallel_state import (
  38. get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
  39. from aphrodite.modeling.megatron.layers import VocabParallelEmbedding
  40. from aphrodite.modeling.quantization_utils import QuantizationConfig
  41. from aphrodite.modeling.hf_downloader import (
  42. convert_pyslice_to_tensor, hf_model_weights_iterator,
  43. load_tensor_parallel_weights, load_padded_tensor_parallel_vocab,
  44. get_parallel_weight)
  45. from aphrodite.common.sequence import SamplerOutput
  46. KVCache = Tuple[torch.Tensor, torch.Tensor]
  47. class LlamaMLP(nn.Module):
  48. def __init__(
  49. self,
  50. hidden_size: int,
  51. intermediate_size: int,
  52. hidden_act: str,
  53. quant_config: Optional[QuantizationConfig] = None,
  54. ) -> None:
  55. super().__init__()
  56. self.gate_up_proj = ParallelLinear.column(hidden_size,
  57. 2 * intermediate_size,
  58. bias=False,
  59. gather_output=False,
  60. quant_config=quant_config)
  61. self.down_proj = ParallelLinear.row(intermediate_size,
  62. hidden_size,
  63. bias=False,
  64. input_is_parallel=True,
  65. quant_config=quant_config)
  66. if hidden_act != "silu":
  67. raise ValueError(f"Unsupported activation: {hidden_act}. "
  68. "Only silu is supported for now.")
  69. self.act_fn = SiluAndMul()
  70. def forward(self, x):
  71. gate_up, _ = self.gate_up_proj(x)
  72. x = self.act_fn(gate_up)
  73. x, _ = self.down_proj(x)
  74. return x
  75. class LlamaAttention(nn.Module):
  76. def __init__(
  77. self,
  78. hidden_size: int,
  79. num_heads: int,
  80. num_kv_heads: int,
  81. rope_theta: float = 10000,
  82. rope_scaling: Optional[Dict[str, Any]] = None,
  83. max_position_embeddings: int = 8192,
  84. quant_config: Optional[QuantizationConfig] = None,
  85. ) -> None:
  86. super().__init__()
  87. self.hidden_size = hidden_size
  88. tp_size = get_tensor_model_parallel_world_size()
  89. self.total_num_heads = num_heads
  90. assert self.total_num_heads % tp_size == 0
  91. self.num_heads = self.total_num_heads // tp_size
  92. self.total_num_kv_heads = num_kv_heads
  93. if self.total_num_kv_heads >= tp_size:
  94. # Number of KV heads is greater than TP size, so we partition
  95. # the KV heads across multiple tensor parallel GPUs.
  96. assert self.total_num_kv_heads % tp_size == 0
  97. else:
  98. # Number of KV heads is less than TP size, so we replicate
  99. # the KV heads across multiple tensor parallel GPUs.
  100. assert tp_size % self.total_num_kv_heads == 0
  101. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  102. num_kv_heads_replicas = max(1, tp_size // self.total_num_kv_heads)
  103. self.head_dim = hidden_size // self.total_num_heads
  104. self.q_size = self.num_heads * self.head_dim
  105. self.kv_size = self.num_kv_heads * self.head_dim
  106. self.scaling = self.head_dim**-0.5
  107. self.rope_theta = rope_theta
  108. self.max_position_embeddings = max_position_embeddings
  109. self.qkv_proj = ParallelLinear.column(
  110. hidden_size,
  111. (self.total_num_heads +
  112. 2 * self.total_num_kv_heads * num_kv_heads_replicas) *
  113. self.head_dim,
  114. bias=False,
  115. gather_output=False,
  116. quant_config=quant_config,
  117. )
  118. self.o_proj = ParallelLinear.row(
  119. self.total_num_heads * self.head_dim,
  120. hidden_size,
  121. bias=False,
  122. input_is_parallel=True,
  123. quant_config=quant_config,
  124. )
  125. self.attn = PagedAttentionWithRoPE(
  126. self.num_heads,
  127. self.head_dim,
  128. self.scaling,
  129. base=self.rope_theta,
  130. max_position=self.max_position_embeddings,
  131. rotary_dim=self.head_dim,
  132. num_kv_heads=self.num_kv_heads,
  133. rope_scaling=rope_scaling)
  134. def forward(
  135. self,
  136. positions: torch.Tensor,
  137. hidden_states: torch.Tensor,
  138. kv_cache: KVCache,
  139. input_metadata: InputMetadata,
  140. cache_event: Optional[torch.cuda.Event],
  141. ) -> torch.Tensor:
  142. qkv, _ = self.qkv_proj(hidden_states)
  143. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  144. k_cache, v_cache = kv_cache
  145. attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
  146. input_metadata, cache_event)
  147. output, _ = self.o_proj(attn_output)
  148. return output
  149. class LlamaDecoderLayer(nn.Module):
  150. def __init__(
  151. self,
  152. config: LlamaConfig,
  153. quant_config: Optional[QuantizationConfig] = None,
  154. ) -> None:
  155. super().__init__()
  156. self.hidden_size = config.hidden_size
  157. # Requires transformers > 4.32.0
  158. rope_theta = getattr(config, "rope_theta", 10000)
  159. rope_scaling = getattr(config, "rope_scaling", None)
  160. max_position_embeddings = getattr(config, "max_position_embeddings",
  161. 8192)
  162. self.self_attn = LlamaAttention(
  163. hidden_size=self.hidden_size,
  164. num_heads=config.num_attention_heads,
  165. num_kv_heads=config.num_key_value_heads,
  166. rope_theta=rope_theta,
  167. rope_scaling=rope_scaling,
  168. max_position_embeddings=max_position_embeddings,
  169. quant_config=quant_config,
  170. )
  171. self.mlp = LlamaMLP(
  172. hidden_size=self.hidden_size,
  173. intermediate_size=config.intermediate_size,
  174. hidden_act=config.hidden_act,
  175. quant_config=quant_config,
  176. )
  177. self.input_layernorm = RMSNorm(config.hidden_size,
  178. eps=config.rms_norm_eps)
  179. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  180. eps=config.rms_norm_eps)
  181. def forward(
  182. self,
  183. positions: torch.Tensor,
  184. hidden_states: torch.Tensor,
  185. kv_cache: KVCache,
  186. input_metadata: InputMetadata,
  187. cache_event: Optional[torch.cuda.Event],
  188. ) -> torch.Tensor:
  189. # Self Attention
  190. residual = hidden_states
  191. hidden_states = self.input_layernorm(hidden_states)
  192. hidden_states = self.self_attn(
  193. positions=positions,
  194. hidden_states=hidden_states,
  195. kv_cache=kv_cache,
  196. input_metadata=input_metadata,
  197. cache_event=cache_event,
  198. )
  199. hidden_states = residual + hidden_states
  200. # Fully Connected
  201. residual = hidden_states
  202. hidden_states = self.post_attention_layernorm(hidden_states)
  203. hidden_states = self.mlp(hidden_states)
  204. hidden_states = residual + hidden_states
  205. return hidden_states
  206. class LlamaModel(nn.Module):
  207. def __init__(
  208. self,
  209. config: LlamaConfig,
  210. quant_config: Optional[QuantizationConfig] = None,
  211. ) -> None:
  212. super().__init__()
  213. self.config = config
  214. self.padding_idx = config.pad_token_id
  215. self.vocab_size = config.vocab_size
  216. vocab_size = ((config.vocab_size + 63) // 64) * 64
  217. self.embed_tokens = VocabParallelEmbedding(
  218. vocab_size,
  219. config.hidden_size,
  220. )
  221. self.layers = nn.ModuleList([
  222. LlamaDecoderLayer(config, quant_config)
  223. for _ in range(config.num_hidden_layers)
  224. ])
  225. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  226. def forward(
  227. self,
  228. input_ids: torch.Tensor,
  229. positions: torch.Tensor,
  230. kv_caches: List[KVCache],
  231. input_metadata: InputMetadata,
  232. cache_events: Optional[List[torch.cuda.Event]],
  233. ) -> torch.Tensor:
  234. hidden_states = self.embed_tokens(input_ids)
  235. for i in range(len(self.layers)):
  236. if cache_events is None:
  237. cache_event = None
  238. else:
  239. cache_event = cache_events[i]
  240. layer = self.layers[i]
  241. hidden_states = layer(
  242. positions,
  243. hidden_states,
  244. kv_caches[i],
  245. input_metadata,
  246. cache_event,
  247. )
  248. hidden_states = self.norm(hidden_states)
  249. return hidden_states
  250. class LlamaForCausalLM(nn.Module):
  251. def __init__(
  252. self,
  253. config: LlamaConfig,
  254. quant_config: Optional[QuantizationConfig] = None,
  255. ) -> None:
  256. super().__init__()
  257. self.config = config
  258. self.quant_config = quant_config
  259. self.model = LlamaModel(config, quant_config)
  260. vocab_size = ((config.vocab_size + 63) // 64) * 64
  261. # NOTE: The LM head is not quantized.
  262. self.lm_head = ParallelLinear.column(config.hidden_size,
  263. vocab_size,
  264. bias=False,
  265. gather_output=False,
  266. quant_config=None)
  267. self.sampler = Sampler(config.vocab_size)
  268. def forward(
  269. self,
  270. input_ids: torch.Tensor,
  271. positions: torch.Tensor,
  272. kv_caches: List[KVCache],
  273. input_metadata: InputMetadata,
  274. cache_events: Optional[List[torch.cuda.Event]],
  275. ) -> SamplerOutput:
  276. hidden_states = self.model(input_ids, positions, kv_caches,
  277. input_metadata, cache_events)
  278. next_tokens = self.sampler(self.lm_head.weight, hidden_states,
  279. input_metadata)
  280. return next_tokens
  281. column_parallel_layers = []
  282. row_parallel_layers = ["o_proj", "down_proj"]
  283. def load_weights(self,
  284. model_name_or_path: str,
  285. cache_dir: Optional[str] = None,
  286. load_format: str = "auto",
  287. revision: Optional[str] = None):
  288. column_parallel_weights, row_parallel_weights = get_parallel_weight(
  289. self)
  290. column_weight_suffixes = (
  291. self.quant_config.get_col_parallel_tensor_names()
  292. ) if self.quant_config is not None else ["weight", "bias"]
  293. tp_size = get_tensor_model_parallel_world_size()
  294. tp_rank = get_tensor_model_parallel_rank()
  295. q_proj_shard_size = (self.config.hidden_size // tp_size)
  296. num_kv_heads_replicas = max(1,
  297. tp_size // self.config.num_key_value_heads)
  298. num_kv_heads_per_gpu = max(1,
  299. self.config.num_key_value_heads // tp_size)
  300. kv_proj_shard_size = (self.config.hidden_size //
  301. self.config.num_attention_heads *
  302. num_kv_heads_per_gpu)
  303. attention_weight_specs = [
  304. # (weight_name, shard_size, offset)
  305. ("q_proj", q_proj_shard_size, 0),
  306. ("k_proj", kv_proj_shard_size, q_proj_shard_size),
  307. ("v_proj", kv_proj_shard_size,
  308. q_proj_shard_size + kv_proj_shard_size),
  309. ]
  310. state_dict = self.state_dict()
  311. for name, loaded_weight in hf_model_weights_iterator(
  312. model_name_or_path, cache_dir, load_format, revision):
  313. if "rotary_emb.inv_freq" in name:
  314. continue
  315. packed_dim = None
  316. is_transposed = False
  317. if self.quant_config is not None:
  318. packed_dim = self.quant_config.get_packed_dim(name)
  319. is_transposed = self.quant_config.is_transposed(name)
  320. if is_transposed:
  321. loaded_weight = convert_pyslice_to_tensor(loaded_weight)
  322. loaded_weight = loaded_weight.T
  323. is_attention_weight = False
  324. for weight_name, shard_size, offset in attention_weight_specs:
  325. if weight_name not in name:
  326. continue
  327. name = name.replace(weight_name, "qkv_proj")
  328. if name not in state_dict: # pylint: disable=unsupported-membership-test
  329. break
  330. param = state_dict[name] # pylint: disable=unsubscriptable-object
  331. if is_transposed:
  332. param = param.T
  333. if packed_dim is not None:
  334. shard_dim = 0 if not is_transposed else 1
  335. if packed_dim == shard_dim:
  336. shard_size //= self.quant_config.pack_factor
  337. offset //= self.quant_config.pack_factor
  338. if weight_name in ["k_proj", "v_proj"]:
  339. shard_id = tp_rank // num_kv_heads_replicas
  340. else:
  341. shard_id = tp_rank
  342. if any(
  343. name.endswith(suffix)
  344. for suffix in column_weight_suffixes):
  345. loaded_weight = loaded_weight[shard_size *
  346. shard_id:shard_size *
  347. (shard_id + 1)]
  348. param_slice = param.data[offset:offset + shard_size]
  349. else:
  350. loaded_weight = convert_pyslice_to_tensor(loaded_weight)
  351. param_slice = param.data
  352. assert param_slice.shape == loaded_weight.shape
  353. param_slice.copy_(loaded_weight)
  354. is_attention_weight = True
  355. break
  356. if is_attention_weight:
  357. continue
  358. is_gate_up_weight = False
  359. for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
  360. if weight_name not in name:
  361. continue
  362. name = name.replace(weight_name, "gate_up_proj")
  363. if name not in state_dict: # pylint: disable=unsupported-membership-test
  364. break
  365. param = state_dict[name] # pylint: disable=unsubscriptable-object
  366. if is_transposed:
  367. param = param.T
  368. shard_size = param.shape[0] // 2
  369. if any(
  370. name.endswith(suffix)
  371. for suffix in column_weight_suffixes):
  372. loaded_weight = loaded_weight[shard_size *
  373. tp_rank:shard_size *
  374. (tp_rank + 1)]
  375. param_slice = param.data[shard_size *
  376. stride_id:shard_size *
  377. (stride_id + 1)]
  378. else:
  379. loaded_weight = convert_pyslice_to_tensor(loaded_weight)
  380. param_slice = param.data
  381. assert param_slice.shape == loaded_weight.shape
  382. param_slice.copy_(loaded_weight)
  383. is_gate_up_weight = True
  384. break
  385. if is_gate_up_weight:
  386. continue
  387. if name not in state_dict: # pylint: disable=unsupported-membership-test
  388. continue
  389. param = state_dict[name] # pylint: disable=unsubscriptable-object
  390. if is_transposed:
  391. param = param.T
  392. if "embed_tokens" in name or "lm_head" in name:
  393. load_padded_tensor_parallel_vocab(param, loaded_weight,
  394. tp_rank)
  395. continue
  396. load_tensor_parallel_weights(param, loaded_weight, name,
  397. column_parallel_weights,
  398. row_parallel_weights, tp_rank)