opt.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
  7. # reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """Inference-only OPT model compatible with HuggingFace weights."""
  21. from typing import List, Optional, Tuple
  22. import torch
  23. from torch import nn
  24. from transformers import OPTConfig
  25. from aphrodite.modeling.metadata import InputMetadata
  26. from aphrodite.modeling.layers.activation import get_act_fn
  27. from aphrodite.modeling.layers.attention import PagedAttention
  28. from aphrodite.modeling.layers.linear import (
  29. ColumnParallelLinear,
  30. LinearMethodBase,
  31. QKVParallelLinear,
  32. ReplicatedLinear,
  33. RowParallelLinear,
  34. )
  35. from aphrodite.modeling.layers.sampler import Sampler, QuantSampler
  36. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  37. VocabParallelEmbedding,
  38. ParallelLMHead,
  39. )
  40. from aphrodite.modeling.megatron.parallel_state import (
  41. get_tensor_model_parallel_world_size, )
  42. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  43. from aphrodite.modeling.hf_downloader import (
  44. default_weight_loader,
  45. hf_model_weights_iterator,
  46. )
  47. from aphrodite.common.sequence import SamplerOutput
  48. KVCache = Tuple[torch.Tensor, torch.Tensor]
  49. class OPTLearnedPositionalEmbedding(nn.Embedding):
  50. def __init__(self, num_embeddings: int, embedding_dim: int):
  51. # OPT is set up so that if padding_idx is specified then offset the
  52. # embedding ids by 2 and adjust num_embeddings appropriately. Other
  53. # models don't have this hack
  54. self.offset = 2
  55. super().__init__(num_embeddings + self.offset, embedding_dim)
  56. def forward(self, positions: torch.Tensor):
  57. return super().forward(positions + self.offset)
  58. class OPTAttention(nn.Module):
  59. def __init__(
  60. self,
  61. embed_dim: int,
  62. num_heads: int,
  63. bias: bool = True,
  64. linear_method: Optional[LinearMethodBase] = None,
  65. ) -> None:
  66. super().__init__()
  67. self.embed_dim = embed_dim
  68. tensor_model_parallel_world_size = (
  69. get_tensor_model_parallel_world_size())
  70. total_num_heads = num_heads
  71. assert num_heads % tensor_model_parallel_world_size == 0
  72. self.num_heads = total_num_heads // tensor_model_parallel_world_size
  73. self.head_dim = embed_dim // total_num_heads
  74. self.scaling = self.head_dim**-0.5
  75. if (linear_method is not None
  76. and not linear_method.quant_config.merge_weight()):
  77. self.merge_weight = False
  78. self.q_proj = ColumnParallelLinear(embed_dim,
  79. embed_dim,
  80. bias=bias,
  81. linear_method=linear_method)
  82. self.k_proj = ColumnParallelLinear(embed_dim,
  83. embed_dim,
  84. bias=bias,
  85. linear_method=linear_method)
  86. self.v_proj = ColumnParallelLinear(embed_dim,
  87. embed_dim,
  88. bias=bias,
  89. linear_method=linear_method)
  90. else:
  91. self.merge_weight = True
  92. self.qkv_proj = QKVParallelLinear(
  93. embed_dim,
  94. self.head_dim,
  95. total_num_heads,
  96. bias=bias,
  97. linear_method=linear_method,
  98. )
  99. self.out_proj = RowParallelLinear(
  100. embed_dim,
  101. embed_dim,
  102. bias=bias,
  103. linear_method=linear_method,
  104. )
  105. self.attn = PagedAttention(self.num_heads,
  106. self.head_dim,
  107. scale=self.scaling)
  108. def forward(
  109. self,
  110. hidden_states: torch.Tensor,
  111. kv_cache: KVCache,
  112. input_metadata: InputMetadata,
  113. ) -> torch.Tensor:
  114. if self.merge_weight:
  115. qkv, _ = self.qkv_proj(hidden_states)
  116. q, k, v = qkv.chunk(chunks=3, dim=-1)
  117. else:
  118. q, _ = self.q_proj(hidden_states)
  119. k, _ = self.k_proj(hidden_states)
  120. v, _ = self.v_proj(hidden_states)
  121. key_cache, value_cache = kv_cache
  122. attn_output = self.attn(q, k, v, key_cache, value_cache,
  123. input_metadata)
  124. output, _ = self.out_proj(attn_output)
  125. return output
  126. class OPTDecoderLayer(nn.Module):
  127. def __init__(
  128. self,
  129. config: OPTConfig,
  130. linear_method: Optional[LinearMethodBase] = None,
  131. ):
  132. super().__init__()
  133. self.config = config
  134. self.embed_dim = config.hidden_size
  135. self.self_attn = OPTAttention(
  136. embed_dim=self.embed_dim,
  137. num_heads=config.num_attention_heads,
  138. bias=config.enable_bias,
  139. linear_method=linear_method,
  140. )
  141. self.do_layer_norm_before = config.do_layer_norm_before
  142. self.self_attn_layer_norm = nn.LayerNorm(
  143. self.embed_dim,
  144. elementwise_affine=config.layer_norm_elementwise_affine,
  145. )
  146. self.fc1 = ColumnParallelLinear(
  147. self.embed_dim,
  148. config.ffn_dim,
  149. bias=config.enable_bias,
  150. linear_method=linear_method,
  151. )
  152. quant_config = getattr(linear_method, "quant_config", None)
  153. self.activation_fn = get_act_fn(config.activation_function,
  154. quant_config, config.ffn_dim)
  155. self.fc2 = RowParallelLinear(
  156. config.ffn_dim,
  157. self.embed_dim,
  158. bias=config.enable_bias,
  159. linear_method=linear_method,
  160. )
  161. self.final_layer_norm = nn.LayerNorm(
  162. self.embed_dim,
  163. elementwise_affine=config.layer_norm_elementwise_affine,
  164. )
  165. def forward(
  166. self,
  167. hidden_states: torch.Tensor,
  168. kv_cache: KVCache,
  169. input_metadata: InputMetadata,
  170. ) -> torch.Tensor:
  171. # Self Attention
  172. residual = hidden_states
  173. # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
  174. if self.do_layer_norm_before:
  175. hidden_states = self.self_attn_layer_norm(hidden_states)
  176. hidden_states = self.self_attn(
  177. hidden_states=hidden_states,
  178. kv_cache=kv_cache,
  179. input_metadata=input_metadata,
  180. )
  181. hidden_states = residual + hidden_states
  182. # 350m applies layer norm AFTER attention
  183. if not self.do_layer_norm_before:
  184. hidden_states = self.self_attn_layer_norm(hidden_states)
  185. # Fully Connected
  186. residual = hidden_states
  187. # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
  188. if self.do_layer_norm_before:
  189. hidden_states = self.final_layer_norm(hidden_states)
  190. hidden_states, _ = self.fc1(hidden_states)
  191. hidden_states = self.activation_fn(hidden_states)
  192. hidden_states, _ = self.fc2(hidden_states)
  193. hidden_states = residual + hidden_states
  194. # 350m applies layer norm AFTER attention
  195. if not self.do_layer_norm_before:
  196. hidden_states = self.final_layer_norm(hidden_states)
  197. return hidden_states
  198. class OPTDecoder(nn.Module):
  199. def __init__(
  200. self,
  201. config: OPTConfig,
  202. linear_method: Optional[LinearMethodBase] = None,
  203. ):
  204. super().__init__()
  205. self.config = config
  206. self.padding_idx = config.pad_token_id
  207. self.max_target_positions = config.max_position_embeddings
  208. self.vocab_size = config.vocab_size
  209. self.embed_tokens = VocabParallelEmbedding(
  210. config.vocab_size,
  211. config.word_embed_proj_dim,
  212. linear_method=linear_method,
  213. )
  214. # Positional embeddings are replicated (not sharded).
  215. self.embed_positions = OPTLearnedPositionalEmbedding(
  216. config.max_position_embeddings, config.hidden_size)
  217. # Project out & in will be replicated if they exist.
  218. if config.word_embed_proj_dim != config.hidden_size:
  219. self.project_out = ReplicatedLinear(
  220. config.hidden_size,
  221. config.word_embed_proj_dim,
  222. bias=False,
  223. linear_method=linear_method,
  224. )
  225. else:
  226. self.project_out = None
  227. if config.word_embed_proj_dim != config.hidden_size:
  228. self.project_in = ReplicatedLinear(
  229. config.word_embed_proj_dim,
  230. config.hidden_size,
  231. bias=False,
  232. linear_method=linear_method,
  233. )
  234. else:
  235. self.project_in = None
  236. # Note that the only purpose of `config._remove_final_layer_norm` is to
  237. # keep backward compatibility with checkpoints that have been fine-tuned
  238. # before transformers v4.20.1
  239. # see https://github.com/facebookresearch/metaseq/pull/164
  240. if config.do_layer_norm_before and not config._remove_final_layer_norm:
  241. self.final_layer_norm = nn.LayerNorm(
  242. config.hidden_size,
  243. elementwise_affine=config.layer_norm_elementwise_affine,
  244. )
  245. else:
  246. self.final_layer_norm = None
  247. self.layers = nn.ModuleList([
  248. OPTDecoderLayer(config, linear_method)
  249. for _ in range(config.num_hidden_layers)
  250. ])
  251. def forward(
  252. self,
  253. input_ids: torch.Tensor,
  254. positions: torch.Tensor,
  255. kv_caches: List[KVCache],
  256. input_metadata: InputMetadata,
  257. ) -> torch.Tensor:
  258. inputs_embeds = self.embed_tokens(input_ids)
  259. pos_embeds = self.embed_positions(positions)
  260. if self.project_in is not None:
  261. inputs_embeds, _ = self.project_in(inputs_embeds)
  262. hidden_states = inputs_embeds + pos_embeds
  263. for i in range(len(self.layers)):
  264. layer = self.layers[i]
  265. hidden_states = layer(hidden_states, kv_caches[i], input_metadata)
  266. if self.final_layer_norm is not None:
  267. hidden_states = self.final_layer_norm(hidden_states)
  268. if self.project_out is not None:
  269. hidden_states, _ = self.project_out(hidden_states)
  270. return hidden_states
  271. class OPTModel(nn.Module):
  272. def __init__(
  273. self,
  274. config: OPTConfig,
  275. linear_method: Optional[LinearMethodBase] = None,
  276. ):
  277. super().__init__()
  278. self.decoder = OPTDecoder(config, linear_method)
  279. def forward(
  280. self,
  281. input_ids: torch.Tensor,
  282. positions: torch.Tensor,
  283. kv_caches: List[KVCache],
  284. input_metadata: InputMetadata,
  285. ) -> torch.Tensor:
  286. return self.decoder(input_ids, positions, kv_caches, input_metadata)
  287. class OPTForCausalLM(nn.Module):
  288. def __init__(
  289. self,
  290. config,
  291. linear_method: Optional[LinearMethodBase] = None,
  292. ):
  293. super().__init__()
  294. self.config = config
  295. self.linear_method = linear_method
  296. self.model = OPTModel(config, linear_method)
  297. self.lm_head = ParallelLMHead(config.vocab_size,
  298. config.hidden_size,
  299. linear_method=linear_method)
  300. self.sampler = Sampler(config.vocab_size)
  301. self.quant_sampler = QuantSampler(config.vocab_size)
  302. def forward(
  303. self,
  304. input_ids: torch.Tensor,
  305. positions: torch.Tensor,
  306. kv_caches: List[KVCache],
  307. input_metadata: InputMetadata,
  308. ) -> torch.Tensor:
  309. hidden_states = self.model(input_ids, positions, kv_caches,
  310. input_metadata)
  311. return hidden_states
  312. def sample(
  313. self,
  314. hidden_states: torch.Tensor,
  315. sampling_metadata: SamplingMetadata,
  316. ) -> Optional[SamplerOutput]:
  317. if (self.linear_method is not None
  318. and not self.linear_method.quant_config.merge_weight()):
  319. next_tokens = self.quant_sampler(self.lm_head(hidden_states),
  320. sampling_metadata)
  321. else:
  322. next_tokens = self.sampler(self.lm_head.weight, hidden_states,
  323. sampling_metadata)
  324. return next_tokens
  325. def load_weights(
  326. self,
  327. model_name_or_path: str,
  328. cache_dir: Optional[str] = None,
  329. load_format: str = "auto",
  330. revision: Optional[str] = None,
  331. ):
  332. stacked_params_mapping = [
  333. # (param_name, shard_name, shard_id)
  334. ("qkv_proj", "q_proj", "q"),
  335. ("qkv_proj", "k_proj", "k"),
  336. ("qkv_proj", "v_proj", "v"),
  337. ]
  338. if (self.linear_method is not None
  339. and not self.linear_method.quant_config.merge_weight()):
  340. stacked_params_mapping = []
  341. params_dict = dict(self.named_parameters(remove_duplicate=False))
  342. for name, loaded_weight in hf_model_weights_iterator(
  343. model_name_or_path, cache_dir, load_format, revision,
  344. self.config):
  345. if "lm_head" in name and name not in params_dict:
  346. continue
  347. if "embed_tokens" in name:
  348. # Copy word embedding to lm_head
  349. if name.startswith("decoder."):
  350. name = "model." + name
  351. head_name = name.replace("model.decoder.embed_tokens",
  352. "lm_head")
  353. if head_name in params_dict:
  354. lm_head_param = params_dict[head_name]
  355. weight_loader = getattr(lm_head_param, "weight_loader",
  356. default_weight_loader)
  357. weight_loader(lm_head_param, loaded_weight)
  358. if name.startswith("decoder."):
  359. name = "model." + name
  360. for param_name, weight_name, shard_id in stacked_params_mapping:
  361. if weight_name not in name:
  362. continue
  363. name = name.replace(weight_name, param_name)
  364. # Skip loading extra bias for GPTQ models.
  365. if name.endswith(".bias") and name not in params_dict:
  366. continue
  367. param = params_dict[name]
  368. weight_loader = param.weight_loader
  369. weight_loader(param, loaded_weight, shard_id)
  370. break
  371. else:
  372. # Skip loading extra bias for GPTQ models.
  373. if name.endswith(".bias") and name not in params_dict:
  374. continue
  375. param = params_dict[name]
  376. weight_loader = getattr(param, "weight_loader",
  377. default_weight_loader)
  378. weight_loader(param, loaded_weight)