gpt2.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
  7. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """Inference-only GPT-2 model compatible with HuggingFace weights."""
  21. from typing import List, Optional, Tuple
  22. import torch
  23. from torch import nn
  24. from transformers import GPT2Config
  25. from aphrodite.modeling.metadata import InputMetadata
  26. from aphrodite.modeling.layers.activation import get_act_fn
  27. from aphrodite.modeling.layers.attention import PagedAttention
  28. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  29. LinearMethodBase,
  30. QKVParallelLinear,
  31. RowParallelLinear)
  32. from aphrodite.modeling.layers.sampler import Sampler, QuantSampler
  33. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  34. VocabParallelEmbedding, ParallelLMHead)
  35. from aphrodite.modeling.megatron.parallel_state import (
  36. get_tensor_model_parallel_world_size)
  37. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  38. from aphrodite.modeling.hf_downloader import (default_weight_loader,
  39. hf_model_weights_iterator)
  40. from aphrodite.common.sequence import SamplerOutput
  41. KVCache = Tuple[torch.Tensor, torch.Tensor]
  42. class GPT2Attention(nn.Module):
  43. def __init__(
  44. self,
  45. config: GPT2Config,
  46. linear_method: Optional[LinearMethodBase] = None,
  47. ):
  48. super().__init__()
  49. self.hidden_size = config.hidden_size
  50. total_num_heads = config.num_attention_heads
  51. tensor_model_parallel_world_size = (
  52. get_tensor_model_parallel_world_size())
  53. assert total_num_heads % tensor_model_parallel_world_size == 0
  54. self.num_heads = total_num_heads // tensor_model_parallel_world_size
  55. self.head_dim = self.hidden_size // total_num_heads
  56. self.scale = self.head_dim**-0.5
  57. self.c_attn = QKVParallelLinear(
  58. self.hidden_size,
  59. self.head_dim,
  60. total_num_heads,
  61. bias=True,
  62. linear_method=linear_method,
  63. )
  64. self.c_proj = RowParallelLinear(
  65. self.hidden_size,
  66. self.hidden_size,
  67. bias=True,
  68. linear_method=linear_method,
  69. )
  70. self.attn = PagedAttention(self.num_heads,
  71. self.head_dim,
  72. scale=self.scale)
  73. def forward(
  74. self,
  75. hidden_states: torch.Tensor,
  76. kv_cache: KVCache,
  77. input_metadata: InputMetadata,
  78. ) -> torch.Tensor:
  79. qkv, _ = self.c_attn(hidden_states)
  80. q, k, v = qkv.chunk(chunks=3, dim=-1)
  81. key_cache, value_cache = kv_cache
  82. attn_output = self.attn(q, k, v, key_cache, value_cache,
  83. input_metadata)
  84. attn_output, _ = self.c_proj(attn_output)
  85. return attn_output
  86. class GPT2MLP(nn.Module):
  87. def __init__(
  88. self,
  89. intermediate_size: int,
  90. config: GPT2Config,
  91. linear_method: Optional[LinearMethodBase] = None,
  92. ):
  93. super().__init__()
  94. hidden_size = config.hidden_size
  95. self.c_fc = ColumnParallelLinear(
  96. hidden_size,
  97. intermediate_size,
  98. bias=True,
  99. linear_method=linear_method,
  100. )
  101. self.c_proj = RowParallelLinear(
  102. intermediate_size,
  103. hidden_size,
  104. bias=True,
  105. linear_method=linear_method,
  106. )
  107. quant_config = getattr(linear_method, "quant_config", None)
  108. self.act = get_act_fn(config.activation_function, quant_config,
  109. intermediate_size)
  110. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  111. hidden_states, _ = self.c_fc(hidden_states)
  112. hidden_states = self.act(hidden_states)
  113. hidden_states, _ = self.c_proj(hidden_states)
  114. return hidden_states
  115. class GPT2Block(nn.Module):
  116. def __init__(
  117. self,
  118. config: GPT2Config,
  119. linear_method: Optional[LinearMethodBase] = None,
  120. ):
  121. super().__init__()
  122. hidden_size = config.hidden_size
  123. inner_dim = (config.n_inner if config.n_inner is not None else 4 *
  124. hidden_size)
  125. self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  126. self.attn = GPT2Attention(config, linear_method)
  127. self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  128. self.mlp = GPT2MLP(inner_dim, config, linear_method)
  129. def forward(
  130. self,
  131. hidden_states: torch.Tensor,
  132. kv_cache: KVCache,
  133. input_metadata: InputMetadata,
  134. ) -> torch.Tensor:
  135. residual = hidden_states
  136. hidden_states = self.ln_1(hidden_states)
  137. attn_output = self.attn(
  138. hidden_states=hidden_states,
  139. kv_cache=kv_cache,
  140. input_metadata=input_metadata,
  141. )
  142. # residual connection
  143. hidden_states = attn_output + residual
  144. residual = hidden_states
  145. hidden_states = self.ln_2(hidden_states)
  146. feed_forward_hidden_states = self.mlp(hidden_states)
  147. # residual connection
  148. hidden_states = residual + feed_forward_hidden_states
  149. return hidden_states
  150. class GPT2Model(nn.Module):
  151. def __init__(
  152. self,
  153. config: GPT2Config,
  154. linear_method: Optional[LinearMethodBase] = None,
  155. ):
  156. super().__init__()
  157. self.config = config
  158. assert not config.add_cross_attention
  159. assert not config.scale_attn_by_inverse_layer_idx
  160. assert not config.reorder_and_upcast_attn
  161. self.embed_dim = config.hidden_size
  162. self.wte = VocabParallelEmbedding(config.vocab_size,
  163. self.embed_dim,
  164. linear_method=linear_method)
  165. self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
  166. self.h = nn.ModuleList([
  167. GPT2Block(config, linear_method)
  168. for _ in range(config.num_hidden_layers)
  169. ])
  170. self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  171. def forward(
  172. self,
  173. input_ids: torch.Tensor,
  174. position_ids: torch.Tensor,
  175. kv_caches: List[KVCache],
  176. input_metadata: InputMetadata,
  177. ) -> torch.Tensor:
  178. inputs_embeds = self.wte(input_ids)
  179. position_embeds = self.wpe(position_ids)
  180. hidden_states = inputs_embeds + position_embeds
  181. for i in range(len(self.h)):
  182. layer = self.h[i]
  183. hidden_states = layer(hidden_states, kv_caches[i], input_metadata)
  184. hidden_states = self.ln_f(hidden_states)
  185. return hidden_states
  186. class GPT2LMHeadModel(nn.Module):
  187. def __init__(
  188. self,
  189. config: GPT2Config,
  190. linear_method: Optional[LinearMethodBase] = None,
  191. ):
  192. super().__init__()
  193. self.config = config
  194. self.linear_method = linear_method
  195. self.transformer = GPT2Model(config, linear_method)
  196. # self.lm_head_weight = self.transformer.wte.weight
  197. self.lm_head = ParallelLMHead(config.vocab_size,
  198. config.hidden_size,
  199. linear_method=linear_method)
  200. self.sampler = Sampler(config.vocab_size)
  201. self.quant_sampler = QuantSampler(config.vocab_size)
  202. def forward(
  203. self,
  204. input_ids: torch.Tensor,
  205. positions: torch.Tensor,
  206. kv_caches: List[KVCache],
  207. input_metadata: InputMetadata,
  208. ) -> torch.Tensor:
  209. hidden_states = self.transformer(input_ids, positions, kv_caches,
  210. input_metadata)
  211. return hidden_states
  212. def sample(
  213. self,
  214. hidden_states: torch.Tensor,
  215. sampling_metadata: SamplingMetadata,
  216. ) -> Optional[SamplerOutput]:
  217. if (self.linear_method is not None
  218. and not self.linear_method.quant_config.merge_weight()):
  219. next_tokens = self.quant_sampler(self.lm_head(hidden_states),
  220. sampling_metadata)
  221. else:
  222. next_tokens = self.sampler(self.lm_head.weight, hidden_states,
  223. sampling_metadata)
  224. return next_tokens
  225. def load_weights(self,
  226. model_name_or_path: str,
  227. cache_dir: Optional[str] = None,
  228. load_format: str = "auto",
  229. revision: Optional[str] = None):
  230. params_dict = dict(self.named_parameters(remove_duplicate=False))
  231. for name, loaded_weight in hf_model_weights_iterator(
  232. model_name_or_path, cache_dir, load_format, revision,
  233. self.config):
  234. if "lm_head" in name and name not in params_dict:
  235. # GPT-2 ties the weights of the embedding layer and the final
  236. # linear layer.
  237. continue
  238. if "wte" in name:
  239. # Copy word embedding to lm_head
  240. head_name = name.replace("transformer.wte", "lm_head")
  241. if head_name in params_dict:
  242. lm_head_param = params_dict[head_name]
  243. weight_loader = getattr(lm_head_param, "weight_loader",
  244. default_weight_loader)
  245. weight_loader(lm_head_param, loaded_weight)
  246. if ".attn.bias" in name or ".attn.masked_bias" in name:
  247. # Skip attention mask.
  248. # NOTE: "c_attn.bias" should not be skipped.
  249. continue
  250. if not name.startswith("transformer."):
  251. name = "transformer." + name
  252. param = params_dict[name]
  253. # The HF's GPT-2 implementation uses Conv1D instead of Linear.
  254. # Because of this, we need to transpose the weights.
  255. # Note(zhuohan): the logic below might break quantized models.
  256. for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
  257. if conv1d_weight_name not in name:
  258. continue
  259. if not name.endswith(".weight"):
  260. continue
  261. loaded_weight = loaded_weight.t()
  262. weight_loader = getattr(param, "weight_loader",
  263. default_weight_loader)
  264. weight_loader(param, loaded_weight)