gpt_j.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
  6. #
  7. # Licensed under the Apache License, Version 2.0 (the "License");
  8. # you may not use this file except in compliance with the License.
  9. # You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing, software
  14. # distributed under the License is distributed on an "AS IS" BASIS,
  15. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. # See the License for the specific language governing permissions and
  17. # limitations under the License.
  18. """Inference-only GPT-J model compatible with HuggingFace weights."""
  19. from typing import Iterable, List, Optional, Tuple
  20. import torch
  21. from torch import nn
  22. from transformers import GPTJConfig
  23. from aphrodite.attention import Attention, AttentionMetadata
  24. from aphrodite.common.config import CacheConfig
  25. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  26. from aphrodite.common.utils import progress_bar
  27. from aphrodite.distributed import get_tensor_model_parallel_world_size
  28. from aphrodite.modeling.layers.activation import get_act_fn
  29. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  30. QKVParallelLinear,
  31. RowParallelLinear)
  32. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  33. from aphrodite.modeling.layers.rotary_embedding import get_rope
  34. from aphrodite.modeling.layers.sampler import Sampler
  35. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  36. ParallelLMHead, VocabParallelEmbedding)
  37. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  38. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  39. from aphrodite.quantization.base_config import QuantizationConfig
  40. class GPTJAttention(nn.Module):
  41. def __init__(
  42. self,
  43. config: GPTJConfig,
  44. cache_config: Optional[CacheConfig] = None,
  45. quant_config: Optional[QuantizationConfig] = None,
  46. ):
  47. super().__init__()
  48. self.total_num_heads = config.num_attention_heads
  49. self.hidden_size = config.hidden_size
  50. self.head_size = self.hidden_size // self.total_num_heads
  51. self.qkv_proj = QKVParallelLinear(
  52. config.hidden_size,
  53. self.head_size,
  54. self.total_num_heads,
  55. bias=False,
  56. quant_config=quant_config,
  57. )
  58. self.out_proj = RowParallelLinear(
  59. config.hidden_size,
  60. config.hidden_size,
  61. bias=False,
  62. quant_config=quant_config,
  63. )
  64. tp_world_size = get_tensor_model_parallel_world_size()
  65. assert self.total_num_heads % tp_world_size == 0
  66. self.num_heads = self.total_num_heads // tp_world_size
  67. scaling = self.head_size**-0.5
  68. assert getattr(config, "rotary", True)
  69. assert config.rotary_dim % 2 == 0
  70. rope_theta = getattr(config, "rope_theta", 10000)
  71. max_position_embeddings = getattr(config, "max_position_embeddings",
  72. 8192)
  73. self.rotary_emb = get_rope(
  74. self.head_size,
  75. rotary_dim=config.rotary_dim,
  76. max_position=max_position_embeddings,
  77. base=rope_theta,
  78. is_neox_style=False,
  79. )
  80. self.attn = Attention(self.num_heads,
  81. self.head_size,
  82. scaling,
  83. cache_config=cache_config,
  84. quant_config=quant_config)
  85. def forward(
  86. self,
  87. position_ids: torch.Tensor,
  88. hidden_states: torch.Tensor,
  89. kv_cache: torch.Tensor,
  90. attn_metadata: AttentionMetadata,
  91. ) -> torch.Tensor:
  92. qkv, _ = self.qkv_proj(hidden_states)
  93. q, k, v = qkv.chunk(chunks=3, dim=-1)
  94. q, k = self.rotary_emb(position_ids, q, k)
  95. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  96. attn_output, _ = self.out_proj(attn_output)
  97. return attn_output
  98. class GPTJMLP(nn.Module):
  99. def __init__(
  100. self,
  101. intermediate_size: int,
  102. config: GPTJConfig,
  103. quant_config: Optional[QuantizationConfig] = None,
  104. ):
  105. super().__init__()
  106. hidden_size = config.n_embd
  107. self.fc_in = ColumnParallelLinear(
  108. hidden_size,
  109. intermediate_size,
  110. quant_config=quant_config,
  111. )
  112. self.fc_out = RowParallelLinear(
  113. intermediate_size,
  114. hidden_size,
  115. quant_config=quant_config,
  116. )
  117. self.act = get_act_fn(config.activation_function, quant_config,
  118. intermediate_size)
  119. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  120. hidden_states, _ = self.fc_in(hidden_states)
  121. hidden_states = self.act(hidden_states)
  122. hidden_states, _ = self.fc_out(hidden_states)
  123. return hidden_states
  124. class GPTJBlock(nn.Module):
  125. def __init__(
  126. self,
  127. config: GPTJConfig,
  128. cache_config: Optional[CacheConfig] = None,
  129. quant_config: Optional[QuantizationConfig] = None,
  130. ):
  131. super().__init__()
  132. inner_dim = (4 * config.n_embd
  133. if config.n_inner is None else config.n_inner)
  134. self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
  135. self.attn = GPTJAttention(config, cache_config, quant_config)
  136. self.mlp = GPTJMLP(inner_dim, config, quant_config)
  137. def forward(
  138. self,
  139. position_ids: torch.Tensor,
  140. hidden_states: torch.Tensor,
  141. kv_cache: torch.Tensor,
  142. attn_metadata: AttentionMetadata,
  143. ) -> torch.Tensor:
  144. residual = hidden_states
  145. hidden_states = self.ln_1(hidden_states)
  146. attn_output = self.attn(
  147. position_ids=position_ids,
  148. hidden_states=hidden_states,
  149. kv_cache=kv_cache,
  150. attn_metadata=attn_metadata,
  151. )
  152. mlp_output = self.mlp(hidden_states)
  153. hidden_states = attn_output + mlp_output + residual
  154. return hidden_states
  155. class GPTJModel(nn.Module):
  156. def __init__(
  157. self,
  158. config: GPTJConfig,
  159. cache_config: Optional[CacheConfig] = None,
  160. quant_config: Optional[QuantizationConfig] = None,
  161. ):
  162. super().__init__()
  163. self.config = config
  164. self.embed_dim = config.n_embd
  165. self.wte = VocabParallelEmbedding(
  166. config.vocab_size,
  167. self.embed_dim,
  168. )
  169. self.h = nn.ModuleList([
  170. GPTJBlock(config, cache_config, quant_config)
  171. for _ in range(config.n_layer)
  172. ])
  173. self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  174. def forward(
  175. self,
  176. input_ids: torch.Tensor,
  177. position_ids: torch.Tensor,
  178. kv_caches: List[torch.Tensor],
  179. attn_metadata: AttentionMetadata,
  180. ) -> torch.Tensor:
  181. hidden_states = self.wte(input_ids)
  182. for i in range(len(self.h)):
  183. layer = self.h[i]
  184. hidden_states = layer(
  185. position_ids,
  186. hidden_states,
  187. kv_caches[i],
  188. attn_metadata,
  189. )
  190. hidden_states = self.ln_f(hidden_states)
  191. return hidden_states
  192. class GPTJForCausalLM(nn.Module):
  193. def __init__(
  194. self,
  195. config: GPTJConfig,
  196. cache_config: Optional[CacheConfig] = None,
  197. quant_config: Optional[QuantizationConfig] = None,
  198. ):
  199. super().__init__()
  200. self.config = config
  201. self.quant_config = quant_config
  202. assert not config.tie_word_embeddings
  203. self.transformer = GPTJModel(config, cache_config, quant_config)
  204. self.lm_head = ParallelLMHead(
  205. config.vocab_size,
  206. config.n_embd,
  207. bias=True,
  208. quant_config=quant_config,
  209. )
  210. self.logits_processor = LogitsProcessor(config.vocab_size)
  211. self.sampler = Sampler()
  212. def forward(
  213. self,
  214. input_ids: torch.Tensor,
  215. positions: torch.Tensor,
  216. kv_caches: List[torch.Tensor],
  217. attn_metadata: AttentionMetadata,
  218. intermediate_tensors: Optional[IntermediateTensors] = None,
  219. ) -> torch.Tensor:
  220. hidden_states = self.transformer(input_ids, positions, kv_caches,
  221. attn_metadata)
  222. return hidden_states
  223. def compute_logits(
  224. self,
  225. hidden_states: torch.Tensor,
  226. sampling_metadata: SamplingMetadata,
  227. ) -> Optional[torch.Tensor]:
  228. logits = self.logits_processor(self.lm_head, hidden_states,
  229. sampling_metadata, self.lm_head.bias)
  230. return logits
  231. def sample(
  232. self,
  233. logits: torch.Tensor,
  234. sampling_metadata: SamplingMetadata,
  235. ) -> Optional[SamplerOutput]:
  236. next_tokens = self.sampler(logits, sampling_metadata)
  237. return next_tokens
  238. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  239. stacked_params_mapping = [
  240. # (param_name, shard_name, shard_id)
  241. ("qkv_proj", "q_proj", "q"),
  242. ("qkv_proj", "k_proj", "k"),
  243. ("qkv_proj", "v_proj", "v"),
  244. ("gate_up_proj", "gate_proj", 0),
  245. ("gate_up_proj", "up_proj", 1),
  246. ]
  247. params_dict = dict(self.named_parameters())
  248. weights_list = list(weights)
  249. for name, loaded_weight in progress_bar(weights_list,
  250. desc="Loading modules..."):
  251. if "attn.bias" in name or "attn.masked_bias" in name:
  252. continue
  253. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  254. if weight_name not in name:
  255. continue
  256. name = name.replace(weight_name, param_name)
  257. # Skip loading extra bias for GPTQ models.
  258. if name.endswith(".bias") and name not in params_dict:
  259. continue
  260. param = params_dict[name]
  261. weight_loader = param.weight_loader
  262. weight_loader(param, loaded_weight, shard_id)
  263. break
  264. else:
  265. # Skip loading extra bias for GPTQ models.
  266. if name.endswith(".bias") and name not in params_dict:
  267. continue
  268. param = params_dict[name]
  269. weight_loader = getattr(param, "weight_loader",
  270. default_weight_loader)
  271. weight_loader(param, loaded_weight)