opt.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
  7. # reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """Inference-only OPT model compatible with HuggingFace weights."""
  21. from typing import List, Optional
  22. import torch
  23. from torch import nn
  24. from transformers import OPTConfig
  25. from aphrodite.attention import Attention, AttentionMetadata
  26. from aphrodite.modeling.layers.activation import get_act_fn
  27. from aphrodite.modeling.layers.linear import (
  28. ColumnParallelLinear,
  29. LinearMethodBase,
  30. QKVParallelLinear,
  31. ReplicatedLinear,
  32. RowParallelLinear,
  33. )
  34. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  35. from aphrodite.modeling.layers.sampler import Sampler
  36. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  37. VocabParallelEmbedding,
  38. ParallelLMHead,
  39. )
  40. from aphrodite.distributed import (
  41. get_tensor_model_parallel_world_size, )
  42. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  43. from aphrodite.modeling.hf_downloader import (
  44. default_weight_loader,
  45. hf_model_weights_iterator,
  46. )
  47. from aphrodite.common.sequence import SamplerOutput
  48. class OPTLearnedPositionalEmbedding(nn.Embedding):
  49. def __init__(self, num_embeddings: int, embedding_dim: int):
  50. # OPT is set up so that if padding_idx is specified then offset the
  51. # embedding ids by 2 and adjust num_embeddings appropriately. Other
  52. # models don't have this hack
  53. self.offset = 2
  54. super().__init__(num_embeddings + self.offset, embedding_dim)
  55. def forward(self, positions: torch.Tensor):
  56. return super().forward(positions + self.offset)
  57. class OPTAttention(nn.Module):
  58. def __init__(
  59. self,
  60. embed_dim: int,
  61. num_heads: int,
  62. bias: bool = True,
  63. linear_method: Optional[LinearMethodBase] = None,
  64. ) -> None:
  65. super().__init__()
  66. self.embed_dim = embed_dim
  67. tensor_model_parallel_world_size = (
  68. get_tensor_model_parallel_world_size())
  69. total_num_heads = num_heads
  70. assert num_heads % tensor_model_parallel_world_size == 0
  71. self.num_heads = total_num_heads // tensor_model_parallel_world_size
  72. self.head_dim = embed_dim // total_num_heads
  73. self.scaling = self.head_dim**-0.5
  74. if (linear_method is not None
  75. and not linear_method.quant_config.merge_weight()):
  76. self.merge_weight = False
  77. self.q_proj = ColumnParallelLinear(embed_dim,
  78. embed_dim,
  79. bias=bias,
  80. linear_method=linear_method)
  81. self.k_proj = ColumnParallelLinear(embed_dim,
  82. embed_dim,
  83. bias=bias,
  84. linear_method=linear_method)
  85. self.v_proj = ColumnParallelLinear(embed_dim,
  86. embed_dim,
  87. bias=bias,
  88. linear_method=linear_method)
  89. else:
  90. self.merge_weight = True
  91. self.qkv_proj = QKVParallelLinear(
  92. embed_dim,
  93. self.head_dim,
  94. total_num_heads,
  95. bias=bias,
  96. linear_method=linear_method,
  97. )
  98. self.out_proj = RowParallelLinear(
  99. embed_dim,
  100. embed_dim,
  101. bias=bias,
  102. linear_method=linear_method,
  103. )
  104. self.attn = Attention(self.num_heads,
  105. self.head_dim,
  106. scale=self.scaling)
  107. def forward(
  108. self,
  109. hidden_states: torch.Tensor,
  110. kv_cache: torch.Tensor,
  111. attn_metadata: AttentionMetadata,
  112. ) -> torch.Tensor:
  113. if self.merge_weight:
  114. qkv, _ = self.qkv_proj(hidden_states)
  115. q, k, v = qkv.chunk(chunks=3, dim=-1)
  116. else:
  117. q, _ = self.q_proj(hidden_states)
  118. k, _ = self.k_proj(hidden_states)
  119. v, _ = self.v_proj(hidden_states)
  120. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  121. output, _ = self.out_proj(attn_output)
  122. return output
  123. class OPTDecoderLayer(nn.Module):
  124. def __init__(
  125. self,
  126. config: OPTConfig,
  127. linear_method: Optional[LinearMethodBase] = None,
  128. ):
  129. super().__init__()
  130. self.config = config
  131. self.embed_dim = config.hidden_size
  132. self.self_attn = OPTAttention(
  133. embed_dim=self.embed_dim,
  134. num_heads=config.num_attention_heads,
  135. bias=config.enable_bias,
  136. linear_method=linear_method,
  137. )
  138. self.do_layer_norm_before = config.do_layer_norm_before
  139. self.self_attn_layer_norm = nn.LayerNorm(
  140. self.embed_dim,
  141. elementwise_affine=config.layer_norm_elementwise_affine,
  142. )
  143. self.fc1 = ColumnParallelLinear(
  144. self.embed_dim,
  145. config.ffn_dim,
  146. bias=config.enable_bias,
  147. linear_method=linear_method,
  148. )
  149. quant_config = getattr(linear_method, "quant_config", None)
  150. self.activation_fn = get_act_fn(config.activation_function,
  151. quant_config, config.ffn_dim)
  152. self.fc2 = RowParallelLinear(
  153. config.ffn_dim,
  154. self.embed_dim,
  155. bias=config.enable_bias,
  156. linear_method=linear_method,
  157. )
  158. self.final_layer_norm = nn.LayerNorm(
  159. self.embed_dim,
  160. elementwise_affine=config.layer_norm_elementwise_affine,
  161. )
  162. def forward(
  163. self,
  164. hidden_states: torch.Tensor,
  165. kv_cache: torch.Tensor,
  166. attn_metadata: AttentionMetadata,
  167. ) -> torch.Tensor:
  168. # Self Attention
  169. residual = hidden_states
  170. # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
  171. if self.do_layer_norm_before:
  172. hidden_states = self.self_attn_layer_norm(hidden_states)
  173. hidden_states = self.self_attn(
  174. hidden_states=hidden_states,
  175. kv_cache=kv_cache,
  176. attn_metadata=attn_metadata,
  177. )
  178. hidden_states = residual + hidden_states
  179. # 350m applies layer norm AFTER attention
  180. if not self.do_layer_norm_before:
  181. hidden_states = self.self_attn_layer_norm(hidden_states)
  182. # Fully Connected
  183. residual = hidden_states
  184. # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
  185. if self.do_layer_norm_before:
  186. hidden_states = self.final_layer_norm(hidden_states)
  187. hidden_states, _ = self.fc1(hidden_states)
  188. hidden_states = self.activation_fn(hidden_states)
  189. hidden_states, _ = self.fc2(hidden_states)
  190. hidden_states = residual + hidden_states
  191. # 350m applies layer norm AFTER attention
  192. if not self.do_layer_norm_before:
  193. hidden_states = self.final_layer_norm(hidden_states)
  194. return hidden_states
  195. class OPTDecoder(nn.Module):
  196. def __init__(
  197. self,
  198. config: OPTConfig,
  199. linear_method: Optional[LinearMethodBase] = None,
  200. ):
  201. super().__init__()
  202. self.config = config
  203. self.padding_idx = config.pad_token_id
  204. self.max_target_positions = config.max_position_embeddings
  205. self.vocab_size = config.vocab_size
  206. self.embed_tokens = VocabParallelEmbedding(
  207. config.vocab_size,
  208. config.word_embed_proj_dim,
  209. linear_method=linear_method,
  210. )
  211. # Positional embeddings are replicated (not sharded).
  212. self.embed_positions = OPTLearnedPositionalEmbedding(
  213. config.max_position_embeddings, config.hidden_size)
  214. # Project out & in will be replicated if they exist.
  215. if config.word_embed_proj_dim != config.hidden_size:
  216. self.project_out = ReplicatedLinear(
  217. config.hidden_size,
  218. config.word_embed_proj_dim,
  219. bias=False,
  220. linear_method=linear_method,
  221. )
  222. else:
  223. self.project_out = None
  224. if config.word_embed_proj_dim != config.hidden_size:
  225. self.project_in = ReplicatedLinear(
  226. config.word_embed_proj_dim,
  227. config.hidden_size,
  228. bias=False,
  229. linear_method=linear_method,
  230. )
  231. else:
  232. self.project_in = None
  233. # Note that the only purpose of `config._remove_final_layer_norm` is to
  234. # keep backward compatibility with checkpoints that have been fine-tuned
  235. # before transformers v4.20.1
  236. # see https://github.com/facebookresearch/metaseq/pull/164
  237. if config.do_layer_norm_before and not config._remove_final_layer_norm:
  238. self.final_layer_norm = nn.LayerNorm(
  239. config.hidden_size,
  240. elementwise_affine=config.layer_norm_elementwise_affine,
  241. )
  242. else:
  243. self.final_layer_norm = None
  244. self.layers = nn.ModuleList([
  245. OPTDecoderLayer(config, linear_method)
  246. for _ in range(config.num_hidden_layers)
  247. ])
  248. def forward(
  249. self,
  250. input_ids: torch.Tensor,
  251. positions: torch.Tensor,
  252. kv_caches: List[torch.Tensor],
  253. attn_metadata: AttentionMetadata,
  254. ) -> torch.Tensor:
  255. inputs_embeds = self.embed_tokens(input_ids)
  256. pos_embeds = self.embed_positions(positions)
  257. if self.project_in is not None:
  258. inputs_embeds, _ = self.project_in(inputs_embeds)
  259. hidden_states = inputs_embeds + pos_embeds
  260. for i in range(len(self.layers)):
  261. layer = self.layers[i]
  262. hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
  263. if self.final_layer_norm is not None:
  264. hidden_states = self.final_layer_norm(hidden_states)
  265. if self.project_out is not None:
  266. hidden_states, _ = self.project_out(hidden_states)
  267. return hidden_states
  268. class OPTModel(nn.Module):
  269. def __init__(
  270. self,
  271. config: OPTConfig,
  272. linear_method: Optional[LinearMethodBase] = None,
  273. ):
  274. super().__init__()
  275. self.decoder = OPTDecoder(config, linear_method)
  276. def forward(
  277. self,
  278. input_ids: torch.Tensor,
  279. positions: torch.Tensor,
  280. kv_caches: List[torch.Tensor],
  281. attn_metadata: AttentionMetadata,
  282. ) -> torch.Tensor:
  283. return self.decoder(input_ids, positions, kv_caches, attn_metadata)
  284. class OPTForCausalLM(nn.Module):
  285. def __init__(
  286. self,
  287. config,
  288. linear_method: Optional[LinearMethodBase] = None,
  289. ):
  290. super().__init__()
  291. self.config = config
  292. self.linear_method = linear_method
  293. self.model = OPTModel(config, linear_method)
  294. self.lm_head = ParallelLMHead(config.vocab_size,
  295. config.hidden_size,
  296. linear_method=linear_method)
  297. self.logits_processor = LogitsProcessor(config.vocab_size)
  298. self.sampler = Sampler()
  299. def forward(
  300. self,
  301. input_ids: torch.Tensor,
  302. positions: torch.Tensor,
  303. kv_caches: List[torch.Tensor],
  304. attn_metadata: AttentionMetadata,
  305. ) -> torch.Tensor:
  306. hidden_states = self.model(input_ids, positions, kv_caches,
  307. attn_metadata)
  308. return hidden_states
  309. def compute_logits(self, hidden_states: torch.Tensor,
  310. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  311. logits = self.logits_processor(self.lm_head, hidden_states,
  312. sampling_metadata)
  313. return logits
  314. def sample(
  315. self,
  316. logits: torch.Tensor,
  317. sampling_metadata: SamplingMetadata,
  318. ) -> Optional[SamplerOutput]:
  319. next_tokens = self.sampler(logits, sampling_metadata)
  320. return next_tokens
  321. def load_weights(
  322. self,
  323. model_name_or_path: str,
  324. cache_dir: Optional[str] = None,
  325. load_format: str = "auto",
  326. revision: Optional[str] = None,
  327. ):
  328. stacked_params_mapping = [
  329. # (param_name, shard_name, shard_id)
  330. ("qkv_proj", "q_proj", "q"),
  331. ("qkv_proj", "k_proj", "k"),
  332. ("qkv_proj", "v_proj", "v"),
  333. ]
  334. if (self.linear_method is not None
  335. and not self.linear_method.quant_config.merge_weight()):
  336. stacked_params_mapping = []
  337. params_dict = dict(self.named_parameters(remove_duplicate=False))
  338. for name, loaded_weight in hf_model_weights_iterator(
  339. model_name_or_path, cache_dir, load_format, revision,
  340. self.config):
  341. if "lm_head" in name and name not in params_dict:
  342. continue
  343. if "embed_tokens" in name:
  344. # Copy word embedding to lm_head
  345. if name.startswith("decoder."):
  346. name = "model." + name
  347. head_name = name.replace("model.decoder.embed_tokens",
  348. "lm_head")
  349. if head_name in params_dict:
  350. lm_head_param = params_dict[head_name]
  351. weight_loader = getattr(lm_head_param, "weight_loader",
  352. default_weight_loader)
  353. weight_loader(lm_head_param, loaded_weight)
  354. if name.startswith("decoder."):
  355. name = "model." + name
  356. for param_name, weight_name, shard_id in stacked_params_mapping:
  357. if weight_name not in name:
  358. continue
  359. name = name.replace(weight_name, param_name)
  360. # Skip loading extra bias for GPTQ models.
  361. if name.endswith(".bias") and name not in params_dict:
  362. continue
  363. param = params_dict[name]
  364. weight_loader = param.weight_loader
  365. weight_loader(param, loaded_weight, shard_id)
  366. break
  367. else:
  368. # Skip loading extra bias for GPTQ models.
  369. if name.endswith(".bias") and name not in params_dict:
  370. continue
  371. param = params_dict[name]
  372. weight_loader = getattr(param, "weight_loader",
  373. default_weight_loader)
  374. weight_loader(param, loaded_weight)