gpt_neox.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. """Inference-only GPT-NeoX model compatible with HuggingFace weights.
  20. The input of the model is flattened to a 1D tensor of tokens. The model uses
  21. InputMetadata to extract the original 2D shape of the input.
  22. """
  23. from typing import List, Optional, Tuple
  24. import torch
  25. from torch import nn
  26. from transformers import GPTNeoXConfig
  27. from aphrodite.modeling.metadata import InputMetadata
  28. from aphrodite.modeling.layers.activation import get_act_fn
  29. from aphrodite.modeling.layers.attention import PagedAttentionWithRoPE
  30. from aphrodite.modeling.layers.sampler import Sampler
  31. from aphrodite.modeling.hf_downloader import (hf_model_weights_iterator,
  32. load_tensor_parallel_weights)
  33. from aphrodite.modeling.megatron.parallel_state import (
  34. get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
  35. from aphrodite.modeling.megatron.layers import (VocabParallelEmbedding,
  36. ColumnParallelLinear,
  37. RowParallelLinear)
  38. from aphrodite.common.sequence import SamplerOutput
  39. KVCache = Tuple[torch.Tensor, torch.Tensor]
  40. class GPTNeoXAttention(nn.Module):
  41. def __init__(self, config: GPTNeoXConfig):
  42. super().__init__()
  43. self.total_num_heads = config.num_attention_heads
  44. self.hidden_size = config.hidden_size
  45. self.head_size = self.hidden_size // self.total_num_heads
  46. tensor_model_parallel_world_size = (
  47. get_tensor_model_parallel_world_size())
  48. assert self.total_num_heads % tensor_model_parallel_world_size == 0
  49. self.num_heads = (self.total_num_heads //
  50. tensor_model_parallel_world_size)
  51. self.query_key_value = ColumnParallelLinear(
  52. config.hidden_size,
  53. 3 * config.hidden_size,
  54. gather_output=False,
  55. )
  56. self.dense = RowParallelLinear(
  57. config.hidden_size,
  58. config.hidden_size,
  59. input_is_parallel=True,
  60. )
  61. scaling = self.head_size**-0.5
  62. rotary_dim = int(self.head_size * config.rotary_pct)
  63. assert rotary_dim % 2 == 0
  64. rope_theta = getattr(config, "rope_theta", 10000)
  65. max_position_embeddings = getattr(config, "max_position_embeddings",
  66. 8192)
  67. self.attn = PagedAttentionWithRoPE(
  68. self.num_heads,
  69. self.head_size,
  70. scaling,
  71. rotary_dim,
  72. base=rope_theta,
  73. max_position=max_position_embeddings)
  74. def forward(
  75. self,
  76. position_ids: torch.Tensor,
  77. hidden_states: torch.Tensor,
  78. kv_cache: KVCache,
  79. input_metadata: InputMetadata,
  80. cache_event: Optional[torch.cuda.Event],
  81. ) -> torch.Tensor:
  82. qkv, _ = self.query_key_value(hidden_states)
  83. q, k, v = qkv.chunk(chunks=3, dim=-1)
  84. k_cache, v_cache = kv_cache
  85. attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache,
  86. input_metadata, cache_event)
  87. output, _ = self.dense(attn_output)
  88. return output
  89. class GPTNeoXMLP(nn.Module):
  90. def __init__(self, config: GPTNeoXConfig):
  91. super().__init__()
  92. self.dense_h_to_4h = ColumnParallelLinear(
  93. config.hidden_size,
  94. config.intermediate_size,
  95. gather_output=False,
  96. )
  97. self.dense_4h_to_h = RowParallelLinear(
  98. config.intermediate_size,
  99. config.hidden_size,
  100. input_is_parallel=True,
  101. )
  102. self.act = get_act_fn(config.hidden_act)
  103. def forward(self, hidden_states):
  104. hidden_states, _ = self.dense_h_to_4h(hidden_states)
  105. hidden_states = self.act(hidden_states)
  106. hidden_states, _ = self.dense_4h_to_h(hidden_states)
  107. return hidden_states
  108. class GPTNeoXLayer(nn.Module):
  109. def __init__(self, config: GPTNeoXConfig):
  110. super().__init__()
  111. self.use_parallel_residual = config.use_parallel_residual
  112. self.input_layernorm = nn.LayerNorm(config.hidden_size,
  113. eps=config.layer_norm_eps)
  114. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
  115. eps=config.layer_norm_eps)
  116. self.attention = GPTNeoXAttention(config)
  117. self.mlp = GPTNeoXMLP(config)
  118. def forward(
  119. self,
  120. position_ids: torch.Tensor,
  121. hidden_states: torch.Tensor,
  122. kv_cache: KVCache,
  123. input_metadata: InputMetadata,
  124. cache_event: Optional[torch.cuda.Event],
  125. ) -> torch.Tensor:
  126. attn_input = self.input_layernorm(hidden_states)
  127. attn_output = self.attention(
  128. position_ids=position_ids,
  129. hidden_states=attn_input,
  130. kv_cache=kv_cache,
  131. input_metadata=input_metadata,
  132. cache_event=cache_event,
  133. )
  134. if self.use_parallel_residual:
  135. # pseudocode:
  136. # x = x + attn(ln1(x)) + mlp(ln2(x))
  137. mlp_input = self.post_attention_layernorm(hidden_states)
  138. mlp_output = self.mlp(mlp_input)
  139. hidden_states = mlp_output + attn_output + hidden_states
  140. else:
  141. # pseudocode:
  142. # x = x + attn(ln1(x))
  143. # x = x + mlp(ln2(x))
  144. attn_output = attn_output + hidden_states
  145. mlp_input = self.post_attention_layernorm(attn_output)
  146. mlp_output = self.mlp(mlp_input)
  147. hidden_states = mlp_output + attn_output
  148. return hidden_states
  149. class GPTNeoXModel(nn.Module):
  150. def __init__(self, config: GPTNeoXConfig):
  151. super().__init__()
  152. self.config = config
  153. self.embed_in = VocabParallelEmbedding(
  154. config.vocab_size,
  155. config.hidden_size,
  156. )
  157. self.layers = nn.ModuleList(
  158. [GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
  159. self.final_layer_norm = nn.LayerNorm(config.hidden_size,
  160. eps=config.layer_norm_eps)
  161. def forward(
  162. self,
  163. input_ids: torch.Tensor,
  164. position_ids: torch.Tensor,
  165. kv_caches: List[KVCache],
  166. input_metadata: InputMetadata,
  167. cache_events: Optional[List[torch.cuda.Event]],
  168. ) -> torch.Tensor:
  169. hidden_states = self.embed_in(input_ids)
  170. for i in range(len(self.layers)):
  171. if cache_events is None:
  172. cache_event = None
  173. else:
  174. cache_event = cache_events[i]
  175. layer = self.layers[i]
  176. hidden_states = layer(
  177. position_ids,
  178. hidden_states,
  179. kv_caches[i],
  180. input_metadata,
  181. cache_event,
  182. )
  183. hidden_states = self.final_layer_norm(hidden_states)
  184. return hidden_states
  185. class GPTNeoXForCausalLM(nn.Module):
  186. def __init__(self, config):
  187. super().__init__()
  188. self.config = config
  189. self.gpt_neox = GPTNeoXModel(config)
  190. self.embed_out = ColumnParallelLinear(
  191. config.hidden_size,
  192. config.vocab_size,
  193. bias=False,
  194. gather_output=False,
  195. )
  196. self.sampler = Sampler(config.vocab_size)
  197. def forward(
  198. self,
  199. input_ids: torch.Tensor,
  200. positions: torch.Tensor,
  201. kv_caches: List[KVCache],
  202. input_metadata: InputMetadata,
  203. cache_events: Optional[List[torch.cuda.Event]],
  204. ) -> SamplerOutput:
  205. hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
  206. input_metadata, cache_events)
  207. next_tokens = self.sampler(self.embed_out.weight, hidden_states,
  208. input_metadata)
  209. return next_tokens
  210. _column_parallel_weights = [
  211. "embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight",
  212. "dense_h_to_4h.bias"
  213. ]
  214. _row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
  215. def load_weights(self,
  216. model_name_or_path: str,
  217. cache_dir: Optional[str] = None,
  218. load_format: str = "auto",
  219. revision: Optional[str] = None):
  220. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  221. state_dict = self.state_dict()
  222. for name, loaded_weight in hf_model_weights_iterator(
  223. model_name_or_path, cache_dir, load_format, revision):
  224. if ("attention.bias" in name or "attention.masked_bias" in name
  225. or "rotary_emb.inv_freq" in name):
  226. continue
  227. # pylint: disable=unsubscriptable-object
  228. param = state_dict[name]
  229. if "query_key_value" in name:
  230. # NOTE: GPT-NeoX's fused QKV has the shape of
  231. # [num_heads * 3 * head_size, hidden_size], while the
  232. # required shape is [3 * num_heads * head_size, hidden_size].
  233. # Thus, we need weight conversion.
  234. shard_size = param.shape[0]
  235. loaded_weight = loaded_weight[
  236. shard_size * tensor_model_parallel_rank:shard_size *
  237. (tensor_model_parallel_rank + 1)]
  238. num_heads = self.config.num_attention_heads
  239. hidden_size = self.config.hidden_size
  240. head_size = hidden_size // num_heads
  241. if "query_key_value.weight" in name:
  242. loaded_weight = loaded_weight.view(-1, 3, head_size,
  243. hidden_size)
  244. loaded_weight = loaded_weight.transpose(0, 1)
  245. loaded_weight = loaded_weight.reshape(-1, hidden_size)
  246. elif "query_key_value.bias" in name:
  247. loaded_weight = loaded_weight.view(-1, 3, head_size)
  248. loaded_weight = loaded_weight.transpose(0, 1)
  249. loaded_weight = loaded_weight.reshape(-1)
  250. else:
  251. raise ValueError(f"Unexpected weight name: {name}")
  252. load_tensor_parallel_weights(param, loaded_weight, name,
  253. self._column_parallel_weights,
  254. self._row_parallel_weights,
  255. tensor_model_parallel_rank)