1
0

opt.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
  6. # reserved.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. """Inference-only OPT model compatible with HuggingFace weights."""
  20. from typing import Iterable, List, Optional, Tuple
  21. import torch
  22. from torch import nn
  23. from transformers import OPTConfig
  24. from aphrodite.attention import Attention, AttentionMetadata
  25. from aphrodite.common.config import CacheConfig
  26. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  27. from aphrodite.common.utils import progress_bar
  28. from aphrodite.distributed import get_tensor_model_parallel_world_size
  29. from aphrodite.modeling.layers.activation import get_act_fn
  30. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  31. QKVParallelLinear,
  32. ReplicatedLinear,
  33. RowParallelLinear)
  34. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  35. from aphrodite.modeling.layers.sampler import Sampler
  36. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  37. VocabParallelEmbedding)
  38. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  39. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  40. from aphrodite.quantization.base_config import QuantizationConfig
  41. class OPTLearnedPositionalEmbedding(nn.Embedding):
  42. def __init__(self, num_embeddings: int, embedding_dim: int):
  43. # OPT is set up so that if padding_idx is specified then offset the
  44. # embedding ids by 2 and adjust num_embeddings appropriately. Other
  45. # models don't have this hack
  46. self.offset = 2
  47. super().__init__(num_embeddings + self.offset, embedding_dim)
  48. def forward(self, positions: torch.Tensor):
  49. return super().forward(positions + self.offset)
  50. class OPTAttention(nn.Module):
  51. def __init__(
  52. self,
  53. embed_dim: int,
  54. num_heads: int,
  55. bias: bool = True,
  56. cache_config: Optional[CacheConfig] = None,
  57. quant_config: Optional[QuantizationConfig] = None,
  58. ) -> None:
  59. super().__init__()
  60. self.embed_dim = embed_dim
  61. tensor_model_parallel_world_size = (
  62. get_tensor_model_parallel_world_size())
  63. total_num_heads = num_heads
  64. assert num_heads % tensor_model_parallel_world_size == 0
  65. self.num_heads = total_num_heads // tensor_model_parallel_world_size
  66. self.head_dim = embed_dim // total_num_heads
  67. self.scaling = self.head_dim**-0.5
  68. self.qkv_proj = QKVParallelLinear(
  69. embed_dim,
  70. self.head_dim,
  71. total_num_heads,
  72. bias=bias,
  73. quant_config=quant_config,
  74. )
  75. self.out_proj = RowParallelLinear(
  76. embed_dim,
  77. embed_dim,
  78. bias=bias,
  79. quant_config=quant_config,
  80. )
  81. self.attn = Attention(self.num_heads,
  82. self.head_dim,
  83. scale=self.scaling,
  84. cache_config=cache_config,
  85. quant_config=quant_config)
  86. def forward(
  87. self,
  88. hidden_states: torch.Tensor,
  89. kv_cache: torch.Tensor,
  90. attn_metadata: AttentionMetadata,
  91. ) -> torch.Tensor:
  92. qkv, _ = self.qkv_proj(hidden_states)
  93. q, k, v = qkv.chunk(chunks=3, dim=-1)
  94. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  95. output, _ = self.out_proj(attn_output)
  96. return output
  97. class OPTDecoderLayer(nn.Module):
  98. def __init__(
  99. self,
  100. config: OPTConfig,
  101. cache_config: Optional[CacheConfig] = None,
  102. quant_config: Optional[QuantizationConfig] = None,
  103. ):
  104. super().__init__()
  105. self.config = config
  106. self.embed_dim = config.hidden_size
  107. self.self_attn = OPTAttention(
  108. embed_dim=self.embed_dim,
  109. num_heads=config.num_attention_heads,
  110. bias=config.enable_bias,
  111. cache_config=cache_config,
  112. quant_config=quant_config,
  113. )
  114. self.do_layer_norm_before = config.do_layer_norm_before
  115. self.self_attn_layer_norm = nn.LayerNorm(
  116. self.embed_dim,
  117. elementwise_affine=config.layer_norm_elementwise_affine)
  118. self.fc1 = ColumnParallelLinear(
  119. self.embed_dim,
  120. config.ffn_dim,
  121. bias=config.enable_bias,
  122. quant_config=quant_config,
  123. )
  124. self.activation_fn = get_act_fn(config.activation_function,
  125. quant_config, config.ffn_dim)
  126. self.fc2 = RowParallelLinear(
  127. config.ffn_dim,
  128. self.embed_dim,
  129. bias=config.enable_bias,
  130. quant_config=quant_config,
  131. )
  132. self.final_layer_norm = nn.LayerNorm(
  133. self.embed_dim,
  134. elementwise_affine=config.layer_norm_elementwise_affine)
  135. def forward(
  136. self,
  137. hidden_states: torch.Tensor,
  138. kv_cache: torch.Tensor,
  139. attn_metadata: AttentionMetadata,
  140. ) -> torch.Tensor:
  141. # Self Attention
  142. residual = hidden_states
  143. # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
  144. if self.do_layer_norm_before:
  145. hidden_states = self.self_attn_layer_norm(hidden_states)
  146. hidden_states = self.self_attn(hidden_states=hidden_states,
  147. kv_cache=kv_cache,
  148. attn_metadata=attn_metadata)
  149. hidden_states = residual + hidden_states
  150. # 350m applies layer norm AFTER attention
  151. if not self.do_layer_norm_before:
  152. hidden_states = self.self_attn_layer_norm(hidden_states)
  153. # Fully Connected
  154. residual = hidden_states
  155. # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
  156. if self.do_layer_norm_before:
  157. hidden_states = self.final_layer_norm(hidden_states)
  158. hidden_states, _ = self.fc1(hidden_states)
  159. hidden_states = self.activation_fn(hidden_states)
  160. hidden_states, _ = self.fc2(hidden_states)
  161. hidden_states = residual + hidden_states
  162. # 350m applies layer norm AFTER attention
  163. if not self.do_layer_norm_before:
  164. hidden_states = self.final_layer_norm(hidden_states)
  165. return hidden_states
  166. class OPTDecoder(nn.Module):
  167. def __init__(
  168. self,
  169. config: OPTConfig,
  170. cache_config: Optional[CacheConfig] = None,
  171. quant_config: Optional[QuantizationConfig] = None,
  172. ):
  173. super().__init__()
  174. self.config = config
  175. self.padding_idx = config.pad_token_id
  176. self.max_target_positions = config.max_position_embeddings
  177. self.vocab_size = config.vocab_size
  178. self.embed_tokens = VocabParallelEmbedding(
  179. config.vocab_size,
  180. config.word_embed_proj_dim,
  181. )
  182. # Positional embeddings are replicated (not sharded).
  183. self.embed_positions = OPTLearnedPositionalEmbedding(
  184. config.max_position_embeddings, config.hidden_size)
  185. # Project out & in will be replicated if they exist.
  186. if config.word_embed_proj_dim != config.hidden_size:
  187. self.project_out = ReplicatedLinear(config.hidden_size,
  188. config.word_embed_proj_dim,
  189. bias=False,
  190. quant_config=quant_config)
  191. else:
  192. self.project_out = None
  193. if config.word_embed_proj_dim != config.hidden_size:
  194. self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
  195. config.hidden_size,
  196. bias=False,
  197. quant_config=quant_config)
  198. else:
  199. self.project_in = None
  200. # Note that the only purpose of `config._remove_final_layer_norm` is to
  201. # keep backward compatibility with checkpoints that have been fine-tuned
  202. # before transformers v4.20.1
  203. # see https://github.com/facebookresearch/metaseq/pull/164
  204. if config.do_layer_norm_before and not config._remove_final_layer_norm:
  205. self.final_layer_norm = nn.LayerNorm(
  206. config.hidden_size,
  207. elementwise_affine=config.layer_norm_elementwise_affine)
  208. else:
  209. self.final_layer_norm = None
  210. self.layers = nn.ModuleList([
  211. OPTDecoderLayer(config, cache_config, quant_config)
  212. for _ in range(config.num_hidden_layers)
  213. ])
  214. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  215. return self.embed_tokens(input_ids)
  216. def forward(
  217. self,
  218. input_ids: torch.Tensor,
  219. positions: torch.Tensor,
  220. kv_caches: List[torch.Tensor],
  221. attn_metadata: AttentionMetadata,
  222. inputs_embeds: Optional[torch.Tensor] = None,
  223. ) -> torch.Tensor:
  224. if inputs_embeds is None:
  225. inputs_embeds = self.get_input_embeddings(input_ids)
  226. pos_embeds = self.embed_positions(positions)
  227. if self.project_in is not None:
  228. inputs_embeds, _ = self.project_in(inputs_embeds)
  229. hidden_states = inputs_embeds + pos_embeds
  230. for i in range(len(self.layers)):
  231. layer = self.layers[i]
  232. hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
  233. if self.final_layer_norm is not None:
  234. hidden_states = self.final_layer_norm(hidden_states)
  235. if self.project_out is not None:
  236. hidden_states, _ = self.project_out(hidden_states)
  237. return hidden_states
  238. class OPTModel(nn.Module):
  239. def __init__(
  240. self,
  241. config: OPTConfig,
  242. cache_config: Optional[CacheConfig] = None,
  243. quant_config: Optional[QuantizationConfig] = None,
  244. ):
  245. super().__init__()
  246. self.decoder = OPTDecoder(config, cache_config, quant_config)
  247. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  248. return self.decoder.get_input_embeddings(input_ids)
  249. def forward(
  250. self,
  251. input_ids: torch.Tensor,
  252. positions: torch.Tensor,
  253. kv_caches: List[torch.Tensor],
  254. attn_metadata: AttentionMetadata,
  255. inputs_embeds: Optional[torch.Tensor] = None,
  256. ) -> torch.Tensor:
  257. return self.decoder(input_ids,
  258. positions,
  259. kv_caches,
  260. attn_metadata,
  261. inputs_embeds=inputs_embeds)
  262. class OPTForCausalLM(nn.Module):
  263. def __init__(
  264. self,
  265. config,
  266. cache_config: Optional[CacheConfig] = None,
  267. quant_config: Optional[QuantizationConfig] = None,
  268. ):
  269. super().__init__()
  270. self.config = config
  271. self.quant_config = quant_config
  272. self.model = OPTModel(config, cache_config, quant_config)
  273. self.lm_head = self.model.decoder.embed_tokens
  274. self.logits_processor = LogitsProcessor(config.vocab_size)
  275. self.sampler = Sampler()
  276. def forward(
  277. self,
  278. input_ids: torch.Tensor,
  279. positions: torch.Tensor,
  280. kv_caches: List[torch.Tensor],
  281. attn_metadata: AttentionMetadata,
  282. intermediate_tensors: Optional[IntermediateTensors] = None,
  283. ) -> torch.Tensor:
  284. hidden_states = self.model(input_ids, positions, kv_caches,
  285. attn_metadata)
  286. return hidden_states
  287. def compute_logits(
  288. self,
  289. hidden_states: torch.Tensor,
  290. sampling_metadata: SamplingMetadata,
  291. ) -> Optional[torch.Tensor]:
  292. logits = self.logits_processor(self.lm_head, hidden_states,
  293. sampling_metadata)
  294. return logits
  295. def sample(
  296. self,
  297. logits: torch.Tensor,
  298. sampling_metadata: SamplingMetadata,
  299. ) -> Optional[SamplerOutput]:
  300. next_tokens = self.sampler(logits, sampling_metadata)
  301. return next_tokens
  302. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  303. stacked_params_mapping = [
  304. # (param_name, shard_name, shard_id)
  305. ("qkv_proj", "q_proj", "q"),
  306. ("qkv_proj", "k_proj", "k"),
  307. ("qkv_proj", "v_proj", "v"),
  308. ]
  309. params_dict = dict(self.named_parameters(remove_duplicate=False))
  310. weights_list = list(weights)
  311. for name, loaded_weight in progress_bar(weights_list,
  312. desc="Loading modules..."):
  313. if "lm_head.weight" in name:
  314. continue
  315. if name.startswith("decoder."):
  316. name = "model." + name
  317. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  318. if weight_name not in name:
  319. continue
  320. name = name.replace(weight_name, param_name)
  321. # Skip loading extra bias for GPTQ models.
  322. if name.endswith(".bias") and name not in params_dict:
  323. continue
  324. param = params_dict[name]
  325. weight_loader = param.weight_loader
  326. weight_loader(param, loaded_weight, shard_id)
  327. break
  328. else:
  329. # Skip loading extra bias for GPTQ models.
  330. if name.endswith(".bias") and name not in params_dict:
  331. continue
  332. param = params_dict[name]
  333. weight_loader = getattr(param, "weight_loader",
  334. default_weight_loader)
  335. weight_loader(param, loaded_weight)