gpt_j.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. """Inference-only GPT-J model compatible with HuggingFace weights."""
  20. from typing import List, Optional, Tuple
  21. import torch
  22. from torch import nn
  23. from transformers import GPTJConfig
  24. from aphrodite.modeling.metadata import InputMetadata
  25. from aphrodite.modeling.layers.activation import get_act_fn
  26. from aphrodite.modeling.layers.attention import PagedAttention
  27. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  28. LinearMethodBase,
  29. QKVParallelLinear,
  30. RowParallelLinear)
  31. from aphrodite.modeling.layers.rotary_embedding import get_rope
  32. from aphrodite.modeling.layers.sampler import Sampler
  33. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  34. VocabParallelEmbedding, ParallelLMHead)
  35. from aphrodite.modeling.megatron.parallel_state import (
  36. get_tensor_model_parallel_world_size)
  37. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  38. from aphrodite.modeling.hf_downloader import (default_weight_loader,
  39. hf_model_weights_iterator)
  40. from aphrodite.common.sequence import SamplerOutput
  41. KVCache = Tuple[torch.Tensor, torch.Tensor]
  42. class GPTJAttention(nn.Module):
  43. def __init__(
  44. self,
  45. config: GPTJConfig,
  46. linear_method: Optional[LinearMethodBase] = None,
  47. ):
  48. super().__init__()
  49. self.total_num_heads = config.num_attention_heads
  50. self.hidden_size = config.hidden_size
  51. self.head_size = self.hidden_size // self.total_num_heads
  52. if linear_method is not None and not linear_method.quant_config.merge_weight(
  53. ):
  54. self.merge_weight = False
  55. self.q_proj = ColumnParallelLinear(config.hidden_size,
  56. config.hidden_size,
  57. bias=False,
  58. linear_method=linear_method)
  59. self.k_proj = ColumnParallelLinear(config.hidden_size,
  60. config.hidden_size,
  61. bias=False,
  62. linear_method=linear_method)
  63. self.v_proj = ColumnParallelLinear(config.hidden_size,
  64. config.hidden_size,
  65. bias=False,
  66. linear_method=linear_method)
  67. else:
  68. self.merge_weight = True
  69. self.qkv_proj = QKVParallelLinear(
  70. config.hidden_size,
  71. self.head_dim,
  72. self.total_num_heads,
  73. bias=False,
  74. linear_method=linear_method,
  75. )
  76. self.out_proj = RowParallelLinear(
  77. config.hidden_size,
  78. config.hidden_size,
  79. bias=False,
  80. linear_method=linear_method,
  81. )
  82. tp_world_size = get_tensor_model_parallel_world_size()
  83. assert self.total_num_heads % tp_world_size == 0
  84. self.num_heads = self.total_num_heads // tp_world_size
  85. scaling = self.head_size**-0.5
  86. assert getattr(config, "rotary", True)
  87. assert config.rotary_dim % 2 == 0
  88. rope_theta = getattr(config, "rope_theta", 10000)
  89. max_position_embeddings = getattr(config, "max_position_embeddings",
  90. 8192)
  91. self.rotary_emb = get_rope(
  92. self.head_size,
  93. rotary_dim=config.rotary_dim,
  94. max_position=max_position_embeddings,
  95. base=rope_theta,
  96. is_neox_style=False,
  97. )
  98. self.attn = PagedAttention(self.num_heads, self.head_size, scaling)
  99. def forward(
  100. self,
  101. position_ids: torch.Tensor,
  102. hidden_states: torch.Tensor,
  103. kv_cache: KVCache,
  104. input_metadata: InputMetadata,
  105. ) -> torch.Tensor:
  106. if self.merge_weight:
  107. qkv, _ = self.qkv_proj(hidden_states)
  108. q, k, v = qkv.chunk(chunks=3, dim=-1)
  109. else:
  110. q, _ = self.q_proj(hidden_states)
  111. k, _ = self.k_proj(hidden_states)
  112. v, _ = self.v_proj(hidden_states)
  113. q, k = self.rotary_emb(position_ids, q, k)
  114. k_cache, v_cache = kv_cache
  115. attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
  116. attn_output, _ = self.out_proj(attn_output)
  117. return attn_output
  118. class GPTJMLP(nn.Module):
  119. def __init__(
  120. self,
  121. intermediate_size: int,
  122. config: GPTJConfig,
  123. linear_method: Optional[LinearMethodBase] = None,
  124. ):
  125. super().__init__()
  126. hidden_size = config.n_embd
  127. self.fc_in = ColumnParallelLinear(
  128. hidden_size,
  129. intermediate_size,
  130. linear_method=linear_method,
  131. )
  132. self.fc_out = RowParallelLinear(
  133. intermediate_size,
  134. hidden_size,
  135. linear_method=linear_method,
  136. )
  137. quant_config = getattr(linear_method, "quant_config", None)
  138. self.act = get_act_fn(config.activation_function, quant_config,
  139. intermediate_size)
  140. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  141. hidden_states, _ = self.fc_in(hidden_states)
  142. hidden_states = self.act(hidden_states)
  143. hidden_states, _ = self.fc_out(hidden_states)
  144. return hidden_states
  145. class GPTJBlock(nn.Module):
  146. def __init__(
  147. self,
  148. config: GPTJConfig,
  149. linear_method: Optional[LinearMethodBase] = None,
  150. ):
  151. super().__init__()
  152. inner_dim = (4 * config.n_embd
  153. if config.n_inner is None else config.n_inner)
  154. self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
  155. self.attn = GPTJAttention(config, linear_method)
  156. self.mlp = GPTJMLP(inner_dim, config, linear_method)
  157. def forward(
  158. self,
  159. position_ids: torch.Tensor,
  160. hidden_states: torch.Tensor,
  161. kv_cache: KVCache,
  162. input_metadata: InputMetadata,
  163. ) -> torch.Tensor:
  164. residual = hidden_states
  165. hidden_states = self.ln_1(hidden_states)
  166. attn_output = self.attn(
  167. position_ids=position_ids,
  168. hidden_states=hidden_states,
  169. kv_cache=kv_cache,
  170. input_metadata=input_metadata,
  171. )
  172. mlp_output = self.mlp(hidden_states)
  173. hidden_states = attn_output + mlp_output + residual
  174. return hidden_states
  175. class GPTJModel(nn.Module):
  176. def __init__(
  177. self,
  178. config: GPTJConfig,
  179. linear_method: Optional[LinearMethodBase] = None,
  180. ):
  181. super().__init__()
  182. self.config = config
  183. self.embed_dim = config.n_embd
  184. self.wte = VocabParallelEmbedding(
  185. config.vocab_size,
  186. self.embed_dim,
  187. linear_method=linear_method,
  188. )
  189. self.h = nn.ModuleList(
  190. [GPTJBlock(config, linear_method) for _ in range(config.n_layer)])
  191. self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  192. def forward(
  193. self,
  194. input_ids: torch.Tensor,
  195. position_ids: torch.Tensor,
  196. kv_caches: List[KVCache],
  197. input_metadata: InputMetadata,
  198. ) -> torch.Tensor:
  199. hidden_states = self.wte(input_ids)
  200. for i in range(len(self.h)):
  201. layer = self.h[i]
  202. hidden_states = layer(
  203. position_ids,
  204. hidden_states,
  205. kv_caches[i],
  206. input_metadata,
  207. )
  208. hidden_states = self.ln_f(hidden_states)
  209. return hidden_states
  210. class GPTJForCausalLM(nn.Module):
  211. def __init__(
  212. self,
  213. config: GPTJConfig,
  214. linear_method: Optional[LinearMethodBase] = None,
  215. ):
  216. super().__init__()
  217. self.config = config
  218. self.linear_method = linear_method
  219. assert not config.tie_word_embeddings
  220. self.transformer = GPTJModel(config, linear_method)
  221. self.lm_head = ParallelLMHead(
  222. config.vocab_size,
  223. config.n_embd,
  224. bias=True,
  225. linear_method=linear_method,
  226. )
  227. self.sampler = Sampler(config.vocab_size)
  228. def forward(
  229. self,
  230. input_ids: torch.Tensor,
  231. positions: torch.Tensor,
  232. kv_caches: List[KVCache],
  233. input_metadata: InputMetadata,
  234. ) -> torch.Tensor:
  235. hidden_states = self.transformer(input_ids, positions, kv_caches,
  236. input_metadata)
  237. return hidden_states
  238. def sample(
  239. self,
  240. hidden_states: torch.Tensor,
  241. sampling_metadata: SamplingMetadata,
  242. ) -> Optional[SamplerOutput]:
  243. next_tokens = self.sampler(self.lm_head(hidden_states),
  244. sampling_metadata, self.lm_head.bias)
  245. return next_tokens
  246. def load_weights(self,
  247. model_name_or_path: str,
  248. cache_dir: Optional[str] = None,
  249. load_format: str = "auto",
  250. revision: Optional[str] = None):
  251. stacked_params_mapping = [
  252. # (param_name, shard_name, shard_id)
  253. ("qkv_proj", "q_proj", "q"),
  254. ("qkv_proj", "k_proj", "k"),
  255. ("qkv_proj", "v_proj", "v"),
  256. ("gate_up_proj", "gate_proj", 0),
  257. ("gate_up_proj", "up_proj", 1),
  258. ]
  259. if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
  260. ):
  261. stacked_params_mapping = []
  262. params_dict = dict(self.named_parameters())
  263. for name, loaded_weight in hf_model_weights_iterator(
  264. model_name_or_path, cache_dir, load_format, revision):
  265. if "attn.bias" in name or "attn.masked_bias" in name:
  266. continue
  267. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  268. if weight_name not in name:
  269. continue
  270. name = name.replace(weight_name, param_name)
  271. # Skip loading extra bias for GPTQ models.
  272. if name.endswith(".bias") and name not in params_dict:
  273. continue
  274. param = params_dict[name]
  275. weight_loader = param.weight_loader
  276. weight_loader(param, loaded_weight, shard_id)
  277. break
  278. else:
  279. # Skip loading extra bias for GPTQ models.
  280. if name.endswith(".bias") and name not in params_dict:
  281. continue
  282. param = params_dict[name]
  283. weight_loader = getattr(param, "weight_loader",
  284. default_weight_loader)
  285. weight_loader(param, loaded_weight)