1
0

gpt_bigcode.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2023 CTranslate2, and Michael Feil
  6. # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
  7. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """Inference-only GPTBigCode model compatible with HuggingFace weights."""
  21. from typing import Iterable, List, Optional, Tuple
  22. import torch
  23. from torch import nn
  24. from transformers import GPTBigCodeConfig
  25. from aphrodite.attention import Attention, AttentionMetadata
  26. from aphrodite.common.config import CacheConfig
  27. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  28. from aphrodite.common.utils import progress_bar
  29. from aphrodite.distributed import get_tensor_model_parallel_world_size
  30. from aphrodite.modeling.layers.activation import get_act_fn
  31. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  32. QKVParallelLinear,
  33. RowParallelLinear)
  34. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  35. from aphrodite.modeling.layers.sampler import Sampler
  36. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  37. VocabParallelEmbedding)
  38. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  39. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  40. from aphrodite.quantization.base_config import QuantizationConfig
  41. class GPTBigCodeAttention(nn.Module):
  42. def __init__(
  43. self,
  44. config: GPTBigCodeConfig,
  45. cache_config: Optional[CacheConfig] = None,
  46. quant_config: Optional[QuantizationConfig] = None,
  47. ):
  48. super().__init__()
  49. self.hidden_size = config.hidden_size
  50. total_num_heads = config.num_attention_heads
  51. self.tensor_model_parallel_world_size = (
  52. get_tensor_model_parallel_world_size())
  53. assert total_num_heads % self.tensor_model_parallel_world_size == 0
  54. self.num_heads = (total_num_heads //
  55. self.tensor_model_parallel_world_size)
  56. self.head_dim = self.hidden_size // total_num_heads
  57. self.scale = self.head_dim**-0.5
  58. self.multi_query = config.multi_query
  59. if self.multi_query:
  60. total_num_kv_heads = 1
  61. self.num_kv_heads = 1
  62. else:
  63. total_num_kv_heads = total_num_heads
  64. self.num_kv_heads = self.num_heads
  65. self.kv_dim = self.head_dim * self.num_kv_heads
  66. self.c_attn = QKVParallelLinear(
  67. self.hidden_size,
  68. self.head_dim,
  69. total_num_heads,
  70. total_num_kv_heads,
  71. bias=True,
  72. quant_config=quant_config,
  73. )
  74. self.c_proj = RowParallelLinear(
  75. self.hidden_size,
  76. self.hidden_size,
  77. bias=True,
  78. quant_config=quant_config,
  79. )
  80. self.attn = Attention(self.num_heads,
  81. self.head_dim,
  82. scale=self.scale,
  83. num_kv_heads=self.num_kv_heads,
  84. cache_config=cache_config,
  85. quant_config=quant_config)
  86. def forward(
  87. self,
  88. hidden_states: torch.Tensor,
  89. kv_cache: torch.Tensor,
  90. attn_metadata: AttentionMetadata,
  91. ) -> torch.Tensor:
  92. qkv, _ = self.c_attn(hidden_states)
  93. q, k, v = qkv.split(
  94. [
  95. self.hidden_size // self.tensor_model_parallel_world_size,
  96. self.kv_dim, self.kv_dim
  97. ],
  98. dim=-1,
  99. )
  100. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  101. attn_output, _ = self.c_proj(attn_output)
  102. return attn_output
  103. class GPTBigMLP(nn.Module):
  104. def __init__(
  105. self,
  106. intermediate_size: int,
  107. config: GPTBigCodeConfig,
  108. quant_config: Optional[QuantizationConfig] = None,
  109. ):
  110. super().__init__()
  111. hidden_size = config.hidden_size
  112. self.c_fc = ColumnParallelLinear(
  113. hidden_size,
  114. intermediate_size,
  115. bias=True,
  116. quant_config=quant_config,
  117. )
  118. self.c_proj = RowParallelLinear(
  119. intermediate_size,
  120. hidden_size,
  121. bias=True,
  122. quant_config=quant_config,
  123. )
  124. self.act = get_act_fn(config.activation_function, quant_config,
  125. intermediate_size)
  126. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  127. hidden_states, _ = self.c_fc(hidden_states)
  128. hidden_states = self.act(hidden_states)
  129. hidden_states, _ = self.c_proj(hidden_states)
  130. return hidden_states
  131. class GPTBigCodeBlock(nn.Module):
  132. def __init__(
  133. self,
  134. config: GPTBigCodeConfig,
  135. cache_config: Optional[CacheConfig] = None,
  136. quant_config: Optional[QuantizationConfig] = None,
  137. ):
  138. super().__init__()
  139. hidden_size = config.hidden_size
  140. inner_dim = (config.n_inner if config.n_inner is not None else 4 *
  141. hidden_size)
  142. self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  143. self.attn = GPTBigCodeAttention(config, cache_config, quant_config)
  144. self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  145. self.mlp = GPTBigMLP(inner_dim, config, quant_config)
  146. def forward(
  147. self,
  148. hidden_states: torch.Tensor,
  149. kv_cache: torch.Tensor,
  150. attn_metadata: AttentionMetadata,
  151. ) -> torch.Tensor:
  152. residual = hidden_states
  153. hidden_states = self.ln_1(hidden_states)
  154. attn_output = self.attn(
  155. hidden_states=hidden_states,
  156. kv_cache=kv_cache,
  157. attn_metadata=attn_metadata,
  158. )
  159. # residual connection
  160. hidden_states = attn_output + residual
  161. residual = hidden_states
  162. hidden_states = self.ln_2(hidden_states)
  163. feed_forward_hidden_states = self.mlp(hidden_states)
  164. # residual connection
  165. hidden_states = residual + feed_forward_hidden_states
  166. return hidden_states
  167. class GPTBigCodeModel(nn.Module):
  168. def __init__(
  169. self,
  170. config: GPTBigCodeConfig,
  171. cache_config: Optional[CacheConfig] = None,
  172. quant_config: Optional[QuantizationConfig] = None,
  173. ):
  174. super().__init__()
  175. self.config = config
  176. assert not config.add_cross_attention
  177. self.embed_dim = config.hidden_size
  178. self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
  179. self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
  180. self.h = nn.ModuleList([
  181. GPTBigCodeBlock(config, cache_config, quant_config)
  182. for _ in range(config.num_hidden_layers)
  183. ])
  184. self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  185. def forward(
  186. self,
  187. input_ids: torch.Tensor,
  188. position_ids: torch.Tensor,
  189. kv_caches: List[torch.Tensor],
  190. attn_metadata: AttentionMetadata,
  191. ) -> torch.Tensor:
  192. inputs_embeds = self.wte(input_ids)
  193. position_embeds = self.wpe(position_ids)
  194. hidden_states = inputs_embeds + position_embeds
  195. for i in range(len(self.h)):
  196. layer = self.h[i]
  197. hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
  198. hidden_states = self.ln_f(hidden_states)
  199. return hidden_states
  200. class GPTBigCodeForCausalLM(nn.Module):
  201. def __init__(
  202. self,
  203. config: GPTBigCodeConfig,
  204. cache_config: Optional[CacheConfig] = None,
  205. quant_config: Optional[QuantizationConfig] = None,
  206. ):
  207. super().__init__()
  208. self.config = config
  209. self.quant_config = quant_config
  210. self.transformer = GPTBigCodeModel(config, cache_config, quant_config)
  211. self.lm_head = self.transformer.wte
  212. self.logits_processor = LogitsProcessor(config.vocab_size)
  213. self.sampler = Sampler()
  214. def forward(
  215. self,
  216. input_ids: torch.Tensor,
  217. positions: torch.Tensor,
  218. kv_caches: List[torch.Tensor],
  219. attn_metadata: AttentionMetadata,
  220. intermediate_tensors: Optional[IntermediateTensors] = None,
  221. ) -> torch.Tensor:
  222. hidden_states = self.transformer(input_ids, positions, kv_caches,
  223. attn_metadata)
  224. return hidden_states
  225. def compute_logits(
  226. self,
  227. hidden_states: torch.Tensor,
  228. sampling_metadata: SamplingMetadata,
  229. ) -> Optional[torch.Tensor]:
  230. logits = self.logits_processor(self.lm_head, hidden_states,
  231. sampling_metadata)
  232. return logits
  233. def sample(
  234. self,
  235. logits: torch.Tensor,
  236. sampling_metadata: SamplingMetadata,
  237. ) -> Optional[SamplerOutput]:
  238. next_tokens = self.sampler(logits, sampling_metadata)
  239. return next_tokens
  240. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  241. params_dict = dict(self.named_parameters(remove_duplicate=False))
  242. weights_list = list(weights)
  243. for name, loaded_weight in progress_bar(weights_list,
  244. desc="Loading modules..."):
  245. if "lm_head.weight" in name:
  246. continue
  247. if ".attn.bias" in name:
  248. # Skip attention mask.
  249. # NOTE: "c_attn.bias" should not be skipped.
  250. continue
  251. param = params_dict[name]
  252. weight_loader = getattr(param, "weight_loader",
  253. default_weight_loader)
  254. weight_loader(param, loaded_weight)