gpt2.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
  7. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """Inference-only GPT-2 model compatible with HuggingFace weights."""
  21. from typing import Iterable, List, Optional, Tuple
  22. import torch
  23. from torch import nn
  24. from transformers import GPT2Config
  25. from aphrodite.attention import Attention, AttentionMetadata
  26. from aphrodite.common.sequence import SamplerOutput
  27. from aphrodite.distributed import get_tensor_model_parallel_world_size
  28. from aphrodite.modeling.layers.activation import get_act_fn
  29. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  30. LinearMethodBase,
  31. QKVParallelLinear,
  32. RowParallelLinear)
  33. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  34. from aphrodite.modeling.layers.sampler import Sampler
  35. from aphrodite.modeling.layers.vocab_parallel_embedding import \
  36. VocabParallelEmbedding
  37. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  38. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  39. class GPT2Attention(nn.Module):
  40. def __init__(
  41. self,
  42. config: GPT2Config,
  43. linear_method: Optional[LinearMethodBase] = None,
  44. ):
  45. super().__init__()
  46. self.hidden_size = config.hidden_size
  47. total_num_heads = config.num_attention_heads
  48. tensor_model_parallel_world_size = (
  49. get_tensor_model_parallel_world_size())
  50. assert total_num_heads % tensor_model_parallel_world_size == 0
  51. self.num_heads = total_num_heads // tensor_model_parallel_world_size
  52. self.head_dim = self.hidden_size // total_num_heads
  53. self.scale = self.head_dim**-0.5
  54. self.c_attn = QKVParallelLinear(
  55. self.hidden_size,
  56. self.head_dim,
  57. total_num_heads,
  58. bias=True,
  59. linear_method=linear_method,
  60. )
  61. self.c_proj = RowParallelLinear(
  62. self.hidden_size,
  63. self.hidden_size,
  64. bias=True,
  65. linear_method=linear_method,
  66. )
  67. self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale)
  68. def forward(
  69. self,
  70. hidden_states: torch.Tensor,
  71. kv_cache: torch.Tensor,
  72. attn_metadata: AttentionMetadata,
  73. ) -> torch.Tensor:
  74. qkv, _ = self.c_attn(hidden_states)
  75. q, k, v = qkv.chunk(chunks=3, dim=-1)
  76. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  77. attn_output, _ = self.c_proj(attn_output)
  78. return attn_output
  79. class GPT2MLP(nn.Module):
  80. def __init__(
  81. self,
  82. intermediate_size: int,
  83. config: GPT2Config,
  84. linear_method: Optional[LinearMethodBase] = None,
  85. ):
  86. super().__init__()
  87. hidden_size = config.hidden_size
  88. self.c_fc = ColumnParallelLinear(
  89. hidden_size,
  90. intermediate_size,
  91. bias=True,
  92. linear_method=linear_method,
  93. )
  94. self.c_proj = RowParallelLinear(
  95. intermediate_size,
  96. hidden_size,
  97. bias=True,
  98. linear_method=linear_method,
  99. )
  100. quant_config = getattr(linear_method, "quant_config", None)
  101. self.act = get_act_fn(config.activation_function, quant_config,
  102. intermediate_size)
  103. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  104. hidden_states, _ = self.c_fc(hidden_states)
  105. hidden_states = self.act(hidden_states)
  106. hidden_states, _ = self.c_proj(hidden_states)
  107. return hidden_states
  108. class GPT2Block(nn.Module):
  109. def __init__(
  110. self,
  111. config: GPT2Config,
  112. linear_method: Optional[LinearMethodBase] = None,
  113. ):
  114. super().__init__()
  115. hidden_size = config.hidden_size
  116. inner_dim = (config.n_inner if config.n_inner is not None else 4 *
  117. hidden_size)
  118. self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  119. self.attn = GPT2Attention(config, linear_method)
  120. self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  121. self.mlp = GPT2MLP(inner_dim, config, linear_method)
  122. def forward(
  123. self,
  124. hidden_states: torch.Tensor,
  125. kv_cache: torch.Tensor,
  126. attn_metadata: AttentionMetadata,
  127. ) -> torch.Tensor:
  128. residual = hidden_states
  129. hidden_states = self.ln_1(hidden_states)
  130. attn_output = self.attn(
  131. hidden_states=hidden_states,
  132. kv_cache=kv_cache,
  133. attn_metadata=attn_metadata,
  134. )
  135. # residual connection
  136. hidden_states = attn_output + residual
  137. residual = hidden_states
  138. hidden_states = self.ln_2(hidden_states)
  139. feed_forward_hidden_states = self.mlp(hidden_states)
  140. # residual connection
  141. hidden_states = residual + feed_forward_hidden_states
  142. return hidden_states
  143. class GPT2Model(nn.Module):
  144. def __init__(
  145. self,
  146. config: GPT2Config,
  147. linear_method: Optional[LinearMethodBase] = None,
  148. ):
  149. super().__init__()
  150. self.config = config
  151. assert not config.add_cross_attention
  152. assert not config.scale_attn_by_inverse_layer_idx
  153. assert not config.reorder_and_upcast_attn
  154. self.embed_dim = config.hidden_size
  155. self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
  156. self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
  157. self.h = nn.ModuleList([
  158. GPT2Block(config, linear_method)
  159. for _ in range(config.num_hidden_layers)
  160. ])
  161. self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  162. def forward(
  163. self,
  164. input_ids: torch.Tensor,
  165. position_ids: torch.Tensor,
  166. kv_caches: List[torch.Tensor],
  167. attn_metadata: AttentionMetadata,
  168. ) -> torch.Tensor:
  169. inputs_embeds = self.wte(input_ids)
  170. position_embeds = self.wpe(position_ids)
  171. hidden_states = inputs_embeds + position_embeds
  172. for i in range(len(self.h)):
  173. layer = self.h[i]
  174. hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
  175. hidden_states = self.ln_f(hidden_states)
  176. return hidden_states
  177. class GPT2LMHeadModel(nn.Module):
  178. def __init__(
  179. self,
  180. config: GPT2Config,
  181. linear_method: Optional[LinearMethodBase] = None,
  182. ):
  183. super().__init__()
  184. self.config = config
  185. self.linear_method = linear_method
  186. self.transformer = GPT2Model(config, linear_method)
  187. self.lm_head_weight = self.transformer.wte.weight
  188. self.logits_processor = LogitsProcessor(config.vocab_size)
  189. self.sampler = Sampler()
  190. def forward(
  191. self,
  192. input_ids: torch.Tensor,
  193. positions: torch.Tensor,
  194. kv_caches: List[torch.Tensor],
  195. attn_metadata: AttentionMetadata,
  196. ) -> torch.Tensor:
  197. hidden_states = self.transformer(input_ids, positions, kv_caches,
  198. attn_metadata)
  199. return hidden_states
  200. def compute_logits(self, hidden_states: torch.Tensor,
  201. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  202. logits = self.logits_processor(self.lm_head_weight, hidden_states,
  203. sampling_metadata)
  204. return logits
  205. def sample(
  206. self,
  207. logits: torch.Tensor,
  208. sampling_metadata: SamplingMetadata,
  209. ) -> Optional[SamplerOutput]:
  210. next_tokens = self.sampler(logits, sampling_metadata)
  211. return next_tokens
  212. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  213. params_dict = dict(self.named_parameters(remove_duplicate=False))
  214. for name, loaded_weight in weights:
  215. if "lm_head.weight" in name:
  216. # GPT-2 ties the weights of the embedding layer and the final
  217. # linear layer.
  218. continue
  219. if ".attn.bias" in name or ".attn.masked_bias" in name:
  220. # Skip attention mask.
  221. # NOTE: "c_attn.bias" should not be skipped.
  222. continue
  223. if not name.startswith("transformer."):
  224. name = "transformer." + name
  225. param = params_dict[name]
  226. # The HF's GPT-2 implementation uses Conv1D instead of Linear.
  227. # Because of this, we need to transpose the weights.
  228. # Note(zhuohan): the logic below might break quantized models.
  229. for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
  230. if conv1d_weight_name not in name:
  231. continue
  232. if not name.endswith(".weight"):
  233. continue
  234. loaded_weight = loaded_weight.t()
  235. weight_loader = getattr(param, "weight_loader",
  236. default_weight_loader)
  237. weight_loader(param, loaded_weight)