llama.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only LLaMA model compatible with HuggingFace weights.
  24. The input of the model is flattened to a 1D tensor of tokens. The model uses
  25. InputMetadata to extract the original 2D shape of the input.
  26. """
  27. from typing import List, Optional, Tuple
  28. import torch
  29. from torch import nn
  30. from transformers import LlamaConfig
  31. from aphrodite.modeling.metadata import InputMetadata
  32. from aphrodite.modeling.layers.activation import SiluAndMul
  33. from aphrodite.modeling.layers.layernorm import RMSNorm
  34. from aphrodite.modeling.layers.attention import PagedAttentionWithRoPE
  35. from aphrodite.modeling.layers.sampler import Sampler
  36. from aphrodite.modeling.hf_downloader import load_tensor_parallel_weights, load_padded_tensor_parallel_vocab, hf_model_weights_iterator
  37. from aphrodite.modeling.megatron.parallel_state import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size
  38. from aphrodite.modeling.megatron.tensor_parallel import VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear
  39. from aphrodite.common.sequence import SamplerOutput
  40. KVCache = Tuple[torch.Tensor, torch.Tensor]
  41. class LlamaMLP(nn.Module):
  42. def __init__(
  43. self,
  44. hidden_size: int,
  45. intermediate_size: int,
  46. hidden_act: str,
  47. ):
  48. super().__init__()
  49. self.gate_up_proj = ColumnParallelLinear(hidden_size,
  50. 2 * intermediate_size,
  51. bias=False,
  52. gather_output=False,
  53. perform_initialization=False)
  54. self.down_proj = RowParallelLinear(intermediate_size,
  55. hidden_size,
  56. bias=False,
  57. input_is_parallel=True,
  58. perform_initialization=False)
  59. if hidden_act != "silu":
  60. raise ValueError(f"Unsupported activation: {hidden_act}. "
  61. "Only silu is supported for now.")
  62. self.act_fn = SiluAndMul()
  63. def forward(self, x):
  64. gate_up, _ = self.gate_up_proj(x)
  65. x = self.act_fn(gate_up)
  66. x, _ = self.down_proj(x)
  67. return x
  68. class LlamaAttention(nn.Module):
  69. def __init__(
  70. self,
  71. hidden_size: int,
  72. num_heads: int,
  73. num_kv_heads: int,
  74. rope_theta: float = 10000,
  75. ):
  76. super().__init__()
  77. self.hidden_size = hidden_size
  78. tp_size = get_tensor_model_parallel_world_size()
  79. self.total_num_heads = num_heads
  80. assert self.total_num_heads % tp_size == 0
  81. self.num_heads = self.total_num_heads // tp_size
  82. self.total_num_kv_heads = num_kv_heads
  83. assert self.total_num_kv_heads % tp_size == 0
  84. self.num_kv_heads = self.total_num_kv_heads // tp_size
  85. self.head_dim = hidden_size // self.total_num_heads
  86. self.q_size = self.num_heads * self.head_dim
  87. self.kv_size = self.num_kv_heads * self.head_dim
  88. self.scaling = self.head_dim**-0.5
  89. self.rope_theta = rope_theta
  90. self.qkv_proj = ColumnParallelLinear(
  91. hidden_size,
  92. (self.total_num_heads + 2 * self.total_num_kv_heads) *
  93. self.head_dim,
  94. bias=False,
  95. gather_output=False,
  96. perform_initialization=False,
  97. )
  98. self.o_proj = RowParallelLinear(
  99. self.total_num_heads * self.head_dim,
  100. hidden_size,
  101. bias=False,
  102. input_is_parallel=True,
  103. perform_initialization=False,
  104. )
  105. self.attn = PagedAttentionWithRoPE(self.num_heads,
  106. self.head_dim,
  107. self.scaling,
  108. base=self.rope_theta,
  109. rotary_dim=self.head_dim,
  110. num_kv_heads=self.num_kv_heads)
  111. def forward(
  112. self,
  113. positions: torch.Tensor,
  114. hidden_states: torch.Tensor,
  115. kv_cache: KVCache,
  116. input_metadata: InputMetadata,
  117. cache_event: Optional[torch.cuda.Event],
  118. ) -> torch.Tensor:
  119. qkv, _ = self.qkv_proj(hidden_states)
  120. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  121. k_cache, v_cache = kv_cache
  122. attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
  123. input_metadata, cache_event)
  124. output, _ = self.o_proj(attn_output)
  125. return output
  126. class LlamaDecoderLayer(nn.Module):
  127. def __init__(self, config: LlamaConfig):
  128. super().__init__()
  129. self.hidden_size = config.hidden_size
  130. # Requires transformers > 4.32.0
  131. rope_theta = getattr(config, "rope_theta", 10000)
  132. self.self_attn = LlamaAttention(
  133. hidden_size=self.hidden_size,
  134. num_heads=config.num_attention_heads,
  135. num_kv_heads=config.num_key_value_heads,
  136. rope_theta=rope_theta,
  137. )
  138. self.mlp = LlamaMLP(
  139. hidden_size=self.hidden_size,
  140. intermediate_size=config.intermediate_size,
  141. hidden_act=config.hidden_act,
  142. )
  143. self.input_layernorm = RMSNorm(config.hidden_size,
  144. eps=config.rms_norm_eps)
  145. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  146. eps=config.rms_norm_eps)
  147. def forward(
  148. self,
  149. positions: torch.Tensor,
  150. hidden_states: torch.Tensor,
  151. kv_cache: KVCache,
  152. input_metadata: InputMetadata,
  153. cache_event: Optional[torch.cuda.Event],
  154. ) -> torch.Tensor:
  155. # Self Attention
  156. residual = hidden_states
  157. hidden_states = self.input_layernorm(hidden_states)
  158. hidden_states = self.self_attn(
  159. positions=positions,
  160. hidden_states=hidden_states,
  161. kv_cache=kv_cache,
  162. input_metadata=input_metadata,
  163. cache_event=cache_event,
  164. )
  165. hidden_states = residual + hidden_states
  166. # Fully Connected
  167. residual = hidden_states
  168. hidden_states = self.post_attention_layernorm(hidden_states)
  169. hidden_states = self.mlp(hidden_states)
  170. hidden_states = residual + hidden_states
  171. return hidden_states
  172. class LlamaModel(nn.Module):
  173. def __init__(self, config: LlamaConfig):
  174. super().__init__()
  175. self.config = config
  176. self.padding_idx = config.pad_token_id
  177. self.vocab_size = config.vocab_size
  178. vocab_size = ((config.vocab_size + 63) // 64) * 64
  179. self.embed_tokens = VocabParallelEmbedding(
  180. vocab_size, config.hidden_size, perform_initialization=False)
  181. self.layers = nn.ModuleList([
  182. LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
  183. ])
  184. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  185. def forward(
  186. self,
  187. input_ids: torch.Tensor,
  188. positions: torch.Tensor,
  189. kv_caches: List[KVCache],
  190. input_metadata: InputMetadata,
  191. cache_events: Optional[List[torch.cuda.Event]],
  192. ) -> torch.Tensor:
  193. hidden_states = self.embed_tokens(input_ids)
  194. for i in range(len(self.layers)):
  195. if cache_events is None:
  196. cache_event = None
  197. else:
  198. cache_event = cache_events[i]
  199. layer = self.layers[i]
  200. hidden_states = layer(
  201. positions,
  202. hidden_states,
  203. kv_caches[i],
  204. input_metadata,
  205. cache_event,
  206. )
  207. hidden_states = self.norm(hidden_states)
  208. return hidden_states
  209. class LlamaForCausalLM(nn.Module):
  210. def __init__(self, config):
  211. super().__init__()
  212. self.config = config
  213. self.model = LlamaModel(config)
  214. vocab_size = ((config.vocab_size + 63) // 64) * 64
  215. self.lm_head = ColumnParallelLinear(config.hidden_size,
  216. vocab_size,
  217. bias=False,
  218. gather_output=False,
  219. perform_initialization=False)
  220. self.sampler = Sampler(config.vocab_size)
  221. def forward(
  222. self,
  223. input_ids: torch.Tensor,
  224. positions: torch.Tensor,
  225. kv_caches: List[KVCache],
  226. input_metadata: InputMetadata,
  227. cache_events: Optional[List[torch.cuda.Event]],
  228. ) -> SamplerOutput:
  229. hidden_states = self.model(input_ids, positions, kv_caches,
  230. input_metadata, cache_events)
  231. next_tokens = self.sampler(self.lm_head.weight, hidden_states,
  232. input_metadata)
  233. return next_tokens
  234. _column_parallel_weights = [
  235. "qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
  236. ]
  237. _row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
  238. def load_weights(self,
  239. model_name_or_path: str,
  240. cache_dir: Optional[str] = None,
  241. load_format: str = "auto",
  242. revision: Optional[str] = None):
  243. tp_size = get_tensor_model_parallel_world_size()
  244. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  245. q_proj_shard_size = (self.config.hidden_size // tp_size)
  246. kv_proj_shard_size = (self.config.hidden_size //
  247. self.config.num_attention_heads *
  248. self.config.num_key_value_heads // tp_size)
  249. attention_weight_specs = [
  250. # (weight_name, shard_size, offset)
  251. ("q_proj", q_proj_shard_size, 0),
  252. ("k_proj", kv_proj_shard_size, q_proj_shard_size),
  253. ("v_proj", kv_proj_shard_size,
  254. q_proj_shard_size + kv_proj_shard_size),
  255. ]
  256. state_dict = self.state_dict()
  257. for name, loaded_weight in hf_model_weights_iterator(
  258. model_name_or_path, cache_dir, load_format, revision):
  259. if "rotary_emb.inv_freq" in name:
  260. continue
  261. is_attention_weight = False
  262. for weight_name, shard_size, offset in attention_weight_specs:
  263. if weight_name not in name:
  264. continue
  265. param = state_dict[name.replace(weight_name, "qkv_proj")]
  266. loaded_weight = loaded_weight[
  267. shard_size * tensor_model_parallel_rank:shard_size *
  268. (tensor_model_parallel_rank + 1)]
  269. param_slice = param.data[offset:offset + shard_size]
  270. assert param_slice.shape == loaded_weight.shape
  271. param_slice.copy_(loaded_weight)
  272. is_attention_weight = True
  273. break
  274. if is_attention_weight:
  275. continue
  276. is_gate_up_weight = False
  277. for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
  278. if weight_name not in name:
  279. continue
  280. param = state_dict[name.replace(weight_name, "gate_up_proj")]
  281. shard_size = param.shape[0] // 2
  282. loaded_weight = loaded_weight[
  283. shard_size * tensor_model_parallel_rank:shard_size *
  284. (tensor_model_parallel_rank + 1)]
  285. param_slice = param.data[shard_size * stride_id:shard_size *
  286. (stride_id + 1)]
  287. assert param_slice.shape == loaded_weight.shape
  288. param_slice.copy_(loaded_weight)
  289. is_gate_up_weight = True
  290. break
  291. if is_gate_up_weight:
  292. continue
  293. param = state_dict[name]
  294. if "embed_tokens" in name or "lm_head" in name:
  295. load_padded_tensor_parallel_vocab(param, loaded_weight,
  296. tensor_model_parallel_rank)
  297. continue
  298. load_tensor_parallel_weights(param, loaded_weight, name,
  299. self._column_parallel_weights,
  300. self._row_parallel_weights,
  301. tensor_model_parallel_rank)