opt.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
  6. # reserved.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. """Inference-only OPT model compatible with HuggingFace weights."""
  20. from typing import Iterable, List, Optional, Tuple
  21. import torch
  22. from torch import nn
  23. from transformers import OPTConfig
  24. from aphrodite.attention import Attention, AttentionMetadata
  25. from aphrodite.common.config import CacheConfig
  26. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  27. from aphrodite.distributed import get_tensor_model_parallel_world_size
  28. from aphrodite.modeling.layers.activation import get_act_fn
  29. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  30. QKVParallelLinear,
  31. ReplicatedLinear,
  32. RowParallelLinear)
  33. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  34. from aphrodite.modeling.layers.sampler import Sampler
  35. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  36. VocabParallelEmbedding)
  37. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  38. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  39. from aphrodite.quantization.base_config import QuantizationConfig
  40. class OPTLearnedPositionalEmbedding(nn.Embedding):
  41. def __init__(self, num_embeddings: int, embedding_dim: int):
  42. # OPT is set up so that if padding_idx is specified then offset the
  43. # embedding ids by 2 and adjust num_embeddings appropriately. Other
  44. # models don't have this hack
  45. self.offset = 2
  46. super().__init__(num_embeddings + self.offset, embedding_dim)
  47. def forward(self, positions: torch.Tensor):
  48. return super().forward(positions + self.offset)
  49. class OPTAttention(nn.Module):
  50. def __init__(
  51. self,
  52. embed_dim: int,
  53. num_heads: int,
  54. bias: bool = True,
  55. cache_config: Optional[CacheConfig] = None,
  56. quant_config: Optional[QuantizationConfig] = None,
  57. ) -> None:
  58. super().__init__()
  59. self.embed_dim = embed_dim
  60. tensor_model_parallel_world_size = (
  61. get_tensor_model_parallel_world_size())
  62. total_num_heads = num_heads
  63. assert num_heads % tensor_model_parallel_world_size == 0
  64. self.num_heads = total_num_heads // tensor_model_parallel_world_size
  65. self.head_dim = embed_dim // total_num_heads
  66. self.scaling = self.head_dim**-0.5
  67. self.qkv_proj = QKVParallelLinear(
  68. embed_dim,
  69. self.head_dim,
  70. total_num_heads,
  71. bias=bias,
  72. quant_config=quant_config,
  73. )
  74. self.out_proj = RowParallelLinear(
  75. embed_dim,
  76. embed_dim,
  77. bias=bias,
  78. quant_config=quant_config,
  79. )
  80. self.attn = Attention(self.num_heads,
  81. self.head_dim,
  82. scale=self.scaling,
  83. cache_config=cache_config,
  84. quant_config=quant_config)
  85. def forward(
  86. self,
  87. hidden_states: torch.Tensor,
  88. kv_cache: torch.Tensor,
  89. attn_metadata: AttentionMetadata,
  90. ) -> torch.Tensor:
  91. qkv, _ = self.qkv_proj(hidden_states)
  92. q, k, v = qkv.chunk(chunks=3, dim=-1)
  93. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  94. output, _ = self.out_proj(attn_output)
  95. return output
  96. class OPTDecoderLayer(nn.Module):
  97. def __init__(
  98. self,
  99. config: OPTConfig,
  100. cache_config: Optional[CacheConfig] = None,
  101. quant_config: Optional[QuantizationConfig] = None,
  102. ):
  103. super().__init__()
  104. self.config = config
  105. self.embed_dim = config.hidden_size
  106. self.self_attn = OPTAttention(
  107. embed_dim=self.embed_dim,
  108. num_heads=config.num_attention_heads,
  109. bias=config.enable_bias,
  110. cache_config=cache_config,
  111. quant_config=quant_config,
  112. )
  113. self.do_layer_norm_before = config.do_layer_norm_before
  114. self.self_attn_layer_norm = nn.LayerNorm(
  115. self.embed_dim,
  116. elementwise_affine=config.layer_norm_elementwise_affine)
  117. self.fc1 = ColumnParallelLinear(
  118. self.embed_dim,
  119. config.ffn_dim,
  120. bias=config.enable_bias,
  121. quant_config=quant_config,
  122. )
  123. self.activation_fn = get_act_fn(config.activation_function,
  124. quant_config, config.ffn_dim)
  125. self.fc2 = RowParallelLinear(
  126. config.ffn_dim,
  127. self.embed_dim,
  128. bias=config.enable_bias,
  129. quant_config=quant_config,
  130. )
  131. self.final_layer_norm = nn.LayerNorm(
  132. self.embed_dim,
  133. elementwise_affine=config.layer_norm_elementwise_affine)
  134. def forward(
  135. self,
  136. hidden_states: torch.Tensor,
  137. kv_cache: torch.Tensor,
  138. attn_metadata: AttentionMetadata,
  139. ) -> torch.Tensor:
  140. # Self Attention
  141. residual = hidden_states
  142. # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
  143. if self.do_layer_norm_before:
  144. hidden_states = self.self_attn_layer_norm(hidden_states)
  145. hidden_states = self.self_attn(hidden_states=hidden_states,
  146. kv_cache=kv_cache,
  147. attn_metadata=attn_metadata)
  148. hidden_states = residual + hidden_states
  149. # 350m applies layer norm AFTER attention
  150. if not self.do_layer_norm_before:
  151. hidden_states = self.self_attn_layer_norm(hidden_states)
  152. # Fully Connected
  153. residual = hidden_states
  154. # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
  155. if self.do_layer_norm_before:
  156. hidden_states = self.final_layer_norm(hidden_states)
  157. hidden_states, _ = self.fc1(hidden_states)
  158. hidden_states = self.activation_fn(hidden_states)
  159. hidden_states, _ = self.fc2(hidden_states)
  160. hidden_states = residual + hidden_states
  161. # 350m applies layer norm AFTER attention
  162. if not self.do_layer_norm_before:
  163. hidden_states = self.final_layer_norm(hidden_states)
  164. return hidden_states
  165. class OPTDecoder(nn.Module):
  166. def __init__(
  167. self,
  168. config: OPTConfig,
  169. cache_config: Optional[CacheConfig] = None,
  170. quant_config: Optional[QuantizationConfig] = None,
  171. ):
  172. super().__init__()
  173. self.config = config
  174. self.padding_idx = config.pad_token_id
  175. self.max_target_positions = config.max_position_embeddings
  176. self.vocab_size = config.vocab_size
  177. self.embed_tokens = VocabParallelEmbedding(
  178. config.vocab_size,
  179. config.word_embed_proj_dim,
  180. )
  181. # Positional embeddings are replicated (not sharded).
  182. self.embed_positions = OPTLearnedPositionalEmbedding(
  183. config.max_position_embeddings, config.hidden_size)
  184. # Project out & in will be replicated if they exist.
  185. if config.word_embed_proj_dim != config.hidden_size:
  186. self.project_out = ReplicatedLinear(config.hidden_size,
  187. config.word_embed_proj_dim,
  188. bias=False,
  189. quant_config=quant_config)
  190. else:
  191. self.project_out = None
  192. if config.word_embed_proj_dim != config.hidden_size:
  193. self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
  194. config.hidden_size,
  195. bias=False,
  196. quant_config=quant_config)
  197. else:
  198. self.project_in = None
  199. # Note that the only purpose of `config._remove_final_layer_norm` is to
  200. # keep backward compatibility with checkpoints that have been fine-tuned
  201. # before transformers v4.20.1
  202. # see https://github.com/facebookresearch/metaseq/pull/164
  203. if config.do_layer_norm_before and not config._remove_final_layer_norm:
  204. self.final_layer_norm = nn.LayerNorm(
  205. config.hidden_size,
  206. elementwise_affine=config.layer_norm_elementwise_affine)
  207. else:
  208. self.final_layer_norm = None
  209. self.layers = nn.ModuleList([
  210. OPTDecoderLayer(config, cache_config, quant_config)
  211. for _ in range(config.num_hidden_layers)
  212. ])
  213. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  214. return self.embed_tokens(input_ids)
  215. def forward(
  216. self,
  217. input_ids: torch.Tensor,
  218. positions: torch.Tensor,
  219. kv_caches: List[torch.Tensor],
  220. attn_metadata: AttentionMetadata,
  221. inputs_embeds: Optional[torch.Tensor] = None,
  222. ) -> torch.Tensor:
  223. if inputs_embeds is None:
  224. inputs_embeds = self.get_input_embeddings(input_ids)
  225. pos_embeds = self.embed_positions(positions)
  226. if self.project_in is not None:
  227. inputs_embeds, _ = self.project_in(inputs_embeds)
  228. hidden_states = inputs_embeds + pos_embeds
  229. for i in range(len(self.layers)):
  230. layer = self.layers[i]
  231. hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
  232. if self.final_layer_norm is not None:
  233. hidden_states = self.final_layer_norm(hidden_states)
  234. if self.project_out is not None:
  235. hidden_states, _ = self.project_out(hidden_states)
  236. return hidden_states
  237. class OPTModel(nn.Module):
  238. def __init__(
  239. self,
  240. config: OPTConfig,
  241. cache_config: Optional[CacheConfig] = None,
  242. quant_config: Optional[QuantizationConfig] = None,
  243. ):
  244. super().__init__()
  245. self.decoder = OPTDecoder(config, cache_config, quant_config)
  246. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  247. return self.decoder.get_input_embeddings(input_ids)
  248. def forward(
  249. self,
  250. input_ids: torch.Tensor,
  251. positions: torch.Tensor,
  252. kv_caches: List[torch.Tensor],
  253. attn_metadata: AttentionMetadata,
  254. inputs_embeds: Optional[torch.Tensor] = None,
  255. ) -> torch.Tensor:
  256. return self.decoder(input_ids,
  257. positions,
  258. kv_caches,
  259. attn_metadata,
  260. inputs_embeds=inputs_embeds)
  261. class OPTForCausalLM(nn.Module):
  262. def __init__(
  263. self,
  264. config,
  265. cache_config: Optional[CacheConfig] = None,
  266. quant_config: Optional[QuantizationConfig] = None,
  267. ):
  268. super().__init__()
  269. self.config = config
  270. self.quant_config = quant_config
  271. self.model = OPTModel(config, cache_config, quant_config)
  272. self.lm_head = self.model.decoder.embed_tokens
  273. self.logits_processor = LogitsProcessor(config.vocab_size)
  274. self.sampler = Sampler()
  275. def forward(
  276. self,
  277. input_ids: torch.Tensor,
  278. positions: torch.Tensor,
  279. kv_caches: List[torch.Tensor],
  280. attn_metadata: AttentionMetadata,
  281. intermediate_tensors: Optional[IntermediateTensors] = None,
  282. ) -> torch.Tensor:
  283. hidden_states = self.model(input_ids, positions, kv_caches,
  284. attn_metadata)
  285. return hidden_states
  286. def compute_logits(
  287. self,
  288. hidden_states: torch.Tensor,
  289. sampling_metadata: SamplingMetadata,
  290. ) -> Optional[torch.Tensor]:
  291. logits = self.logits_processor(self.lm_head, hidden_states,
  292. sampling_metadata)
  293. return logits
  294. def sample(
  295. self,
  296. logits: torch.Tensor,
  297. sampling_metadata: SamplingMetadata,
  298. ) -> Optional[SamplerOutput]:
  299. next_tokens = self.sampler(logits, sampling_metadata)
  300. return next_tokens
  301. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  302. stacked_params_mapping = [
  303. # (param_name, shard_name, shard_id)
  304. ("qkv_proj", "q_proj", "q"),
  305. ("qkv_proj", "k_proj", "k"),
  306. ("qkv_proj", "v_proj", "v"),
  307. ]
  308. params_dict = dict(self.named_parameters(remove_duplicate=False))
  309. for name, loaded_weight in weights:
  310. if "lm_head.weight" in name:
  311. continue
  312. if name.startswith("decoder."):
  313. name = "model." + name
  314. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  315. if weight_name not in name:
  316. continue
  317. name = name.replace(weight_name, param_name)
  318. # Skip loading extra bias for GPTQ models.
  319. if name.endswith(".bias") and name not in params_dict:
  320. continue
  321. param = params_dict[name]
  322. weight_loader = param.weight_loader
  323. weight_loader(param, loaded_weight, shard_id)
  324. break
  325. else:
  326. # Skip loading extra bias for GPTQ models.
  327. if name.endswith(".bias") and name not in params_dict:
  328. continue
  329. param = params_dict[name]
  330. weight_loader = getattr(param, "weight_loader",
  331. default_weight_loader)
  332. weight_loader(param, loaded_weight)