gpt_j.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
  6. #
  7. # Licensed under the Apache License, Version 2.0 (the "License");
  8. # you may not use this file except in compliance with the License.
  9. # You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing, software
  14. # distributed under the License is distributed on an "AS IS" BASIS,
  15. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. # See the License for the specific language governing permissions and
  17. # limitations under the License.
  18. """Inference-only GPT-J model compatible with HuggingFace weights."""
  19. from typing import Iterable, List, Optional, Tuple
  20. import torch
  21. from torch import nn
  22. from transformers import GPTJConfig
  23. from aphrodite.attention import Attention, AttentionMetadata
  24. from aphrodite.common.sequence import SamplerOutput
  25. from aphrodite.distributed import get_tensor_model_parallel_world_size
  26. from aphrodite.modeling.layers.activation import get_act_fn
  27. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  28. QKVParallelLinear,
  29. RowParallelLinear)
  30. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  31. from aphrodite.modeling.layers.rotary_embedding import get_rope
  32. from aphrodite.modeling.layers.sampler import Sampler
  33. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  34. ParallelLMHead, VocabParallelEmbedding)
  35. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  36. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  37. from aphrodite.quantization.base_config import QuantizationConfig
  38. class GPTJAttention(nn.Module):
  39. def __init__(
  40. self,
  41. config: GPTJConfig,
  42. quant_config: Optional[QuantizationConfig] = None,
  43. ):
  44. super().__init__()
  45. self.total_num_heads = config.num_attention_heads
  46. self.hidden_size = config.hidden_size
  47. self.head_size = self.hidden_size // self.total_num_heads
  48. self.qkv_proj = QKVParallelLinear(
  49. config.hidden_size,
  50. self.head_size,
  51. self.total_num_heads,
  52. bias=False,
  53. quant_config=quant_config,
  54. )
  55. self.out_proj = RowParallelLinear(
  56. config.hidden_size,
  57. config.hidden_size,
  58. bias=False,
  59. quant_config=quant_config,
  60. )
  61. tp_world_size = get_tensor_model_parallel_world_size()
  62. assert self.total_num_heads % tp_world_size == 0
  63. self.num_heads = self.total_num_heads // tp_world_size
  64. scaling = self.head_size**-0.5
  65. assert getattr(config, "rotary", True)
  66. assert config.rotary_dim % 2 == 0
  67. rope_theta = getattr(config, "rope_theta", 10000)
  68. max_position_embeddings = getattr(config, "max_position_embeddings",
  69. 8192)
  70. self.rotary_emb = get_rope(
  71. self.head_size,
  72. rotary_dim=config.rotary_dim,
  73. max_position=max_position_embeddings,
  74. base=rope_theta,
  75. is_neox_style=False,
  76. )
  77. self.attn = Attention(self.num_heads, self.head_size, scaling)
  78. def forward(
  79. self,
  80. position_ids: torch.Tensor,
  81. hidden_states: torch.Tensor,
  82. kv_cache: torch.Tensor,
  83. attn_metadata: AttentionMetadata,
  84. ) -> torch.Tensor:
  85. qkv, _ = self.qkv_proj(hidden_states)
  86. q, k, v = qkv.chunk(chunks=3, dim=-1)
  87. q, k = self.rotary_emb(position_ids, q, k)
  88. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  89. attn_output, _ = self.out_proj(attn_output)
  90. return attn_output
  91. class GPTJMLP(nn.Module):
  92. def __init__(
  93. self,
  94. intermediate_size: int,
  95. config: GPTJConfig,
  96. quant_config: Optional[QuantizationConfig] = None,
  97. ):
  98. super().__init__()
  99. hidden_size = config.n_embd
  100. self.fc_in = ColumnParallelLinear(
  101. hidden_size,
  102. intermediate_size,
  103. quant_config=quant_config,
  104. )
  105. self.fc_out = RowParallelLinear(
  106. intermediate_size,
  107. hidden_size,
  108. quant_config=quant_config,
  109. )
  110. quant_config = getattr(quant_config, "quant_config", None)
  111. self.act = get_act_fn(config.activation_function, quant_config,
  112. intermediate_size)
  113. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  114. hidden_states, _ = self.fc_in(hidden_states)
  115. hidden_states = self.act(hidden_states)
  116. hidden_states, _ = self.fc_out(hidden_states)
  117. return hidden_states
  118. class GPTJBlock(nn.Module):
  119. def __init__(
  120. self,
  121. config: GPTJConfig,
  122. quant_config: Optional[QuantizationConfig] = None,
  123. ):
  124. super().__init__()
  125. inner_dim = (4 * config.n_embd
  126. if config.n_inner is None else config.n_inner)
  127. self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
  128. self.attn = GPTJAttention(config, quant_config)
  129. self.mlp = GPTJMLP(inner_dim, config, quant_config)
  130. def forward(
  131. self,
  132. position_ids: torch.Tensor,
  133. hidden_states: torch.Tensor,
  134. kv_cache: torch.Tensor,
  135. attn_metadata: AttentionMetadata,
  136. ) -> torch.Tensor:
  137. residual = hidden_states
  138. hidden_states = self.ln_1(hidden_states)
  139. attn_output = self.attn(
  140. position_ids=position_ids,
  141. hidden_states=hidden_states,
  142. kv_cache=kv_cache,
  143. attn_metadata=attn_metadata,
  144. )
  145. mlp_output = self.mlp(hidden_states)
  146. hidden_states = attn_output + mlp_output + residual
  147. return hidden_states
  148. class GPTJModel(nn.Module):
  149. def __init__(
  150. self,
  151. config: GPTJConfig,
  152. quant_config: Optional[QuantizationConfig] = None,
  153. ):
  154. super().__init__()
  155. self.config = config
  156. self.embed_dim = config.n_embd
  157. self.wte = VocabParallelEmbedding(
  158. config.vocab_size,
  159. self.embed_dim,
  160. )
  161. self.h = nn.ModuleList(
  162. [GPTJBlock(config, quant_config) for _ in range(config.n_layer)])
  163. self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  164. def forward(
  165. self,
  166. input_ids: torch.Tensor,
  167. position_ids: torch.Tensor,
  168. kv_caches: List[torch.Tensor],
  169. attn_metadata: AttentionMetadata,
  170. ) -> torch.Tensor:
  171. hidden_states = self.wte(input_ids)
  172. for i in range(len(self.h)):
  173. layer = self.h[i]
  174. hidden_states = layer(
  175. position_ids,
  176. hidden_states,
  177. kv_caches[i],
  178. attn_metadata,
  179. )
  180. hidden_states = self.ln_f(hidden_states)
  181. return hidden_states
  182. class GPTJForCausalLM(nn.Module):
  183. def __init__(
  184. self,
  185. config: GPTJConfig,
  186. quant_config: Optional[QuantizationConfig] = None,
  187. ):
  188. super().__init__()
  189. self.config = config
  190. self.quant_config = quant_config
  191. assert not config.tie_word_embeddings
  192. self.transformer = GPTJModel(config, quant_config)
  193. self.lm_head = ParallelLMHead(
  194. config.vocab_size,
  195. config.n_embd,
  196. bias=True,
  197. )
  198. self.logits_processor = LogitsProcessor(config.vocab_size)
  199. self.sampler = Sampler()
  200. def forward(
  201. self,
  202. input_ids: torch.Tensor,
  203. positions: torch.Tensor,
  204. kv_caches: List[torch.Tensor],
  205. attn_metadata: AttentionMetadata,
  206. ) -> torch.Tensor:
  207. hidden_states = self.transformer(input_ids, positions, kv_caches,
  208. attn_metadata)
  209. return hidden_states
  210. def compute_logits(self, hidden_states: torch.Tensor,
  211. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  212. logits = self.logits_processor(self.lm_head.weight, hidden_states,
  213. sampling_metadata, self.lm_head.bias)
  214. return logits
  215. def sample(
  216. self,
  217. logits: torch.Tensor,
  218. sampling_metadata: SamplingMetadata,
  219. ) -> Optional[SamplerOutput]:
  220. next_tokens = self.sampler(logits, sampling_metadata)
  221. return next_tokens
  222. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  223. stacked_params_mapping = [
  224. # (param_name, shard_name, shard_id)
  225. ("qkv_proj", "q_proj", "q"),
  226. ("qkv_proj", "k_proj", "k"),
  227. ("qkv_proj", "v_proj", "v"),
  228. ("gate_up_proj", "gate_proj", 0),
  229. ("gate_up_proj", "up_proj", 1),
  230. ]
  231. params_dict = dict(self.named_parameters())
  232. for name, loaded_weight in weights:
  233. if "attn.bias" in name or "attn.masked_bias" in name:
  234. continue
  235. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  236. if weight_name not in name:
  237. continue
  238. name = name.replace(weight_name, param_name)
  239. # Skip loading extra bias for GPTQ models.
  240. if name.endswith(".bias") and name not in params_dict:
  241. continue
  242. param = params_dict[name]
  243. weight_loader = param.weight_loader
  244. weight_loader(param, loaded_weight, shard_id)
  245. break
  246. else:
  247. # Skip loading extra bias for GPTQ models.
  248. if name.endswith(".bias") and name not in params_dict:
  249. continue
  250. param = params_dict[name]
  251. weight_loader = getattr(param, "weight_loader",
  252. default_weight_loader)
  253. weight_loader(param, loaded_weight)