bart.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993
  1. # Derived from BART implementation posted on HuggingFace; license below:
  2. #
  3. # coding=utf-8
  4. # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team.
  5. # All rights reserved.
  6. #
  7. # Licensed under the Apache License, Version 2.0 (the "License");
  8. # you may not use this file except in compliance with the License.
  9. # You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing, software
  14. # distributed under the License is distributed on an "AS IS" BASIS,
  15. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. # See the License for the specific language governing permissions and
  17. # limitations under the License.
  18. """PyTorch BART model."""
  19. import math
  20. from typing import Iterable, List, Optional, Tuple
  21. import torch
  22. from torch import nn
  23. from transformers import BartConfig
  24. from aphrodite.attention import Attention, AttentionMetadata, AttentionType
  25. from aphrodite.common.config import CacheConfig, LoRAConfig
  26. from aphrodite.common.sequence import IntermediateTensors
  27. from aphrodite.distributed import get_tensor_model_parallel_world_size
  28. from aphrodite.modeling.layers.activation import get_act_fn
  29. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  30. QKVParallelLinear,
  31. RowParallelLinear)
  32. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  33. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  34. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  35. ParallelLMHead, VocabParallelEmbedding)
  36. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  37. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  38. from aphrodite.quantization.base_config import QuantizationConfig
  39. def get_bsz_seq_len(input_ids):
  40. shp = input_ids.shape
  41. ndim = len(shp)
  42. if ndim == 1:
  43. return 1, input_ids.numel()
  44. else:
  45. return shp[:2]
  46. class BartLearnedPositionalEmbedding(VocabParallelEmbedding):
  47. """
  48. This module learns positional embeddings up to a fixed maximum size.
  49. """
  50. def __init__(self, num_embeddings: int, embedding_dim: int):
  51. # Bart is set up so that if padding_idx is
  52. # specified then offset the embedding ids by 2
  53. # and adjust num_embeddings appropriately.
  54. # Other models don't have this hack
  55. self.offset = 2
  56. super().__init__(num_embeddings + self.offset, embedding_dim)
  57. def forward(
  58. self,
  59. positions: torch.Tensor,
  60. attn_type: AttentionType,
  61. ) -> torch.Tensor:
  62. """`input_ids' shape is expected to be [bsz x seqlen]."""
  63. assert attn_type != AttentionType.ENCODER_DECODER
  64. return super().forward(positions + self.offset)
  65. class BartScaledWordEmbedding(VocabParallelEmbedding):
  66. """
  67. This module overrides VocabParallelEmbedding's
  68. forward by multiplying with embeddings scale.
  69. """
  70. def __init__(self,
  71. num_embeddings: int,
  72. embedding_dim: int,
  73. embed_scale: float = 1.0):
  74. super().__init__(num_embeddings, embedding_dim)
  75. self.embed_scale = embed_scale
  76. def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
  77. return super().forward(input_ids) * self.embed_scale
  78. class BartParallelLMHead(ParallelLMHead):
  79. """
  80. This module overrides ParallelLMHead's
  81. forward by dividing by embeddings scale,
  82. yielding effectively the inverse of
  83. BartScaledWordEmbedding
  84. """
  85. def __init__(self,
  86. num_embeddings: int,
  87. embedding_dim: int,
  88. embed_scale: float = 1.0):
  89. super().__init__(num_embeddings, embedding_dim)
  90. self.embed_scale = embed_scale
  91. def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
  92. return super().forward(input_ids) / self.embed_scale
  93. class BartEncoderAttention(nn.Module):
  94. def __init__(
  95. self,
  96. embed_dim: int,
  97. num_heads: int,
  98. bias: bool = True,
  99. config: Optional[BartConfig] = None,
  100. cache_config: Optional[CacheConfig] = None,
  101. quant_config: Optional[QuantizationConfig] = None,
  102. ):
  103. super().__init__()
  104. self.d_model = config.d_model
  105. self.embed_dim = embed_dim
  106. self.total_num_heads = num_heads
  107. self.total_num_kv_heads = self.total_num_heads
  108. self.head_dim = embed_dim // num_heads
  109. self.config = config
  110. if (self.head_dim * num_heads) != self.embed_dim:
  111. raise ValueError(f"embed_dim must be divisible by num_heads "
  112. f"(got `embed_dim`: {self.embed_dim}"
  113. f" and `num_heads`: {num_heads}).")
  114. self.scaling = self.head_dim**-0.5
  115. self.qkv_proj = QKVParallelLinear(
  116. self.d_model,
  117. self.d_model // self.total_num_heads,
  118. self.total_num_heads,
  119. self.total_num_kv_heads,
  120. bias=bias,
  121. quant_config=quant_config,
  122. )
  123. self.out_proj = RowParallelLinear(
  124. embed_dim,
  125. embed_dim,
  126. bias=bias,
  127. quant_config=quant_config,
  128. )
  129. tp_world_size = get_tensor_model_parallel_world_size()
  130. assert self.total_num_heads % tp_world_size == 0
  131. self.num_heads = self.total_num_heads // tp_world_size
  132. if self.total_num_kv_heads >= tp_world_size:
  133. # Number of KV heads is greater than TP size, so we partition
  134. # the KV heads across multiple tensor parallel GPUs.
  135. assert self.total_num_kv_heads % tp_world_size == 0
  136. else:
  137. # Number of KV heads is less than TP size, so we replicate
  138. # the KV heads across multiple tensor parallel GPUs.
  139. assert tp_world_size % self.total_num_kv_heads == 0
  140. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
  141. self.q_size = self.num_heads * self.head_dim
  142. self.kv_size = self.num_kv_heads * self.head_dim
  143. self.attn = Attention(self.num_heads,
  144. self.head_dim,
  145. self.scaling,
  146. num_kv_heads=self.num_kv_heads,
  147. cache_config=cache_config,
  148. quant_config=quant_config)
  149. def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
  150. attn_metadata: AttentionMetadata) -> torch.Tensor:
  151. """Input shape: Batch x Time x Channel"""
  152. qkv, _ = self.qkv_proj(hidden_states)
  153. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  154. attn_output = self.attn(q,
  155. k,
  156. v,
  157. kv_cache,
  158. attn_metadata,
  159. attn_type=AttentionType.ENCODER)
  160. output, _ = self.out_proj(attn_output)
  161. return output
  162. class BartDecoderSelfAttention(nn.Module):
  163. def __init__(
  164. self,
  165. embed_dim: int,
  166. num_heads: int,
  167. bias: bool = True,
  168. config: Optional[BartConfig] = None,
  169. cache_config: Optional[CacheConfig] = None,
  170. quant_config: Optional[QuantizationConfig] = None,
  171. ):
  172. super().__init__()
  173. self.d_model = config.d_model
  174. self.embed_dim = embed_dim
  175. self.total_num_heads = num_heads
  176. self.total_num_kv_heads = self.total_num_heads
  177. self.head_dim = embed_dim // num_heads
  178. self.config = config
  179. if (self.head_dim * num_heads) != self.embed_dim:
  180. raise ValueError(f"embed_dim must be divisible by num_heads "
  181. f"(got `embed_dim`: {self.embed_dim}"
  182. f" and `num_heads`: {num_heads}).")
  183. self.scaling = self.head_dim**-0.5
  184. self.qkv_proj = QKVParallelLinear(
  185. self.d_model,
  186. self.d_model // self.total_num_heads,
  187. self.total_num_heads,
  188. self.total_num_kv_heads,
  189. bias=bias,
  190. quant_config=quant_config,
  191. )
  192. self.out_proj = RowParallelLinear(
  193. embed_dim,
  194. embed_dim,
  195. bias=bias,
  196. quant_config=quant_config,
  197. )
  198. tp_world_size = get_tensor_model_parallel_world_size()
  199. assert self.total_num_heads % tp_world_size == 0
  200. self.num_heads = self.total_num_heads // tp_world_size
  201. if self.total_num_kv_heads >= tp_world_size:
  202. # Number of KV heads is greater than TP size, so we partition
  203. # the KV heads across multiple tensor parallel GPUs.
  204. assert self.total_num_kv_heads % tp_world_size == 0
  205. else:
  206. # Number of KV heads is less than TP size, so we replicate
  207. # the KV heads across multiple tensor parallel GPUs.
  208. assert tp_world_size % self.total_num_kv_heads == 0
  209. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
  210. self.q_size = self.num_heads * self.head_dim
  211. self.kv_size = self.num_kv_heads * self.head_dim
  212. self.attn = Attention(self.num_heads,
  213. self.head_dim,
  214. self.scaling,
  215. num_kv_heads=self.num_kv_heads,
  216. cache_config=cache_config,
  217. quant_config=quant_config)
  218. def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
  219. attn_metadata: AttentionMetadata) -> torch.Tensor:
  220. """Input shape: Batch x Time x Channel"""
  221. qkv, _ = self.qkv_proj(hidden_states)
  222. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  223. attn_output = self.attn(q,
  224. k,
  225. v,
  226. kv_cache,
  227. attn_metadata,
  228. attn_type=AttentionType.DECODER)
  229. output, _ = self.out_proj(attn_output)
  230. return output
  231. class BartCrossAttention(nn.Module):
  232. def __init__(
  233. self,
  234. embed_dim: int,
  235. num_heads: int,
  236. bias: bool = True,
  237. config: Optional[BartConfig] = None,
  238. cache_config: Optional[CacheConfig] = None,
  239. quant_config: Optional[QuantizationConfig] = None,
  240. ):
  241. super().__init__()
  242. self.d_model = config.d_model
  243. self.embed_dim = embed_dim
  244. self.total_num_heads = num_heads
  245. self.total_num_kv_heads = self.total_num_heads
  246. self.head_dim = embed_dim // num_heads
  247. self.config = config
  248. if (self.head_dim * num_heads) != self.embed_dim:
  249. raise ValueError(f"embed_dim must be divisible by num_heads "
  250. f"(got `embed_dim`: {self.embed_dim}"
  251. f" and `num_heads`: {num_heads}).")
  252. self.scaling = self.head_dim**-0.5
  253. self.qkv_proj = QKVParallelLinear(
  254. self.d_model,
  255. self.d_model // self.total_num_heads,
  256. self.total_num_heads,
  257. self.total_num_kv_heads,
  258. bias=bias,
  259. quant_config=quant_config,
  260. )
  261. self.out_proj = RowParallelLinear(
  262. embed_dim,
  263. embed_dim,
  264. bias=bias,
  265. quant_config=quant_config,
  266. )
  267. tp_world_size = get_tensor_model_parallel_world_size()
  268. assert self.total_num_heads % tp_world_size == 0
  269. self.num_heads = self.total_num_heads // tp_world_size
  270. if self.total_num_kv_heads >= tp_world_size:
  271. # Number of KV heads is greater than TP size, so we partition
  272. # the KV heads across multiple tensor parallel GPUs.
  273. assert self.total_num_kv_heads % tp_world_size == 0
  274. else:
  275. # Number of KV heads is less than TP size, so we replicate
  276. # the KV heads across multiple tensor parallel GPUs.
  277. assert tp_world_size % self.total_num_kv_heads == 0
  278. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
  279. self.q_size = self.num_heads * self.head_dim
  280. self.kv_size = self.num_kv_heads * self.head_dim
  281. self.attn = Attention(self.num_heads,
  282. self.head_dim,
  283. self.scaling,
  284. num_kv_heads=self.num_kv_heads,
  285. cache_config=cache_config,
  286. quant_config=quant_config)
  287. def forward(
  288. self,
  289. decoder_hidden_states: torch.Tensor,
  290. kv_cache: torch.Tensor,
  291. attn_metadata: AttentionMetadata,
  292. encoder_hidden_states: Optional[torch.Tensor] = None,
  293. ) -> torch.Tensor:
  294. """Input shape: Batch x Time x Channel"""
  295. # (afeldman-nm 2024/07/22) TODO:
  296. # Need a more efficient solution for q/k/v
  297. qkv_dec, _ = self.qkv_proj(decoder_hidden_states)
  298. q, _, _ = qkv_dec.split([self.q_size, self.kv_size, self.kv_size],
  299. dim=-1)
  300. if encoder_hidden_states is None:
  301. k = None
  302. v = None
  303. else:
  304. qkv_enc, _ = self.qkv_proj(encoder_hidden_states)
  305. _, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size],
  306. dim=-1)
  307. attn_output = self.attn(q,
  308. k,
  309. v,
  310. kv_cache,
  311. attn_metadata,
  312. attn_type=AttentionType.ENCODER_DECODER)
  313. output, _ = self.out_proj(attn_output)
  314. return output
  315. class BartEncoderLayer(nn.Module):
  316. def __init__(
  317. self,
  318. config: BartConfig,
  319. cache_config: Optional[CacheConfig] = None,
  320. quant_config: Optional[QuantizationConfig] = None,
  321. ):
  322. super().__init__()
  323. self.embed_dim = config.d_model
  324. self.self_attn = BartEncoderAttention(
  325. embed_dim=self.embed_dim,
  326. num_heads=config.encoder_attention_heads,
  327. config=config,
  328. cache_config=cache_config,
  329. quant_config=quant_config)
  330. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  331. self.activation_fn = get_act_fn(config.activation_function,
  332. quant_config)
  333. ffn_hidden_size = self.embed_dim
  334. ffn_intermediate_size = config.encoder_ffn_dim
  335. ffn_has_bias = True
  336. self.fc1 = ColumnParallelLinear(
  337. ffn_hidden_size,
  338. ffn_intermediate_size,
  339. bias=ffn_has_bias,
  340. quant_config=quant_config,
  341. )
  342. self.act = get_act_fn("gelu", quant_config, ffn_intermediate_size)
  343. self.fc2 = RowParallelLinear(
  344. ffn_intermediate_size,
  345. ffn_hidden_size,
  346. bias=ffn_has_bias,
  347. quant_config=quant_config,
  348. )
  349. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  350. def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
  351. attn_metadata: AttentionMetadata) -> torch.Tensor:
  352. r"""
  353. Args:
  354. hidden_states
  355. torch.Tensor of *encoder* input embeddings.
  356. kv_cache:
  357. Layer-wise list of KV cache tensors
  358. attn_metadata:
  359. Aphrodite Attention metadata structure
  360. Returns:
  361. Encoder layer output torch.Tensor
  362. """
  363. residual = hidden_states
  364. hidden_states = self.self_attn(hidden_states=hidden_states,
  365. kv_cache=kv_cache,
  366. attn_metadata=attn_metadata)
  367. hidden_states = residual + hidden_states
  368. hidden_states = self.self_attn_layer_norm(hidden_states)
  369. residual = hidden_states
  370. fc1_out, _ = self.fc1(hidden_states)
  371. hidden_states = self.activation_fn(fc1_out)
  372. hidden_states, _ = self.fc2(hidden_states)
  373. hidden_states = residual + hidden_states
  374. hidden_states = self.final_layer_norm(hidden_states)
  375. if hidden_states.dtype == torch.float16 and (
  376. torch.isinf(hidden_states).any()
  377. or torch.isnan(hidden_states).any()):
  378. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  379. hidden_states = torch.clamp(hidden_states,
  380. min=-clamp_value,
  381. max=clamp_value)
  382. return hidden_states
  383. class BartDecoderLayer(nn.Module):
  384. def __init__(
  385. self,
  386. config: BartConfig,
  387. cache_config: Optional[CacheConfig] = None,
  388. quant_config: Optional[QuantizationConfig] = None,
  389. ):
  390. super().__init__()
  391. self.embed_dim = config.d_model
  392. self.self_attn = BartDecoderSelfAttention(
  393. embed_dim=self.embed_dim,
  394. num_heads=config.decoder_attention_heads,
  395. config=config,
  396. cache_config=cache_config,
  397. quant_config=quant_config)
  398. self.activation_fn = get_act_fn(config.activation_function,
  399. quant_config)
  400. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  401. '''
  402. afeldman-nm: personally I would call this "cross-attention",
  403. however I left the name as "encoder_attn" to maintain consistency
  404. with the name of the pretrained weights.
  405. '''
  406. self.encoder_attn = BartCrossAttention(
  407. self.embed_dim,
  408. config.decoder_attention_heads,
  409. config=config,
  410. )
  411. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  412. ffn_hidden_size = self.embed_dim
  413. ffn_intermediate_size = config.encoder_ffn_dim
  414. ffn_has_bias = True
  415. self.fc1 = ColumnParallelLinear(
  416. ffn_hidden_size,
  417. ffn_intermediate_size,
  418. bias=ffn_has_bias,
  419. quant_config=quant_config,
  420. )
  421. self.fc2 = RowParallelLinear(
  422. ffn_intermediate_size,
  423. ffn_hidden_size,
  424. bias=ffn_has_bias,
  425. quant_config=quant_config,
  426. )
  427. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  428. def forward(
  429. self,
  430. decoder_hidden_states: torch.Tensor,
  431. kv_cache: torch.Tensor,
  432. attn_metadata: AttentionMetadata,
  433. encoder_hidden_states: Optional[torch.Tensor] = None,
  434. ) -> torch.Tensor:
  435. r"""
  436. Args:
  437. decoder_hidden_states
  438. torch.Tensor of *decoder* input embeddings.
  439. kv_cache:
  440. KV cache tensor
  441. attn_metadata:
  442. Aphrodite Attention metadata structure
  443. encoder_hidden_states
  444. torch.Tensor of *encoder* input embeddings.
  445. Returns:
  446. Decoder layer output torch.Tensor
  447. """
  448. residual = decoder_hidden_states
  449. # Self Attention
  450. hidden_states = self.self_attn(hidden_states=decoder_hidden_states,
  451. kv_cache=kv_cache,
  452. attn_metadata=attn_metadata)
  453. hidden_states = residual + hidden_states
  454. hidden_states = self.self_attn_layer_norm(hidden_states)
  455. # Cross-Attention Block
  456. residual = hidden_states
  457. hidden_states = self.encoder_attn(
  458. decoder_hidden_states=hidden_states,
  459. kv_cache=kv_cache,
  460. attn_metadata=attn_metadata,
  461. encoder_hidden_states=encoder_hidden_states,
  462. )
  463. hidden_states = residual + hidden_states
  464. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  465. # Fully Connected
  466. residual = hidden_states
  467. fc1_out, _ = self.fc1(hidden_states)
  468. hidden_states = self.activation_fn(fc1_out)
  469. hidden_states, _ = self.fc2(hidden_states)
  470. hidden_states = residual + hidden_states
  471. hidden_states = self.final_layer_norm(hidden_states)
  472. return hidden_states
  473. class BartEncoder(nn.Module):
  474. """
  475. Transformer encoder consisting of *config.encoder_layers*
  476. self attention layers. Each layer is a [`BartEncoderLayer`].
  477. Args:
  478. config: BartConfig
  479. embed_tokens (nn.Embedding): output embedding
  480. """
  481. def __init__(self,
  482. config: BartConfig,
  483. cache_config: Optional[CacheConfig] = None,
  484. quant_config: Optional[QuantizationConfig] = None,
  485. lora_config: Optional[LoRAConfig] = None,
  486. embed_tokens: Optional[nn.Embedding] = None):
  487. super().__init__()
  488. self.cache_config = cache_config
  489. self.quant_config = quant_config
  490. self.lora_config = lora_config
  491. embed_dim = config.d_model
  492. self.max_source_positions = config.max_position_embeddings
  493. embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
  494. self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
  495. embed_dim,
  496. embed_scale=embed_scale)
  497. if embed_tokens is not None:
  498. self.embed_tokens.weight = embed_tokens.weight
  499. self.embed_positions = BartLearnedPositionalEmbedding(
  500. config.max_position_embeddings,
  501. embed_dim,
  502. )
  503. self.layers = nn.ModuleList(
  504. [BartEncoderLayer(config,cache_config,quant_config) \
  505. for _ in range(config.encoder_layers)])
  506. self.layernorm_embedding = nn.LayerNorm(embed_dim)
  507. def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
  508. kv_caches: List[torch.Tensor],
  509. attn_metadata: AttentionMetadata) -> torch.Tensor:
  510. r"""
  511. Args:
  512. input_ids
  513. Indices of *encoder* input sequence tokens in the vocabulary.
  514. Padding will be ignored by default should you
  515. provide it.
  516. positions
  517. Positions of *encoder* input sequence tokens.
  518. kv_caches:
  519. Layer-wise list of KV cache tensors
  520. attn_metadata:
  521. Aphrodite Attention metadata structure
  522. Returns:
  523. Decoder output torch.Tensor
  524. """
  525. # retrieve input_ids and inputs_embeds
  526. input_ids = input_ids.view(-1, input_ids.shape[-1])
  527. inputs_embeds = self.embed_tokens(input_ids)
  528. embed_pos = self.embed_positions(
  529. positions,
  530. AttentionType.ENCODER,
  531. )
  532. embed_pos = embed_pos.to(inputs_embeds.device)
  533. hidden_states = inputs_embeds + embed_pos
  534. hidden_states = self.layernorm_embedding(hidden_states)
  535. for idx, encoder_layer in enumerate(self.layers):
  536. hidden_states = encoder_layer(
  537. hidden_states=hidden_states,
  538. kv_cache=kv_caches[idx],
  539. attn_metadata=attn_metadata,
  540. )
  541. return hidden_states
  542. class BartDecoder(nn.Module):
  543. """
  544. Transformer decoder consisting of *config.decoder_layers* layers.
  545. Each layer is a [`BartDecoderLayer`]
  546. Args:
  547. config: BartConfig
  548. embed_tokens (nn.Embedding): output embedding
  549. """
  550. def __init__(
  551. self,
  552. config: BartConfig,
  553. cache_config: Optional[CacheConfig] = None,
  554. quant_config: Optional[QuantizationConfig] = None,
  555. lora_config: Optional[LoRAConfig] = None,
  556. embed_tokens: Optional[nn.Embedding] = None,
  557. ):
  558. super().__init__()
  559. self.cache_config = cache_config
  560. self.quant_config = quant_config
  561. self.lora_config = lora_config
  562. self.max_target_positions = config.max_position_embeddings
  563. embed_scale = math.sqrt(
  564. config.d_model) if config.scale_embedding else 1.0
  565. self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
  566. config.d_model,
  567. embed_scale=embed_scale)
  568. if embed_tokens is not None:
  569. self.embed_tokens.weight = embed_tokens.weight
  570. self.embed_positions = BartLearnedPositionalEmbedding(
  571. config.max_position_embeddings,
  572. config.d_model,
  573. )
  574. self.layers = nn.ModuleList(
  575. [BartDecoderLayer(config,cache_config,quant_config) \
  576. for _ in range(config.decoder_layers)])
  577. self.layernorm_embedding = nn.LayerNorm(config.d_model)
  578. def forward(self, decoder_input_ids: torch.Tensor,
  579. decoder_positions: torch.Tensor,
  580. encoder_hidden_states: Optional[torch.Tensor],
  581. kv_caches: List[torch.Tensor],
  582. attn_metadata: AttentionMetadata) -> torch.Tensor:
  583. r"""
  584. Args:
  585. decoder_input_ids
  586. Indices of *decoder* input sequence tokens in the vocabulary.
  587. Padding will be ignored by default should you
  588. provide it.
  589. decoder_positions
  590. Positions of *decoder* input sequence tokens.
  591. encoder_hidden_states:
  592. Tensor of encoder output embeddings
  593. kv_caches:
  594. Layer-wise list of KV cache tensors
  595. attn_metadata:
  596. Aphrodite Attention metadata structure
  597. Returns:
  598. Decoder output torch.Tensor
  599. """
  600. inputs_embeds = self.embed_tokens(decoder_input_ids)
  601. # embed positions
  602. embed_pos = self.embed_positions(
  603. decoder_positions,
  604. AttentionType.DECODER,
  605. )
  606. embed_pos = embed_pos.to(inputs_embeds.device)
  607. hidden_states = inputs_embeds + embed_pos
  608. hidden_states = self.layernorm_embedding(hidden_states)
  609. # decoder layers
  610. for idx, decoder_layer in enumerate(self.layers):
  611. hidden_states = decoder_layer(
  612. decoder_hidden_states=hidden_states,
  613. kv_cache=kv_caches[idx],
  614. attn_metadata=attn_metadata,
  615. encoder_hidden_states=encoder_hidden_states,
  616. )
  617. return hidden_states
  618. class BartModel(nn.Module):
  619. _tied_weights_keys = [
  620. "encoder.embed_tokens.weight", "decoder.embed_tokens.weight"
  621. ]
  622. def __init__(self,
  623. config: BartConfig,
  624. cache_config: Optional[CacheConfig] = None,
  625. quant_config: Optional[QuantizationConfig] = None,
  626. lora_config: Optional[LoRAConfig] = None):
  627. super().__init__()
  628. self.config = config
  629. self.padding_idx = config.pad_token_id
  630. lora_vocab = (lora_config.lora_extra_vocab_size *
  631. (lora_config.max_loras or 1)) if lora_config else 0
  632. self.vocab_size = config.vocab_size + lora_vocab
  633. self.org_vocab_size = config.vocab_size
  634. self.encoder = BartEncoder(config,
  635. cache_config,
  636. quant_config=quant_config)
  637. self.decoder = BartDecoder(config,
  638. cache_config,
  639. quant_config=quant_config)
  640. def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
  641. encoder_input_ids: torch.Tensor,
  642. encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor],
  643. attn_metadata: AttentionMetadata) -> torch.Tensor:
  644. r"""
  645. Args:
  646. input_ids
  647. Indices of *decoder* input sequence tokens in the vocabulary.
  648. Padding will be ignored by default should you
  649. provide it.
  650. positions
  651. Positions of *decoder* input sequence tokens.
  652. encoder_input_ids
  653. Indices of *encoder* input sequence tokens in the vocabulary.
  654. encoder_positions:
  655. Positions of *encoder* input sequence tokens.
  656. kv_caches:
  657. Layer-wise list of KV cache tensors
  658. attn_metadata:
  659. Aphrodite Attention metadata structure
  660. Returns:
  661. Model output torch.Tensor
  662. """
  663. encoder_hidden_states = None
  664. if encoder_input_ids.numel() > 0:
  665. # Run encoder attention if a non-zero number of encoder tokens
  666. # are provided as input
  667. encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
  668. positions=encoder_positions,
  669. kv_caches=kv_caches,
  670. attn_metadata=attn_metadata)
  671. # decoder outputs consists of
  672. # (dec_features, past_key_value, dec_hidden, dec_attn)
  673. decoder_outputs = self.decoder(
  674. decoder_input_ids=input_ids,
  675. decoder_positions=positions,
  676. encoder_hidden_states=encoder_hidden_states,
  677. kv_caches=kv_caches,
  678. attn_metadata=attn_metadata)
  679. return decoder_outputs
  680. class BartForConditionalGeneration(nn.Module):
  681. base_model_prefix = "model"
  682. def __init__(self,
  683. config: BartConfig,
  684. cache_config: Optional[CacheConfig] = None,
  685. quant_config: Optional[QuantizationConfig] = None,
  686. lora_config: Optional[LoRAConfig] = None):
  687. super().__init__()
  688. self.config = config
  689. self.model = BartModel(config,
  690. cache_config,
  691. quant_config,
  692. lora_config=lora_config)
  693. self.unpadded_vocab_size = config.vocab_size
  694. if lora_config:
  695. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  696. embed_scale = math.sqrt(
  697. config.d_model) if config.scale_embedding else 1.0
  698. self.lm_head = BartParallelLMHead(config.vocab_size,
  699. config.d_model,
  700. embed_scale=embed_scale)
  701. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  702. config.vocab_size)
  703. self.sampler = Sampler()
  704. def forward(
  705. self,
  706. input_ids: torch.Tensor,
  707. positions: torch.Tensor,
  708. encoder_input_ids: torch.Tensor,
  709. encoder_positions: torch.Tensor,
  710. kv_caches: List[torch.Tensor],
  711. attn_metadata: AttentionMetadata,
  712. intermediate_tensors: Optional[IntermediateTensors] = None,
  713. ) -> torch.Tensor:
  714. r"""
  715. Args:
  716. input_ids
  717. torch.Tensor of *decoder* input token ids.
  718. positions
  719. torch.Tensor of *decoder* position indices.
  720. encoder_input_ids
  721. torch.Tensor of *encoder* input token ids.
  722. encoder_positions
  723. torch.Tensor of *encoder* position indices
  724. kv_caches:
  725. Layer-wise list of KV cache tensors
  726. attn_metadata:
  727. Aphrodite Attention metadata structure
  728. Returns:
  729. Output torch.Tensor
  730. """
  731. return self.model(input_ids, positions, encoder_input_ids,
  732. encoder_positions, kv_caches, attn_metadata)
  733. def compute_logits(
  734. self,
  735. hidden_states: torch.Tensor,
  736. sampling_metadata: SamplingMetadata,
  737. ) -> Optional[torch.Tensor]:
  738. logits = self.logits_processor(self.lm_head, hidden_states,
  739. sampling_metadata)
  740. return logits
  741. def sample(
  742. self,
  743. logits: Optional[torch.Tensor],
  744. sampling_metadata: SamplingMetadata,
  745. ) -> Optional[SamplerOutput]:
  746. next_tokens = self.sampler(logits, sampling_metadata)
  747. return next_tokens
  748. stacked_params_mapping = {
  749. "q_proj": {
  750. "param_name": "qkv_proj",
  751. "shard_id": "q",
  752. },
  753. "k_proj": {
  754. "param_name": "qkv_proj",
  755. "shard_id": "k",
  756. },
  757. "v_proj": {
  758. "param_name": "qkv_proj",
  759. "shard_id": "v",
  760. },
  761. }
  762. params_mapping = {
  763. "beta": "bias",
  764. "gamma": "weight",
  765. "LayerNorm": "layernorm",
  766. }
  767. def _rename_key(self, key: str):
  768. prefix = f"{self.base_model_prefix}."
  769. key = key[len(prefix):] if key.startswith(prefix) else key
  770. for src, dst in self.params_mapping.items():
  771. key = key.replace(src, dst)
  772. return key
  773. def _rename_stacked_param(
  774. self,
  775. name: str,
  776. ) -> Tuple[str, Optional[str]]:
  777. for key, mapping in self.stacked_params_mapping.items():
  778. if key in name:
  779. name = name.replace(key, mapping["param_name"])
  780. return name, mapping["shard_id"]
  781. return name, None
  782. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  783. model_params_dict = dict(self.model.named_parameters())
  784. top_params_dict = dict(self.named_parameters())
  785. shared_embedding_weight = None
  786. shared_embedding_shard_id = None
  787. for name, loaded_weight in weights:
  788. name = self._rename_key(name)
  789. name, shard_id = self._rename_stacked_param(name)
  790. if ('shared.weight' in name
  791. or 'encoder.embed_tokens.weight' in name
  792. or 'decoder.embed_tokens.weight' in name
  793. or 'lm_head.weight' in name):
  794. assert shared_embedding_weight is None, (
  795. "Conflicting embedding weights.")
  796. shared_embedding_weight = loaded_weight
  797. shared_embedding_shard_id = shard_id
  798. else:
  799. # Skip the specific downstream task weight.
  800. if name.startswith('cls.'):
  801. continue
  802. # use Pooler instead.
  803. if name.startswith('pooler.'):
  804. continue
  805. # Skip loading extra bias for GPTQ models.
  806. if name.endswith(".bias") and name not in model_params_dict:
  807. continue
  808. param = model_params_dict[name]
  809. weight_loader = getattr(param, "weight_loader",
  810. default_weight_loader)
  811. if shard_id:
  812. weight_loader(param, loaded_weight, shard_id)
  813. else:
  814. weight_loader(param, loaded_weight)
  815. # Assign shared weight values
  816. encoder_in_param = model_params_dict['encoder.embed_tokens.weight']
  817. encoder_in_weight_loader = getattr(encoder_in_param, "weight_loader",
  818. default_weight_loader)
  819. decoder_in_param = model_params_dict['decoder.embed_tokens.weight']
  820. decoder_in_weight_loader = getattr(decoder_in_param, "weight_loader",
  821. default_weight_loader)
  822. lm_head_in_param = top_params_dict['lm_head.weight']
  823. lm_head_in_weight_loader = getattr(lm_head_in_param, "weight_loader",
  824. default_weight_loader)
  825. assert shared_embedding_weight is not None
  826. if shared_embedding_shard_id:
  827. encoder_in_weight_loader(encoder_in_param, shared_embedding_weight,
  828. shared_embedding_shard_id)
  829. decoder_in_weight_loader(decoder_in_param, shared_embedding_weight,
  830. shared_embedding_shard_id)
  831. lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight,
  832. shared_embedding_shard_id)
  833. else:
  834. encoder_in_weight_loader(encoder_in_param, shared_embedding_weight)
  835. decoder_in_weight_loader(decoder_in_param, shared_embedding_weight)
  836. lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight)