gpt_neox.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. """Inference-only GPT-NeoX model compatible with HuggingFace weights."""
  20. from typing import List, Optional
  21. import torch
  22. from torch import nn
  23. from transformers import GPTNeoXConfig
  24. from aphrodite.attention import Attention, AttentionMetadata
  25. from aphrodite.modeling.layers.activation import get_act_fn
  26. from aphrodite.modeling.layers.linear import (
  27. ColumnParallelLinear,
  28. LinearMethodBase,
  29. QKVParallelLinear,
  30. RowParallelLinear,
  31. )
  32. from aphrodite.modeling.layers.rotary_embedding import get_rope
  33. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  34. from aphrodite.modeling.layers.sampler import Sampler
  35. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  36. VocabParallelEmbedding,
  37. ParallelLMHead,
  38. )
  39. from aphrodite.distributed import (
  40. get_tensor_model_parallel_world_size, )
  41. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  42. from aphrodite.modeling.hf_downloader import (
  43. default_weight_loader,
  44. hf_model_weights_iterator,
  45. )
  46. from aphrodite.common.sequence import SamplerOutput
  47. class GPTNeoXAttention(nn.Module):
  48. def __init__(
  49. self,
  50. config: GPTNeoXConfig,
  51. linear_method: Optional[LinearMethodBase] = None,
  52. ):
  53. super().__init__()
  54. self.total_num_heads = config.num_attention_heads
  55. self.hidden_size = config.hidden_size
  56. self.head_size = self.hidden_size // self.total_num_heads
  57. self.bias = getattr(config, "attention_bias", True)
  58. tensor_model_parallel_world_size = (
  59. get_tensor_model_parallel_world_size())
  60. assert self.total_num_heads % tensor_model_parallel_world_size == 0
  61. self.num_heads = (self.total_num_heads //
  62. tensor_model_parallel_world_size)
  63. self.query_key_value = QKVParallelLinear(
  64. config.hidden_size,
  65. self.head_size,
  66. self.total_num_heads,
  67. bias=self.bias,
  68. linear_method=linear_method,
  69. )
  70. self.dense = RowParallelLinear(
  71. config.hidden_size,
  72. config.hidden_size,
  73. bias=self.bias,
  74. linear_method=linear_method,
  75. )
  76. scaling = self.head_size**-0.5
  77. rotary_dim = int(self.head_size * config.rotary_pct)
  78. assert rotary_dim % 2 == 0
  79. rope_theta = getattr(config, "rope_theta", 10000)
  80. max_position_embeddings = getattr(config, "max_position_embeddings",
  81. 8192)
  82. is_neox_style = (True if linear_method is None
  83. or linear_method.quant_config.rope_style() is None
  84. else linear_method.quant_config.rope_style())
  85. self.rotary_emb = get_rope(
  86. self.head_size,
  87. rotary_dim=rotary_dim,
  88. max_position=max_position_embeddings,
  89. base=rope_theta,
  90. is_neox_style=is_neox_style,
  91. )
  92. self.attn = Attention(self.num_heads, self.head_size, scaling)
  93. def forward(
  94. self,
  95. position_ids: torch.Tensor,
  96. hidden_states: torch.Tensor,
  97. kv_cache: torch.Tensor,
  98. attn_metadata: AttentionMetadata,
  99. ) -> torch.Tensor:
  100. qkv, _ = self.query_key_value(hidden_states)
  101. q, k, v = qkv.chunk(chunks=3, dim=-1)
  102. q, k = self.rotary_emb(position_ids, q, k)
  103. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  104. output, _ = self.dense(attn_output)
  105. return output
  106. class GPTNeoXMLP(nn.Module):
  107. def __init__(
  108. self,
  109. config: GPTNeoXConfig,
  110. linear_method: Optional[LinearMethodBase] = None,
  111. ):
  112. super().__init__()
  113. self.dense_h_to_4h = ColumnParallelLinear(
  114. config.hidden_size,
  115. config.intermediate_size,
  116. linear_method=linear_method,
  117. )
  118. self.dense_4h_to_h = RowParallelLinear(
  119. config.intermediate_size,
  120. config.hidden_size,
  121. linear_method=linear_method,
  122. )
  123. quant_config = getattr(linear_method, "quant_config", None)
  124. self.act = get_act_fn(config.hidden_act, quant_config,
  125. config.intermediate_size)
  126. def forward(self, hidden_states):
  127. hidden_states, _ = self.dense_h_to_4h(hidden_states)
  128. hidden_states = self.act(hidden_states)
  129. hidden_states, _ = self.dense_4h_to_h(hidden_states)
  130. return hidden_states
  131. class GPTNeoXLayer(nn.Module):
  132. def __init__(
  133. self,
  134. config: GPTNeoXConfig,
  135. linear_method: Optional[LinearMethodBase] = None,
  136. ):
  137. super().__init__()
  138. self.use_parallel_residual = config.use_parallel_residual
  139. self.input_layernorm = nn.LayerNorm(config.hidden_size,
  140. eps=config.layer_norm_eps)
  141. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
  142. eps=config.layer_norm_eps)
  143. self.attention = GPTNeoXAttention(config, linear_method)
  144. self.mlp = GPTNeoXMLP(config, linear_method)
  145. def forward(
  146. self,
  147. position_ids: torch.Tensor,
  148. hidden_states: torch.Tensor,
  149. kv_cache: torch.Tensor,
  150. attn_metadata: AttentionMetadata,
  151. ) -> torch.Tensor:
  152. attn_input = self.input_layernorm(hidden_states)
  153. attn_output = self.attention(
  154. position_ids=position_ids,
  155. hidden_states=attn_input,
  156. kv_cache=kv_cache,
  157. attn_metadata=attn_metadata,
  158. )
  159. if self.use_parallel_residual:
  160. # pseudocode:
  161. # x = x + attn(ln1(x)) + mlp(ln2(x))
  162. mlp_input = self.post_attention_layernorm(hidden_states)
  163. mlp_output = self.mlp(mlp_input)
  164. hidden_states = mlp_output + attn_output + hidden_states
  165. else:
  166. # pseudocode:
  167. # x = x + attn(ln1(x))
  168. # x = x + mlp(ln2(x))
  169. attn_output = attn_output + hidden_states
  170. mlp_input = self.post_attention_layernorm(attn_output)
  171. mlp_output = self.mlp(mlp_input)
  172. hidden_states = mlp_output + attn_output
  173. return hidden_states
  174. class GPTNeoXModel(nn.Module):
  175. def __init__(
  176. self,
  177. config: GPTNeoXConfig,
  178. linear_method: Optional[LinearMethodBase] = None,
  179. ):
  180. super().__init__()
  181. self.config = config
  182. self.embed_in = VocabParallelEmbedding(config.vocab_size,
  183. config.hidden_size,
  184. linear_method=linear_method)
  185. self.layers = nn.ModuleList([
  186. GPTNeoXLayer(config, linear_method)
  187. for _ in range(config.num_hidden_layers)
  188. ])
  189. self.final_layer_norm = nn.LayerNorm(config.hidden_size,
  190. eps=config.layer_norm_eps)
  191. def forward(
  192. self,
  193. input_ids: torch.Tensor,
  194. position_ids: torch.Tensor,
  195. kv_caches: List[torch.Tensor],
  196. attn_metadata: AttentionMetadata,
  197. ) -> torch.Tensor:
  198. hidden_states = self.embed_in(input_ids)
  199. for i in range(len(self.layers)):
  200. layer = self.layers[i]
  201. hidden_states = layer(
  202. position_ids,
  203. hidden_states,
  204. kv_caches[i],
  205. attn_metadata,
  206. )
  207. hidden_states = self.final_layer_norm(hidden_states)
  208. return hidden_states
  209. class GPTNeoXForCausalLM(nn.Module):
  210. def __init__(
  211. self,
  212. config,
  213. linear_method: Optional[LinearMethodBase] = None,
  214. ):
  215. super().__init__()
  216. self.config = config
  217. self.linear_method = linear_method
  218. self.gpt_neox = GPTNeoXModel(config, linear_method)
  219. self.embed_out = ParallelLMHead(config.vocab_size,
  220. config.hidden_size,
  221. linear_method=linear_method)
  222. self.logits_processor = LogitsProcessor(config.vocab_size,
  223. config.tokenizer_vocab_size)
  224. self.sampler = Sampler()
  225. def forward(
  226. self,
  227. input_ids: torch.Tensor,
  228. positions: torch.Tensor,
  229. kv_caches: List[torch.Tensor],
  230. attn_metadata: AttentionMetadata,
  231. ) -> torch.Tensor:
  232. hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
  233. attn_metadata)
  234. return hidden_states
  235. def compute_logits(self, hidden_states: torch.Tensor,
  236. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  237. logits = self.logits_processor(self.embed_out, hidden_states,
  238. sampling_metadata)
  239. return logits
  240. def sample(
  241. self,
  242. logits: torch.Tensor,
  243. sampling_metadata: SamplingMetadata,
  244. ) -> Optional[SamplerOutput]:
  245. next_tokens = self.sampler(logits, sampling_metadata)
  246. return next_tokens
  247. def load_weights(
  248. self,
  249. model_name_or_path: str,
  250. cache_dir: Optional[str] = None,
  251. load_format: str = "auto",
  252. revision: Optional[str] = None,
  253. ):
  254. params_dict = dict(self.named_parameters())
  255. for name, loaded_weight in hf_model_weights_iterator(
  256. model_name_or_path, cache_dir, load_format, revision,
  257. self.config):
  258. if ("attention.bias" in name or "attention.masked_bias" in name
  259. or "rotary_emb.inv_freq" in name):
  260. continue
  261. if ("rotary_emb.cos_cached" in name
  262. or "rotary_emb.sin_cached" in name):
  263. # Models trained using OpenRLHF may include
  264. # these tensors in the checkpoint. Skip them.
  265. continue
  266. param = params_dict[name]
  267. if "query_key_value" in name:
  268. # NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
  269. # (num_heads * 3 * head_size), while the
  270. # required shape is (3 * num_heads * head_size).
  271. # Thus, we need weight conversion.
  272. output_dim = getattr(param, "output_dim", None)
  273. num_heads = self.config.num_attention_heads
  274. if output_dim is not None:
  275. loaded_weight_shape = loaded_weight.shape
  276. loaded_weight = loaded_weight.view(
  277. loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
  278. loaded_weight_shape[output_dim + 1:])
  279. loaded_weight = loaded_weight.transpose(
  280. output_dim, output_dim + 1)
  281. loaded_weight = loaded_weight.reshape(loaded_weight_shape)
  282. weight_loader = getattr(param, "weight_loader",
  283. default_weight_loader)
  284. weight_loader(param, loaded_weight)