gpt_bigcode.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2023 CTranslate2, and Michael Feil
  6. # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
  7. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """Inference-only GPTBigCode model compatible with HuggingFace weights."""
  21. from typing import Iterable, List, Optional, Tuple
  22. import torch
  23. from torch import nn
  24. from transformers import GPTBigCodeConfig
  25. from aphrodite.attention import Attention, AttentionMetadata
  26. from aphrodite.common.sequence import SamplerOutput
  27. from aphrodite.distributed import get_tensor_model_parallel_world_size
  28. from aphrodite.modeling.layers.activation import get_act_fn
  29. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  30. QKVParallelLinear,
  31. RowParallelLinear)
  32. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  33. from aphrodite.modeling.layers.sampler import Sampler
  34. from aphrodite.modeling.layers.vocab_parallel_embedding import \
  35. VocabParallelEmbedding
  36. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  37. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  38. from aphrodite.quantization.base_config import QuantizationConfig
  39. class GPTBigCodeAttention(nn.Module):
  40. def __init__(
  41. self,
  42. config: GPTBigCodeConfig,
  43. quant_config: Optional[QuantizationConfig] = None,
  44. ):
  45. super().__init__()
  46. self.hidden_size = config.hidden_size
  47. total_num_heads = config.num_attention_heads
  48. self.tensor_model_parallel_world_size = (
  49. get_tensor_model_parallel_world_size())
  50. assert total_num_heads % self.tensor_model_parallel_world_size == 0
  51. self.num_heads = (total_num_heads //
  52. self.tensor_model_parallel_world_size)
  53. self.head_dim = self.hidden_size // total_num_heads
  54. self.scale = self.head_dim**-0.5
  55. self.multi_query = config.multi_query
  56. if self.multi_query:
  57. total_num_kv_heads = 1
  58. self.num_kv_heads = 1
  59. else:
  60. total_num_kv_heads = total_num_heads
  61. self.num_kv_heads = self.num_heads
  62. self.kv_dim = self.head_dim * self.num_kv_heads
  63. self.c_attn = QKVParallelLinear(
  64. self.hidden_size,
  65. self.head_dim,
  66. total_num_heads,
  67. total_num_kv_heads,
  68. bias=True,
  69. quant_config=quant_config,
  70. )
  71. self.c_proj = RowParallelLinear(
  72. self.hidden_size,
  73. self.hidden_size,
  74. bias=True,
  75. quant_config=quant_config,
  76. )
  77. self.attn = Attention(self.num_heads,
  78. self.head_dim,
  79. scale=self.scale,
  80. num_kv_heads=self.num_kv_heads)
  81. def forward(
  82. self,
  83. hidden_states: torch.Tensor,
  84. kv_cache: torch.Tensor,
  85. attn_metadata: AttentionMetadata,
  86. ) -> torch.Tensor:
  87. qkv, _ = self.c_attn(hidden_states)
  88. q, k, v = qkv.split(
  89. [
  90. self.hidden_size // self.tensor_model_parallel_world_size,
  91. self.kv_dim, self.kv_dim
  92. ],
  93. dim=-1,
  94. )
  95. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  96. attn_output, _ = self.c_proj(attn_output)
  97. return attn_output
  98. class GPTBigMLP(nn.Module):
  99. def __init__(
  100. self,
  101. intermediate_size: int,
  102. config: GPTBigCodeConfig,
  103. quant_config: Optional[QuantizationConfig] = None,
  104. ):
  105. super().__init__()
  106. hidden_size = config.hidden_size
  107. self.c_fc = ColumnParallelLinear(
  108. hidden_size,
  109. intermediate_size,
  110. bias=True,
  111. quant_config=quant_config,
  112. )
  113. self.c_proj = RowParallelLinear(
  114. intermediate_size,
  115. hidden_size,
  116. bias=True,
  117. quant_config=quant_config,
  118. )
  119. quant_config = getattr(quant_config, "quant_config", None)
  120. self.act = get_act_fn(config.activation_function, quant_config,
  121. intermediate_size)
  122. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  123. hidden_states, _ = self.c_fc(hidden_states)
  124. hidden_states = self.act(hidden_states)
  125. hidden_states, _ = self.c_proj(hidden_states)
  126. return hidden_states
  127. class GPTBigCodeBlock(nn.Module):
  128. def __init__(
  129. self,
  130. config: GPTBigCodeConfig,
  131. quant_config: Optional[QuantizationConfig] = None,
  132. ):
  133. super().__init__()
  134. hidden_size = config.hidden_size
  135. inner_dim = (config.n_inner if config.n_inner is not None else 4 *
  136. hidden_size)
  137. self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  138. self.attn = GPTBigCodeAttention(config, quant_config)
  139. self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  140. self.mlp = GPTBigMLP(inner_dim, config, quant_config)
  141. def forward(
  142. self,
  143. hidden_states: torch.Tensor,
  144. kv_cache: torch.Tensor,
  145. attn_metadata: AttentionMetadata,
  146. ) -> torch.Tensor:
  147. residual = hidden_states
  148. hidden_states = self.ln_1(hidden_states)
  149. attn_output = self.attn(
  150. hidden_states=hidden_states,
  151. kv_cache=kv_cache,
  152. attn_metadata=attn_metadata,
  153. )
  154. # residual connection
  155. hidden_states = attn_output + residual
  156. residual = hidden_states
  157. hidden_states = self.ln_2(hidden_states)
  158. feed_forward_hidden_states = self.mlp(hidden_states)
  159. # residual connection
  160. hidden_states = residual + feed_forward_hidden_states
  161. return hidden_states
  162. class GPTBigCodeModel(nn.Module):
  163. def __init__(
  164. self,
  165. config: GPTBigCodeConfig,
  166. quant_config: Optional[QuantizationConfig] = None,
  167. ):
  168. super().__init__()
  169. self.config = config
  170. assert not config.add_cross_attention
  171. self.embed_dim = config.hidden_size
  172. self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
  173. self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
  174. self.h = nn.ModuleList([
  175. GPTBigCodeBlock(config, quant_config)
  176. for _ in range(config.num_hidden_layers)
  177. ])
  178. self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  179. def forward(
  180. self,
  181. input_ids: torch.Tensor,
  182. position_ids: torch.Tensor,
  183. kv_caches: List[torch.Tensor],
  184. attn_metadata: AttentionMetadata,
  185. ) -> torch.Tensor:
  186. inputs_embeds = self.wte(input_ids)
  187. position_embeds = self.wpe(position_ids)
  188. hidden_states = inputs_embeds + position_embeds
  189. for i in range(len(self.h)):
  190. layer = self.h[i]
  191. hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
  192. hidden_states = self.ln_f(hidden_states)
  193. return hidden_states
  194. class GPTBigCodeForCausalLM(nn.Module):
  195. def __init__(
  196. self,
  197. config: GPTBigCodeConfig,
  198. quant_config: Optional[QuantizationConfig] = None,
  199. ):
  200. super().__init__()
  201. self.config = config
  202. self.quant_config = quant_config
  203. self.transformer = GPTBigCodeModel(config, quant_config)
  204. self.lm_head_weight = self.transformer.wte.weight
  205. self.logits_processor = LogitsProcessor(config.vocab_size)
  206. self.sampler = Sampler()
  207. def forward(
  208. self,
  209. input_ids: torch.Tensor,
  210. positions: torch.Tensor,
  211. kv_caches: List[torch.Tensor],
  212. attn_metadata: AttentionMetadata,
  213. ) -> torch.Tensor:
  214. hidden_states = self.transformer(input_ids, positions, kv_caches,
  215. attn_metadata)
  216. return hidden_states
  217. def compute_logits(self, hidden_states: torch.Tensor,
  218. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  219. logits = self.logits_processor(self.lm_head_weight, hidden_states,
  220. sampling_metadata)
  221. return logits
  222. def sample(
  223. self,
  224. logits: torch.Tensor,
  225. sampling_metadata: SamplingMetadata,
  226. ) -> Optional[SamplerOutput]:
  227. next_tokens = self.sampler(logits, sampling_metadata)
  228. return next_tokens
  229. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  230. params_dict = dict(self.named_parameters(remove_duplicate=False))
  231. for name, loaded_weight in weights:
  232. if "lm_head.weight" in name:
  233. continue
  234. if ".attn.bias" in name:
  235. # Skip attention mask.
  236. # NOTE: "c_attn.bias" should not be skipped.
  237. continue
  238. param = params_dict[name]
  239. weight_loader = getattr(param, "weight_loader",
  240. default_weight_loader)
  241. weight_loader(param, loaded_weight)