1
0

gpt_neox.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
  6. #
  7. # Licensed under the Apache License, Version 2.0 (the "License");
  8. # you may not use this file except in compliance with the License.
  9. # You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing, software
  14. # distributed under the License is distributed on an "AS IS" BASIS,
  15. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. # See the License for the specific language governing permissions and
  17. # limitations under the License.
  18. """Inference-only GPT-NeoX model compatible with HuggingFace weights."""
  19. from typing import Iterable, List, Optional, Tuple
  20. import torch
  21. from torch import nn
  22. from transformers import GPTNeoXConfig
  23. from aphrodite.attention import Attention, AttentionMetadata
  24. from aphrodite.common.sequence import SamplerOutput
  25. from aphrodite.distributed import get_tensor_model_parallel_world_size
  26. from aphrodite.modeling.layers.activation import get_act_fn
  27. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  28. QKVParallelLinear,
  29. RowParallelLinear)
  30. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  31. from aphrodite.modeling.layers.rotary_embedding import get_rope
  32. from aphrodite.modeling.layers.sampler import Sampler
  33. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  34. ParallelLMHead, VocabParallelEmbedding)
  35. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  36. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  37. from aphrodite.quantization.base_config import QuantizationConfig
  38. class GPTNeoXAttention(nn.Module):
  39. def __init__(
  40. self,
  41. config: GPTNeoXConfig,
  42. quant_config: Optional[QuantizationConfig] = None,
  43. ):
  44. super().__init__()
  45. self.total_num_heads = config.num_attention_heads
  46. self.hidden_size = config.hidden_size
  47. self.head_size = self.hidden_size // self.total_num_heads
  48. self.bias = getattr(config, "attention_bias", True)
  49. tensor_model_parallel_world_size = (
  50. get_tensor_model_parallel_world_size())
  51. assert self.total_num_heads % tensor_model_parallel_world_size == 0
  52. self.num_heads = (self.total_num_heads //
  53. tensor_model_parallel_world_size)
  54. self.query_key_value = QKVParallelLinear(
  55. config.hidden_size,
  56. self.head_size,
  57. self.total_num_heads,
  58. bias=self.bias,
  59. quant_config=quant_config,
  60. )
  61. self.dense = RowParallelLinear(
  62. config.hidden_size,
  63. config.hidden_size,
  64. bias=self.bias,
  65. quant_config=quant_config,
  66. )
  67. scaling = self.head_size**-0.5
  68. rotary_dim = int(self.head_size * config.rotary_pct)
  69. assert rotary_dim % 2 == 0
  70. rope_theta = getattr(config, "rope_theta", 10000)
  71. max_position_embeddings = getattr(config, "max_position_embeddings",
  72. 8192)
  73. self.rotary_emb = get_rope(
  74. self.head_size,
  75. rotary_dim=rotary_dim,
  76. max_position=max_position_embeddings,
  77. base=rope_theta,
  78. )
  79. self.attn = Attention(self.num_heads, self.head_size, scaling)
  80. def forward(
  81. self,
  82. position_ids: torch.Tensor,
  83. hidden_states: torch.Tensor,
  84. kv_cache: torch.Tensor,
  85. attn_metadata: AttentionMetadata,
  86. ) -> torch.Tensor:
  87. qkv, _ = self.query_key_value(hidden_states)
  88. q, k, v = qkv.chunk(chunks=3, dim=-1)
  89. q, k = self.rotary_emb(position_ids, q, k)
  90. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  91. output, _ = self.dense(attn_output)
  92. return output
  93. class GPTNeoXMLP(nn.Module):
  94. def __init__(
  95. self,
  96. config: GPTNeoXConfig,
  97. quant_config: Optional[QuantizationConfig] = None,
  98. ):
  99. super().__init__()
  100. self.dense_h_to_4h = ColumnParallelLinear(
  101. config.hidden_size,
  102. config.intermediate_size,
  103. quant_config=quant_config,
  104. )
  105. self.dense_4h_to_h = RowParallelLinear(
  106. config.intermediate_size,
  107. config.hidden_size,
  108. quant_config=quant_config,
  109. )
  110. quant_config = getattr(quant_config, "quant_config", None)
  111. self.act = get_act_fn(config.hidden_act, quant_config,
  112. config.intermediate_size)
  113. def forward(self, hidden_states):
  114. hidden_states, _ = self.dense_h_to_4h(hidden_states)
  115. hidden_states = self.act(hidden_states)
  116. hidden_states, _ = self.dense_4h_to_h(hidden_states)
  117. return hidden_states
  118. class GPTNeoXLayer(nn.Module):
  119. def __init__(
  120. self,
  121. config: GPTNeoXConfig,
  122. quant_config: Optional[QuantizationConfig] = None,
  123. ):
  124. super().__init__()
  125. self.use_parallel_residual = config.use_parallel_residual
  126. self.input_layernorm = nn.LayerNorm(config.hidden_size,
  127. eps=config.layer_norm_eps)
  128. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
  129. eps=config.layer_norm_eps)
  130. self.attention = GPTNeoXAttention(config, quant_config)
  131. self.mlp = GPTNeoXMLP(config, quant_config)
  132. def forward(
  133. self,
  134. position_ids: torch.Tensor,
  135. hidden_states: torch.Tensor,
  136. kv_cache: torch.Tensor,
  137. attn_metadata: AttentionMetadata,
  138. ) -> torch.Tensor:
  139. attn_input = self.input_layernorm(hidden_states)
  140. attn_output = self.attention(
  141. position_ids=position_ids,
  142. hidden_states=attn_input,
  143. kv_cache=kv_cache,
  144. attn_metadata=attn_metadata,
  145. )
  146. if self.use_parallel_residual:
  147. # pseudocode:
  148. # x = x + attn(ln1(x)) + mlp(ln2(x))
  149. mlp_input = self.post_attention_layernorm(hidden_states)
  150. mlp_output = self.mlp(mlp_input)
  151. hidden_states = mlp_output + attn_output + hidden_states
  152. else:
  153. # pseudocode:
  154. # x = x + attn(ln1(x))
  155. # x = x + mlp(ln2(x))
  156. attn_output = attn_output + hidden_states
  157. mlp_input = self.post_attention_layernorm(attn_output)
  158. mlp_output = self.mlp(mlp_input)
  159. hidden_states = mlp_output + attn_output
  160. return hidden_states
  161. class GPTNeoXModel(nn.Module):
  162. def __init__(
  163. self,
  164. config: GPTNeoXConfig,
  165. quant_config: Optional[QuantizationConfig] = None,
  166. ):
  167. super().__init__()
  168. self.config = config
  169. self.embed_in = VocabParallelEmbedding(
  170. config.vocab_size,
  171. config.hidden_size,
  172. )
  173. self.layers = nn.ModuleList([
  174. GPTNeoXLayer(config, quant_config)
  175. for _ in range(config.num_hidden_layers)
  176. ])
  177. self.final_layer_norm = nn.LayerNorm(config.hidden_size,
  178. eps=config.layer_norm_eps)
  179. def forward(
  180. self,
  181. input_ids: torch.Tensor,
  182. position_ids: torch.Tensor,
  183. kv_caches: List[torch.Tensor],
  184. attn_metadata: AttentionMetadata,
  185. ) -> torch.Tensor:
  186. hidden_states = self.embed_in(input_ids)
  187. for i in range(len(self.layers)):
  188. layer = self.layers[i]
  189. hidden_states = layer(
  190. position_ids,
  191. hidden_states,
  192. kv_caches[i],
  193. attn_metadata,
  194. )
  195. hidden_states = self.final_layer_norm(hidden_states)
  196. return hidden_states
  197. class GPTNeoXForCausalLM(nn.Module):
  198. def __init__(
  199. self,
  200. config,
  201. quant_config: Optional[QuantizationConfig] = None,
  202. ):
  203. super().__init__()
  204. self.config = config
  205. self.quant_config = quant_config
  206. self.gpt_neox = GPTNeoXModel(config, quant_config)
  207. self.embed_out = ParallelLMHead(
  208. config.vocab_size,
  209. config.hidden_size,
  210. )
  211. self.logits_processor = LogitsProcessor(config.vocab_size)
  212. self.sampler = Sampler()
  213. def forward(
  214. self,
  215. input_ids: torch.Tensor,
  216. positions: torch.Tensor,
  217. kv_caches: List[torch.Tensor],
  218. attn_metadata: AttentionMetadata,
  219. ) -> torch.Tensor:
  220. hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
  221. attn_metadata)
  222. return hidden_states
  223. def compute_logits(self, hidden_states: torch.Tensor,
  224. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  225. logits = self.logits_processor(self.embed_out.weight, hidden_states,
  226. sampling_metadata)
  227. return logits
  228. def sample(
  229. self,
  230. logits: torch.Tensor,
  231. sampling_metadata: SamplingMetadata,
  232. ) -> Optional[SamplerOutput]:
  233. next_tokens = self.sampler(logits, sampling_metadata)
  234. return next_tokens
  235. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  236. params_dict = dict(self.named_parameters())
  237. for name, loaded_weight in weights:
  238. if ("attention.bias" in name or "attention.masked_bias" in name
  239. or "rotary_emb.inv_freq" in name):
  240. continue
  241. if ("rotary_emb.cos_cached" in name
  242. or "rotary_emb.sin_cached" in name):
  243. # Models trained using OpenRLHF may include
  244. # these tensors in the checkpoint. Skip them.
  245. continue
  246. param = params_dict[name]
  247. if "query_key_value" in name:
  248. # NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
  249. # (num_heads * 3 * head_size), while the
  250. # required shape is (3 * num_heads * head_size).
  251. # Thus, we need weight conversion.
  252. output_dim = getattr(param, "output_dim", None)
  253. num_heads = self.config.num_attention_heads
  254. if output_dim is not None:
  255. loaded_weight_shape = loaded_weight.shape
  256. loaded_weight = loaded_weight.view(
  257. loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
  258. loaded_weight_shape[output_dim + 1:])
  259. loaded_weight = loaded_weight.transpose(
  260. output_dim, output_dim + 1)
  261. loaded_weight = loaded_weight.reshape(loaded_weight_shape)
  262. weight_loader = getattr(param, "weight_loader",
  263. default_weight_loader)
  264. weight_loader(param, loaded_weight)