1
0

gpt2.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
  6. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. """Inference-only GPT-2 model compatible with HuggingFace weights."""
  20. from typing import Iterable, List, Optional, Tuple
  21. import torch
  22. from torch import nn
  23. from transformers import GPT2Config
  24. from aphrodite.attention import Attention, AttentionMetadata
  25. from aphrodite.common.sequence import SamplerOutput
  26. from aphrodite.distributed import get_tensor_model_parallel_world_size
  27. from aphrodite.modeling.layers.activation import get_act_fn
  28. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  29. QKVParallelLinear,
  30. RowParallelLinear)
  31. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  32. from aphrodite.modeling.layers.sampler import Sampler
  33. from aphrodite.modeling.layers.vocab_parallel_embedding import \
  34. VocabParallelEmbedding
  35. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  36. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  37. from aphrodite.quantization.base_config import QuantizationConfig
  38. class GPT2Attention(nn.Module):
  39. def __init__(
  40. self,
  41. config: GPT2Config,
  42. quant_config: Optional[QuantizationConfig] = None,
  43. ):
  44. super().__init__()
  45. self.hidden_size = config.hidden_size
  46. total_num_heads = config.num_attention_heads
  47. tensor_model_parallel_world_size = (
  48. get_tensor_model_parallel_world_size())
  49. assert total_num_heads % tensor_model_parallel_world_size == 0
  50. self.num_heads = total_num_heads // tensor_model_parallel_world_size
  51. self.head_dim = self.hidden_size // total_num_heads
  52. self.scale = self.head_dim**-0.5
  53. self.c_attn = QKVParallelLinear(
  54. self.hidden_size,
  55. self.head_dim,
  56. total_num_heads,
  57. bias=True,
  58. quant_config=quant_config,
  59. )
  60. self.c_proj = RowParallelLinear(
  61. self.hidden_size,
  62. self.hidden_size,
  63. bias=True,
  64. quant_config=quant_config,
  65. )
  66. self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale)
  67. def forward(
  68. self,
  69. hidden_states: torch.Tensor,
  70. kv_cache: torch.Tensor,
  71. attn_metadata: AttentionMetadata,
  72. ) -> torch.Tensor:
  73. qkv, _ = self.c_attn(hidden_states)
  74. q, k, v = qkv.chunk(chunks=3, dim=-1)
  75. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  76. attn_output, _ = self.c_proj(attn_output)
  77. return attn_output
  78. class GPT2MLP(nn.Module):
  79. def __init__(
  80. self,
  81. intermediate_size: int,
  82. config: GPT2Config,
  83. quant_config: Optional[QuantizationConfig] = None,
  84. ):
  85. super().__init__()
  86. hidden_size = config.hidden_size
  87. self.c_fc = ColumnParallelLinear(
  88. hidden_size,
  89. intermediate_size,
  90. bias=True,
  91. quant_config=quant_config,
  92. )
  93. self.c_proj = RowParallelLinear(
  94. intermediate_size,
  95. hidden_size,
  96. bias=True,
  97. quant_config=quant_config,
  98. )
  99. quant_config = getattr(quant_config, "quant_config", None)
  100. self.act = get_act_fn(config.activation_function, quant_config,
  101. intermediate_size)
  102. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  103. hidden_states, _ = self.c_fc(hidden_states)
  104. hidden_states = self.act(hidden_states)
  105. hidden_states, _ = self.c_proj(hidden_states)
  106. return hidden_states
  107. class GPT2Block(nn.Module):
  108. def __init__(
  109. self,
  110. config: GPT2Config,
  111. quant_config: Optional[QuantizationConfig] = None,
  112. ):
  113. super().__init__()
  114. hidden_size = config.hidden_size
  115. inner_dim = (config.n_inner if config.n_inner is not None else 4 *
  116. hidden_size)
  117. self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  118. self.attn = GPT2Attention(config, quant_config)
  119. self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  120. self.mlp = GPT2MLP(inner_dim, config, quant_config)
  121. def forward(
  122. self,
  123. hidden_states: torch.Tensor,
  124. kv_cache: torch.Tensor,
  125. attn_metadata: AttentionMetadata,
  126. ) -> torch.Tensor:
  127. residual = hidden_states
  128. hidden_states = self.ln_1(hidden_states)
  129. attn_output = self.attn(
  130. hidden_states=hidden_states,
  131. kv_cache=kv_cache,
  132. attn_metadata=attn_metadata,
  133. )
  134. # residual connection
  135. hidden_states = attn_output + residual
  136. residual = hidden_states
  137. hidden_states = self.ln_2(hidden_states)
  138. feed_forward_hidden_states = self.mlp(hidden_states)
  139. # residual connection
  140. hidden_states = residual + feed_forward_hidden_states
  141. return hidden_states
  142. class GPT2Model(nn.Module):
  143. def __init__(
  144. self,
  145. config: GPT2Config,
  146. quant_config: Optional[QuantizationConfig] = None,
  147. ):
  148. super().__init__()
  149. self.config = config
  150. assert not config.add_cross_attention
  151. assert not config.scale_attn_by_inverse_layer_idx
  152. assert not config.reorder_and_upcast_attn
  153. self.embed_dim = config.hidden_size
  154. self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
  155. self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
  156. self.h = nn.ModuleList([
  157. GPT2Block(config, quant_config)
  158. for _ in range(config.num_hidden_layers)
  159. ])
  160. self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  161. def forward(
  162. self,
  163. input_ids: torch.Tensor,
  164. position_ids: torch.Tensor,
  165. kv_caches: List[torch.Tensor],
  166. attn_metadata: AttentionMetadata,
  167. ) -> torch.Tensor:
  168. inputs_embeds = self.wte(input_ids)
  169. position_embeds = self.wpe(position_ids)
  170. hidden_states = inputs_embeds + position_embeds
  171. for i in range(len(self.h)):
  172. layer = self.h[i]
  173. hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
  174. hidden_states = self.ln_f(hidden_states)
  175. return hidden_states
  176. class GPT2LMHeadModel(nn.Module):
  177. def __init__(
  178. self,
  179. config: GPT2Config,
  180. quant_config: Optional[QuantizationConfig] = None,
  181. ):
  182. super().__init__()
  183. self.config = config
  184. self.quant_config = quant_config
  185. self.transformer = GPT2Model(config, quant_config)
  186. self.lm_head_weight = self.transformer.wte.weight
  187. self.logits_processor = LogitsProcessor(config.vocab_size)
  188. self.sampler = Sampler()
  189. def forward(
  190. self,
  191. input_ids: torch.Tensor,
  192. positions: torch.Tensor,
  193. kv_caches: List[torch.Tensor],
  194. attn_metadata: AttentionMetadata,
  195. ) -> torch.Tensor:
  196. hidden_states = self.transformer(input_ids, positions, kv_caches,
  197. attn_metadata)
  198. return hidden_states
  199. def compute_logits(self, hidden_states: torch.Tensor,
  200. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  201. logits = self.logits_processor(self.lm_head_weight, hidden_states,
  202. sampling_metadata)
  203. return logits
  204. def sample(
  205. self,
  206. logits: torch.Tensor,
  207. sampling_metadata: SamplingMetadata,
  208. ) -> Optional[SamplerOutput]:
  209. next_tokens = self.sampler(logits, sampling_metadata)
  210. return next_tokens
  211. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  212. params_dict = dict(self.named_parameters(remove_duplicate=False))
  213. for name, loaded_weight in weights:
  214. if "lm_head.weight" in name:
  215. # GPT-2 ties the weights of the embedding layer and the final
  216. # linear layer.
  217. continue
  218. if ".attn.bias" in name or ".attn.masked_bias" in name:
  219. # Skip attention mask.
  220. # NOTE: "c_attn.bias" should not be skipped.
  221. continue
  222. if not name.startswith("transformer."):
  223. name = "transformer." + name
  224. param = params_dict[name]
  225. # The HF's GPT-2 implementation uses Conv1D instead of Linear.
  226. # Because of this, we need to transpose the weights.
  227. # Note(zhuohan): the logic below might break quantized models.
  228. for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
  229. if conv1d_weight_name not in name:
  230. continue
  231. if not name.endswith(".weight"):
  232. continue
  233. loaded_weight = loaded_weight.t()
  234. weight_loader = getattr(param, "weight_loader",
  235. default_weight_loader)
  236. weight_loader(param, loaded_weight)