gpt2.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
  6. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. """Inference-only GPT-2 model compatible with HuggingFace weights."""
  20. from typing import Iterable, List, Optional, Tuple
  21. import torch
  22. from torch import nn
  23. from transformers import GPT2Config
  24. from aphrodite.attention import Attention, AttentionMetadata
  25. from aphrodite.common.config import CacheConfig
  26. from aphrodite.common.sequence import IntermediateTensors
  27. from aphrodite.distributed import get_tensor_model_parallel_world_size
  28. from aphrodite.modeling.layers.activation import get_act_fn
  29. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  30. QKVParallelLinear,
  31. RowParallelLinear)
  32. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  33. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  34. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  35. VocabParallelEmbedding)
  36. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  37. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  38. from aphrodite.quantization.base_config import QuantizationConfig
  39. class GPT2Attention(nn.Module):
  40. def __init__(
  41. self,
  42. config: GPT2Config,
  43. cache_config: Optional[CacheConfig] = None,
  44. quant_config: Optional[QuantizationConfig] = None,
  45. ):
  46. super().__init__()
  47. self.hidden_size = config.hidden_size
  48. total_num_heads = config.num_attention_heads
  49. tensor_model_parallel_world_size = (
  50. get_tensor_model_parallel_world_size())
  51. assert total_num_heads % tensor_model_parallel_world_size == 0
  52. self.num_heads = total_num_heads // tensor_model_parallel_world_size
  53. self.head_dim = self.hidden_size // total_num_heads
  54. self.scale = self.head_dim**-0.5
  55. self.c_attn = QKVParallelLinear(
  56. self.hidden_size,
  57. self.head_dim,
  58. total_num_heads,
  59. bias=True,
  60. quant_config=quant_config,
  61. )
  62. self.c_proj = RowParallelLinear(
  63. self.hidden_size,
  64. self.hidden_size,
  65. bias=True,
  66. quant_config=quant_config,
  67. )
  68. self.attn = Attention(self.num_heads,
  69. self.head_dim,
  70. scale=self.scale,
  71. cache_config=cache_config,
  72. quant_config=quant_config)
  73. def forward(
  74. self,
  75. hidden_states: torch.Tensor,
  76. kv_cache: torch.Tensor,
  77. attn_metadata: AttentionMetadata,
  78. ) -> torch.Tensor:
  79. qkv, _ = self.c_attn(hidden_states)
  80. q, k, v = qkv.chunk(chunks=3, dim=-1)
  81. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  82. attn_output, _ = self.c_proj(attn_output)
  83. return attn_output
  84. class GPT2MLP(nn.Module):
  85. def __init__(
  86. self,
  87. intermediate_size: int,
  88. config: GPT2Config,
  89. quant_config: Optional[QuantizationConfig] = None,
  90. ):
  91. super().__init__()
  92. hidden_size = config.hidden_size
  93. self.c_fc = ColumnParallelLinear(
  94. hidden_size,
  95. intermediate_size,
  96. bias=True,
  97. quant_config=quant_config,
  98. )
  99. self.c_proj = RowParallelLinear(
  100. intermediate_size,
  101. hidden_size,
  102. bias=True,
  103. quant_config=quant_config,
  104. )
  105. self.act = get_act_fn(config.activation_function, quant_config,
  106. intermediate_size)
  107. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  108. hidden_states, _ = self.c_fc(hidden_states)
  109. hidden_states = self.act(hidden_states)
  110. hidden_states, _ = self.c_proj(hidden_states)
  111. return hidden_states
  112. class GPT2Block(nn.Module):
  113. def __init__(
  114. self,
  115. config: GPT2Config,
  116. cache_config: Optional[CacheConfig] = None,
  117. quant_config: Optional[QuantizationConfig] = None,
  118. ):
  119. super().__init__()
  120. hidden_size = config.hidden_size
  121. inner_dim = (config.n_inner if config.n_inner is not None else 4 *
  122. hidden_size)
  123. self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  124. self.attn = GPT2Attention(config, cache_config, quant_config)
  125. self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  126. self.mlp = GPT2MLP(inner_dim, config, quant_config)
  127. def forward(
  128. self,
  129. hidden_states: torch.Tensor,
  130. kv_cache: torch.Tensor,
  131. attn_metadata: AttentionMetadata,
  132. ) -> torch.Tensor:
  133. residual = hidden_states
  134. hidden_states = self.ln_1(hidden_states)
  135. attn_output = self.attn(
  136. hidden_states=hidden_states,
  137. kv_cache=kv_cache,
  138. attn_metadata=attn_metadata,
  139. )
  140. # residual connection
  141. hidden_states = attn_output + residual
  142. residual = hidden_states
  143. hidden_states = self.ln_2(hidden_states)
  144. feed_forward_hidden_states = self.mlp(hidden_states)
  145. # residual connection
  146. hidden_states = residual + feed_forward_hidden_states
  147. return hidden_states
  148. class GPT2Model(nn.Module):
  149. def __init__(
  150. self,
  151. config: GPT2Config,
  152. cache_config: Optional[CacheConfig] = None,
  153. quant_config: Optional[QuantizationConfig] = None,
  154. ):
  155. super().__init__()
  156. self.config = config
  157. assert not config.add_cross_attention
  158. assert not config.scale_attn_by_inverse_layer_idx
  159. assert not config.reorder_and_upcast_attn
  160. self.embed_dim = config.hidden_size
  161. self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
  162. self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
  163. self.h = nn.ModuleList([
  164. GPT2Block(config, cache_config, quant_config)
  165. for _ in range(config.num_hidden_layers)
  166. ])
  167. self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  168. def forward(
  169. self,
  170. input_ids: torch.Tensor,
  171. position_ids: torch.Tensor,
  172. kv_caches: List[torch.Tensor],
  173. attn_metadata: AttentionMetadata,
  174. ) -> torch.Tensor:
  175. inputs_embeds = self.wte(input_ids)
  176. position_embeds = self.wpe(position_ids)
  177. hidden_states = inputs_embeds + position_embeds
  178. for i in range(len(self.h)):
  179. layer = self.h[i]
  180. hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
  181. hidden_states = self.ln_f(hidden_states)
  182. return hidden_states
  183. class GPT2LMHeadModel(nn.Module):
  184. def __init__(
  185. self,
  186. config: GPT2Config,
  187. cache_config: Optional[CacheConfig] = None,
  188. quant_config: Optional[QuantizationConfig] = None,
  189. ):
  190. super().__init__()
  191. self.config = config
  192. self.quant_config = quant_config
  193. self.transformer = GPT2Model(config, cache_config, quant_config)
  194. self.lm_head = self.transformer.wte
  195. self.logits_processor = LogitsProcessor(config.vocab_size)
  196. self.sampler = Sampler()
  197. def forward(
  198. self,
  199. input_ids: torch.Tensor,
  200. positions: torch.Tensor,
  201. kv_caches: List[torch.Tensor],
  202. attn_metadata: AttentionMetadata,
  203. intermediate_tensors: Optional[IntermediateTensors] = None,
  204. ) -> torch.Tensor:
  205. hidden_states = self.transformer(input_ids, positions, kv_caches,
  206. attn_metadata)
  207. return hidden_states
  208. def compute_logits(
  209. self,
  210. hidden_states: torch.Tensor,
  211. sampling_metadata: SamplingMetadata,
  212. ) -> Optional[torch.Tensor]:
  213. logits = self.logits_processor(self.lm_head, hidden_states,
  214. sampling_metadata)
  215. return logits
  216. def sample(
  217. self,
  218. logits: torch.Tensor,
  219. sampling_metadata: SamplingMetadata,
  220. ) -> Optional[SamplerOutput]:
  221. next_tokens = self.sampler(logits, sampling_metadata)
  222. return next_tokens
  223. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  224. params_dict = dict(self.named_parameters(remove_duplicate=False))
  225. for name, loaded_weight in weights:
  226. if "lm_head.weight" in name:
  227. # GPT-2 ties the weights of the embedding layer and the final
  228. # linear layer.
  229. continue
  230. if ".attn.bias" in name or ".attn.masked_bias" in name:
  231. # Skip attention mask.
  232. # NOTE: "c_attn.bias" should not be skipped.
  233. continue
  234. if not name.startswith("transformer."):
  235. name = "transformer." + name
  236. param = params_dict[name]
  237. # The HF's GPT-2 implementation uses Conv1D instead of Linear.
  238. # Because of this, we need to transpose the weights.
  239. # Note(zhuohan): the logic below might break quantized models.
  240. for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
  241. if conv1d_weight_name not in name:
  242. continue
  243. if not name.endswith(".weight"):
  244. continue
  245. loaded_weight = loaded_weight.t()
  246. weight_loader = getattr(param, "weight_loader",
  247. default_weight_loader)
  248. weight_loader(param, loaded_weight)