gpt_j.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
  6. #
  7. # Licensed under the Apache License, Version 2.0 (the "License");
  8. # you may not use this file except in compliance with the License.
  9. # You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing, software
  14. # distributed under the License is distributed on an "AS IS" BASIS,
  15. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. # See the License for the specific language governing permissions and
  17. # limitations under the License.
  18. """Inference-only GPT-J model compatible with HuggingFace weights."""
  19. from typing import Iterable, List, Optional, Tuple
  20. import torch
  21. from torch import nn
  22. from transformers import GPTJConfig
  23. from aphrodite.attention import Attention, AttentionMetadata
  24. from aphrodite.common.config import CacheConfig
  25. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  26. from aphrodite.distributed import get_tensor_model_parallel_world_size
  27. from aphrodite.modeling.layers.activation import get_act_fn
  28. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  29. QKVParallelLinear,
  30. RowParallelLinear)
  31. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  32. from aphrodite.modeling.layers.rotary_embedding import get_rope
  33. from aphrodite.modeling.layers.sampler import Sampler
  34. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  35. ParallelLMHead, VocabParallelEmbedding)
  36. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  37. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  38. from aphrodite.quantization.base_config import QuantizationConfig
  39. class GPTJAttention(nn.Module):
  40. def __init__(
  41. self,
  42. config: GPTJConfig,
  43. cache_config: Optional[CacheConfig] = None,
  44. quant_config: Optional[QuantizationConfig] = None,
  45. ):
  46. super().__init__()
  47. self.total_num_heads = config.num_attention_heads
  48. self.hidden_size = config.hidden_size
  49. self.head_size = self.hidden_size // self.total_num_heads
  50. self.qkv_proj = QKVParallelLinear(
  51. config.hidden_size,
  52. self.head_size,
  53. self.total_num_heads,
  54. bias=False,
  55. quant_config=quant_config,
  56. )
  57. self.out_proj = RowParallelLinear(
  58. config.hidden_size,
  59. config.hidden_size,
  60. bias=False,
  61. quant_config=quant_config,
  62. )
  63. tp_world_size = get_tensor_model_parallel_world_size()
  64. assert self.total_num_heads % tp_world_size == 0
  65. self.num_heads = self.total_num_heads // tp_world_size
  66. scaling = self.head_size**-0.5
  67. assert getattr(config, "rotary", True)
  68. assert config.rotary_dim % 2 == 0
  69. rope_theta = getattr(config, "rope_theta", 10000)
  70. max_position_embeddings = getattr(config, "max_position_embeddings",
  71. 8192)
  72. self.rotary_emb = get_rope(
  73. self.head_size,
  74. rotary_dim=config.rotary_dim,
  75. max_position=max_position_embeddings,
  76. base=rope_theta,
  77. is_neox_style=False,
  78. )
  79. self.attn = Attention(self.num_heads,
  80. self.head_size,
  81. scaling,
  82. cache_config=cache_config,
  83. quant_config=quant_config)
  84. def forward(
  85. self,
  86. position_ids: torch.Tensor,
  87. hidden_states: torch.Tensor,
  88. kv_cache: torch.Tensor,
  89. attn_metadata: AttentionMetadata,
  90. ) -> torch.Tensor:
  91. qkv, _ = self.qkv_proj(hidden_states)
  92. q, k, v = qkv.chunk(chunks=3, dim=-1)
  93. q, k = self.rotary_emb(position_ids, q, k)
  94. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  95. attn_output, _ = self.out_proj(attn_output)
  96. return attn_output
  97. class GPTJMLP(nn.Module):
  98. def __init__(
  99. self,
  100. intermediate_size: int,
  101. config: GPTJConfig,
  102. quant_config: Optional[QuantizationConfig] = None,
  103. ):
  104. super().__init__()
  105. hidden_size = config.n_embd
  106. self.fc_in = ColumnParallelLinear(
  107. hidden_size,
  108. intermediate_size,
  109. quant_config=quant_config,
  110. )
  111. self.fc_out = RowParallelLinear(
  112. intermediate_size,
  113. hidden_size,
  114. quant_config=quant_config,
  115. )
  116. self.act = get_act_fn(config.activation_function, quant_config,
  117. intermediate_size)
  118. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  119. hidden_states, _ = self.fc_in(hidden_states)
  120. hidden_states = self.act(hidden_states)
  121. hidden_states, _ = self.fc_out(hidden_states)
  122. return hidden_states
  123. class GPTJBlock(nn.Module):
  124. def __init__(
  125. self,
  126. config: GPTJConfig,
  127. cache_config: Optional[CacheConfig] = None,
  128. quant_config: Optional[QuantizationConfig] = None,
  129. ):
  130. super().__init__()
  131. inner_dim = (4 * config.n_embd
  132. if config.n_inner is None else config.n_inner)
  133. self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
  134. self.attn = GPTJAttention(config, cache_config, quant_config)
  135. self.mlp = GPTJMLP(inner_dim, config, quant_config)
  136. def forward(
  137. self,
  138. position_ids: torch.Tensor,
  139. hidden_states: torch.Tensor,
  140. kv_cache: torch.Tensor,
  141. attn_metadata: AttentionMetadata,
  142. ) -> torch.Tensor:
  143. residual = hidden_states
  144. hidden_states = self.ln_1(hidden_states)
  145. attn_output = self.attn(
  146. position_ids=position_ids,
  147. hidden_states=hidden_states,
  148. kv_cache=kv_cache,
  149. attn_metadata=attn_metadata,
  150. )
  151. mlp_output = self.mlp(hidden_states)
  152. hidden_states = attn_output + mlp_output + residual
  153. return hidden_states
  154. class GPTJModel(nn.Module):
  155. def __init__(
  156. self,
  157. config: GPTJConfig,
  158. cache_config: Optional[CacheConfig] = None,
  159. quant_config: Optional[QuantizationConfig] = None,
  160. ):
  161. super().__init__()
  162. self.config = config
  163. self.embed_dim = config.n_embd
  164. self.wte = VocabParallelEmbedding(
  165. config.vocab_size,
  166. self.embed_dim,
  167. )
  168. self.h = nn.ModuleList([
  169. GPTJBlock(config, cache_config, quant_config)
  170. for _ in range(config.n_layer)
  171. ])
  172. self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  173. def forward(
  174. self,
  175. input_ids: torch.Tensor,
  176. position_ids: torch.Tensor,
  177. kv_caches: List[torch.Tensor],
  178. attn_metadata: AttentionMetadata,
  179. ) -> torch.Tensor:
  180. hidden_states = self.wte(input_ids)
  181. for i in range(len(self.h)):
  182. layer = self.h[i]
  183. hidden_states = layer(
  184. position_ids,
  185. hidden_states,
  186. kv_caches[i],
  187. attn_metadata,
  188. )
  189. hidden_states = self.ln_f(hidden_states)
  190. return hidden_states
  191. class GPTJForCausalLM(nn.Module):
  192. def __init__(
  193. self,
  194. config: GPTJConfig,
  195. cache_config: Optional[CacheConfig] = None,
  196. quant_config: Optional[QuantizationConfig] = None,
  197. ):
  198. super().__init__()
  199. self.config = config
  200. self.quant_config = quant_config
  201. assert not config.tie_word_embeddings
  202. self.transformer = GPTJModel(config, cache_config, quant_config)
  203. self.lm_head = ParallelLMHead(
  204. config.vocab_size,
  205. config.n_embd,
  206. bias=True,
  207. quant_config=quant_config,
  208. )
  209. self.logits_processor = LogitsProcessor(config.vocab_size)
  210. self.sampler = Sampler()
  211. def forward(
  212. self,
  213. input_ids: torch.Tensor,
  214. positions: torch.Tensor,
  215. kv_caches: List[torch.Tensor],
  216. attn_metadata: AttentionMetadata,
  217. intermediate_tensors: Optional[IntermediateTensors] = None,
  218. ) -> torch.Tensor:
  219. hidden_states = self.transformer(input_ids, positions, kv_caches,
  220. attn_metadata)
  221. return hidden_states
  222. def compute_logits(self, hidden_states: torch.Tensor,
  223. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  224. logits = self.logits_processor(self.lm_head, hidden_states,
  225. sampling_metadata, self.lm_head.bias)
  226. return logits
  227. def sample(
  228. self,
  229. logits: torch.Tensor,
  230. sampling_metadata: SamplingMetadata,
  231. ) -> Optional[SamplerOutput]:
  232. next_tokens = self.sampler(logits, sampling_metadata)
  233. return next_tokens
  234. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  235. stacked_params_mapping = [
  236. # (param_name, shard_name, shard_id)
  237. ("qkv_proj", "q_proj", "q"),
  238. ("qkv_proj", "k_proj", "k"),
  239. ("qkv_proj", "v_proj", "v"),
  240. ("gate_up_proj", "gate_proj", 0),
  241. ("gate_up_proj", "up_proj", 1),
  242. ]
  243. params_dict = dict(self.named_parameters())
  244. for name, loaded_weight in weights:
  245. if "attn.bias" in name or "attn.masked_bias" in name:
  246. continue
  247. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  248. if weight_name not in name:
  249. continue
  250. name = name.replace(weight_name, param_name)
  251. # Skip loading extra bias for GPTQ models.
  252. if name.endswith(".bias") and name not in params_dict:
  253. continue
  254. param = params_dict[name]
  255. weight_loader = param.weight_loader
  256. weight_loader(param, loaded_weight, shard_id)
  257. break
  258. else:
  259. # Skip loading extra bias for GPTQ models.
  260. if name.endswith(".bias") and name not in params_dict:
  261. continue
  262. param = params_dict[name]
  263. weight_loader = getattr(param, "weight_loader",
  264. default_weight_loader)
  265. weight_loader(param, loaded_weight)