1
0

gpt_neox.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. """Inference-only GPT-NeoX model compatible with HuggingFace weights."""
  20. from typing import List, Optional, Tuple
  21. import torch
  22. from torch import nn
  23. from transformers import GPTNeoXConfig
  24. from aphrodite.modeling.metadata import InputMetadata
  25. from aphrodite.modeling.layers.activation import get_act_fn
  26. from aphrodite.modeling.layers.attention import PagedAttention
  27. from aphrodite.modeling.layers.linear import (
  28. ColumnParallelLinear,
  29. LinearMethodBase,
  30. QKVParallelLinear,
  31. RowParallelLinear,
  32. )
  33. from aphrodite.modeling.layers.rotary_embedding import get_rope
  34. from aphrodite.modeling.layers.sampler import Sampler, QuantSampler
  35. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  36. VocabParallelEmbedding,
  37. ParallelLMHead,
  38. )
  39. from aphrodite.modeling.megatron.parallel_state import (
  40. get_tensor_model_parallel_world_size, )
  41. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  42. from aphrodite.modeling.hf_downloader import (
  43. default_weight_loader,
  44. hf_model_weights_iterator,
  45. )
  46. from aphrodite.common.sequence import SamplerOutput
  47. KVCache = Tuple[torch.Tensor, torch.Tensor]
  48. class GPTNeoXAttention(nn.Module):
  49. def __init__(
  50. self,
  51. config: GPTNeoXConfig,
  52. linear_method: Optional[LinearMethodBase] = None,
  53. ):
  54. super().__init__()
  55. self.total_num_heads = config.num_attention_heads
  56. self.hidden_size = config.hidden_size
  57. self.head_size = self.hidden_size // self.total_num_heads
  58. self.bias = getattr(config, "attention_bias", True)
  59. tensor_model_parallel_world_size = (
  60. get_tensor_model_parallel_world_size())
  61. assert self.total_num_heads % tensor_model_parallel_world_size == 0
  62. self.num_heads = (self.total_num_heads //
  63. tensor_model_parallel_world_size)
  64. self.query_key_value = QKVParallelLinear(
  65. config.hidden_size,
  66. self.head_size,
  67. self.total_num_heads,
  68. bias=self.bias,
  69. linear_method=linear_method,
  70. )
  71. self.dense = RowParallelLinear(
  72. config.hidden_size,
  73. config.hidden_size,
  74. bias=self.bias,
  75. linear_method=linear_method,
  76. )
  77. scaling = self.head_size**-0.5
  78. rotary_dim = int(self.head_size * config.rotary_pct)
  79. assert rotary_dim % 2 == 0
  80. rope_theta = getattr(config, "rope_theta", 10000)
  81. max_position_embeddings = getattr(config, "max_position_embeddings",
  82. 8192)
  83. is_neox_style = (True if linear_method is None
  84. or linear_method.quant_config.rope_style() is None
  85. else linear_method.quant_config.rope_style())
  86. self.rotary_emb = get_rope(
  87. self.head_size,
  88. rotary_dim=rotary_dim,
  89. max_position=max_position_embeddings,
  90. base=rope_theta,
  91. is_neox_style=is_neox_style,
  92. )
  93. self.attn = PagedAttention(self.num_heads, self.head_size, scaling)
  94. def forward(
  95. self,
  96. position_ids: torch.Tensor,
  97. hidden_states: torch.Tensor,
  98. kv_cache: KVCache,
  99. input_metadata: InputMetadata,
  100. ) -> torch.Tensor:
  101. qkv, _ = self.query_key_value(hidden_states)
  102. q, k, v = qkv.chunk(chunks=3, dim=-1)
  103. q, k = self.rotary_emb(position_ids, q, k)
  104. k_cache, v_cache = kv_cache
  105. attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
  106. output, _ = self.dense(attn_output)
  107. return output
  108. class GPTNeoXMLP(nn.Module):
  109. def __init__(
  110. self,
  111. config: GPTNeoXConfig,
  112. linear_method: Optional[LinearMethodBase] = None,
  113. ):
  114. super().__init__()
  115. self.dense_h_to_4h = ColumnParallelLinear(
  116. config.hidden_size,
  117. config.intermediate_size,
  118. linear_method=linear_method,
  119. )
  120. self.dense_4h_to_h = RowParallelLinear(
  121. config.intermediate_size,
  122. config.hidden_size,
  123. linear_method=linear_method,
  124. )
  125. quant_config = getattr(linear_method, "quant_config", None)
  126. self.act = get_act_fn(config.hidden_act, quant_config,
  127. config.intermediate_size)
  128. def forward(self, hidden_states):
  129. hidden_states, _ = self.dense_h_to_4h(hidden_states)
  130. hidden_states = self.act(hidden_states)
  131. hidden_states, _ = self.dense_4h_to_h(hidden_states)
  132. return hidden_states
  133. class GPTNeoXLayer(nn.Module):
  134. def __init__(
  135. self,
  136. config: GPTNeoXConfig,
  137. linear_method: Optional[LinearMethodBase] = None,
  138. ):
  139. super().__init__()
  140. self.use_parallel_residual = config.use_parallel_residual
  141. self.input_layernorm = nn.LayerNorm(config.hidden_size,
  142. eps=config.layer_norm_eps)
  143. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
  144. eps=config.layer_norm_eps)
  145. self.attention = GPTNeoXAttention(config, linear_method)
  146. self.mlp = GPTNeoXMLP(config, linear_method)
  147. def forward(
  148. self,
  149. position_ids: torch.Tensor,
  150. hidden_states: torch.Tensor,
  151. kv_cache: KVCache,
  152. input_metadata: InputMetadata,
  153. ) -> torch.Tensor:
  154. attn_input = self.input_layernorm(hidden_states)
  155. attn_output = self.attention(
  156. position_ids=position_ids,
  157. hidden_states=attn_input,
  158. kv_cache=kv_cache,
  159. input_metadata=input_metadata,
  160. )
  161. if self.use_parallel_residual:
  162. # pseudocode:
  163. # x = x + attn(ln1(x)) + mlp(ln2(x))
  164. mlp_input = self.post_attention_layernorm(hidden_states)
  165. mlp_output = self.mlp(mlp_input)
  166. hidden_states = mlp_output + attn_output + hidden_states
  167. else:
  168. # pseudocode:
  169. # x = x + attn(ln1(x))
  170. # x = x + mlp(ln2(x))
  171. attn_output = attn_output + hidden_states
  172. mlp_input = self.post_attention_layernorm(attn_output)
  173. mlp_output = self.mlp(mlp_input)
  174. hidden_states = mlp_output + attn_output
  175. return hidden_states
  176. class GPTNeoXModel(nn.Module):
  177. def __init__(
  178. self,
  179. config: GPTNeoXConfig,
  180. linear_method: Optional[LinearMethodBase] = None,
  181. ):
  182. super().__init__()
  183. self.config = config
  184. self.embed_in = VocabParallelEmbedding(config.vocab_size,
  185. config.hidden_size,
  186. linear_method=linear_method)
  187. self.layers = nn.ModuleList([
  188. GPTNeoXLayer(config, linear_method)
  189. for _ in range(config.num_hidden_layers)
  190. ])
  191. self.final_layer_norm = nn.LayerNorm(config.hidden_size,
  192. eps=config.layer_norm_eps)
  193. def forward(
  194. self,
  195. input_ids: torch.Tensor,
  196. position_ids: torch.Tensor,
  197. kv_caches: List[KVCache],
  198. input_metadata: InputMetadata,
  199. ) -> torch.Tensor:
  200. hidden_states = self.embed_in(input_ids)
  201. for i in range(len(self.layers)):
  202. layer = self.layers[i]
  203. hidden_states = layer(
  204. position_ids,
  205. hidden_states,
  206. kv_caches[i],
  207. input_metadata,
  208. )
  209. hidden_states = self.final_layer_norm(hidden_states)
  210. return hidden_states
  211. class GPTNeoXForCausalLM(nn.Module):
  212. def __init__(
  213. self,
  214. config,
  215. linear_method: Optional[LinearMethodBase] = None,
  216. ):
  217. super().__init__()
  218. self.config = config
  219. self.linear_method = linear_method
  220. self.gpt_neox = GPTNeoXModel(config, linear_method)
  221. self.embed_out = ParallelLMHead(config.vocab_size,
  222. config.hidden_size,
  223. linear_method=linear_method)
  224. self.sampler = Sampler(config.vocab_size)
  225. self.quant_sampler = QuantSampler(config.vocab_size)
  226. def forward(
  227. self,
  228. input_ids: torch.Tensor,
  229. positions: torch.Tensor,
  230. kv_caches: List[KVCache],
  231. input_metadata: InputMetadata,
  232. ) -> torch.Tensor:
  233. hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
  234. input_metadata)
  235. return hidden_states
  236. def sample(
  237. self,
  238. hidden_states: torch.Tensor,
  239. sampling_metadata: SamplingMetadata,
  240. ) -> Optional[SamplerOutput]:
  241. if (self.linear_method is not None
  242. and not self.linear_method.quant_config.merge_weight()):
  243. next_tokens = self.quant_sampler(self.embed_out(hidden_states),
  244. sampling_metadata)
  245. else:
  246. next_tokens = self.sampler(self.embed_out.weight, hidden_states,
  247. sampling_metadata)
  248. return next_tokens
  249. def load_weights(
  250. self,
  251. model_name_or_path: str,
  252. cache_dir: Optional[str] = None,
  253. load_format: str = "auto",
  254. revision: Optional[str] = None,
  255. ):
  256. params_dict = dict(self.named_parameters())
  257. for name, loaded_weight in hf_model_weights_iterator(
  258. model_name_or_path, cache_dir, load_format, revision,
  259. self.config):
  260. if ("attention.bias" in name or "attention.masked_bias" in name
  261. or "rotary_emb.inv_freq" in name):
  262. continue
  263. param = params_dict[name]
  264. if "query_key_value" in name:
  265. # NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
  266. # (num_heads * 3 * head_size), while the
  267. # required shape is (3 * num_heads * head_size).
  268. # Thus, we need weight conversion.
  269. output_dim = getattr(param, "output_dim", None)
  270. num_heads = self.config.num_attention_heads
  271. if output_dim is not None:
  272. loaded_weight_shape = loaded_weight.shape
  273. loaded_weight = loaded_weight.view(
  274. loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
  275. loaded_weight_shape[output_dim + 1:])
  276. loaded_weight = loaded_weight.transpose(
  277. output_dim, output_dim + 1)
  278. loaded_weight = loaded_weight.reshape(loaded_weight_shape)
  279. weight_loader = getattr(param, "weight_loader",
  280. default_weight_loader)
  281. weight_loader(param, loaded_weight)