opt.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
  7. # reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """Inference-only OPT model compatible with HuggingFace weights."""
  21. from typing import Iterable, List, Optional, Tuple
  22. import torch
  23. from torch import nn
  24. from transformers import OPTConfig
  25. from aphrodite.attention import Attention, AttentionMetadata
  26. from aphrodite.common.sequence import SamplerOutput
  27. from aphrodite.distributed import get_tensor_model_parallel_world_size
  28. from aphrodite.modeling.layers.activation import get_act_fn
  29. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  30. LinearMethodBase,
  31. QKVParallelLinear,
  32. ReplicatedLinear,
  33. RowParallelLinear)
  34. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  35. from aphrodite.modeling.layers.sampler import Sampler
  36. from aphrodite.modeling.layers.vocab_parallel_embedding import \
  37. VocabParallelEmbedding
  38. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  39. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  40. class OPTLearnedPositionalEmbedding(nn.Embedding):
  41. def __init__(self, num_embeddings: int, embedding_dim: int):
  42. # OPT is set up so that if padding_idx is specified then offset the
  43. # embedding ids by 2 and adjust num_embeddings appropriately. Other
  44. # models don't have this hack
  45. self.offset = 2
  46. super().__init__(num_embeddings + self.offset, embedding_dim)
  47. def forward(self, positions: torch.Tensor):
  48. return super().forward(positions + self.offset)
  49. class OPTAttention(nn.Module):
  50. def __init__(
  51. self,
  52. embed_dim: int,
  53. num_heads: int,
  54. bias: bool = True,
  55. linear_method: Optional[LinearMethodBase] = None,
  56. ) -> None:
  57. super().__init__()
  58. self.embed_dim = embed_dim
  59. tensor_model_parallel_world_size = (
  60. get_tensor_model_parallel_world_size())
  61. total_num_heads = num_heads
  62. assert num_heads % tensor_model_parallel_world_size == 0
  63. self.num_heads = total_num_heads // tensor_model_parallel_world_size
  64. self.head_dim = embed_dim // total_num_heads
  65. self.scaling = self.head_dim**-0.5
  66. self.qkv_proj = QKVParallelLinear(
  67. embed_dim,
  68. self.head_dim,
  69. total_num_heads,
  70. bias=bias,
  71. linear_method=linear_method,
  72. )
  73. self.out_proj = RowParallelLinear(
  74. embed_dim,
  75. embed_dim,
  76. bias=bias,
  77. linear_method=linear_method,
  78. )
  79. self.attn = Attention(self.num_heads,
  80. self.head_dim,
  81. scale=self.scaling)
  82. def forward(
  83. self,
  84. hidden_states: torch.Tensor,
  85. kv_cache: torch.Tensor,
  86. attn_metadata: AttentionMetadata,
  87. ) -> torch.Tensor:
  88. qkv, _ = self.qkv_proj(hidden_states)
  89. q, k, v = qkv.chunk(chunks=3, dim=-1)
  90. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  91. output, _ = self.out_proj(attn_output)
  92. return output
  93. class OPTDecoderLayer(nn.Module):
  94. def __init__(
  95. self,
  96. config: OPTConfig,
  97. linear_method: Optional[LinearMethodBase] = None,
  98. ):
  99. super().__init__()
  100. self.config = config
  101. self.embed_dim = config.hidden_size
  102. self.self_attn = OPTAttention(
  103. embed_dim=self.embed_dim,
  104. num_heads=config.num_attention_heads,
  105. bias=config.enable_bias,
  106. linear_method=linear_method,
  107. )
  108. self.do_layer_norm_before = config.do_layer_norm_before
  109. self.self_attn_layer_norm = nn.LayerNorm(
  110. self.embed_dim,
  111. elementwise_affine=config.layer_norm_elementwise_affine)
  112. self.fc1 = ColumnParallelLinear(
  113. self.embed_dim,
  114. config.ffn_dim,
  115. bias=config.enable_bias,
  116. linear_method=linear_method,
  117. )
  118. quant_config = getattr(linear_method, "quant_config", None)
  119. self.activation_fn = get_act_fn(config.activation_function,
  120. quant_config, config.ffn_dim)
  121. self.fc2 = RowParallelLinear(
  122. config.ffn_dim,
  123. self.embed_dim,
  124. bias=config.enable_bias,
  125. linear_method=linear_method,
  126. )
  127. self.final_layer_norm = nn.LayerNorm(
  128. self.embed_dim,
  129. elementwise_affine=config.layer_norm_elementwise_affine)
  130. def forward(
  131. self,
  132. hidden_states: torch.Tensor,
  133. kv_cache: torch.Tensor,
  134. attn_metadata: AttentionMetadata,
  135. ) -> torch.Tensor:
  136. # Self Attention
  137. residual = hidden_states
  138. # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
  139. if self.do_layer_norm_before:
  140. hidden_states = self.self_attn_layer_norm(hidden_states)
  141. hidden_states = self.self_attn(hidden_states=hidden_states,
  142. kv_cache=kv_cache,
  143. attn_metadata=attn_metadata)
  144. hidden_states = residual + hidden_states
  145. # 350m applies layer norm AFTER attention
  146. if not self.do_layer_norm_before:
  147. hidden_states = self.self_attn_layer_norm(hidden_states)
  148. # Fully Connected
  149. residual = hidden_states
  150. # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
  151. if self.do_layer_norm_before:
  152. hidden_states = self.final_layer_norm(hidden_states)
  153. hidden_states, _ = self.fc1(hidden_states)
  154. hidden_states = self.activation_fn(hidden_states)
  155. hidden_states, _ = self.fc2(hidden_states)
  156. hidden_states = residual + hidden_states
  157. # 350m applies layer norm AFTER attention
  158. if not self.do_layer_norm_before:
  159. hidden_states = self.final_layer_norm(hidden_states)
  160. return hidden_states
  161. class OPTDecoder(nn.Module):
  162. def __init__(
  163. self,
  164. config: OPTConfig,
  165. linear_method: Optional[LinearMethodBase] = None,
  166. ):
  167. super().__init__()
  168. self.config = config
  169. self.padding_idx = config.pad_token_id
  170. self.max_target_positions = config.max_position_embeddings
  171. self.vocab_size = config.vocab_size
  172. self.embed_tokens = VocabParallelEmbedding(
  173. config.vocab_size,
  174. config.word_embed_proj_dim,
  175. )
  176. # Positional embeddings are replicated (not sharded).
  177. self.embed_positions = OPTLearnedPositionalEmbedding(
  178. config.max_position_embeddings, config.hidden_size)
  179. # Project out & in will be replicated if they exist.
  180. if config.word_embed_proj_dim != config.hidden_size:
  181. self.project_out = ReplicatedLinear(config.hidden_size,
  182. config.word_embed_proj_dim,
  183. bias=False,
  184. linear_method=linear_method)
  185. else:
  186. self.project_out = None
  187. if config.word_embed_proj_dim != config.hidden_size:
  188. self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
  189. config.hidden_size,
  190. bias=False,
  191. linear_method=linear_method)
  192. else:
  193. self.project_in = None
  194. # Note that the only purpose of `config._remove_final_layer_norm` is to
  195. # keep backward compatibility with checkpoints that have been fine-tuned
  196. # before transformers v4.20.1
  197. # see https://github.com/facebookresearch/metaseq/pull/164
  198. if config.do_layer_norm_before and not config._remove_final_layer_norm:
  199. self.final_layer_norm = nn.LayerNorm(
  200. config.hidden_size,
  201. elementwise_affine=config.layer_norm_elementwise_affine)
  202. else:
  203. self.final_layer_norm = None
  204. self.layers = nn.ModuleList([
  205. OPTDecoderLayer(config, linear_method)
  206. for _ in range(config.num_hidden_layers)
  207. ])
  208. def forward(
  209. self,
  210. input_ids: torch.Tensor,
  211. positions: torch.Tensor,
  212. kv_caches: List[torch.Tensor],
  213. attn_metadata: AttentionMetadata,
  214. ) -> torch.Tensor:
  215. inputs_embeds = self.embed_tokens(input_ids)
  216. pos_embeds = self.embed_positions(positions)
  217. if self.project_in is not None:
  218. inputs_embeds, _ = self.project_in(inputs_embeds)
  219. hidden_states = inputs_embeds + pos_embeds
  220. for i in range(len(self.layers)):
  221. layer = self.layers[i]
  222. hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
  223. if self.final_layer_norm is not None:
  224. hidden_states = self.final_layer_norm(hidden_states)
  225. if self.project_out is not None:
  226. hidden_states, _ = self.project_out(hidden_states)
  227. return hidden_states
  228. class OPTModel(nn.Module):
  229. def __init__(
  230. self,
  231. config: OPTConfig,
  232. linear_method: Optional[LinearMethodBase] = None,
  233. ):
  234. super().__init__()
  235. self.decoder = OPTDecoder(config, linear_method)
  236. def forward(
  237. self,
  238. input_ids: torch.Tensor,
  239. positions: torch.Tensor,
  240. kv_caches: List[torch.Tensor],
  241. attn_metadata: AttentionMetadata,
  242. ) -> torch.Tensor:
  243. return self.decoder(input_ids, positions, kv_caches, attn_metadata)
  244. class OPTForCausalLM(nn.Module):
  245. def __init__(
  246. self,
  247. config,
  248. linear_method: Optional[LinearMethodBase] = None,
  249. ):
  250. super().__init__()
  251. self.config = config
  252. self.linear_method = linear_method
  253. self.model = OPTModel(config, linear_method)
  254. self.lm_head_weight = self.model.decoder.embed_tokens.weight
  255. self.logits_processor = LogitsProcessor(config.vocab_size)
  256. self.sampler = Sampler()
  257. def forward(
  258. self,
  259. input_ids: torch.Tensor,
  260. positions: torch.Tensor,
  261. kv_caches: List[torch.Tensor],
  262. attn_metadata: AttentionMetadata,
  263. ) -> torch.Tensor:
  264. hidden_states = self.model(input_ids, positions, kv_caches,
  265. attn_metadata)
  266. return hidden_states
  267. def compute_logits(self, hidden_states: torch.Tensor,
  268. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  269. logits = self.logits_processor(self.lm_head_weight, hidden_states,
  270. sampling_metadata)
  271. return logits
  272. def sample(
  273. self,
  274. logits: torch.Tensor,
  275. sampling_metadata: SamplingMetadata,
  276. ) -> Optional[SamplerOutput]:
  277. next_tokens = self.sampler(logits, sampling_metadata)
  278. return next_tokens
  279. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  280. stacked_params_mapping = [
  281. # (param_name, shard_name, shard_id)
  282. ("qkv_proj", "q_proj", "q"),
  283. ("qkv_proj", "k_proj", "k"),
  284. ("qkv_proj", "v_proj", "v"),
  285. ]
  286. params_dict = dict(self.named_parameters(remove_duplicate=False))
  287. for name, loaded_weight in weights:
  288. if "lm_head.weight" in name:
  289. continue
  290. if name.startswith("decoder."):
  291. name = "model." + name
  292. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  293. if weight_name not in name:
  294. continue
  295. name = name.replace(weight_name, param_name)
  296. # Skip loading extra bias for GPTQ models.
  297. if name.endswith(".bias") and name not in params_dict:
  298. continue
  299. param = params_dict[name]
  300. weight_loader = param.weight_loader
  301. weight_loader(param, loaded_weight, shard_id)
  302. break
  303. else:
  304. # Skip loading extra bias for GPTQ models.
  305. if name.endswith(".bias") and name not in params_dict:
  306. continue
  307. param = params_dict[name]
  308. weight_loader = getattr(param, "weight_loader",
  309. default_weight_loader)
  310. weight_loader(param, loaded_weight)