gpt2.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
  7. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """Inference-only GPT-2 model compatible with HuggingFace weights."""
  21. from typing import List, Optional
  22. import torch
  23. from torch import nn
  24. from transformers import GPT2Config
  25. from aphrodite.attention import Attention, AttentionMetadata
  26. from aphrodite.modeling.layers.activation import get_act_fn
  27. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  28. LinearMethodBase,
  29. QKVParallelLinear,
  30. RowParallelLinear)
  31. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  32. from aphrodite.modeling.layers.sampler import Sampler
  33. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  34. VocabParallelEmbedding, ParallelLMHead)
  35. from aphrodite.distributed import (get_tensor_model_parallel_world_size)
  36. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  37. from aphrodite.modeling.hf_downloader import (default_weight_loader,
  38. hf_model_weights_iterator)
  39. from aphrodite.common.sequence import SamplerOutput
  40. class GPT2Attention(nn.Module):
  41. def __init__(
  42. self,
  43. config: GPT2Config,
  44. linear_method: Optional[LinearMethodBase] = None,
  45. ):
  46. super().__init__()
  47. self.hidden_size = config.hidden_size
  48. total_num_heads = config.num_attention_heads
  49. tensor_model_parallel_world_size = (
  50. get_tensor_model_parallel_world_size())
  51. assert total_num_heads % tensor_model_parallel_world_size == 0
  52. self.num_heads = total_num_heads // tensor_model_parallel_world_size
  53. self.head_dim = self.hidden_size // total_num_heads
  54. self.scale = self.head_dim**-0.5
  55. self.c_attn = QKVParallelLinear(
  56. self.hidden_size,
  57. self.head_dim,
  58. total_num_heads,
  59. bias=True,
  60. linear_method=linear_method,
  61. )
  62. self.c_proj = RowParallelLinear(
  63. self.hidden_size,
  64. self.hidden_size,
  65. bias=True,
  66. linear_method=linear_method,
  67. )
  68. self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale)
  69. def forward(
  70. self,
  71. hidden_states: torch.Tensor,
  72. kv_cache: torch.Tensor,
  73. attn_metadata: AttentionMetadata,
  74. ) -> torch.Tensor:
  75. qkv, _ = self.c_attn(hidden_states)
  76. q, k, v = qkv.chunk(chunks=3, dim=-1)
  77. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  78. attn_output, _ = self.c_proj(attn_output)
  79. return attn_output
  80. class GPT2MLP(nn.Module):
  81. def __init__(
  82. self,
  83. intermediate_size: int,
  84. config: GPT2Config,
  85. linear_method: Optional[LinearMethodBase] = None,
  86. ):
  87. super().__init__()
  88. hidden_size = config.hidden_size
  89. self.c_fc = ColumnParallelLinear(
  90. hidden_size,
  91. intermediate_size,
  92. bias=True,
  93. linear_method=linear_method,
  94. )
  95. self.c_proj = RowParallelLinear(
  96. intermediate_size,
  97. hidden_size,
  98. bias=True,
  99. linear_method=linear_method,
  100. )
  101. quant_config = getattr(linear_method, "quant_config", None)
  102. self.act = get_act_fn(config.activation_function, quant_config,
  103. intermediate_size)
  104. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  105. hidden_states, _ = self.c_fc(hidden_states)
  106. hidden_states = self.act(hidden_states)
  107. hidden_states, _ = self.c_proj(hidden_states)
  108. return hidden_states
  109. class GPT2Block(nn.Module):
  110. def __init__(
  111. self,
  112. config: GPT2Config,
  113. linear_method: Optional[LinearMethodBase] = None,
  114. ):
  115. super().__init__()
  116. hidden_size = config.hidden_size
  117. inner_dim = (config.n_inner if config.n_inner is not None else 4 *
  118. hidden_size)
  119. self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  120. self.attn = GPT2Attention(config, linear_method)
  121. self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  122. self.mlp = GPT2MLP(inner_dim, config, linear_method)
  123. def forward(
  124. self,
  125. hidden_states: torch.Tensor,
  126. kv_cache: torch.Tensor,
  127. attn_metadata: AttentionMetadata,
  128. ) -> torch.Tensor:
  129. residual = hidden_states
  130. hidden_states = self.ln_1(hidden_states)
  131. attn_output = self.attn(
  132. hidden_states=hidden_states,
  133. kv_cache=kv_cache,
  134. attn_metadata=attn_metadata,
  135. )
  136. # residual connection
  137. hidden_states = attn_output + residual
  138. residual = hidden_states
  139. hidden_states = self.ln_2(hidden_states)
  140. feed_forward_hidden_states = self.mlp(hidden_states)
  141. # residual connection
  142. hidden_states = residual + feed_forward_hidden_states
  143. return hidden_states
  144. class GPT2Model(nn.Module):
  145. def __init__(
  146. self,
  147. config: GPT2Config,
  148. linear_method: Optional[LinearMethodBase] = None,
  149. ):
  150. super().__init__()
  151. self.config = config
  152. assert not config.add_cross_attention
  153. assert not config.scale_attn_by_inverse_layer_idx
  154. assert not config.reorder_and_upcast_attn
  155. self.embed_dim = config.hidden_size
  156. self.wte = VocabParallelEmbedding(config.vocab_size,
  157. self.embed_dim,
  158. linear_method=linear_method)
  159. self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
  160. self.h = nn.ModuleList([
  161. GPT2Block(config, linear_method)
  162. for _ in range(config.num_hidden_layers)
  163. ])
  164. self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  165. def forward(
  166. self,
  167. input_ids: torch.Tensor,
  168. position_ids: torch.Tensor,
  169. kv_caches: List[torch.Tensor],
  170. attn_metadata: AttentionMetadata,
  171. ) -> torch.Tensor:
  172. inputs_embeds = self.wte(input_ids)
  173. position_embeds = self.wpe(position_ids)
  174. hidden_states = inputs_embeds + position_embeds
  175. for i in range(len(self.h)):
  176. layer = self.h[i]
  177. hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
  178. hidden_states = self.ln_f(hidden_states)
  179. return hidden_states
  180. class GPT2LMHeadModel(nn.Module):
  181. def __init__(
  182. self,
  183. config: GPT2Config,
  184. linear_method: Optional[LinearMethodBase] = None,
  185. ):
  186. super().__init__()
  187. self.config = config
  188. self.linear_method = linear_method
  189. self.transformer = GPT2Model(config, linear_method)
  190. # self.lm_head_weight = self.transformer.wte.weight
  191. self.lm_head = ParallelLMHead(config.vocab_size,
  192. config.hidden_size,
  193. linear_method=linear_method)
  194. self.logits_processor = LogitsProcessor(config.vocab_size)
  195. self.sampler = Sampler()
  196. def forward(
  197. self,
  198. input_ids: torch.Tensor,
  199. positions: torch.Tensor,
  200. kv_caches: List[torch.Tensor],
  201. attn_metadata: AttentionMetadata,
  202. ) -> torch.Tensor:
  203. hidden_states = self.transformer(input_ids, positions, kv_caches,
  204. attn_metadata)
  205. return hidden_states
  206. def compute_logits(self, hidden_states: torch.Tensor,
  207. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  208. logits = self.logits_processor(self.lm_head, hidden_states,
  209. sampling_metadata)
  210. return logits
  211. def sample(
  212. self,
  213. logits: torch.Tensor,
  214. sampling_metadata: SamplingMetadata,
  215. ) -> Optional[SamplerOutput]:
  216. next_tokens = self.sampler(logits, sampling_metadata)
  217. return next_tokens
  218. def load_weights(self,
  219. model_name_or_path: str,
  220. cache_dir: Optional[str] = None,
  221. load_format: str = "auto",
  222. revision: Optional[str] = None):
  223. params_dict = dict(self.named_parameters(remove_duplicate=False))
  224. for name, loaded_weight in hf_model_weights_iterator(
  225. model_name_or_path, cache_dir, load_format, revision,
  226. self.config):
  227. if "lm_head" in name and name not in params_dict:
  228. # GPT-2 ties the weights of the embedding layer and the final
  229. # linear layer.
  230. continue
  231. if "wte" in name:
  232. # Copy word embedding to lm_head
  233. head_name = name.replace("transformer.wte", "lm_head")
  234. if head_name in params_dict:
  235. lm_head_param = params_dict[head_name]
  236. weight_loader = getattr(lm_head_param, "weight_loader",
  237. default_weight_loader)
  238. weight_loader(lm_head_param, loaded_weight)
  239. if ".attn.bias" in name or ".attn.masked_bias" in name:
  240. # Skip attention mask.
  241. # NOTE: "c_attn.bias" should not be skipped.
  242. continue
  243. if not name.startswith("transformer."):
  244. name = "transformer." + name
  245. param = params_dict[name]
  246. # The HF's GPT-2 implementation uses Conv1D instead of Linear.
  247. # Because of this, we need to transpose the weights.
  248. # Note(zhuohan): the logic below might break quantized models.
  249. for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
  250. if conv1d_weight_name not in name:
  251. continue
  252. if not name.endswith(".weight"):
  253. continue
  254. loaded_weight = loaded_weight.t()
  255. weight_loader = getattr(param, "weight_loader",
  256. default_weight_loader)
  257. weight_loader(param, loaded_weight)