gpt_bigcode.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2023 CTranslate2, and Michael Feil
  6. # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
  7. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """Inference-only GPTBigCode model compatible with HuggingFace weights."""
  21. from typing import List, Optional
  22. import torch
  23. from torch import nn
  24. from transformers import GPTBigCodeConfig
  25. from aphrodite.attention import Attention, AttentionMetadata
  26. from aphrodite.modeling.layers.activation import get_act_fn
  27. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  28. LinearMethodBase,
  29. QKVParallelLinear,
  30. RowParallelLinear)
  31. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  32. from aphrodite.modeling.layers.sampler import Sampler
  33. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  34. VocabParallelEmbedding, ParallelLMHead)
  35. from aphrodite.distributed import (get_tensor_model_parallel_world_size)
  36. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  37. from aphrodite.modeling.hf_downloader import (default_weight_loader,
  38. hf_model_weights_iterator)
  39. from aphrodite.common.sequence import SamplerOutput
  40. class GPTBigCodeAttention(nn.Module):
  41. def __init__(
  42. self,
  43. config: GPTBigCodeConfig,
  44. linear_method: Optional[LinearMethodBase] = None,
  45. ):
  46. super().__init__()
  47. self.hidden_size = config.hidden_size
  48. total_num_heads = config.num_attention_heads
  49. self.tensor_model_parallel_world_size = (
  50. get_tensor_model_parallel_world_size())
  51. assert total_num_heads % self.tensor_model_parallel_world_size == 0
  52. self.num_heads = (total_num_heads //
  53. self.tensor_model_parallel_world_size)
  54. self.head_dim = self.hidden_size // total_num_heads
  55. self.scale = self.head_dim**-0.5
  56. self.multi_query = config.multi_query
  57. if self.multi_query:
  58. total_num_kv_heads = 1
  59. self.num_kv_heads = 1
  60. else:
  61. total_num_kv_heads = total_num_heads
  62. self.num_kv_heads = self.num_heads
  63. self.kv_dim = self.head_dim * self.num_kv_heads
  64. self.c_attn = QKVParallelLinear(
  65. self.hidden_size,
  66. self.head_dim,
  67. total_num_heads,
  68. total_num_kv_heads,
  69. bias=True,
  70. linear_method=linear_method,
  71. )
  72. self.c_proj = RowParallelLinear(
  73. self.hidden_size,
  74. self.hidden_size,
  75. bias=True,
  76. linear_method=linear_method,
  77. )
  78. self.attn = Attention(self.num_heads,
  79. self.head_dim,
  80. scale=self.scale,
  81. num_kv_heads=self.num_kv_heads)
  82. def forward(
  83. self,
  84. hidden_states: torch.Tensor,
  85. kv_cache: torch.Tensor,
  86. attn_metadata: AttentionMetadata,
  87. ) -> torch.Tensor:
  88. qkv, _ = self.c_attn(hidden_states)
  89. q, k, v = qkv.split(
  90. [
  91. self.hidden_size // self.tensor_model_parallel_world_size,
  92. self.kv_dim, self.kv_dim
  93. ],
  94. dim=-1,
  95. )
  96. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  97. attn_output, _ = self.c_proj(attn_output)
  98. return attn_output
  99. class GPTBigMLP(nn.Module):
  100. def __init__(
  101. self,
  102. intermediate_size: int,
  103. config: GPTBigCodeConfig,
  104. linear_method: Optional[LinearMethodBase] = None,
  105. ):
  106. super().__init__()
  107. hidden_size = config.hidden_size
  108. self.c_fc = ColumnParallelLinear(
  109. hidden_size,
  110. intermediate_size,
  111. bias=True,
  112. linear_method=linear_method,
  113. )
  114. self.c_proj = RowParallelLinear(
  115. intermediate_size,
  116. hidden_size,
  117. bias=True,
  118. linear_method=linear_method,
  119. )
  120. quant_config = getattr(linear_method, "quant_config", None)
  121. self.act = get_act_fn(config.activation_function, quant_config,
  122. intermediate_size)
  123. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  124. hidden_states, _ = self.c_fc(hidden_states)
  125. hidden_states = self.act(hidden_states)
  126. hidden_states, _ = self.c_proj(hidden_states)
  127. return hidden_states
  128. class GPTBigCodeBlock(nn.Module):
  129. def __init__(
  130. self,
  131. config: GPTBigCodeConfig,
  132. linear_method: Optional[LinearMethodBase] = None,
  133. ):
  134. super().__init__()
  135. hidden_size = config.hidden_size
  136. inner_dim = (config.n_inner if config.n_inner is not None else 4 *
  137. hidden_size)
  138. self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  139. self.attn = GPTBigCodeAttention(config, linear_method)
  140. self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  141. self.mlp = GPTBigMLP(inner_dim, config, linear_method)
  142. def forward(
  143. self,
  144. hidden_states: torch.Tensor,
  145. kv_cache: torch.Tensor,
  146. attn_metadata: AttentionMetadata,
  147. ) -> torch.Tensor:
  148. residual = hidden_states
  149. hidden_states = self.ln_1(hidden_states)
  150. attn_output = self.attn(
  151. hidden_states=hidden_states,
  152. kv_cache=kv_cache,
  153. attn_metadata=attn_metadata,
  154. )
  155. # residual connection
  156. hidden_states = attn_output + residual
  157. residual = hidden_states
  158. hidden_states = self.ln_2(hidden_states)
  159. feed_forward_hidden_states = self.mlp(hidden_states)
  160. # residual connection
  161. hidden_states = residual + feed_forward_hidden_states
  162. return hidden_states
  163. class GPTBigCodeModel(nn.Module):
  164. def __init__(
  165. self,
  166. config: GPTBigCodeConfig,
  167. linear_method: Optional[LinearMethodBase] = None,
  168. ):
  169. super().__init__()
  170. self.config = config
  171. assert not config.add_cross_attention
  172. self.embed_dim = config.hidden_size
  173. self.wte = VocabParallelEmbedding(config.vocab_size,
  174. self.embed_dim,
  175. linear_method=linear_method)
  176. self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
  177. self.h = nn.ModuleList([
  178. GPTBigCodeBlock(config, linear_method)
  179. for _ in range(config.num_hidden_layers)
  180. ])
  181. self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  182. def forward(
  183. self,
  184. input_ids: torch.Tensor,
  185. position_ids: torch.Tensor,
  186. kv_caches: List[torch.Tensor],
  187. attn_metadata: AttentionMetadata,
  188. ) -> torch.Tensor:
  189. inputs_embeds = self.wte(input_ids)
  190. position_embeds = self.wpe(position_ids)
  191. hidden_states = inputs_embeds + position_embeds
  192. for i in range(len(self.h)):
  193. layer = self.h[i]
  194. hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
  195. hidden_states = self.ln_f(hidden_states)
  196. return hidden_states
  197. class GPTBigCodeForCausalLM(nn.Module):
  198. def __init__(
  199. self,
  200. config: GPTBigCodeConfig,
  201. linear_method: Optional[LinearMethodBase] = None,
  202. ):
  203. super().__init__()
  204. self.config = config
  205. self.linear_method = linear_method
  206. self.transformer = GPTBigCodeModel(config, linear_method)
  207. # self.lm_head_weight = self.transformer.wte.weight
  208. self.lm_head = ParallelLMHead(config.vocab_size,
  209. config.hidden_size,
  210. linear_method=linear_method)
  211. self.logits_processor = LogitsProcessor(config.vocab_size)
  212. self.sampler = Sampler()
  213. def forward(
  214. self,
  215. input_ids: torch.Tensor,
  216. positions: torch.Tensor,
  217. kv_caches: List[torch.Tensor],
  218. attn_metadata: AttentionMetadata,
  219. ) -> torch.Tensor:
  220. hidden_states = self.transformer(input_ids, positions, kv_caches,
  221. attn_metadata)
  222. return hidden_states
  223. def compute_logits(self, hidden_states: torch.Tensor,
  224. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  225. logits = self.logits_processor(self.lm_head, hidden_states,
  226. sampling_metadata)
  227. return logits
  228. def sample(
  229. self,
  230. logits: torch.Tensor,
  231. sampling_metadata: SamplingMetadata,
  232. ) -> Optional[SamplerOutput]:
  233. next_tokens = self.sampler(logits, sampling_metadata)
  234. return next_tokens
  235. def load_weights(self,
  236. model_name_or_path: str,
  237. cache_dir: Optional[str] = None,
  238. load_format: str = "auto",
  239. revision: Optional[str] = None):
  240. params_dict = dict(self.named_parameters(remove_duplicate=False))
  241. for name, loaded_weight in hf_model_weights_iterator(
  242. model_name_or_path, cache_dir, load_format, revision,
  243. self.config):
  244. if "lm_head" in name and name not in params_dict:
  245. continue
  246. if "wte" in name:
  247. # Copy word embedding to lm_head
  248. head_name = name.replace("transformer.wte", "lm_head")
  249. if head_name in params_dict:
  250. lm_head_param = params_dict[head_name]
  251. weight_loader = getattr(lm_head_param, "weight_loader",
  252. default_weight_loader)
  253. weight_loader(lm_head_param, loaded_weight)
  254. if ".attn.bias" in name:
  255. # Skip attention mask.
  256. # NOTE: "c_attn.bias" should not be skipped.
  257. continue
  258. param = params_dict[name]
  259. weight_loader = getattr(param, "weight_loader",
  260. default_weight_loader)
  261. weight_loader(param, loaded_weight)