1
0

gpt_bigcode.py 11 KB


  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2023 CTranslate2, and Michael Feil
  7. # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
  8. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. """Inference-only GPTBigCode model compatible with HuggingFace weights."""
  22. from typing import List, Optional, Tuple
  23. import torch
  24. from torch import nn
  25. from transformers import GPTBigCodeConfig
  26. from aphrodite.modeling.metadata import InputMetadata
  27. from aphrodite.modeling.layers.activation import get_act_fn
  28. from aphrodite.modeling.layers.attention import PagedAttention
  29. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  30. LinearMethodBase,
  31. QKVParallelLinear,
  32. RowParallelLinear)
  33. from aphrodite.modeling.layers.sampler import Sampler, QuantSampler
  34. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  35. VocabParallelEmbedding, ParallelLMHead)
  36. from aphrodite.modeling.megatron.parallel_state import (
  37. get_tensor_model_parallel_world_size)
  38. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  39. from aphrodite.modeling.hf_downloader import (default_weight_loader,
  40. hf_model_weights_iterator)
  41. from aphrodite.common.sequence import SamplerOutput
  42. KVCache = Tuple[torch.Tensor, torch.Tensor]
  43. class GPTBigCodeAttention(nn.Module):
  44. def __init__(
  45. self,
  46. config: GPTBigCodeConfig,
  47. linear_method: Optional[LinearMethodBase] = None,
  48. ):
  49. super().__init__()
  50. self.hidden_size = config.hidden_size
  51. total_num_heads = config.num_attention_heads
  52. self.tensor_model_parallel_world_size = (
  53. get_tensor_model_parallel_world_size())
  54. assert total_num_heads % self.tensor_model_parallel_world_size == 0
  55. self.num_heads = (total_num_heads //
  56. self.tensor_model_parallel_world_size)
  57. self.head_dim = self.hidden_size // total_num_heads
  58. self.scale = self.head_dim**-0.5
  59. self.multi_query = config.multi_query
  60. if self.multi_query:
  61. total_num_kv_heads = 1
  62. self.num_kv_heads = 1
  63. else:
  64. total_num_kv_heads = total_num_heads
  65. self.num_kv_heads = self.num_heads
  66. self.kv_dim = self.head_dim * self.num_kv_heads
  67. self.c_attn = QKVParallelLinear(
  68. self.hidden_size,
  69. self.head_dim,
  70. total_num_heads,
  71. total_num_kv_heads,
  72. bias=True,
  73. linear_method=linear_method,
  74. )
  75. self.c_proj = RowParallelLinear(
  76. self.hidden_size,
  77. self.hidden_size,
  78. bias=True,
  79. linear_method=linear_method,
  80. )
  81. self.attn = PagedAttention(self.num_heads,
  82. self.head_dim,
  83. scale=self.scale,
  84. num_kv_heads=self.num_kv_heads)
  85. def forward(
  86. self,
  87. hidden_states: torch.Tensor,
  88. kv_cache: KVCache,
  89. input_metadata: InputMetadata,
  90. ) -> torch.Tensor:
  91. qkv, _ = self.c_attn(hidden_states)
  92. q, k, v = qkv.split(
  93. [
  94. self.hidden_size // self.tensor_model_parallel_world_size,
  95. self.kv_dim, self.kv_dim
  96. ],
  97. dim=-1,
  98. )
  99. key_cache, value_cache = kv_cache
  100. attn_output = self.attn(q, k, v, key_cache, value_cache,
  101. input_metadata)
  102. attn_output, _ = self.c_proj(attn_output)
  103. return attn_output
  104. class GPTBigMLP(nn.Module):
  105. def __init__(
  106. self,
  107. intermediate_size: int,
  108. config: GPTBigCodeConfig,
  109. linear_method: Optional[LinearMethodBase] = None,
  110. ):
  111. super().__init__()
  112. hidden_size = config.hidden_size
  113. self.c_fc = ColumnParallelLinear(
  114. hidden_size,
  115. intermediate_size,
  116. bias=True,
  117. linear_method=linear_method,
  118. )
  119. self.c_proj = RowParallelLinear(
  120. intermediate_size,
  121. hidden_size,
  122. bias=True,
  123. linear_method=linear_method,
  124. )
  125. quant_config = getattr(linear_method, "quant_config", None)
  126. self.act = get_act_fn(config.activation_function, quant_config,
  127. intermediate_size)
  128. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  129. hidden_states, _ = self.c_fc(hidden_states)
  130. hidden_states = self.act(hidden_states)
  131. hidden_states, _ = self.c_proj(hidden_states)
  132. return hidden_states
  133. class GPTBigCodeBlock(nn.Module):
  134. def __init__(
  135. self,
  136. config: GPTBigCodeConfig,
  137. linear_method: Optional[LinearMethodBase] = None,
  138. ):
  139. super().__init__()
  140. hidden_size = config.hidden_size
  141. inner_dim = (config.n_inner if config.n_inner is not None else 4 *
  142. hidden_size)
  143. self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  144. self.attn = GPTBigCodeAttention(config, linear_method)
  145. self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  146. self.mlp = GPTBigMLP(inner_dim, config, linear_method)
  147. def forward(
  148. self,
  149. hidden_states: torch.Tensor,
  150. kv_cache: KVCache,
  151. input_metadata: InputMetadata,
  152. ) -> torch.Tensor:
  153. residual = hidden_states
  154. hidden_states = self.ln_1(hidden_states)
  155. attn_output = self.attn(
  156. hidden_states=hidden_states,
  157. kv_cache=kv_cache,
  158. input_metadata=input_metadata,
  159. )
  160. # residual connection
  161. hidden_states = attn_output + residual
  162. residual = hidden_states
  163. hidden_states = self.ln_2(hidden_states)
  164. feed_forward_hidden_states = self.mlp(hidden_states)
  165. # residual connection
  166. hidden_states = residual + feed_forward_hidden_states
  167. return hidden_states
  168. class GPTBigCodeModel(nn.Module):
  169. def __init__(
  170. self,
  171. config: GPTBigCodeConfig,
  172. linear_method: Optional[LinearMethodBase] = None,
  173. ):
  174. super().__init__()
  175. self.config = config
  176. assert not config.add_cross_attention
  177. self.embed_dim = config.hidden_size
  178. self.wte = VocabParallelEmbedding(config.vocab_size,
  179. self.embed_dim,
  180. linear_method=linear_method)
  181. self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
  182. self.h = nn.ModuleList([
  183. GPTBigCodeBlock(config, linear_method)
  184. for _ in range(config.num_hidden_layers)
  185. ])
  186. self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  187. def forward(
  188. self,
  189. input_ids: torch.Tensor,
  190. position_ids: torch.Tensor,
  191. kv_caches: List[KVCache],
  192. input_metadata: InputMetadata,
  193. ) -> torch.Tensor:
  194. inputs_embeds = self.wte(input_ids)
  195. position_embeds = self.wpe(position_ids)
  196. hidden_states = inputs_embeds + position_embeds
  197. for i in range(len(self.h)):
  198. layer = self.h[i]
  199. hidden_states = layer(hidden_states, kv_caches[i], input_metadata)
  200. hidden_states = self.ln_f(hidden_states)
  201. return hidden_states
  202. class GPTBigCodeForCausalLM(nn.Module):
  203. def __init__(
  204. self,
  205. config: GPTBigCodeConfig,
  206. linear_method: Optional[LinearMethodBase] = None,
  207. ):
  208. super().__init__()
  209. self.config = config
  210. self.linear_method = linear_method
  211. self.transformer = GPTBigCodeModel(config, linear_method)
  212. # self.lm_head_weight = self.transformer.wte.weight
  213. self.lm_head = ParallelLMHead(config.vocab_size,
  214. config.hidden_size,
  215. linear_method=linear_method)
  216. self.sampler = Sampler(config.vocab_size)
  217. self.quant_sampler = QuantSampler(config.vocab_size)
  218. def forward(
  219. self,
  220. input_ids: torch.Tensor,
  221. positions: torch.Tensor,
  222. kv_caches: List[KVCache],
  223. input_metadata: InputMetadata,
  224. ) -> torch.Tensor:
  225. hidden_states = self.transformer(input_ids, positions, kv_caches,
  226. input_metadata)
  227. return hidden_states
  228. def sample(
  229. self,
  230. hidden_states: torch.Tensor,
  231. sampling_metadata: SamplingMetadata,
  232. ) -> Optional[SamplerOutput]:
  233. if (self.linear_method is not None
  234. and not self.linear_method.quant_config.merge_weight()):
  235. next_tokens = self.quant_sampler(self.lm_head(hidden_states),
  236. sampling_metadata)
  237. else:
  238. next_tokens = self.sampler(self.lm_head.weight, hidden_states,
  239. sampling_metadata)
  240. return next_tokens
  241. def load_weights(self,
  242. model_name_or_path: str,
  243. cache_dir: Optional[str] = None,
  244. load_format: str = "auto",
  245. revision: Optional[str] = None):
  246. params_dict = dict(self.named_parameters(remove_duplicate=False))
  247. for name, loaded_weight in hf_model_weights_iterator(
  248. model_name_or_path, cache_dir, load_format, revision,
  249. self.config):
  250. if "lm_head" in name and name not in params_dict:
  251. continue
  252. if "wte" in name:
  253. # Copy word embedding to lm_head
  254. head_name = name.replace("transformer.wte", "lm_head")
  255. if head_name in params_dict:
  256. lm_head_param = params_dict[head_name]
  257. weight_loader = getattr(lm_head_param, "weight_loader",
  258. default_weight_loader)
  259. weight_loader(lm_head_param, loaded_weight)
  260. if ".attn.bias" in name:
  261. # Skip attention mask.
  262. # NOTE: "c_attn.bias" should not be skipped.
  263. continue
  264. param = params_dict[name]
  265. weight_loader = getattr(param, "weight_loader",
  266. default_weight_loader)
  267. weight_loader(param, loaded_weight)