gpt_j.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
  6. #
  7. # Licensed under the Apache License, Version 2.0 (the "License");
  8. # you may not use this file except in compliance with the License.
  9. # You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing, software
  14. # distributed under the License is distributed on an "AS IS" BASIS,
  15. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. # See the License for the specific language governing permissions and
  17. # limitations under the License.
  18. """Inference-only GPT-J model compatible with HuggingFace weights."""
  19. from typing import List, Optional
  20. import torch
  21. from torch import nn
  22. from transformers import GPTJConfig
  23. from aphrodite.attention import Attention, AttentionMetadata
  24. from aphrodite.modeling.layers.activation import get_act_fn
  25. from aphrodite.modeling.layers.linear import (
  26. ColumnParallelLinear,
  27. LinearMethodBase,
  28. QKVParallelLinear,
  29. RowParallelLinear,
  30. )
  31. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  32. from aphrodite.modeling.layers.rotary_embedding import get_rope
  33. from aphrodite.modeling.layers.sampler import Sampler
  34. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  35. ParallelLMHead,
  36. VocabParallelEmbedding,
  37. )
  38. from aphrodite.distributed import (
  39. get_tensor_model_parallel_world_size, )
  40. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  41. from aphrodite.modeling.hf_downloader import (
  42. default_weight_loader,
  43. hf_model_weights_iterator,
  44. )
  45. from aphrodite.common.sequence import SamplerOutput
  46. class GPTJAttention(nn.Module):
  47. def __init__(
  48. self,
  49. config: GPTJConfig,
  50. linear_method: Optional[LinearMethodBase] = None,
  51. ):
  52. super().__init__()
  53. self.total_num_heads = config.num_attention_heads
  54. self.hidden_size = config.hidden_size
  55. self.head_size = self.hidden_size // self.total_num_heads
  56. if (linear_method is not None
  57. and not linear_method.quant_config.merge_weight()):
  58. self.merge_weight = False
  59. self.q_proj = ColumnParallelLinear(
  60. config.hidden_size,
  61. config.hidden_size,
  62. bias=False,
  63. linear_method=linear_method,
  64. )
  65. self.k_proj = ColumnParallelLinear(
  66. config.hidden_size,
  67. config.hidden_size,
  68. bias=False,
  69. linear_method=linear_method,
  70. )
  71. self.v_proj = ColumnParallelLinear(
  72. config.hidden_size,
  73. config.hidden_size,
  74. bias=False,
  75. linear_method=linear_method,
  76. )
  77. else:
  78. self.merge_weight = True
  79. self.qkv_proj = QKVParallelLinear(
  80. config.hidden_size,
  81. self.head_size,
  82. self.total_num_heads,
  83. bias=False,
  84. linear_method=linear_method,
  85. )
  86. self.out_proj = RowParallelLinear(
  87. config.hidden_size,
  88. config.hidden_size,
  89. bias=False,
  90. linear_method=linear_method,
  91. )
  92. tp_world_size = get_tensor_model_parallel_world_size()
  93. assert self.total_num_heads % tp_world_size == 0
  94. self.num_heads = self.total_num_heads // tp_world_size
  95. scaling = self.head_size**-0.5
  96. assert getattr(config, "rotary", True)
  97. assert config.rotary_dim % 2 == 0
  98. rope_theta = getattr(config, "rope_theta", 10000)
  99. max_position_embeddings = getattr(config, "max_position_embeddings",
  100. 8192)
  101. self.rotary_emb = get_rope(
  102. self.head_size,
  103. rotary_dim=config.rotary_dim,
  104. max_position=max_position_embeddings,
  105. base=rope_theta,
  106. is_neox_style=False,
  107. )
  108. self.attn = Attention(self.num_heads, self.head_size, scaling)
  109. def forward(
  110. self,
  111. position_ids: torch.Tensor,
  112. hidden_states: torch.Tensor,
  113. kv_cache: torch.Tensor,
  114. attn_metadata: AttentionMetadata,
  115. ) -> torch.Tensor:
  116. if self.merge_weight:
  117. qkv, _ = self.qkv_proj(hidden_states)
  118. q, k, v = qkv.chunk(chunks=3, dim=-1)
  119. else:
  120. q, _ = self.q_proj(hidden_states)
  121. k, _ = self.k_proj(hidden_states)
  122. v, _ = self.v_proj(hidden_states)
  123. q, k = self.rotary_emb(position_ids, q, k)
  124. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  125. attn_output, _ = self.out_proj(attn_output)
  126. return attn_output
  127. class GPTJMLP(nn.Module):
  128. def __init__(
  129. self,
  130. intermediate_size: int,
  131. config: GPTJConfig,
  132. linear_method: Optional[LinearMethodBase] = None,
  133. ):
  134. super().__init__()
  135. hidden_size = config.n_embd
  136. self.fc_in = ColumnParallelLinear(
  137. hidden_size,
  138. intermediate_size,
  139. linear_method=linear_method,
  140. )
  141. self.fc_out = RowParallelLinear(
  142. intermediate_size,
  143. hidden_size,
  144. linear_method=linear_method,
  145. )
  146. quant_config = getattr(linear_method, "quant_config", None)
  147. self.act = get_act_fn(config.activation_function, quant_config,
  148. intermediate_size)
  149. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  150. hidden_states, _ = self.fc_in(hidden_states)
  151. hidden_states = self.act(hidden_states)
  152. hidden_states, _ = self.fc_out(hidden_states)
  153. return hidden_states
  154. class GPTJBlock(nn.Module):
  155. def __init__(
  156. self,
  157. config: GPTJConfig,
  158. linear_method: Optional[LinearMethodBase] = None,
  159. ):
  160. super().__init__()
  161. inner_dim = (4 * config.n_embd
  162. if config.n_inner is None else config.n_inner)
  163. self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
  164. self.attn = GPTJAttention(config, linear_method)
  165. self.mlp = GPTJMLP(inner_dim, config, linear_method)
  166. def forward(
  167. self,
  168. position_ids: torch.Tensor,
  169. hidden_states: torch.Tensor,
  170. kv_cache: torch.Tensor,
  171. attn_metadata: AttentionMetadata,
  172. ) -> torch.Tensor:
  173. residual = hidden_states
  174. hidden_states = self.ln_1(hidden_states)
  175. attn_output = self.attn(
  176. position_ids=position_ids,
  177. hidden_states=hidden_states,
  178. kv_cache=kv_cache,
  179. attn_metadata=attn_metadata,
  180. )
  181. mlp_output = self.mlp(hidden_states)
  182. hidden_states = attn_output + mlp_output + residual
  183. return hidden_states
  184. class GPTJModel(nn.Module):
  185. def __init__(
  186. self,
  187. config: GPTJConfig,
  188. linear_method: Optional[LinearMethodBase] = None,
  189. ):
  190. super().__init__()
  191. self.config = config
  192. self.embed_dim = config.n_embd
  193. self.wte = VocabParallelEmbedding(config.vocab_size,
  194. self.embed_dim,
  195. linear_method=linear_method)
  196. self.h = nn.ModuleList(
  197. [GPTJBlock(config, linear_method) for _ in range(config.n_layer)])
  198. self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  199. def forward(
  200. self,
  201. input_ids: torch.Tensor,
  202. position_ids: torch.Tensor,
  203. kv_caches: List[torch.Tensor],
  204. attn_metadata: AttentionMetadata,
  205. ) -> torch.Tensor:
  206. hidden_states = self.wte(input_ids)
  207. for i in range(len(self.h)):
  208. layer = self.h[i]
  209. hidden_states = layer(
  210. position_ids,
  211. hidden_states,
  212. kv_caches[i],
  213. attn_metadata,
  214. )
  215. hidden_states = self.ln_f(hidden_states)
  216. return hidden_states
  217. class GPTJForCausalLM(nn.Module):
  218. def __init__(
  219. self,
  220. config: GPTJConfig,
  221. linear_method: Optional[LinearMethodBase] = None,
  222. ):
  223. super().__init__()
  224. self.config = config
  225. self.linear_method = linear_method
  226. assert not config.tie_word_embeddings
  227. self.transformer = GPTJModel(config, linear_method)
  228. self.lm_head = ParallelLMHead(
  229. config.vocab_size,
  230. config.n_embd,
  231. bias=True,
  232. linear_method=linear_method,
  233. )
  234. self.logits_processor = LogitsProcessor(config.vocab_size)
  235. self.sampler = Sampler()
  236. def forward(
  237. self,
  238. input_ids: torch.Tensor,
  239. positions: torch.Tensor,
  240. kv_caches: List[torch.Tensor],
  241. attn_metadata: AttentionMetadata,
  242. ) -> torch.Tensor:
  243. hidden_states = self.transformer(input_ids, positions, kv_caches,
  244. attn_metadata)
  245. return hidden_states
  246. def compute_logits(self, hidden_states: torch.Tensor,
  247. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  248. logits = self.logits_processor(self.lm_head, hidden_states,
  249. sampling_metadata, self.lm_head.bias)
  250. return logits
  251. def sample(
  252. self,
  253. logits: torch.Tensor,
  254. sampling_metadata: SamplingMetadata,
  255. ) -> Optional[SamplerOutput]:
  256. next_tokens = self.sampler(logits, sampling_metadata)
  257. return next_tokens
  258. def load_weights(
  259. self,
  260. model_name_or_path: str,
  261. cache_dir: Optional[str] = None,
  262. load_format: str = "auto",
  263. revision: Optional[str] = None,
  264. ):
  265. stacked_params_mapping = [
  266. # (param_name, shard_name, shard_id)
  267. ("qkv_proj", "q_proj", "q"),
  268. ("qkv_proj", "k_proj", "k"),
  269. ("qkv_proj", "v_proj", "v"),
  270. ("gate_up_proj", "gate_proj", 0),
  271. ("gate_up_proj", "up_proj", 1),
  272. ]
  273. if (self.linear_method is not None
  274. and not self.linear_method.quant_config.merge_weight()):
  275. stacked_params_mapping = []
  276. params_dict = dict(self.named_parameters())
  277. for name, loaded_weight in hf_model_weights_iterator(
  278. model_name_or_path, cache_dir, load_format, revision,
  279. self.config):
  280. if "attn.bias" in name or "attn.masked_bias" in name:
  281. continue
  282. for param_name, weight_name, shard_id in stacked_params_mapping:
  283. if weight_name not in name:
  284. continue
  285. name = name.replace(weight_name, param_name)
  286. # Skip loading extra bias for GPTQ models.
  287. if name.endswith(".bias") and name not in params_dict:
  288. continue
  289. param = params_dict[name]
  290. weight_loader = param.weight_loader
  291. weight_loader(param, loaded_weight, shard_id)
  292. break
  293. else:
  294. # Skip loading extra bias for GPTQ models.
  295. if name.endswith(".bias") and name not in params_dict:
  296. continue
  297. param = params_dict[name]
  298. weight_loader = getattr(param, "weight_loader",
  299. default_weight_loader)
  300. weight_loader(param, loaded_weight)