1
0

xformers.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802
  1. """Attention layer with xFormers and PagedAttention."""
  2. from dataclasses import dataclass
  3. from typing import Any, Dict, List, Optional, Tuple, Type
  4. import torch
  5. from xformers import ops as xops
  6. from xformers.ops.fmha.attn_bias import (AttentionBias,
  7. BlockDiagonalCausalMask,
  8. BlockDiagonalMask,
  9. LowerTriangularMaskWithTensorBias)
  10. from aphrodite.attention.backends.abstract import (AttentionBackend,
  11. AttentionImpl,
  12. AttentionMetadata,
  13. AttentionType)
  14. from aphrodite.attention.backends.utils import (CommonAttentionState,
  15. CommonMetadataBuilder)
  16. from aphrodite.attention.ops.paged_attn import (PagedAttention,
  17. PagedAttentionMetadata)
  18. class XFormersBackend(AttentionBackend):
  19. @staticmethod
  20. def get_name() -> str:
  21. return "xformers"
  22. @staticmethod
  23. def get_impl_cls() -> Type["XFormersImpl"]:
  24. return XFormersImpl
  25. @staticmethod
  26. def get_metadata_cls() -> Type["AttentionMetadata"]:
  27. return XFormersMetadata
  28. @staticmethod
  29. def get_builder_cls() -> Type["XFormersMetadataBuilder"]:
  30. return XFormersMetadataBuilder
  31. @staticmethod
  32. def get_state_cls() -> Type["CommonAttentionState"]:
  33. return CommonAttentionState
  34. @staticmethod
  35. def get_kv_cache_shape(
  36. num_blocks: int,
  37. block_size: int,
  38. num_kv_heads: int,
  39. head_size: int,
  40. ) -> Tuple[int, ...]:
  41. return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
  42. num_kv_heads, head_size)
  43. @staticmethod
  44. def swap_blocks(
  45. src_kv_cache: torch.Tensor,
  46. dst_kv_cache: torch.Tensor,
  47. src_to_dst: Dict[int, int],
  48. ) -> None:
  49. PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
  50. @staticmethod
  51. def copy_blocks(
  52. kv_caches: List[torch.Tensor],
  53. src_to_dists: torch.Tensor,
  54. ) -> None:
  55. PagedAttention.copy_blocks(kv_caches, src_to_dists)
  56. @dataclass
  57. class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
  58. """Metadata for XFormersbackend.
  59. NOTE: Any python object stored here is not updated when it is
  60. cuda-graph replayed. If you have values that need to be changed
  61. dynamically, it should be stored in tensor. The tensor has to be
  62. updated from `CUDAGraphRunner.forward` API.
  63. """
  64. # |---------- N-1 iteration --------|
  65. # |---------------- N iteration ---------------------|
  66. # |- tokenA -|......................|-- newTokens ---|
  67. # |---------- context_len ----------|
  68. # |-------------------- seq_len ----------------------|
  69. # |-- query_len ---|
  70. # seq_lens stored as a tensor.
  71. seq_lens_tensor: Optional[torch.Tensor]
  72. # FIXME: It is for flash attn.
  73. # Maximum sequence length among prefill batch. 0 if there are decoding
  74. # requests only.
  75. max_prefill_seq_len: int
  76. # Maximum sequence length among decode batch. 0 if there are prefill
  77. # requests only.
  78. max_decode_seq_len: int
  79. # Whether or not if cuda graph is enabled.
  80. # Cuda-graph is currently enabled for decoding only.
  81. # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
  82. use_cuda_graph: bool
  83. # (batch_size,). The sequence length per sequence. Sequence length means
  84. # the computed tokens + new tokens None if it is a decoding.
  85. seq_lens: Optional[List[int]] = None
  86. # FIXME: It is for flash attn.
  87. # (batch_size + 1,). The cumulative sequence lengths of the sequences in
  88. # the batch, used to index into sequence. E.g., if the sequence length is
  89. # [4, 6], it is [0, 4, 10].
  90. seq_start_loc: Optional[torch.Tensor] = None
  91. # (batch_size,) A tensor of context lengths (tokens that are computed
  92. # so far).
  93. context_lens_tensor: Optional[torch.Tensor] = None
  94. # Maximum query length in the batch. None for decoding.
  95. max_query_len: Optional[int] = None
  96. # (batch_size + 1,). The cumulative subquery lengths of the sequences in
  97. # the batch, used to index into subquery. E.g., if the subquery length
  98. # is [4, 6], it is [0, 4, 10].
  99. query_start_loc: Optional[torch.Tensor] = None
  100. # Self-attention prefill/decode metadata cache
  101. _cached_prefill_metadata: Optional["XFormersMetadata"] = None
  102. _cached_decode_metadata: Optional["XFormersMetadata"] = None
  103. # Begin encoder attn & enc/dec cross-attn fields...
  104. # Encoder sequence lengths representation
  105. encoder_seq_lens: Optional[List[int]] = None
  106. encoder_seq_lens_tensor: Optional[torch.Tensor] = None
  107. # Maximum sequence length among encoder sequences
  108. max_encoder_seq_len: Optional[int] = None
  109. # Number of tokens input to encoder
  110. num_encoder_tokens: Optional[int] = None
  111. # Cross-attention memory-mapping data structures: slot mapping
  112. # and block tables
  113. cross_slot_mapping: Optional[torch.Tensor] = None
  114. cross_block_tables: Optional[torch.Tensor] = None
  115. def __post_init__(self):
  116. # Set during the execution of the first attention op.
  117. # It is a list because it is needed to set per prompt
  118. # when alibi slopes is used. It is because of the limitation
  119. # from xformer API.
  120. # will not appear in the __repr__ and __init__
  121. self.attn_bias: Optional[List[AttentionBias]] = None
  122. self.encoder_attn_bias: Optional[List[AttentionBias]] = None
  123. self.cross_attn_bias: Optional[List[AttentionBias]] = None
  124. @property
  125. def is_all_encoder_attn_metadata_set(self):
  126. '''
  127. All attention metadata required for encoder attention is set.
  128. '''
  129. return ((self.encoder_seq_lens is not None)
  130. and (self.encoder_seq_lens_tensor is not None)
  131. and (self.max_encoder_seq_len is not None))
  132. @property
  133. def is_all_cross_attn_metadata_set(self):
  134. '''
  135. All attention metadata required for enc/dec cross-attention is set.
  136. Superset of encoder attention required metadata.
  137. '''
  138. return (self.is_all_encoder_attn_metadata_set
  139. and (self.cross_slot_mapping is not None)
  140. and (self.cross_block_tables is not None))
  141. @property
  142. def prefill_metadata(self) -> Optional["XFormersMetadata"]:
  143. if self.num_prefills == 0:
  144. return None
  145. if self._cached_prefill_metadata is not None:
  146. # Recover cached prefill-phase attention
  147. # metadata structure
  148. return self._cached_prefill_metadata
  149. assert ((self.seq_lens is not None)
  150. or (self.encoder_seq_lens is not None))
  151. assert ((self.seq_lens_tensor is not None)
  152. or (self.encoder_seq_lens_tensor is not None))
  153. # Compute some attn_metadata fields which default to None
  154. query_start_loc = (None if self.query_start_loc is None else
  155. self.query_start_loc[:self.num_prefills + 1])
  156. slot_mapping = (None if self.slot_mapping is None else
  157. self.slot_mapping[:self.num_prefill_tokens])
  158. seq_lens = (None if self.seq_lens is None else
  159. self.seq_lens[:self.num_prefills])
  160. seq_lens_tensor = (None if self.seq_lens_tensor is None else
  161. self.seq_lens_tensor[:self.num_prefills])
  162. context_lens_tensor = (None if self.context_lens_tensor is None else
  163. self.context_lens_tensor[:self.num_prefills])
  164. block_tables = (None if self.block_tables is None else
  165. self.block_tables[:self.num_prefills])
  166. # Construct & cache prefill-phase attention metadata structure
  167. self._cached_prefill_metadata = XFormersMetadata(
  168. num_prefills=self.num_prefills,
  169. num_prefill_tokens=self.num_prefill_tokens,
  170. num_decode_tokens=0,
  171. slot_mapping=slot_mapping,
  172. seq_lens=seq_lens,
  173. seq_lens_tensor=seq_lens_tensor,
  174. max_query_len=self.max_query_len,
  175. max_prefill_seq_len=self.max_prefill_seq_len,
  176. max_decode_seq_len=0,
  177. query_start_loc=query_start_loc,
  178. context_lens_tensor=context_lens_tensor,
  179. block_tables=block_tables,
  180. use_cuda_graph=False,
  181. # Begin encoder & cross attn fields below...
  182. encoder_seq_lens=self.encoder_seq_lens,
  183. encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
  184. max_encoder_seq_len=self.max_encoder_seq_len,
  185. cross_slot_mapping=self.cross_slot_mapping,
  186. cross_block_tables=self.cross_block_tables)
  187. return self._cached_prefill_metadata
  188. @property
  189. def decode_metadata(self) -> Optional["XFormersMetadata"]:
  190. if self.num_decode_tokens == 0:
  191. return None
  192. if self._cached_decode_metadata is not None:
  193. # Recover cached decode-phase attention
  194. # metadata structure
  195. return self._cached_decode_metadata
  196. assert ((self.seq_lens_tensor is not None)
  197. or (self.encoder_seq_lens_tensor is not None))
  198. # Compute some attn_metadata fields which default to None
  199. slot_mapping = (None if self.slot_mapping is None else
  200. self.slot_mapping[self.num_prefill_tokens:])
  201. seq_lens_tensor = (None if self.seq_lens_tensor is None else
  202. self.seq_lens_tensor[self.num_prefills:])
  203. block_tables = (None if self.block_tables is None else
  204. self.block_tables[self.num_prefills:])
  205. # Construct & cache decode-phase attention metadata structure
  206. self._cached_decode_metadata = XFormersMetadata(
  207. num_prefills=0,
  208. num_prefill_tokens=0,
  209. num_decode_tokens=self.num_decode_tokens,
  210. slot_mapping=slot_mapping,
  211. seq_lens_tensor=seq_lens_tensor,
  212. max_prefill_seq_len=0,
  213. max_decode_seq_len=self.max_decode_seq_len,
  214. block_tables=block_tables,
  215. use_cuda_graph=self.use_cuda_graph,
  216. # Begin encoder & cross attn fields below...
  217. encoder_seq_lens=self.encoder_seq_lens,
  218. encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
  219. max_encoder_seq_len=self.max_encoder_seq_len,
  220. cross_slot_mapping=self.cross_slot_mapping,
  221. cross_block_tables=self.cross_block_tables)
  222. return self._cached_decode_metadata
  223. def _get_attn_bias(
  224. attn_metadata: XFormersMetadata,
  225. attn_type: AttentionType,
  226. ) -> Optional[AttentionBias]:
  227. '''
  228. Extract appropriate attention bias from attention metadata
  229. according to attention type.
  230. Arguments:
  231. * attn_metadata: Attention metadata structure associated with attention
  232. * attn_type: encoder attention, decoder self-attention,
  233. encoder/decoder cross-attention
  234. Returns:
  235. * Appropriate attention bias value given the attention type
  236. '''
  237. if attn_type == AttentionType.DECODER:
  238. return attn_metadata.attn_bias
  239. elif attn_type == AttentionType.ENCODER:
  240. return attn_metadata.encoder_attn_bias
  241. else:
  242. # attn_type == AttentionType.ENCODER_DECODER
  243. return attn_metadata.cross_attn_bias
  244. def _set_attn_bias(
  245. attn_metadata: XFormersMetadata,
  246. attn_bias: List[Optional[AttentionBias]],
  247. attn_type: AttentionType,
  248. ) -> None:
  249. '''
  250. Update appropriate attention bias field of attention metadata,
  251. according to attention type.
  252. Arguments:
  253. * attn_metadata: Attention metadata structure associated with attention
  254. * attn_bias: The desired attention bias value
  255. * attn_type: encoder attention, decoder self-attention,
  256. encoder/decoder cross-attention
  257. '''
  258. if attn_type == AttentionType.DECODER:
  259. attn_metadata.attn_bias = attn_bias
  260. elif attn_type == AttentionType.ENCODER:
  261. attn_metadata.encoder_attn_bias = attn_bias
  262. elif attn_type == AttentionType.ENCODER_DECODER:
  263. attn_metadata.cross_attn_bias = attn_bias
  264. else:
  265. raise AttributeError(f"Invalid attention type {str(attn_type)}")
  266. def _get_seq_len_block_table_args(
  267. attn_metadata: XFormersMetadata,
  268. is_prompt: bool,
  269. attn_type: AttentionType,
  270. ) -> tuple:
  271. '''
  272. The particular choice of sequence-length- and block-table-related
  273. attributes which should be extracted from attn_metadata is dependent
  274. on the type of attention operation.
  275. Decoder attn -> select entirely decoder self-attention-related fields
  276. Encoder/decoder cross-attn -> select encoder sequence lengths &
  277. cross-attn block-tables fields
  278. Encoder attn -> select encoder sequence lengths fields & no block tables
  279. Arguments:
  280. * attn_metadata: Attention metadata structure associated with attention op
  281. * is_prompt: True if prefill, False otherwise
  282. * attn_type: encoder attention, decoder self-attention,
  283. encoder/decoder cross-attention
  284. Returns:
  285. * Appropriate sequence-lengths tensor
  286. * Appropriate max sequence-length scalar
  287. * Appropriate block tables (or None)
  288. '''
  289. if attn_type == AttentionType.DECODER:
  290. # Decoder self-attention
  291. # Choose max_seq_len based on whether we are in prompt_run
  292. if is_prompt:
  293. max_seq_len = attn_metadata.max_prefill_seq_len
  294. else:
  295. max_seq_len = attn_metadata.max_decode_seq_len
  296. return (attn_metadata.seq_lens_tensor, max_seq_len,
  297. attn_metadata.block_tables)
  298. elif attn_type == AttentionType.ENCODER_DECODER:
  299. # Enc/dec cross-attention KVs match encoder sequence length;
  300. # cross-attention utilizes special "cross" block tables
  301. return (attn_metadata.encoder_seq_lens_tensor,
  302. attn_metadata.max_encoder_seq_len,
  303. attn_metadata.cross_block_tables)
  304. elif attn_type == AttentionType.ENCODER:
  305. # No block tables associated with encoder attention
  306. return (attn_metadata.encoder_seq_lens_tensor,
  307. attn_metadata.max_encoder_seq_len, None)
  308. else:
  309. raise AttributeError(f"Invalid attention type {str(attn_type)}")
  310. class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]):
  311. _metadata_cls = XFormersMetadata
  312. class XFormersImpl(AttentionImpl[XFormersMetadata]):
  313. """
  314. If the input tensors contain prompt tokens, the layout is as follows:
  315. |<--------------- num_prefill_tokens ----------------->|
  316. |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
  317. Otherwise, the layout is as follows:
  318. |<----------------- num_decode_tokens ------------------>|
  319. |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
  320. Generation tokens can contain padding when cuda-graph is used.
  321. Currently, prompt tokens don't contain any padding.
  322. The prompts might have different lengths, while the generation tokens
  323. always have length 1.
  324. If chunked prefill is enabled, prefill tokens and decode tokens can be
  325. batched together in a flattened 1D query.
  326. |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
  327. |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
  328. Currently, cuda graph is disabled for chunked prefill, meaning there's no
  329. padding between prefill and decode tokens.
  330. """
  331. def __init__(
  332. self,
  333. num_heads: int,
  334. head_size: int,
  335. scale: float,
  336. num_kv_heads: int,
  337. alibi_slopes: Optional[List[float]],
  338. sliding_window: Optional[int],
  339. kv_cache_dtype: str,
  340. blocksparse_params: Optional[Dict[str, Any]] = None,
  341. logits_soft_cap: Optional[float] = None,
  342. ) -> None:
  343. if blocksparse_params is not None:
  344. raise ValueError(
  345. "XFormers does not support block-sparse attention.")
  346. if logits_soft_cap is not None:
  347. raise ValueError(
  348. "XFormers does not support attention logits soft capping.")
  349. self.num_heads = num_heads
  350. self.head_size = head_size
  351. self.scale = float(scale)
  352. self.num_kv_heads = num_kv_heads
  353. if alibi_slopes is not None:
  354. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  355. self.alibi_slopes = alibi_slopes
  356. self.sliding_window = sliding_window
  357. self.kv_cache_dtype = kv_cache_dtype
  358. assert self.num_heads % self.num_kv_heads == 0
  359. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  360. suppored_head_sizes = PagedAttention.get_supported_head_sizes()
  361. if head_size not in suppored_head_sizes:
  362. raise ValueError(
  363. f"Head size {head_size} is not supported by PagedAttention. "
  364. f"Supported head sizes are: {suppored_head_sizes}.")
  365. def forward(
  366. self,
  367. query: torch.Tensor,
  368. key: Optional[torch.Tensor],
  369. value: Optional[torch.Tensor],
  370. kv_cache: Optional[torch.Tensor],
  371. attn_metadata: "XFormersMetadata",
  372. k_scale: float = 1.0,
  373. v_scale: float = 1.0,
  374. attn_type: AttentionType = AttentionType.DECODER,
  375. ) -> torch.Tensor:
  376. """Forward pass with xFormers and PagedAttention.
  377. For decoder-only models: query, key and value must be non-None.
  378. For encoder/decoder models:
  379. * XFormersImpl.forward() may be invoked for both self- and cross-
  380. attention layers.
  381. * For self-attention: query, key and value must be non-None.
  382. * For cross-attention:
  383. * Query must be non-None
  384. * During prefill, key and value must be non-None; key and value
  385. get cached for use during decode.
  386. * During decode, key and value may be None, since:
  387. (1) key and value tensors were cached during prefill, and
  388. (2) cross-attention key and value tensors do not grow during
  389. decode
  390. A note on how the attn_type (attention type enum) argument impacts
  391. attention forward() behavior:
  392. * DECODER: normal decoder-only behavior;
  393. use decoder self-attention block table
  394. * ENCODER: no KV caching; pass encoder sequence
  395. attributes (encoder_seq_lens/encoder_seq_lens_tensor/
  396. max_encoder_seq_len) to kernel, in lieu of decoder
  397. sequence attributes (seq_lens/seq_lens_tensor/max_seq_len)
  398. * ENCODER_DECODER: cross-attention behavior;
  399. use cross-attention block table for caching KVs derived
  400. from encoder hidden states; since KV sequence lengths
  401. will match encoder sequence lengths, pass encoder sequence
  402. attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
  403. max_encoder_seq_len)
  404. Args:
  405. query: shape = [num_tokens, num_heads * head_size]
  406. key: shape = [num_tokens, num_kv_heads * head_size]
  407. value: shape = [num_tokens, num_kv_heads * head_size]
  408. kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
  409. attn_metadata: Metadata for attention.
  410. attn_type: Select attention type, between encoder attention,
  411. decoder self-attention, or encoder/decoder cross-
  412. attention. Defaults to decoder self-attention,
  413. which is the Aphrodite default generally
  414. Returns:
  415. shape = [num_tokens, num_heads * head_size]
  416. """
  417. # Check that appropriate attention metadata attributes are
  418. # selected for the desired attention type
  419. if (attn_type == AttentionType.ENCODER
  420. and (not attn_metadata.is_all_encoder_attn_metadata_set)):
  421. raise AttributeError("Encoder attention requires setting "
  422. "encoder metadata attributes.")
  423. elif (attn_type == AttentionType.ENCODER_DECODER
  424. and (not attn_metadata.is_all_cross_attn_metadata_set)):
  425. raise AttributeError("Encoder/decoder cross-attention "
  426. "requires setting cross-attention "
  427. "metadata attributes.")
  428. query = query.view(-1, self.num_heads, self.head_size)
  429. if key is not None:
  430. assert value is not None
  431. key = key.view(-1, self.num_kv_heads, self.head_size)
  432. value = value.view(-1, self.num_kv_heads, self.head_size)
  433. else:
  434. assert value is None
  435. # Self-attention vs. cross-attention will impact
  436. # which KV cache memory-mapping & which
  437. # seqlen datastructures we utilize
  438. if (attn_type != AttentionType.ENCODER and kv_cache is not None):
  439. # KV-cache during decoder-self- or
  440. # encoder-decoder-cross-attention, but not
  441. # during encoder attention.
  442. #
  443. # Even if there are no new key/value pairs to cache,
  444. # we still need to break out key_cache and value_cache
  445. # i.e. for later use by paged attention
  446. key_cache, value_cache = PagedAttention.split_kv_cache(
  447. kv_cache, self.num_kv_heads, self.head_size)
  448. if (key is not None) and (value is not None):
  449. if attn_type == AttentionType.ENCODER_DECODER:
  450. # Update cross-attention KV cache (prefill-only)
  451. # During cross-attention decode, key & value will be None,
  452. # preventing this IF-statement branch from running
  453. updated_slot_mapping = attn_metadata.cross_slot_mapping
  454. else:
  455. # Update self-attention KV cache (prefill/decode)
  456. updated_slot_mapping = attn_metadata.slot_mapping
  457. # Reshape the input keys and values and store them in the cache.
  458. # If kv_cache is not provided, the new key and value tensors are
  459. # not cached. This happens during the initial memory
  460. # profiling run.
  461. PagedAttention.write_to_paged_cache(key, value, key_cache,
  462. value_cache,
  463. updated_slot_mapping,
  464. self.kv_cache_dtype,
  465. k_scale, v_scale)
  466. if attn_type != AttentionType.ENCODER:
  467. # Decoder self-attention supports chunked prefill.
  468. # Encoder/decoder cross-attention requires no chunked
  469. # prefill (100% prefill or 100% decode tokens, no mix)
  470. num_prefill_tokens = attn_metadata.num_prefill_tokens
  471. num_decode_tokens = attn_metadata.num_decode_tokens
  472. else:
  473. # Encoder attention - chunked prefill is not applicable;
  474. # derive token-count from query shape & and treat them
  475. # as 100% prefill tokens
  476. assert attn_metadata.num_encoder_tokens is not None
  477. num_prefill_tokens = attn_metadata.num_encoder_tokens
  478. num_decode_tokens = 0
  479. if attn_type == AttentionType.DECODER:
  480. # Only enforce this shape-constraint for decoder
  481. # self-attention
  482. assert key.shape[0] == num_prefill_tokens + num_decode_tokens
  483. assert value.shape[0] == num_prefill_tokens + num_decode_tokens
  484. output = torch.empty_like(query)
  485. # Query for decode. KV is not needed because it is already cached.
  486. decode_query = query[num_prefill_tokens:]
  487. # QKV for prefill.
  488. query = query[:num_prefill_tokens]
  489. if key is not None and value is not None:
  490. key = key[:num_prefill_tokens]
  491. value = value[:num_prefill_tokens]
  492. assert query.shape[0] == num_prefill_tokens
  493. assert decode_query.shape[0] == num_decode_tokens
  494. if prefill_meta := attn_metadata.prefill_metadata:
  495. # Prompt run.
  496. if kv_cache is None or prefill_meta.block_tables.numel() == 0:
  497. # normal attention.
  498. # block tables are empty if the prompt does not have a cached
  499. # prefix.
  500. out = self._run_memory_efficient_xformers_forward(
  501. query, key, value, prefill_meta, attn_type=attn_type)
  502. assert out.shape == output[:num_prefill_tokens].shape
  503. output[:num_prefill_tokens] = out
  504. else:
  505. assert prefill_meta.query_start_loc is not None
  506. assert prefill_meta.max_query_len is not None
  507. # prefix-enabled attention
  508. # TODO: this triton kernel has regression issue (broke) to
  509. # deal with different data types between KV and FP8 KV cache,
  510. # to be addressed separately.
  511. out = PagedAttention.forward_prefix(
  512. query,
  513. key,
  514. value,
  515. self.kv_cache_dtype,
  516. key_cache,
  517. value_cache,
  518. prefill_meta.block_tables,
  519. prefill_meta.query_start_loc,
  520. prefill_meta.seq_lens_tensor,
  521. prefill_meta.context_lens_tensor,
  522. prefill_meta.max_query_len,
  523. self.alibi_slopes,
  524. self.sliding_window,
  525. k_scale,
  526. v_scale,
  527. )
  528. assert output[:num_prefill_tokens].shape == out.shape
  529. output[:num_prefill_tokens] = out
  530. if decode_meta := attn_metadata.decode_metadata:
  531. (
  532. seq_lens_arg,
  533. max_seq_len_arg,
  534. block_tables_arg,
  535. ) = _get_seq_len_block_table_args(decode_meta, False, attn_type)
  536. output[num_prefill_tokens:] = PagedAttention.forward_decode(
  537. decode_query,
  538. key_cache,
  539. value_cache,
  540. block_tables_arg,
  541. seq_lens_arg,
  542. max_seq_len_arg,
  543. self.kv_cache_dtype,
  544. self.num_kv_heads,
  545. self.scale,
  546. self.alibi_slopes,
  547. k_scale,
  548. v_scale,
  549. )
  550. # Reshape the output tensor.
  551. return output.view(-1, self.num_heads * self.head_size)
  552. def _run_memory_efficient_xformers_forward(
  553. self,
  554. query: torch.Tensor,
  555. key: torch.Tensor,
  556. value: torch.Tensor,
  557. attn_metadata: XFormersMetadata,
  558. attn_type: AttentionType = AttentionType.DECODER,
  559. ) -> torch.Tensor:
  560. """Attention for 1D query of multiple prompts. Multiple prompt
  561. tokens are flattened in to `query` input.
  562. See https://facebookresearch.github.io/xformers/components/ops.html
  563. for API spec.
  564. Args:
  565. output: shape = [num_prefill_tokens, num_heads, head_size]
  566. query: shape = [num_prefill_tokens, num_heads, head_size]
  567. key: shape = [num_prefill_tokens, num_kv_heads, head_size]
  568. value: shape = [num_prefill_tokens, num_kv_heads, head_size]
  569. attn_metadata: Metadata for attention.
  570. attn_type: Select attention type, between encoder attention,
  571. decoder self-attention, or encoder/decoder cross-
  572. attention. Defaults to decoder self-attention,
  573. which is the Aphrodite default generally
  574. """
  575. original_query = query
  576. if self.num_kv_heads != self.num_heads:
  577. # GQA/MQA requires the shape [B, M, G, H, K].
  578. # Note that the output also has the same shape (which is different
  579. # from a spec from the doc).
  580. query = query.view(query.shape[0], self.num_kv_heads,
  581. self.num_queries_per_kv, query.shape[-1])
  582. key = key[:, :,
  583. None, :].expand(key.shape[0], self.num_kv_heads,
  584. self.num_queries_per_kv, key.shape[-1])
  585. value = value[:, :,
  586. None, :].expand(value.shape[0], self.num_kv_heads,
  587. self.num_queries_per_kv,
  588. value.shape[-1])
  589. # Set attention bias if not provided. This typically happens at
  590. # the very attention layer of every iteration.
  591. # FIXME: This is a hack.
  592. attn_bias = _get_attn_bias(attn_metadata, attn_type)
  593. if attn_bias is None:
  594. if self.alibi_slopes is None:
  595. if (attn_type == AttentionType.ENCODER_DECODER):
  596. assert attn_metadata.seq_lens is not None
  597. assert attn_metadata.encoder_seq_lens is not None
  598. # Default enc/dec cross-attention mask is non-causal
  599. attn_bias = BlockDiagonalMask.from_seqlens(
  600. attn_metadata.seq_lens, attn_metadata.encoder_seq_lens)
  601. elif attn_type == AttentionType.ENCODER:
  602. assert attn_metadata.encoder_seq_lens is not None
  603. # Default encoder self-attention mask is non-causal
  604. attn_bias = BlockDiagonalMask.from_seqlens(
  605. attn_metadata.encoder_seq_lens)
  606. else:
  607. assert attn_metadata.seq_lens is not None
  608. # Default decoder self-attention mask is causal
  609. attn_bias = BlockDiagonalCausalMask.from_seqlens(
  610. attn_metadata.seq_lens)
  611. if self.sliding_window is not None:
  612. attn_bias = attn_bias.make_local_attention(
  613. self.sliding_window)
  614. attn_bias = [attn_bias]
  615. else:
  616. assert attn_metadata.seq_lens is not None
  617. attn_bias = _make_alibi_bias(self.alibi_slopes,
  618. self.num_kv_heads, query.dtype,
  619. attn_metadata.seq_lens)
  620. _set_attn_bias(attn_metadata, attn_bias, attn_type)
  621. # No alibi slopes.
  622. # TODO: Too many view operations. Let's try to reduce
  623. # them in the future for code readability.
  624. if self.alibi_slopes is None:
  625. # Add the batch dimension.
  626. query = query.unsqueeze(0)
  627. key = key.unsqueeze(0)
  628. value = value.unsqueeze(0)
  629. out = xops.memory_efficient_attention_forward(
  630. query,
  631. key,
  632. value,
  633. attn_bias=attn_bias[0],
  634. p=0.0,
  635. scale=self.scale)
  636. return out.view_as(original_query)
  637. # Attention with alibi slopes.
  638. # FIXME: Because xformers does not support dynamic sequence
  639. # lengths with custom attention bias, we process each prompt one by
  640. # one. This is inefficient, especially when we have many short prompts.
  641. assert attn_metadata.seq_lens is not None
  642. output = torch.empty_like(original_query)
  643. start = 0
  644. for i, seq_len in enumerate(attn_metadata.seq_lens):
  645. end = start + seq_len
  646. out = xops.memory_efficient_attention_forward(
  647. query[None, start:end],
  648. key[None, start:end],
  649. value[None, start:end],
  650. attn_bias=attn_bias[i],
  651. p=0.0,
  652. scale=self.scale)
  653. # TODO: Unnecessary copy. Optimize.
  654. output[start:end].copy_(out.view_as(original_query[start:end]))
  655. start += seq_len
  656. return output
  657. def _make_alibi_bias(
  658. alibi_slopes: torch.Tensor,
  659. num_kv_heads: int,
  660. dtype: torch.dtype,
  661. seq_lens: List[int],
  662. ) -> List[AttentionBias]:
  663. attn_biases: List[AttentionBias] = []
  664. for seq_len in seq_lens:
  665. bias = torch.arange(seq_len, dtype=dtype)
  666. # NOTE: HF uses
  667. # `bias = bias[None, :].repeat(seq_len, 1)`
  668. # here. We find that both biases give the same results, but
  669. # the bias below more accurately follows the original ALiBi
  670. # paper.
  671. # Calculate a matrix where each element represents ith element- jth
  672. # element.
  673. bias = bias[None, :] - bias[:, None]
  674. padded_len = (seq_len + 7) // 8 * 8
  675. num_heads = alibi_slopes.shape[0]
  676. bias = torch.empty(
  677. 1, # batch size
  678. num_heads,
  679. seq_len,
  680. padded_len,
  681. device=alibi_slopes.device,
  682. dtype=dtype,
  683. )[:, :, :, :seq_len].copy_(bias)
  684. bias.mul_(alibi_slopes[:, None, None])
  685. if num_heads != num_kv_heads:
  686. bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
  687. attn_biases.append(LowerTriangularMaskWithTensorBias(bias))
  688. return attn_biases