gpt_j.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. """Inference-only GPT-J model compatible with HuggingFace weights.
  20. The input of the model is flattened to a 1D tensor of tokens. The model uses
  21. InputMetadata to extract the original 2D shape of the input.
  22. """
  23. from typing import List, Optional, Tuple
  24. import torch
  25. from torch import nn
  26. from transformers import GPTJConfig
  27. from aphrodite.modeling.metadata import InputMetadata
  28. from aphrodite.modeling.layers.activation import get_act_fn
  29. from aphrodite.modeling.layers.attention import PagedAttentionWithRoPE
  30. from aphrodite.modeling.layers.sampler import Sampler
  31. from aphrodite.modeling.hf_downloader import (hf_model_weights_iterator,
  32. load_tensor_parallel_weights)
  33. from aphrodite.modeling.megatron.parallel_state import (
  34. get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
  35. from aphrodite.modeling.megatron.tensor_parallel import (
  36. VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
  37. from aphrodite.common.sequence import SamplerOutput
  38. KVCache = Tuple[torch.Tensor, torch.Tensor]
  39. class GPTJAttention(nn.Module):
  40. def __init__(self, config: GPTJConfig):
  41. super().__init__()
  42. self.total_num_heads = config.num_attention_heads
  43. self.hidden_size = config.hidden_size
  44. self.head_size = self.hidden_size // self.total_num_heads
  45. self.qkv_proj = ColumnParallelLinear(config.hidden_size,
  46. 3 * config.hidden_size,
  47. bias=False,
  48. gather_output=False,
  49. perform_initialization=False)
  50. self.out_proj = RowParallelLinear(config.hidden_size,
  51. config.hidden_size,
  52. bias=False,
  53. input_is_parallel=True,
  54. perform_initialization=False)
  55. tp_world_size = get_tensor_model_parallel_world_size()
  56. assert self.total_num_heads % tp_world_size == 0
  57. self.num_heads = self.total_num_heads // tp_world_size
  58. scaling = self.head_size**-0.5
  59. assert getattr(config, "rotary", True)
  60. assert config.rotary_dim % 2 == 0
  61. self.attn = PagedAttentionWithRoPE(self.num_heads,
  62. self.head_size,
  63. scaling,
  64. config.rotary_dim,
  65. is_neox_style=False)
  66. self.warmup = False
  67. def forward(
  68. self,
  69. position_ids: torch.Tensor,
  70. hidden_states: torch.Tensor,
  71. kv_cache: KVCache,
  72. input_metadata: InputMetadata,
  73. cache_event: Optional[torch.cuda.Event],
  74. ) -> torch.Tensor:
  75. qkv, _ = self.qkv_proj(hidden_states)
  76. q, k, v = qkv.chunk(chunks=3, dim=-1)
  77. k_cache, v_cache = kv_cache
  78. attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache,
  79. input_metadata, cache_event)
  80. attn_output, _ = self.out_proj(attn_output)
  81. return attn_output
  82. class GPTJMLP(nn.Module):
  83. def __init__(self, intermediate_size: int, config: GPTJConfig):
  84. super().__init__()
  85. hidden_size = config.n_embd
  86. self.fc_in = ColumnParallelLinear(hidden_size,
  87. intermediate_size,
  88. gather_output=False,
  89. perform_initialization=False)
  90. self.fc_out = RowParallelLinear(intermediate_size,
  91. hidden_size,
  92. input_is_parallel=True,
  93. perform_initialization=False)
  94. self.act = get_act_fn(config.activation_function)
  95. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  96. hidden_states, _ = self.fc_in(hidden_states)
  97. hidden_states = self.act(hidden_states)
  98. hidden_states, _ = self.fc_out(hidden_states)
  99. return hidden_states
  100. class GPTJBlock(nn.Module):
  101. def __init__(self, config: GPTJConfig):
  102. super().__init__()
  103. if config.n_inner is None:
  104. inner_dim = 4 * config.n_embd
  105. else:
  106. inner_dim = config.n_inner
  107. self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
  108. self.attn = GPTJAttention(config)
  109. self.mlp = GPTJMLP(inner_dim, config)
  110. def forward(
  111. self,
  112. position_ids: torch.Tensor,
  113. hidden_states: torch.Tensor,
  114. kv_cache: KVCache,
  115. input_metadata: InputMetadata,
  116. cache_event: Optional[torch.cuda.Event],
  117. ) -> torch.Tensor:
  118. residual = hidden_states
  119. hidden_states = self.ln_1(hidden_states)
  120. attn_output = self.attn(
  121. position_ids=position_ids,
  122. hidden_states=hidden_states,
  123. kv_cache=kv_cache,
  124. input_metadata=input_metadata,
  125. cache_event=cache_event,
  126. )
  127. mlp_output = self.mlp(hidden_states)
  128. hidden_states = attn_output + mlp_output + residual
  129. return hidden_states
  130. class GPTJModel(nn.Module):
  131. def __init__(self, config: GPTJConfig):
  132. super().__init__()
  133. self.config = config
  134. self.embed_dim = config.n_embd
  135. self.wte = VocabParallelEmbedding(config.vocab_size,
  136. self.embed_dim,
  137. perform_initialization=False)
  138. self.h = nn.ModuleList(
  139. [GPTJBlock(config) for _ in range(config.n_layer)])
  140. self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  141. def forward(
  142. self,
  143. input_ids: torch.Tensor,
  144. position_ids: torch.Tensor,
  145. kv_caches: List[KVCache],
  146. input_metadata: InputMetadata,
  147. cache_events: Optional[List[torch.cuda.Event]],
  148. ) -> torch.Tensor:
  149. hidden_states = self.wte(input_ids)
  150. for i in range(len(self.h)):
  151. if cache_events is None:
  152. cache_event = None
  153. else:
  154. cache_event = cache_events[i]
  155. layer = self.h[i]
  156. hidden_states = layer(
  157. position_ids,
  158. hidden_states,
  159. kv_caches[i],
  160. input_metadata,
  161. cache_event,
  162. )
  163. hidden_states = self.ln_f(hidden_states)
  164. return hidden_states
  165. class GPTJForCausalLM(nn.Module):
  166. def __init__(self, config: GPTJConfig):
  167. super().__init__()
  168. self.config = config
  169. assert not config.tie_word_embeddings
  170. self.transformer = GPTJModel(config)
  171. self.lm_head = ColumnParallelLinear(config.n_embd,
  172. config.vocab_size,
  173. gather_output=False,
  174. perform_initialization=False)
  175. self.sampler = Sampler(config.vocab_size)
  176. def forward(
  177. self,
  178. input_ids: torch.Tensor,
  179. positions: torch.Tensor,
  180. kv_caches: List[KVCache],
  181. input_metadata: InputMetadata,
  182. cache_events: Optional[List[torch.cuda.Event]],
  183. ) -> SamplerOutput:
  184. hidden_states = self.transformer(input_ids, positions, kv_caches,
  185. input_metadata, cache_events)
  186. next_tokens = self.sampler(self.lm_head.weight, hidden_states,
  187. input_metadata, self.lm_head.bias)
  188. return next_tokens
  189. _column_parallel_weights = [
  190. "wte.weight", "fc_in.weight", "fc_in.bias", "lm_head.weight",
  191. "lm_head.bias"
  192. ]
  193. _row_parallel_weights = ["out_proj.weight", "fc_out.weight"]
  194. def load_weights(self,
  195. model_name_or_path: str,
  196. cache_dir: Optional[str] = None,
  197. load_format: str = "auto",
  198. revision: Optional[str] = None):
  199. tp_rank = get_tensor_model_parallel_rank()
  200. state_dict = self.state_dict()
  201. for name, loaded_weight in hf_model_weights_iterator(
  202. model_name_or_path, cache_dir, load_format, revision):
  203. if "attn.bias" in name or "attn.masked_bias" in name:
  204. continue
  205. is_attention_weight = False
  206. for stride_id, att_weight_name in enumerate(
  207. ["q_proj", "k_proj", "v_proj"]):
  208. if att_weight_name not in name:
  209. continue
  210. param = state_dict[name.replace(att_weight_name, "qkv_proj")]
  211. shard_size = param.shape[1]
  212. loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
  213. (tp_rank + 1)]
  214. param_slice = param.data[shard_size * stride_id:shard_size *
  215. (stride_id + 1)]
  216. assert param_slice.shape == loaded_weight.shape
  217. param_slice.copy_(loaded_weight)
  218. is_attention_weight = True
  219. break
  220. if is_attention_weight:
  221. continue
  222. param = state_dict[name]
  223. load_tensor_parallel_weights(param, loaded_weight, name,
  224. self._column_parallel_weights,
  225. self._row_parallel_weights, tp_rank)