1
0

gpt_neox.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. """Inference-only GPT-NeoX model compatible with HuggingFace weights."""
  20. from typing import Iterable, List, Optional, Tuple
  21. import torch
  22. from torch import nn
  23. from transformers import GPTNeoXConfig
  24. from aphrodite.attention import Attention, AttentionMetadata
  25. from aphrodite.common.sequence import SamplerOutput
  26. from aphrodite.distributed import get_tensor_model_parallel_world_size
  27. from aphrodite.modeling.layers.activation import get_act_fn
  28. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  29. LinearMethodBase,
  30. QKVParallelLinear,
  31. RowParallelLinear)
  32. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  33. from aphrodite.modeling.layers.rotary_embedding import get_rope
  34. from aphrodite.modeling.layers.sampler import Sampler
  35. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  36. ParallelLMHead, VocabParallelEmbedding)
  37. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  38. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  39. class GPTNeoXAttention(nn.Module):
  40. def __init__(
  41. self,
  42. config: GPTNeoXConfig,
  43. linear_method: Optional[LinearMethodBase] = None,
  44. ):
  45. super().__init__()
  46. self.total_num_heads = config.num_attention_heads
  47. self.hidden_size = config.hidden_size
  48. self.head_size = self.hidden_size // self.total_num_heads
  49. self.bias = getattr(config, "attention_bias", True)
  50. tensor_model_parallel_world_size = (
  51. get_tensor_model_parallel_world_size())
  52. assert self.total_num_heads % tensor_model_parallel_world_size == 0
  53. self.num_heads = (self.total_num_heads //
  54. tensor_model_parallel_world_size)
  55. self.query_key_value = QKVParallelLinear(
  56. config.hidden_size,
  57. self.head_size,
  58. self.total_num_heads,
  59. bias=self.bias,
  60. linear_method=linear_method,
  61. )
  62. self.dense = RowParallelLinear(
  63. config.hidden_size,
  64. config.hidden_size,
  65. bias=self.bias,
  66. linear_method=linear_method,
  67. )
  68. scaling = self.head_size**-0.5
  69. rotary_dim = int(self.head_size * config.rotary_pct)
  70. assert rotary_dim % 2 == 0
  71. rope_theta = getattr(config, "rope_theta", 10000)
  72. max_position_embeddings = getattr(config, "max_position_embeddings",
  73. 8192)
  74. self.rotary_emb = get_rope(
  75. self.head_size,
  76. rotary_dim=rotary_dim,
  77. max_position=max_position_embeddings,
  78. base=rope_theta,
  79. )
  80. self.attn = Attention(self.num_heads, self.head_size, scaling)
  81. def forward(
  82. self,
  83. position_ids: torch.Tensor,
  84. hidden_states: torch.Tensor,
  85. kv_cache: torch.Tensor,
  86. attn_metadata: AttentionMetadata,
  87. ) -> torch.Tensor:
  88. qkv, _ = self.query_key_value(hidden_states)
  89. q, k, v = qkv.chunk(chunks=3, dim=-1)
  90. q, k = self.rotary_emb(position_ids, q, k)
  91. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  92. output, _ = self.dense(attn_output)
  93. return output
  94. class GPTNeoXMLP(nn.Module):
  95. def __init__(
  96. self,
  97. config: GPTNeoXConfig,
  98. linear_method: Optional[LinearMethodBase] = None,
  99. ):
  100. super().__init__()
  101. self.dense_h_to_4h = ColumnParallelLinear(
  102. config.hidden_size,
  103. config.intermediate_size,
  104. linear_method=linear_method,
  105. )
  106. self.dense_4h_to_h = RowParallelLinear(
  107. config.intermediate_size,
  108. config.hidden_size,
  109. linear_method=linear_method,
  110. )
  111. quant_config = getattr(linear_method, "quant_config", None)
  112. self.act = get_act_fn(config.hidden_act, quant_config,
  113. config.intermediate_size)
  114. def forward(self, hidden_states):
  115. hidden_states, _ = self.dense_h_to_4h(hidden_states)
  116. hidden_states = self.act(hidden_states)
  117. hidden_states, _ = self.dense_4h_to_h(hidden_states)
  118. return hidden_states
  119. class GPTNeoXLayer(nn.Module):
  120. def __init__(
  121. self,
  122. config: GPTNeoXConfig,
  123. linear_method: Optional[LinearMethodBase] = None,
  124. ):
  125. super().__init__()
  126. self.use_parallel_residual = config.use_parallel_residual
  127. self.input_layernorm = nn.LayerNorm(config.hidden_size,
  128. eps=config.layer_norm_eps)
  129. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
  130. eps=config.layer_norm_eps)
  131. self.attention = GPTNeoXAttention(config, linear_method)
  132. self.mlp = GPTNeoXMLP(config, linear_method)
  133. def forward(
  134. self,
  135. position_ids: torch.Tensor,
  136. hidden_states: torch.Tensor,
  137. kv_cache: torch.Tensor,
  138. attn_metadata: AttentionMetadata,
  139. ) -> torch.Tensor:
  140. attn_input = self.input_layernorm(hidden_states)
  141. attn_output = self.attention(
  142. position_ids=position_ids,
  143. hidden_states=attn_input,
  144. kv_cache=kv_cache,
  145. attn_metadata=attn_metadata,
  146. )
  147. if self.use_parallel_residual:
  148. # pseudocode:
  149. # x = x + attn(ln1(x)) + mlp(ln2(x))
  150. mlp_input = self.post_attention_layernorm(hidden_states)
  151. mlp_output = self.mlp(mlp_input)
  152. hidden_states = mlp_output + attn_output + hidden_states
  153. else:
  154. # pseudocode:
  155. # x = x + attn(ln1(x))
  156. # x = x + mlp(ln2(x))
  157. attn_output = attn_output + hidden_states
  158. mlp_input = self.post_attention_layernorm(attn_output)
  159. mlp_output = self.mlp(mlp_input)
  160. hidden_states = mlp_output + attn_output
  161. return hidden_states
  162. class GPTNeoXModel(nn.Module):
  163. def __init__(
  164. self,
  165. config: GPTNeoXConfig,
  166. linear_method: Optional[LinearMethodBase] = None,
  167. ):
  168. super().__init__()
  169. self.config = config
  170. self.embed_in = VocabParallelEmbedding(
  171. config.vocab_size,
  172. config.hidden_size,
  173. )
  174. self.layers = nn.ModuleList([
  175. GPTNeoXLayer(config, linear_method)
  176. for _ in range(config.num_hidden_layers)
  177. ])
  178. self.final_layer_norm = nn.LayerNorm(config.hidden_size,
  179. eps=config.layer_norm_eps)
  180. def forward(
  181. self,
  182. input_ids: torch.Tensor,
  183. position_ids: torch.Tensor,
  184. kv_caches: List[torch.Tensor],
  185. attn_metadata: AttentionMetadata,
  186. ) -> torch.Tensor:
  187. hidden_states = self.embed_in(input_ids)
  188. for i in range(len(self.layers)):
  189. layer = self.layers[i]
  190. hidden_states = layer(
  191. position_ids,
  192. hidden_states,
  193. kv_caches[i],
  194. attn_metadata,
  195. )
  196. hidden_states = self.final_layer_norm(hidden_states)
  197. return hidden_states
  198. class GPTNeoXForCausalLM(nn.Module):
  199. def __init__(
  200. self,
  201. config,
  202. linear_method: Optional[LinearMethodBase] = None,
  203. ):
  204. super().__init__()
  205. self.config = config
  206. self.linear_method = linear_method
  207. self.gpt_neox = GPTNeoXModel(config, linear_method)
  208. self.embed_out = ParallelLMHead(
  209. config.vocab_size,
  210. config.hidden_size,
  211. )
  212. self.logits_processor = LogitsProcessor(config.vocab_size)
  213. self.sampler = Sampler()
  214. def forward(
  215. self,
  216. input_ids: torch.Tensor,
  217. positions: torch.Tensor,
  218. kv_caches: List[torch.Tensor],
  219. attn_metadata: AttentionMetadata,
  220. ) -> torch.Tensor:
  221. hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
  222. attn_metadata)
  223. return hidden_states
  224. def compute_logits(self, hidden_states: torch.Tensor,
  225. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  226. logits = self.logits_processor(self.embed_out.weight, hidden_states,
  227. sampling_metadata)
  228. return logits
  229. def sample(
  230. self,
  231. logits: torch.Tensor,
  232. sampling_metadata: SamplingMetadata,
  233. ) -> Optional[SamplerOutput]:
  234. next_tokens = self.sampler(logits, sampling_metadata)
  235. return next_tokens
  236. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  237. params_dict = dict(self.named_parameters())
  238. for name, loaded_weight in weights:
  239. if ("attention.bias" in name or "attention.masked_bias" in name
  240. or "rotary_emb.inv_freq" in name):
  241. continue
  242. if ("rotary_emb.cos_cached" in name
  243. or "rotary_emb.sin_cached" in name):
  244. # Models trained using OpenRLHF may include
  245. # these tensors in the checkpoint. Skip them.
  246. continue
  247. param = params_dict[name]
  248. if "query_key_value" in name:
  249. # NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
  250. # (num_heads * 3 * head_size), while the
  251. # required shape is (3 * num_heads * head_size).
  252. # Thus, we need weight conversion.
  253. output_dim = getattr(param, "output_dim", None)
  254. num_heads = self.config.num_attention_heads
  255. if output_dim is not None:
  256. loaded_weight_shape = loaded_weight.shape
  257. loaded_weight = loaded_weight.view(
  258. loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
  259. loaded_weight_shape[output_dim + 1:])
  260. loaded_weight = loaded_weight.transpose(
  261. output_dim, output_dim + 1)
  262. loaded_weight = loaded_weight.reshape(loaded_weight_shape)
  263. weight_loader = getattr(param, "weight_loader",
  264. default_weight_loader)
  265. weight_loader(param, loaded_weight)