1
0

gpt_neox.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. """Inference-only GPT-NeoX model compatible with HuggingFace weights.
  20. The input of the model is flattened to a 1D tensor of tokens. The model uses
  21. InputMetadata to extract the original 2D shape of the input.
  22. """
  23. from typing import List, Optional, Tuple
  24. import torch
  25. from torch import nn
  26. from transformers import GPTNeoXConfig
  27. from aphrodite.modeling.metadata import InputMetadata
  28. from aphrodite.modeling.layers.activation import get_act_fn
  29. from aphrodite.modeling.layers.attention import PagedAttentionWithRoPE
  30. from aphrodite.modeling.layers.sampler import Sampler
  31. from aphrodite.modeling.hf_downloader import (hf_model_weights_iterator,
  32. load_tensor_parallel_weights)
  33. from aphrodite.modeling.megatron.parallel_state import (
  34. get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
  35. from aphrodite.modeling.megatron.tensor_parallel import (
  36. VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
  37. from aphrodite.common.sequence import SamplerOutput
  38. KVCache = Tuple[torch.Tensor, torch.Tensor]
  39. class GPTNeoXAttention(nn.Module):
  40. def __init__(self, config: GPTNeoXConfig):
  41. super().__init__()
  42. self.total_num_heads = config.num_attention_heads
  43. self.hidden_size = config.hidden_size
  44. self.head_size = self.hidden_size // self.total_num_heads
  45. tensor_model_parallel_world_size = (
  46. get_tensor_model_parallel_world_size())
  47. assert self.total_num_heads % tensor_model_parallel_world_size == 0
  48. self.num_heads = (self.total_num_heads //
  49. tensor_model_parallel_world_size)
  50. self.query_key_value = ColumnParallelLinear(
  51. config.hidden_size,
  52. 3 * config.hidden_size,
  53. gather_output=False,
  54. perform_initialization=False)
  55. self.dense = RowParallelLinear(config.hidden_size,
  56. config.hidden_size,
  57. input_is_parallel=True,
  58. perform_initialization=False)
  59. scaling = self.head_size**-0.5
  60. rotary_dim = int(self.head_size * config.rotary_pct)
  61. assert rotary_dim % 2 == 0
  62. self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size,
  63. scaling, rotary_dim)
  64. def forward(
  65. self,
  66. position_ids: torch.Tensor,
  67. hidden_states: torch.Tensor,
  68. kv_cache: KVCache,
  69. input_metadata: InputMetadata,
  70. cache_event: Optional[torch.cuda.Event],
  71. ) -> torch.Tensor:
  72. qkv, _ = self.query_key_value(hidden_states)
  73. q, k, v = qkv.chunk(chunks=3, dim=-1)
  74. k_cache, v_cache = kv_cache
  75. attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache,
  76. input_metadata, cache_event)
  77. output, _ = self.dense(attn_output)
  78. return output
  79. class GPTNeoXMLP(nn.Module):
  80. def __init__(self, config: GPTNeoXConfig):
  81. super().__init__()
  82. self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size,
  83. config.intermediate_size,
  84. gather_output=False,
  85. perform_initialization=False)
  86. self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
  87. config.hidden_size,
  88. input_is_parallel=True,
  89. perform_initialization=False)
  90. self.act = get_act_fn(config.hidden_act)
  91. def forward(self, hidden_states):
  92. hidden_states, _ = self.dense_h_to_4h(hidden_states)
  93. hidden_states = self.act(hidden_states)
  94. hidden_states, _ = self.dense_4h_to_h(hidden_states)
  95. return hidden_states
  96. class GPTNeoXLayer(nn.Module):
  97. def __init__(self, config: GPTNeoXConfig):
  98. super().__init__()
  99. self.use_parallel_residual = config.use_parallel_residual
  100. self.input_layernorm = nn.LayerNorm(config.hidden_size,
  101. eps=config.layer_norm_eps)
  102. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
  103. eps=config.layer_norm_eps)
  104. self.attention = GPTNeoXAttention(config)
  105. self.mlp = GPTNeoXMLP(config)
  106. def forward(
  107. self,
  108. position_ids: torch.Tensor,
  109. hidden_states: torch.Tensor,
  110. kv_cache: KVCache,
  111. input_metadata: InputMetadata,
  112. cache_event: Optional[torch.cuda.Event],
  113. ) -> torch.Tensor:
  114. attn_input = self.input_layernorm(hidden_states)
  115. attn_output = self.attention(
  116. position_ids=position_ids,
  117. hidden_states=attn_input,
  118. kv_cache=kv_cache,
  119. input_metadata=input_metadata,
  120. cache_event=cache_event,
  121. )
  122. if self.use_parallel_residual:
  123. # pseudocode:
  124. # x = x + attn(ln1(x)) + mlp(ln2(x))
  125. mlp_input = self.post_attention_layernorm(hidden_states)
  126. mlp_output = self.mlp(mlp_input)
  127. hidden_states = mlp_output + attn_output + hidden_states
  128. else:
  129. # pseudocode:
  130. # x = x + attn(ln1(x))
  131. # x = x + mlp(ln2(x))
  132. attn_output = attn_output + hidden_states
  133. mlp_input = self.post_attention_layernorm(attn_output)
  134. mlp_output = self.mlp(mlp_input)
  135. hidden_states = mlp_output + attn_output
  136. return hidden_states
  137. class GPTNeoXModel(nn.Module):
  138. def __init__(self, config: GPTNeoXConfig):
  139. super().__init__()
  140. self.config = config
  141. self.embed_in = VocabParallelEmbedding(config.vocab_size,
  142. config.hidden_size,
  143. perform_initialization=False)
  144. self.layers = nn.ModuleList(
  145. [GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
  146. self.final_layer_norm = nn.LayerNorm(config.hidden_size,
  147. eps=config.layer_norm_eps)
  148. def forward(
  149. self,
  150. input_ids: torch.Tensor,
  151. position_ids: torch.Tensor,
  152. kv_caches: List[KVCache],
  153. input_metadata: InputMetadata,
  154. cache_events: Optional[List[torch.cuda.Event]],
  155. ) -> torch.Tensor:
  156. hidden_states = self.embed_in(input_ids)
  157. for i in range(len(self.layers)):
  158. if cache_events is None:
  159. cache_event = None
  160. else:
  161. cache_event = cache_events[i]
  162. layer = self.layers[i]
  163. hidden_states = layer(
  164. position_ids,
  165. hidden_states,
  166. kv_caches[i],
  167. input_metadata,
  168. cache_event,
  169. )
  170. hidden_states = self.final_layer_norm(hidden_states)
  171. return hidden_states
  172. class GPTNeoXForCausalLM(nn.Module):
  173. def __init__(self, config):
  174. super().__init__()
  175. self.config = config
  176. self.gpt_neox = GPTNeoXModel(config)
  177. self.embed_out = ColumnParallelLinear(config.hidden_size,
  178. config.vocab_size,
  179. bias=False,
  180. gather_output=False,
  181. perform_initialization=False)
  182. self.sampler = Sampler(config.vocab_size)
  183. def forward(
  184. self,
  185. input_ids: torch.Tensor,
  186. positions: torch.Tensor,
  187. kv_caches: List[KVCache],
  188. input_metadata: InputMetadata,
  189. cache_events: Optional[List[torch.cuda.Event]],
  190. ) -> SamplerOutput:
  191. hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
  192. input_metadata, cache_events)
  193. next_tokens = self.sampler(self.embed_out.weight, hidden_states,
  194. input_metadata)
  195. return next_tokens
  196. _column_parallel_weights = [
  197. "embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight",
  198. "dense_h_to_4h.bias"
  199. ]
  200. _row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
  201. def load_weights(self,
  202. model_name_or_path: str,
  203. cache_dir: Optional[str] = None,
  204. load_format: str = "auto",
  205. revision: Optional[str] = None):
  206. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  207. state_dict = self.state_dict()
  208. for name, loaded_weight in hf_model_weights_iterator(
  209. model_name_or_path, cache_dir, load_format, revision):
  210. if ("attention.bias" in name or "attention.masked_bias" in name
  211. or "rotary_emb.inv_freq" in name):
  212. continue
  213. param = state_dict[name]
  214. if "query_key_value" in name:
  215. # NOTE: GPT-NeoX's fused QKV has the shape of
  216. # [num_heads * 3 * head_size, hidden_size], while the
  217. # required shape is [3 * num_heads * head_size, hidden_size].
  218. # Thus, we need weight conversion.
  219. shard_size = param.shape[0]
  220. loaded_weight = loaded_weight[
  221. shard_size * tensor_model_parallel_rank:shard_size *
  222. (tensor_model_parallel_rank + 1)]
  223. num_heads = self.config.num_attention_heads
  224. hidden_size = self.config.hidden_size
  225. head_size = hidden_size // num_heads
  226. if "query_key_value.weight" in name:
  227. loaded_weight = loaded_weight.view(-1, 3, head_size,
  228. hidden_size)
  229. loaded_weight = loaded_weight.transpose(0, 1)
  230. loaded_weight = loaded_weight.reshape(-1, hidden_size)
  231. elif "query_key_value.bias" in name:
  232. loaded_weight = loaded_weight.view(-1, 3, head_size)
  233. loaded_weight = loaded_weight.transpose(0, 1)
  234. loaded_weight = loaded_weight.reshape(-1)
  235. else:
  236. raise ValueError(f"Unexpected weight name: {name}")
  237. load_tensor_parallel_weights(param, loaded_weight, name,
  238. self._column_parallel_weights,
  239. self._row_parallel_weights,
  240. tensor_model_parallel_rank)