1
0

gpt_neox.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
  6. #
  7. # Licensed under the Apache License, Version 2.0 (the "License");
  8. # you may not use this file except in compliance with the License.
  9. # You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing, software
  14. # distributed under the License is distributed on an "AS IS" BASIS,
  15. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. # See the License for the specific language governing permissions and
  17. # limitations under the License.
  18. """Inference-only GPT-NeoX model compatible with HuggingFace weights."""
  19. from typing import Iterable, List, Optional, Tuple
  20. import torch
  21. from torch import nn
  22. from transformers import GPTNeoXConfig
  23. from aphrodite.attention import Attention, AttentionMetadata
  24. from aphrodite.common.config import CacheConfig
  25. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  26. from aphrodite.distributed import get_tensor_model_parallel_world_size
  27. from aphrodite.modeling.layers.activation import get_act_fn
  28. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  29. QKVParallelLinear,
  30. RowParallelLinear)
  31. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  32. from aphrodite.modeling.layers.rotary_embedding import get_rope
  33. from aphrodite.modeling.layers.sampler import Sampler
  34. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  35. ParallelLMHead, VocabParallelEmbedding)
  36. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  37. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  38. from aphrodite.quantization.base_config import QuantizationConfig
  39. class GPTNeoXAttention(nn.Module):
  40. def __init__(
  41. self,
  42. config: GPTNeoXConfig,
  43. cache_config: Optional[CacheConfig] = None,
  44. quant_config: Optional[QuantizationConfig] = None,
  45. ):
  46. super().__init__()
  47. self.total_num_heads = config.num_attention_heads
  48. self.hidden_size = config.hidden_size
  49. self.head_size = self.hidden_size // self.total_num_heads
  50. self.bias = getattr(config, "attention_bias", True)
  51. tensor_model_parallel_world_size = (
  52. get_tensor_model_parallel_world_size())
  53. assert self.total_num_heads % tensor_model_parallel_world_size == 0
  54. self.num_heads = (self.total_num_heads //
  55. tensor_model_parallel_world_size)
  56. self.query_key_value = QKVParallelLinear(
  57. config.hidden_size,
  58. self.head_size,
  59. self.total_num_heads,
  60. bias=self.bias,
  61. quant_config=quant_config,
  62. )
  63. self.dense = RowParallelLinear(
  64. config.hidden_size,
  65. config.hidden_size,
  66. bias=self.bias,
  67. quant_config=quant_config,
  68. )
  69. scaling = self.head_size**-0.5
  70. rotary_dim = int(self.head_size * config.rotary_pct)
  71. assert rotary_dim % 2 == 0
  72. rope_theta = getattr(config, "rope_theta", 10000)
  73. max_position_embeddings = getattr(config, "max_position_embeddings",
  74. 8192)
  75. self.rotary_emb = get_rope(
  76. self.head_size,
  77. rotary_dim=rotary_dim,
  78. max_position=max_position_embeddings,
  79. base=rope_theta,
  80. )
  81. self.attn = Attention(self.num_heads,
  82. self.head_size,
  83. scaling,
  84. cache_config=cache_config,
  85. quant_config=quant_config)
  86. def forward(
  87. self,
  88. position_ids: torch.Tensor,
  89. hidden_states: torch.Tensor,
  90. kv_cache: torch.Tensor,
  91. attn_metadata: AttentionMetadata,
  92. ) -> torch.Tensor:
  93. qkv, _ = self.query_key_value(hidden_states)
  94. q, k, v = qkv.chunk(chunks=3, dim=-1)
  95. q, k = self.rotary_emb(position_ids, q, k)
  96. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  97. output, _ = self.dense(attn_output)
  98. return output
  99. class GPTNeoXMLP(nn.Module):
  100. def __init__(
  101. self,
  102. config: GPTNeoXConfig,
  103. quant_config: Optional[QuantizationConfig] = None,
  104. ):
  105. super().__init__()
  106. self.dense_h_to_4h = ColumnParallelLinear(
  107. config.hidden_size,
  108. config.intermediate_size,
  109. quant_config=quant_config,
  110. )
  111. self.dense_4h_to_h = RowParallelLinear(
  112. config.intermediate_size,
  113. config.hidden_size,
  114. quant_config=quant_config,
  115. )
  116. self.act = get_act_fn(config.hidden_act, quant_config,
  117. config.intermediate_size)
  118. def forward(self, hidden_states):
  119. hidden_states, _ = self.dense_h_to_4h(hidden_states)
  120. hidden_states = self.act(hidden_states)
  121. hidden_states, _ = self.dense_4h_to_h(hidden_states)
  122. return hidden_states
  123. class GPTNeoXLayer(nn.Module):
  124. def __init__(
  125. self,
  126. config: GPTNeoXConfig,
  127. cache_config: Optional[CacheConfig] = None,
  128. quant_config: Optional[QuantizationConfig] = None,
  129. ):
  130. super().__init__()
  131. self.use_parallel_residual = config.use_parallel_residual
  132. self.input_layernorm = nn.LayerNorm(config.hidden_size,
  133. eps=config.layer_norm_eps)
  134. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
  135. eps=config.layer_norm_eps)
  136. self.attention = GPTNeoXAttention(config, cache_config, quant_config)
  137. self.mlp = GPTNeoXMLP(config, quant_config)
  138. def forward(
  139. self,
  140. position_ids: torch.Tensor,
  141. hidden_states: torch.Tensor,
  142. kv_cache: torch.Tensor,
  143. attn_metadata: AttentionMetadata,
  144. ) -> torch.Tensor:
  145. attn_input = self.input_layernorm(hidden_states)
  146. attn_output = self.attention(
  147. position_ids=position_ids,
  148. hidden_states=attn_input,
  149. kv_cache=kv_cache,
  150. attn_metadata=attn_metadata,
  151. )
  152. if self.use_parallel_residual:
  153. # pseudocode:
  154. # x = x + attn(ln1(x)) + mlp(ln2(x))
  155. mlp_input = self.post_attention_layernorm(hidden_states)
  156. mlp_output = self.mlp(mlp_input)
  157. hidden_states = mlp_output + attn_output + hidden_states
  158. else:
  159. # pseudocode:
  160. # x = x + attn(ln1(x))
  161. # x = x + mlp(ln2(x))
  162. attn_output = attn_output + hidden_states
  163. mlp_input = self.post_attention_layernorm(attn_output)
  164. mlp_output = self.mlp(mlp_input)
  165. hidden_states = mlp_output + attn_output
  166. return hidden_states
  167. class GPTNeoXModel(nn.Module):
  168. def __init__(
  169. self,
  170. config: GPTNeoXConfig,
  171. cache_config: Optional[CacheConfig] = None,
  172. quant_config: Optional[QuantizationConfig] = None,
  173. ):
  174. super().__init__()
  175. self.config = config
  176. self.embed_in = VocabParallelEmbedding(
  177. config.vocab_size,
  178. config.hidden_size,
  179. )
  180. self.layers = nn.ModuleList([
  181. GPTNeoXLayer(config, cache_config, quant_config)
  182. for _ in range(config.num_hidden_layers)
  183. ])
  184. self.final_layer_norm = nn.LayerNorm(config.hidden_size,
  185. eps=config.layer_norm_eps)
  186. def forward(
  187. self,
  188. input_ids: torch.Tensor,
  189. position_ids: torch.Tensor,
  190. kv_caches: List[torch.Tensor],
  191. attn_metadata: AttentionMetadata,
  192. ) -> torch.Tensor:
  193. hidden_states = self.embed_in(input_ids)
  194. for i in range(len(self.layers)):
  195. layer = self.layers[i]
  196. hidden_states = layer(
  197. position_ids,
  198. hidden_states,
  199. kv_caches[i],
  200. attn_metadata,
  201. )
  202. hidden_states = self.final_layer_norm(hidden_states)
  203. return hidden_states
  204. class GPTNeoXForCausalLM(nn.Module):
  205. def __init__(
  206. self,
  207. config,
  208. cache_config: Optional[CacheConfig] = None,
  209. quant_config: Optional[QuantizationConfig] = None,
  210. ):
  211. super().__init__()
  212. self.config = config
  213. self.quant_config = quant_config
  214. self.gpt_neox = GPTNeoXModel(config, cache_config, quant_config)
  215. self.embed_out = ParallelLMHead(
  216. config.vocab_size,
  217. config.hidden_size,
  218. quant_config=quant_config,
  219. )
  220. self.logits_processor = LogitsProcessor(config.vocab_size)
  221. self.sampler = Sampler()
  222. def forward(
  223. self,
  224. input_ids: torch.Tensor,
  225. positions: torch.Tensor,
  226. kv_caches: List[torch.Tensor],
  227. attn_metadata: AttentionMetadata,
  228. intermediate_tensors: Optional[IntermediateTensors] = None,
  229. ) -> torch.Tensor:
  230. hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
  231. attn_metadata)
  232. return hidden_states
  233. def compute_logits(
  234. self,
  235. hidden_states: torch.Tensor,
  236. sampling_metadata: SamplingMetadata,
  237. ) -> Optional[torch.Tensor]:
  238. logits = self.logits_processor(self.embed_out, hidden_states,
  239. sampling_metadata)
  240. return logits
  241. def sample(
  242. self,
  243. logits: torch.Tensor,
  244. sampling_metadata: SamplingMetadata,
  245. ) -> Optional[SamplerOutput]:
  246. next_tokens = self.sampler(logits, sampling_metadata)
  247. return next_tokens
  248. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  249. params_dict = dict(self.named_parameters())
  250. for name, loaded_weight in weights:
  251. if ("attention.bias" in name or "attention.masked_bias" in name
  252. or "rotary_emb.inv_freq" in name):
  253. continue
  254. if ("rotary_emb.cos_cached" in name
  255. or "rotary_emb.sin_cached" in name):
  256. # Models trained using OpenRLHF may include
  257. # these tensors in the checkpoint. Skip them.
  258. continue
  259. param = params_dict[name]
  260. if "query_key_value" in name:
  261. # NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
  262. # (num_heads * 3 * head_size), while the
  263. # required shape is (3 * num_heads * head_size).
  264. # Thus, we need weight conversion.
  265. output_dim = getattr(param, "output_dim", None)
  266. num_heads = self.config.num_attention_heads
  267. if output_dim is not None:
  268. loaded_weight_shape = loaded_weight.shape
  269. loaded_weight = loaded_weight.view(
  270. loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
  271. loaded_weight_shape[output_dim + 1:])
  272. loaded_weight = loaded_weight.transpose(
  273. output_dim, output_dim + 1)
  274. loaded_weight = loaded_weight.reshape(loaded_weight_shape)
  275. weight_loader = getattr(param, "weight_loader",
  276. default_weight_loader)
  277. weight_loader(param, loaded_weight)