gpt_j.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. """Inference-only GPT-J model compatible with HuggingFace weights.
  20. The input of the model is flattened to a 1D tensor of tokens. The model uses
  21. InputMetadata to extract the original 2D shape of the input.
  22. """
  23. from typing import List, Optional, Tuple
  24. import torch
  25. from torch import nn
  26. from transformers import GPTJConfig
  27. from aphrodite.modeling.metadata import InputMetadata
  28. from aphrodite.modeling.layers.activation import get_act_fn
  29. from aphrodite.modeling.layers.attention import PagedAttentionWithRoPE
  30. from aphrodite.modeling.layers.sampler import Sampler
  31. from aphrodite.modeling.hf_downloader import (hf_model_weights_iterator,
  32. load_tensor_parallel_weights)
  33. from aphrodite.modeling.megatron.parallel_state import (
  34. get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
  35. from aphrodite.modeling.megatron.layers import (VocabParallelEmbedding,
  36. ColumnParallelLinear,
  37. RowParallelLinear)
  38. from aphrodite.common.sequence import SamplerOutput
  39. KVCache = Tuple[torch.Tensor, torch.Tensor]
  40. class GPTJAttention(nn.Module):
  41. def __init__(self, config: GPTJConfig):
  42. super().__init__()
  43. self.total_num_heads = config.num_attention_heads
  44. self.hidden_size = config.hidden_size
  45. self.head_size = self.hidden_size // self.total_num_heads
  46. self.qkv_proj = ColumnParallelLinear(
  47. config.hidden_size,
  48. 3 * config.hidden_size,
  49. bias=False,
  50. gather_output=False,
  51. )
  52. self.out_proj = RowParallelLinear(
  53. config.hidden_size,
  54. config.hidden_size,
  55. bias=False,
  56. input_is_parallel=True,
  57. )
  58. tp_world_size = get_tensor_model_parallel_world_size()
  59. assert self.total_num_heads % tp_world_size == 0
  60. self.num_heads = self.total_num_heads // tp_world_size
  61. scaling = self.head_size**-0.5
  62. assert getattr(config, "rotary", True)
  63. assert config.rotary_dim % 2 == 0
  64. rope_theta = getattr(config, "rope_theta", 10000)
  65. max_position_embeddings = getattr(config, "max_position_embeddings",
  66. 8192)
  67. self.attn = PagedAttentionWithRoPE(
  68. self.num_heads,
  69. self.head_size,
  70. scaling,
  71. config.rotary_dim,
  72. base=rope_theta,
  73. max_position=max_position_embeddings,
  74. is_neox_style=False)
  75. self.warmup = False
  76. def forward(
  77. self,
  78. position_ids: torch.Tensor,
  79. hidden_states: torch.Tensor,
  80. kv_cache: KVCache,
  81. input_metadata: InputMetadata,
  82. cache_event: Optional[torch.cuda.Event],
  83. ) -> torch.Tensor:
  84. qkv, _ = self.qkv_proj(hidden_states)
  85. q, k, v = qkv.chunk(chunks=3, dim=-1)
  86. k_cache, v_cache = kv_cache
  87. attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache,
  88. input_metadata, cache_event)
  89. attn_output, _ = self.out_proj(attn_output)
  90. return attn_output
  91. class GPTJMLP(nn.Module):
  92. def __init__(self, intermediate_size: int, config: GPTJConfig):
  93. super().__init__()
  94. hidden_size = config.n_embd
  95. self.fc_in = ColumnParallelLinear(
  96. hidden_size,
  97. intermediate_size,
  98. gather_output=False,
  99. )
  100. self.fc_out = RowParallelLinear(
  101. intermediate_size,
  102. hidden_size,
  103. input_is_parallel=True,
  104. )
  105. self.act = get_act_fn(config.activation_function)
  106. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  107. hidden_states, _ = self.fc_in(hidden_states)
  108. hidden_states = self.act(hidden_states)
  109. hidden_states, _ = self.fc_out(hidden_states)
  110. return hidden_states
  111. class GPTJBlock(nn.Module):
  112. def __init__(self, config: GPTJConfig):
  113. super().__init__()
  114. if config.n_inner is None:
  115. inner_dim = 4 * config.n_embd
  116. else:
  117. inner_dim = config.n_inner
  118. self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
  119. self.attn = GPTJAttention(config)
  120. self.mlp = GPTJMLP(inner_dim, config)
  121. def forward(
  122. self,
  123. position_ids: torch.Tensor,
  124. hidden_states: torch.Tensor,
  125. kv_cache: KVCache,
  126. input_metadata: InputMetadata,
  127. cache_event: Optional[torch.cuda.Event],
  128. ) -> torch.Tensor:
  129. residual = hidden_states
  130. hidden_states = self.ln_1(hidden_states)
  131. attn_output = self.attn(
  132. position_ids=position_ids,
  133. hidden_states=hidden_states,
  134. kv_cache=kv_cache,
  135. input_metadata=input_metadata,
  136. cache_event=cache_event,
  137. )
  138. mlp_output = self.mlp(hidden_states)
  139. hidden_states = attn_output + mlp_output + residual
  140. return hidden_states
  141. class GPTJModel(nn.Module):
  142. def __init__(self, config: GPTJConfig):
  143. super().__init__()
  144. self.config = config
  145. self.embed_dim = config.n_embd
  146. self.wte = VocabParallelEmbedding(
  147. config.vocab_size,
  148. self.embed_dim,
  149. )
  150. self.h = nn.ModuleList(
  151. [GPTJBlock(config) for _ in range(config.n_layer)])
  152. self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  153. def forward(
  154. self,
  155. input_ids: torch.Tensor,
  156. position_ids: torch.Tensor,
  157. kv_caches: List[KVCache],
  158. input_metadata: InputMetadata,
  159. cache_events: Optional[List[torch.cuda.Event]],
  160. ) -> torch.Tensor:
  161. hidden_states = self.wte(input_ids)
  162. for i in range(len(self.h)):
  163. if cache_events is None:
  164. cache_event = None
  165. else:
  166. cache_event = cache_events[i]
  167. layer = self.h[i]
  168. hidden_states = layer(
  169. position_ids,
  170. hidden_states,
  171. kv_caches[i],
  172. input_metadata,
  173. cache_event,
  174. )
  175. hidden_states = self.ln_f(hidden_states)
  176. return hidden_states
  177. class GPTJForCausalLM(nn.Module):
  178. def __init__(self, config: GPTJConfig):
  179. super().__init__()
  180. self.config = config
  181. assert not config.tie_word_embeddings
  182. self.transformer = GPTJModel(config)
  183. self.lm_head = ColumnParallelLinear(
  184. config.n_embd,
  185. config.vocab_size,
  186. gather_output=False,
  187. )
  188. self.sampler = Sampler(config.vocab_size)
  189. def forward(
  190. self,
  191. input_ids: torch.Tensor,
  192. positions: torch.Tensor,
  193. kv_caches: List[KVCache],
  194. input_metadata: InputMetadata,
  195. cache_events: Optional[List[torch.cuda.Event]],
  196. ) -> SamplerOutput:
  197. hidden_states = self.transformer(input_ids, positions, kv_caches,
  198. input_metadata, cache_events)
  199. next_tokens = self.sampler(self.lm_head.weight, hidden_states,
  200. input_metadata, self.lm_head.bias)
  201. return next_tokens
  202. _column_parallel_weights = [
  203. "wte.weight", "fc_in.weight", "fc_in.bias", "lm_head.weight",
  204. "lm_head.bias"
  205. ]
  206. _row_parallel_weights = ["out_proj.weight", "fc_out.weight"]
  207. def load_weights(self,
  208. model_name_or_path: str,
  209. cache_dir: Optional[str] = None,
  210. load_format: str = "auto",
  211. revision: Optional[str] = None):
  212. tp_rank = get_tensor_model_parallel_rank()
  213. state_dict = self.state_dict()
  214. for name, loaded_weight in hf_model_weights_iterator(
  215. model_name_or_path, cache_dir, load_format, revision):
  216. if "attn.bias" in name or "attn.masked_bias" in name:
  217. continue
  218. is_attention_weight = False
  219. for stride_id, att_weight_name in enumerate(
  220. ["q_proj", "k_proj", "v_proj"]):
  221. if att_weight_name not in name:
  222. continue
  223. # pylint: disable=unsubscriptable-object
  224. param = state_dict[name.replace(att_weight_name, "qkv_proj")]
  225. shard_size = param.shape[1]
  226. loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
  227. (tp_rank + 1)]
  228. param_slice = param.data[shard_size * stride_id:shard_size *
  229. (stride_id + 1)]
  230. assert param_slice.shape == loaded_weight.shape
  231. param_slice.copy_(loaded_weight)
  232. is_attention_weight = True
  233. break
  234. if is_attention_weight:
  235. continue
  236. # pylint: disable=unsubscriptable-object
  237. param = state_dict[name]
  238. load_tensor_parallel_weights(param, loaded_weight, name,
  239. self._column_parallel_weights,
  240. self._row_parallel_weights, tp_rank)