gpt_j.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. """Inference-only GPT-J model compatible with HuggingFace weights."""
  20. from typing import List, Optional, Tuple
  21. import torch
  22. from torch import nn
  23. from transformers import GPTJConfig
  24. from aphrodite.modeling.metadata import InputMetadata
  25. from aphrodite.modeling.layers.activation import get_act_fn
  26. from aphrodite.modeling.layers.attention import PagedAttention
  27. from aphrodite.modeling.layers.linear import (
  28. ColumnParallelLinear,
  29. LinearMethodBase,
  30. QKVParallelLinear,
  31. RowParallelLinear,
  32. )
  33. from aphrodite.modeling.layers.rotary_embedding import get_rope
  34. from aphrodite.modeling.layers.sampler import Sampler, QuantSampler
  35. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  36. VocabParallelEmbedding,
  37. ParallelLMHead,
  38. )
  39. from aphrodite.modeling.megatron.parallel_state import (
  40. get_tensor_model_parallel_world_size, )
  41. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  42. from aphrodite.modeling.hf_downloader import (
  43. default_weight_loader,
  44. hf_model_weights_iterator,
  45. )
  46. from aphrodite.common.sequence import SamplerOutput
  47. KVCache = Tuple[torch.Tensor, torch.Tensor]
  48. class GPTJAttention(nn.Module):
  49. def __init__(
  50. self,
  51. config: GPTJConfig,
  52. linear_method: Optional[LinearMethodBase] = None,
  53. ):
  54. super().__init__()
  55. self.total_num_heads = config.num_attention_heads
  56. self.hidden_size = config.hidden_size
  57. self.head_size = self.hidden_size // self.total_num_heads
  58. if (linear_method is not None
  59. and not linear_method.quant_config.merge_weight()):
  60. self.merge_weight = False
  61. self.q_proj = ColumnParallelLinear(
  62. config.hidden_size,
  63. config.hidden_size,
  64. bias=False,
  65. linear_method=linear_method,
  66. )
  67. self.k_proj = ColumnParallelLinear(
  68. config.hidden_size,
  69. config.hidden_size,
  70. bias=False,
  71. linear_method=linear_method,
  72. )
  73. self.v_proj = ColumnParallelLinear(
  74. config.hidden_size,
  75. config.hidden_size,
  76. bias=False,
  77. linear_method=linear_method,
  78. )
  79. else:
  80. self.merge_weight = True
  81. self.qkv_proj = QKVParallelLinear(
  82. config.hidden_size,
  83. self.head_dim,
  84. self.total_num_heads,
  85. bias=False,
  86. linear_method=linear_method,
  87. )
  88. self.out_proj = RowParallelLinear(
  89. config.hidden_size,
  90. config.hidden_size,
  91. bias=False,
  92. linear_method=linear_method,
  93. )
  94. tp_world_size = get_tensor_model_parallel_world_size()
  95. assert self.total_num_heads % tp_world_size == 0
  96. self.num_heads = self.total_num_heads // tp_world_size
  97. scaling = self.head_size**-0.5
  98. assert getattr(config, "rotary", True)
  99. assert config.rotary_dim % 2 == 0
  100. rope_theta = getattr(config, "rope_theta", 10000)
  101. max_position_embeddings = getattr(config, "max_position_embeddings",
  102. 8192)
  103. self.rotary_emb = get_rope(
  104. self.head_size,
  105. rotary_dim=config.rotary_dim,
  106. max_position=max_position_embeddings,
  107. base=rope_theta,
  108. is_neox_style=False,
  109. )
  110. self.attn = PagedAttention(self.num_heads, self.head_size, scaling)
  111. def forward(
  112. self,
  113. position_ids: torch.Tensor,
  114. hidden_states: torch.Tensor,
  115. kv_cache: KVCache,
  116. input_metadata: InputMetadata,
  117. ) -> torch.Tensor:
  118. if self.merge_weight:
  119. qkv, _ = self.qkv_proj(hidden_states)
  120. q, k, v = qkv.chunk(chunks=3, dim=-1)
  121. else:
  122. q, _ = self.q_proj(hidden_states)
  123. k, _ = self.k_proj(hidden_states)
  124. v, _ = self.v_proj(hidden_states)
  125. q, k = self.rotary_emb(position_ids, q, k)
  126. k_cache, v_cache = kv_cache
  127. attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
  128. attn_output, _ = self.out_proj(attn_output)
  129. return attn_output
  130. class GPTJMLP(nn.Module):
  131. def __init__(
  132. self,
  133. intermediate_size: int,
  134. config: GPTJConfig,
  135. linear_method: Optional[LinearMethodBase] = None,
  136. ):
  137. super().__init__()
  138. hidden_size = config.n_embd
  139. self.fc_in = ColumnParallelLinear(
  140. hidden_size,
  141. intermediate_size,
  142. linear_method=linear_method,
  143. )
  144. self.fc_out = RowParallelLinear(
  145. intermediate_size,
  146. hidden_size,
  147. linear_method=linear_method,
  148. )
  149. quant_config = getattr(linear_method, "quant_config", None)
  150. self.act = get_act_fn(config.activation_function, quant_config,
  151. intermediate_size)
  152. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  153. hidden_states, _ = self.fc_in(hidden_states)
  154. hidden_states = self.act(hidden_states)
  155. hidden_states, _ = self.fc_out(hidden_states)
  156. return hidden_states
  157. class GPTJBlock(nn.Module):
  158. def __init__(
  159. self,
  160. config: GPTJConfig,
  161. linear_method: Optional[LinearMethodBase] = None,
  162. ):
  163. super().__init__()
  164. inner_dim = (4 * config.n_embd
  165. if config.n_inner is None else config.n_inner)
  166. self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
  167. self.attn = GPTJAttention(config, linear_method)
  168. self.mlp = GPTJMLP(inner_dim, config, linear_method)
  169. def forward(
  170. self,
  171. position_ids: torch.Tensor,
  172. hidden_states: torch.Tensor,
  173. kv_cache: KVCache,
  174. input_metadata: InputMetadata,
  175. ) -> torch.Tensor:
  176. residual = hidden_states
  177. hidden_states = self.ln_1(hidden_states)
  178. attn_output = self.attn(
  179. position_ids=position_ids,
  180. hidden_states=hidden_states,
  181. kv_cache=kv_cache,
  182. input_metadata=input_metadata,
  183. )
  184. mlp_output = self.mlp(hidden_states)
  185. hidden_states = attn_output + mlp_output + residual
  186. return hidden_states
  187. class GPTJModel(nn.Module):
  188. def __init__(
  189. self,
  190. config: GPTJConfig,
  191. linear_method: Optional[LinearMethodBase] = None,
  192. ):
  193. super().__init__()
  194. self.config = config
  195. self.embed_dim = config.n_embd
  196. self.wte = VocabParallelEmbedding(config.vocab_size,
  197. self.embed_dim,
  198. linear_method=linear_method)
  199. self.h = nn.ModuleList(
  200. [GPTJBlock(config, linear_method) for _ in range(config.n_layer)])
  201. self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  202. def forward(
  203. self,
  204. input_ids: torch.Tensor,
  205. position_ids: torch.Tensor,
  206. kv_caches: List[KVCache],
  207. input_metadata: InputMetadata,
  208. ) -> torch.Tensor:
  209. hidden_states = self.wte(input_ids)
  210. for i in range(len(self.h)):
  211. layer = self.h[i]
  212. hidden_states = layer(
  213. position_ids,
  214. hidden_states,
  215. kv_caches[i],
  216. input_metadata,
  217. )
  218. hidden_states = self.ln_f(hidden_states)
  219. return hidden_states
  220. class GPTJForCausalLM(nn.Module):
  221. def __init__(
  222. self,
  223. config: GPTJConfig,
  224. linear_method: Optional[LinearMethodBase] = None,
  225. ):
  226. super().__init__()
  227. self.config = config
  228. self.linear_method = linear_method
  229. assert not config.tie_word_embeddings
  230. self.transformer = GPTJModel(config, linear_method)
  231. self.lm_head = ParallelLMHead(
  232. config.vocab_size,
  233. config.n_embd,
  234. bias=True,
  235. linear_method=linear_method,
  236. )
  237. self.sampler = Sampler(config.vocab_size)
  238. self.quant_sampler = QuantSampler(config.vocab_size)
  239. def forward(
  240. self,
  241. input_ids: torch.Tensor,
  242. positions: torch.Tensor,
  243. kv_caches: List[KVCache],
  244. input_metadata: InputMetadata,
  245. ) -> torch.Tensor:
  246. hidden_states = self.transformer(input_ids, positions, kv_caches,
  247. input_metadata)
  248. return hidden_states
  249. def sample(
  250. self,
  251. hidden_states: torch.Tensor,
  252. sampling_metadata: SamplingMetadata,
  253. ) -> Optional[SamplerOutput]:
  254. if (self.linear_method is not None
  255. and not self.linear_method.quant_config.merge_weight()):
  256. next_tokens = self.quant_sampler(self.lm_head(hidden_states),
  257. sampling_metadata,
  258. self.lm_head.bias)
  259. else:
  260. next_tokens = self.sampler(self.lm_head.weight, hidden_states,
  261. sampling_metadata, self.lm_head.bias)
  262. return next_tokens
  263. def load_weights(
  264. self,
  265. model_name_or_path: str,
  266. cache_dir: Optional[str] = None,
  267. load_format: str = "auto",
  268. revision: Optional[str] = None,
  269. ):
  270. stacked_params_mapping = [
  271. # (param_name, shard_name, shard_id)
  272. ("qkv_proj", "q_proj", "q"),
  273. ("qkv_proj", "k_proj", "k"),
  274. ("qkv_proj", "v_proj", "v"),
  275. ("gate_up_proj", "gate_proj", 0),
  276. ("gate_up_proj", "up_proj", 1),
  277. ]
  278. if (self.linear_method is not None
  279. and not self.linear_method.quant_config.merge_weight()):
  280. stacked_params_mapping = []
  281. params_dict = dict(self.named_parameters())
  282. for name, loaded_weight in hf_model_weights_iterator(
  283. model_name_or_path, cache_dir, load_format, revision,
  284. self.config):
  285. if "attn.bias" in name or "attn.masked_bias" in name:
  286. continue
  287. for param_name, weight_name, shard_id in stacked_params_mapping:
  288. if weight_name not in name:
  289. continue
  290. name = name.replace(weight_name, param_name)
  291. # Skip loading extra bias for GPTQ models.
  292. if name.endswith(".bias") and name not in params_dict:
  293. continue
  294. param = params_dict[name]
  295. weight_loader = param.weight_loader
  296. weight_loader(param, loaded_weight, shard_id)
  297. break
  298. else:
  299. # Skip loading extra bias for GPTQ models.
  300. if name.endswith(".bias") and name not in params_dict:
  301. continue
  302. param = params_dict[name]
  303. weight_loader = getattr(param, "weight_loader",
  304. default_weight_loader)
  305. weight_loader(param, loaded_weight)