xformers.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797
  1. """Attention layer with xFormers and PagedAttention."""
  2. from dataclasses import dataclass
  3. from typing import Any, Dict, List, Optional, Tuple, Type
  4. import torch
  5. from xformers import ops as xops
  6. from xformers.ops.fmha.attn_bias import (AttentionBias,
  7. BlockDiagonalCausalMask,
  8. BlockDiagonalMask,
  9. LowerTriangularMaskWithTensorBias)
  10. from aphrodite.attention.backends.abstract import (AttentionBackend,
  11. AttentionImpl,
  12. AttentionMetadata,
  13. AttentionType)
  14. from aphrodite.attention.backends.utils import CommonMetadataBuilder
  15. from aphrodite.attention.ops.paged_attn import (PagedAttention,
  16. PagedAttentionMetadata)
  17. class XFormersBackend(AttentionBackend):
  18. @staticmethod
  19. def get_name() -> str:
  20. return "xformers"
  21. @staticmethod
  22. def get_impl_cls() -> Type["XFormersImpl"]:
  23. return XFormersImpl
  24. @staticmethod
  25. def get_metadata_cls() -> Type["AttentionMetadata"]:
  26. return XFormersMetadata
  27. @staticmethod
  28. def get_builder_cls() -> Type["XFormersMetadataBuilder"]:
  29. return XFormersMetadataBuilder
  30. @staticmethod
  31. def get_kv_cache_shape(
  32. num_blocks: int,
  33. block_size: int,
  34. num_kv_heads: int,
  35. head_size: int,
  36. ) -> Tuple[int, ...]:
  37. return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
  38. num_kv_heads, head_size)
  39. @staticmethod
  40. def swap_blocks(
  41. src_kv_cache: torch.Tensor,
  42. dst_kv_cache: torch.Tensor,
  43. src_to_dst: Dict[int, int],
  44. ) -> None:
  45. PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
  46. @staticmethod
  47. def copy_blocks(
  48. kv_caches: List[torch.Tensor],
  49. src_to_dists: torch.Tensor,
  50. ) -> None:
  51. PagedAttention.copy_blocks(kv_caches, src_to_dists)
  52. @dataclass
  53. class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
  54. """Metadata for XFormersbackend.
  55. NOTE: Any python object stored here is not updated when it is
  56. cuda-graph replayed. If you have values that need to be changed
  57. dynamically, it should be stored in tensor. The tensor has to be
  58. updated from `CUDAGraphRunner.forward` API.
  59. """
  60. # |---------- N-1 iteration --------|
  61. # |---------------- N iteration ---------------------|
  62. # |- tokenA -|......................|-- newTokens ---|
  63. # |---------- context_len ----------|
  64. # |-------------------- seq_len ----------------------|
  65. # |-- query_len ---|
  66. # seq_lens stored as a tensor.
  67. seq_lens_tensor: Optional[torch.Tensor]
  68. # FIXME: It is for flash attn.
  69. # Maximum sequence length among prefill batch. 0 if there are decoding
  70. # requests only.
  71. max_prefill_seq_len: int
  72. # Maximum sequence length among decode batch. 0 if there are prefill
  73. # requests only.
  74. max_decode_seq_len: int
  75. # Whether or not if cuda graph is enabled.
  76. # Cuda-graph is currently enabled for decoding only.
  77. # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
  78. use_cuda_graph: bool
  79. # (batch_size,). The sequence length per sequence. Sequence length means
  80. # the computed tokens + new tokens None if it is a decoding.
  81. seq_lens: Optional[List[int]] = None
  82. # FIXME: It is for flash attn.
  83. # (batch_size + 1,). The cumulative sequence lengths of the sequences in
  84. # the batch, used to index into sequence. E.g., if the sequence length is
  85. # [4, 6], it is [0, 4, 10].
  86. seq_start_loc: Optional[torch.Tensor] = None
  87. # (batch_size,) A tensor of context lengths (tokens that are computed
  88. # so far).
  89. context_lens_tensor: Optional[torch.Tensor] = None
  90. # Maximum query length in the batch. None for decoding.
  91. max_query_len: Optional[int] = None
  92. # (batch_size + 1,). The cumulative subquery lengths of the sequences in
  93. # the batch, used to index into subquery. E.g., if the subquery length
  94. # is [4, 6], it is [0, 4, 10].
  95. query_start_loc: Optional[torch.Tensor] = None
  96. # Self-attention prefill/decode metadata cache
  97. _cached_prefill_metadata: Optional["XFormersMetadata"] = None
  98. _cached_decode_metadata: Optional["XFormersMetadata"] = None
  99. # Begin encoder attn & enc/dec cross-attn fields...
  100. # Encoder sequence lengths representation
  101. encoder_seq_lens: Optional[List[int]] = None
  102. encoder_seq_lens_tensor: Optional[torch.Tensor] = None
  103. # Maximum sequence length among encoder sequences
  104. max_encoder_seq_len: Optional[int] = None
  105. # Number of tokens input to encoder
  106. num_encoder_tokens: Optional[int] = None
  107. # Cross-attention memory-mapping data structures: slot mapping
  108. # and block tables
  109. cross_slot_mapping: Optional[torch.Tensor] = None
  110. cross_block_tables: Optional[torch.Tensor] = None
  111. def __post_init__(self):
  112. # Set during the execution of the first attention op.
  113. # It is a list because it is needed to set per prompt
  114. # when alibi slopes is used. It is because of the limitation
  115. # from xformer API.
  116. # will not appear in the __repr__ and __init__
  117. self.attn_bias: Optional[List[AttentionBias]] = None
  118. self.encoder_attn_bias: Optional[List[AttentionBias]] = None
  119. self.cross_attn_bias: Optional[List[AttentionBias]] = None
  120. @property
  121. def is_all_encoder_attn_metadata_set(self):
  122. '''
  123. All attention metadata required for encoder attention is set.
  124. '''
  125. return ((self.encoder_seq_lens is not None)
  126. and (self.encoder_seq_lens_tensor is not None)
  127. and (self.max_encoder_seq_len is not None))
  128. @property
  129. def is_all_cross_attn_metadata_set(self):
  130. '''
  131. All attention metadata required for enc/dec cross-attention is set.
  132. Superset of encoder attention required metadata.
  133. '''
  134. return (self.is_all_encoder_attn_metadata_set
  135. and (self.cross_slot_mapping is not None)
  136. and (self.cross_block_tables is not None))
  137. @property
  138. def prefill_metadata(self) -> Optional["XFormersMetadata"]:
  139. if self.num_prefills == 0:
  140. return None
  141. if self._cached_prefill_metadata is not None:
  142. # Recover cached prefill-phase attention
  143. # metadata structure
  144. return self._cached_prefill_metadata
  145. assert ((self.seq_lens is not None)
  146. or (self.encoder_seq_lens is not None))
  147. assert ((self.seq_lens_tensor is not None)
  148. or (self.encoder_seq_lens_tensor is not None))
  149. # Compute some attn_metadata fields which default to None
  150. query_start_loc = (None if self.query_start_loc is None else
  151. self.query_start_loc[:self.num_prefills + 1])
  152. slot_mapping = (None if self.slot_mapping is None else
  153. self.slot_mapping[:self.num_prefill_tokens])
  154. seq_lens = (None if self.seq_lens is None else
  155. self.seq_lens[:self.num_prefills])
  156. seq_lens_tensor = (None if self.seq_lens_tensor is None else
  157. self.seq_lens_tensor[:self.num_prefills])
  158. context_lens_tensor = (None if self.context_lens_tensor is None else
  159. self.context_lens_tensor[:self.num_prefills])
  160. block_tables = (None if self.block_tables is None else
  161. self.block_tables[:self.num_prefills])
  162. # Construct & cache prefill-phase attention metadata structure
  163. self._cached_prefill_metadata = XFormersMetadata(
  164. num_prefills=self.num_prefills,
  165. num_prefill_tokens=self.num_prefill_tokens,
  166. num_decode_tokens=0,
  167. slot_mapping=slot_mapping,
  168. seq_lens=seq_lens,
  169. seq_lens_tensor=seq_lens_tensor,
  170. max_query_len=self.max_query_len,
  171. max_prefill_seq_len=self.max_prefill_seq_len,
  172. max_decode_seq_len=0,
  173. query_start_loc=query_start_loc,
  174. context_lens_tensor=context_lens_tensor,
  175. block_tables=block_tables,
  176. use_cuda_graph=False,
  177. # Begin encoder & cross attn fields below...
  178. encoder_seq_lens=self.encoder_seq_lens,
  179. encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
  180. max_encoder_seq_len=self.max_encoder_seq_len,
  181. cross_slot_mapping=self.cross_slot_mapping,
  182. cross_block_tables=self.cross_block_tables)
  183. return self._cached_prefill_metadata
  184. @property
  185. def decode_metadata(self) -> Optional["XFormersMetadata"]:
  186. if self.num_decode_tokens == 0:
  187. return None
  188. if self._cached_decode_metadata is not None:
  189. # Recover cached decode-phase attention
  190. # metadata structure
  191. return self._cached_decode_metadata
  192. assert ((self.seq_lens_tensor is not None)
  193. or (self.encoder_seq_lens_tensor is not None))
  194. # Compute some attn_metadata fields which default to None
  195. slot_mapping = (None if self.slot_mapping is None else
  196. self.slot_mapping[self.num_prefill_tokens:])
  197. seq_lens_tensor = (None if self.seq_lens_tensor is None else
  198. self.seq_lens_tensor[self.num_prefills:])
  199. block_tables = (None if self.block_tables is None else
  200. self.block_tables[self.num_prefills:])
  201. # Construct & cache decode-phase attention metadata structure
  202. self._cached_decode_metadata = XFormersMetadata(
  203. num_prefills=0,
  204. num_prefill_tokens=0,
  205. num_decode_tokens=self.num_decode_tokens,
  206. slot_mapping=slot_mapping,
  207. seq_lens_tensor=seq_lens_tensor,
  208. max_prefill_seq_len=0,
  209. max_decode_seq_len=self.max_decode_seq_len,
  210. block_tables=block_tables,
  211. use_cuda_graph=self.use_cuda_graph,
  212. # Begin encoder & cross attn fields below...
  213. encoder_seq_lens=self.encoder_seq_lens,
  214. encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
  215. max_encoder_seq_len=self.max_encoder_seq_len,
  216. cross_slot_mapping=self.cross_slot_mapping,
  217. cross_block_tables=self.cross_block_tables)
  218. return self._cached_decode_metadata
  219. def _get_attn_bias(
  220. attn_metadata: XFormersMetadata,
  221. attn_type: AttentionType,
  222. ) -> Optional[AttentionBias]:
  223. '''
  224. Extract appropriate attention bias from attention metadata
  225. according to attention type.
  226. Arguments:
  227. * attn_metadata: Attention metadata structure associated with attention
  228. * attn_type: encoder attention, decoder self-attention,
  229. encoder/decoder cross-attention
  230. Returns:
  231. * Appropriate attention bias value given the attention type
  232. '''
  233. if attn_type == AttentionType.DECODER:
  234. return attn_metadata.attn_bias
  235. elif attn_type == AttentionType.ENCODER:
  236. return attn_metadata.encoder_attn_bias
  237. else:
  238. # attn_type == AttentionType.ENCODER_DECODER
  239. return attn_metadata.cross_attn_bias
  240. def _set_attn_bias(
  241. attn_metadata: XFormersMetadata,
  242. attn_bias: List[Optional[AttentionBias]],
  243. attn_type: AttentionType,
  244. ) -> None:
  245. '''
  246. Update appropriate attention bias field of attention metadata,
  247. according to attention type.
  248. Arguments:
  249. * attn_metadata: Attention metadata structure associated with attention
  250. * attn_bias: The desired attention bias value
  251. * attn_type: encoder attention, decoder self-attention,
  252. encoder/decoder cross-attention
  253. '''
  254. if attn_type == AttentionType.DECODER:
  255. attn_metadata.attn_bias = attn_bias
  256. elif attn_type == AttentionType.ENCODER:
  257. attn_metadata.encoder_attn_bias = attn_bias
  258. elif attn_type == AttentionType.ENCODER_DECODER:
  259. attn_metadata.cross_attn_bias = attn_bias
  260. else:
  261. raise AttributeError(f"Invalid attention type {str(attn_type)}")
  262. def _get_seq_len_block_table_args(
  263. attn_metadata: XFormersMetadata,
  264. is_prompt: bool,
  265. attn_type: AttentionType,
  266. ) -> tuple:
  267. '''
  268. The particular choice of sequence-length- and block-table-related
  269. attributes which should be extracted from attn_metadata is dependent
  270. on the type of attention operation.
  271. Decoder attn -> select entirely decoder self-attention-related fields
  272. Encoder/decoder cross-attn -> select encoder sequence lengths &
  273. cross-attn block-tables fields
  274. Encoder attn -> select encoder sequence lengths fields & no block tables
  275. Arguments:
  276. * attn_metadata: Attention metadata structure associated with attention op
  277. * is_prompt: True if prefill, False otherwise
  278. * attn_type: encoder attention, decoder self-attention,
  279. encoder/decoder cross-attention
  280. Returns:
  281. * Appropriate sequence-lengths tensor
  282. * Appropriate max sequence-length scalar
  283. * Appropriate block tables (or None)
  284. '''
  285. if attn_type == AttentionType.DECODER:
  286. # Decoder self-attention
  287. # Choose max_seq_len based on whether we are in prompt_run
  288. if is_prompt:
  289. max_seq_len = attn_metadata.max_prefill_seq_len
  290. else:
  291. max_seq_len = attn_metadata.max_decode_seq_len
  292. return (attn_metadata.seq_lens_tensor, max_seq_len,
  293. attn_metadata.block_tables)
  294. elif attn_type == AttentionType.ENCODER_DECODER:
  295. # Enc/dec cross-attention KVs match encoder sequence length;
  296. # cross-attention utilizes special "cross" block tables
  297. return (attn_metadata.encoder_seq_lens_tensor,
  298. attn_metadata.max_encoder_seq_len,
  299. attn_metadata.cross_block_tables)
  300. elif attn_type == AttentionType.ENCODER:
  301. # No block tables associated with encoder attention
  302. return (attn_metadata.encoder_seq_lens_tensor,
  303. attn_metadata.max_encoder_seq_len, None)
  304. else:
  305. raise AttributeError(f"Invalid attention type {str(attn_type)}")
  306. class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]):
  307. _metadata_cls = XFormersMetadata
  308. class XFormersImpl(AttentionImpl[XFormersMetadata]):
  309. """
  310. If the input tensors contain prompt tokens, the layout is as follows:
  311. |<--------------- num_prefill_tokens ----------------->|
  312. |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
  313. Otherwise, the layout is as follows:
  314. |<----------------- num_decode_tokens ------------------>|
  315. |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
  316. Generation tokens can contain padding when cuda-graph is used.
  317. Currently, prompt tokens don't contain any padding.
  318. The prompts might have different lengths, while the generation tokens
  319. always have length 1.
  320. If chunked prefill is enabled, prefill tokens and decode tokens can be
  321. batched together in a flattened 1D query.
  322. |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
  323. |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
  324. Currently, cuda graph is disabled for chunked prefill, meaning there's no
  325. padding between prefill and decode tokens.
  326. """
  327. def __init__(
  328. self,
  329. num_heads: int,
  330. head_size: int,
  331. scale: float,
  332. num_kv_heads: int,
  333. alibi_slopes: Optional[List[float]],
  334. sliding_window: Optional[int],
  335. kv_cache_dtype: str,
  336. blocksparse_params: Optional[Dict[str, Any]] = None,
  337. logits_soft_cap: Optional[float] = None,
  338. ) -> None:
  339. if blocksparse_params is not None:
  340. raise ValueError(
  341. "XFormers does not support block-sparse attention.")
  342. if logits_soft_cap is not None:
  343. raise ValueError(
  344. "XFormers does not support attention logits soft capping.")
  345. self.num_heads = num_heads
  346. self.head_size = head_size
  347. self.scale = float(scale)
  348. self.num_kv_heads = num_kv_heads
  349. if alibi_slopes is not None:
  350. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  351. self.alibi_slopes = alibi_slopes
  352. self.sliding_window = sliding_window
  353. self.kv_cache_dtype = kv_cache_dtype
  354. assert self.num_heads % self.num_kv_heads == 0
  355. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  356. suppored_head_sizes = PagedAttention.get_supported_head_sizes()
  357. if head_size not in suppored_head_sizes:
  358. raise ValueError(
  359. f"Head size {head_size} is not supported by PagedAttention. "
  360. f"Supported head sizes are: {suppored_head_sizes}.")
  361. def forward(
  362. self,
  363. query: torch.Tensor,
  364. key: Optional[torch.Tensor],
  365. value: Optional[torch.Tensor],
  366. kv_cache: Optional[torch.Tensor],
  367. attn_metadata: "XFormersMetadata",
  368. k_scale: float = 1.0,
  369. v_scale: float = 1.0,
  370. attn_type: AttentionType = AttentionType.DECODER,
  371. ) -> torch.Tensor:
  372. """Forward pass with xFormers and PagedAttention.
  373. For decoder-only models: query, key and value must be non-None.
  374. For encoder/decoder models:
  375. * XFormersImpl.forward() may be invoked for both self- and cross-
  376. attention layers.
  377. * For self-attention: query, key and value must be non-None.
  378. * For cross-attention:
  379. * Query must be non-None
  380. * During prefill, key and value must be non-None; key and value
  381. get cached for use during decode.
  382. * During decode, key and value may be None, since:
  383. (1) key and value tensors were cached during prefill, and
  384. (2) cross-attention key and value tensors do not grow during
  385. decode
  386. A note on how the attn_type (attention type enum) argument impacts
  387. attention forward() behavior:
  388. * DECODER: normal decoder-only behavior;
  389. use decoder self-attention block table
  390. * ENCODER: no KV caching; pass encoder sequence
  391. attributes (encoder_seq_lens/encoder_seq_lens_tensor/
  392. max_encoder_seq_len) to kernel, in lieu of decoder
  393. sequence attributes (seq_lens/seq_lens_tensor/max_seq_len)
  394. * ENCODER_DECODER: cross-attention behavior;
  395. use cross-attention block table for caching KVs derived
  396. from encoder hidden states; since KV sequence lengths
  397. will match encoder sequence lengths, pass encoder sequence
  398. attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
  399. max_encoder_seq_len)
  400. Args:
  401. query: shape = [num_tokens, num_heads * head_size]
  402. key: shape = [num_tokens, num_kv_heads * head_size]
  403. value: shape = [num_tokens, num_kv_heads * head_size]
  404. kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
  405. attn_metadata: Metadata for attention.
  406. attn_type: Select attention type, between encoder attention,
  407. decoder self-attention, or encoder/decoder cross-
  408. attention. Defaults to decoder self-attention,
  409. which is the Aphrodite default generally
  410. Returns:
  411. shape = [num_tokens, num_heads * head_size]
  412. """
  413. # Check that appropriate attention metadata attributes are
  414. # selected for the desired attention type
  415. if (attn_type == AttentionType.ENCODER
  416. and (not attn_metadata.is_all_encoder_attn_metadata_set)):
  417. raise AttributeError("Encoder attention requires setting "
  418. "encoder metadata attributes.")
  419. elif (attn_type == AttentionType.ENCODER_DECODER
  420. and (not attn_metadata.is_all_cross_attn_metadata_set)):
  421. raise AttributeError("Encoder/decoder cross-attention "
  422. "requires setting cross-attention "
  423. "metadata attributes.")
  424. query = query.view(-1, self.num_heads, self.head_size)
  425. if key is not None:
  426. assert value is not None
  427. key = key.view(-1, self.num_kv_heads, self.head_size)
  428. value = value.view(-1, self.num_kv_heads, self.head_size)
  429. else:
  430. assert value is None
  431. # Self-attention vs. cross-attention will impact
  432. # which KV cache memory-mapping & which
  433. # seqlen datastructures we utilize
  434. if (attn_type != AttentionType.ENCODER and kv_cache is not None):
  435. # KV-cache during decoder-self- or
  436. # encoder-decoder-cross-attention, but not
  437. # during encoder attention.
  438. #
  439. # Even if there are no new key/value pairs to cache,
  440. # we still need to break out key_cache and value_cache
  441. # i.e. for later use by paged attention
  442. key_cache, value_cache = PagedAttention.split_kv_cache(
  443. kv_cache, self.num_kv_heads, self.head_size)
  444. if (key is not None) and (value is not None):
  445. if attn_type == AttentionType.ENCODER_DECODER:
  446. # Update cross-attention KV cache (prefill-only)
  447. # During cross-attention decode, key & value will be None,
  448. # preventing this IF-statement branch from running
  449. updated_slot_mapping = attn_metadata.cross_slot_mapping
  450. else:
  451. # Update self-attention KV cache (prefill/decode)
  452. updated_slot_mapping = attn_metadata.slot_mapping
  453. # Reshape the input keys and values and store them in the cache.
  454. # If kv_cache is not provided, the new key and value tensors are
  455. # not cached. This happens during the initial memory
  456. # profiling run.
  457. PagedAttention.write_to_paged_cache(key, value, key_cache,
  458. value_cache,
  459. updated_slot_mapping,
  460. self.kv_cache_dtype,
  461. k_scale, v_scale)
  462. if attn_type != AttentionType.ENCODER:
  463. # Decoder self-attention supports chunked prefill.
  464. # Encoder/decoder cross-attention requires no chunked
  465. # prefill (100% prefill or 100% decode tokens, no mix)
  466. num_prefill_tokens = attn_metadata.num_prefill_tokens
  467. num_decode_tokens = attn_metadata.num_decode_tokens
  468. else:
  469. # Encoder attention - chunked prefill is not applicable;
  470. # derive token-count from query shape & and treat them
  471. # as 100% prefill tokens
  472. assert attn_metadata.num_encoder_tokens is not None
  473. num_prefill_tokens = attn_metadata.num_encoder_tokens
  474. num_decode_tokens = 0
  475. if attn_type == AttentionType.DECODER:
  476. # Only enforce this shape-constraint for decoder
  477. # self-attention
  478. assert key.shape[0] == num_prefill_tokens + num_decode_tokens
  479. assert value.shape[0] == num_prefill_tokens + num_decode_tokens
  480. output = torch.empty_like(query)
  481. # Query for decode. KV is not needed because it is already cached.
  482. decode_query = query[num_prefill_tokens:]
  483. # QKV for prefill.
  484. query = query[:num_prefill_tokens]
  485. if key is not None and value is not None:
  486. key = key[:num_prefill_tokens]
  487. value = value[:num_prefill_tokens]
  488. assert query.shape[0] == num_prefill_tokens
  489. assert decode_query.shape[0] == num_decode_tokens
  490. if prefill_meta := attn_metadata.prefill_metadata:
  491. # Prompt run.
  492. if kv_cache is None or prefill_meta.block_tables.numel() == 0:
  493. # normal attention.
  494. # block tables are empty if the prompt does not have a cached
  495. # prefix.
  496. out = self._run_memory_efficient_xformers_forward(
  497. query, key, value, prefill_meta, attn_type=attn_type)
  498. assert out.shape == output[:num_prefill_tokens].shape
  499. output[:num_prefill_tokens] = out
  500. else:
  501. assert prefill_meta.query_start_loc is not None
  502. assert prefill_meta.max_query_len is not None
  503. # prefix-enabled attention
  504. # TODO: this triton kernel has regression issue (broke) to
  505. # deal with different data types between KV and FP8 KV cache,
  506. # to be addressed separately.
  507. out = PagedAttention.forward_prefix(
  508. query,
  509. key,
  510. value,
  511. self.kv_cache_dtype,
  512. key_cache,
  513. value_cache,
  514. prefill_meta.block_tables,
  515. prefill_meta.query_start_loc,
  516. prefill_meta.seq_lens_tensor,
  517. prefill_meta.context_lens_tensor,
  518. prefill_meta.max_query_len,
  519. self.alibi_slopes,
  520. self.sliding_window,
  521. k_scale,
  522. v_scale,
  523. )
  524. assert output[:num_prefill_tokens].shape == out.shape
  525. output[:num_prefill_tokens] = out
  526. if decode_meta := attn_metadata.decode_metadata:
  527. (
  528. seq_lens_arg,
  529. max_seq_len_arg,
  530. block_tables_arg,
  531. ) = _get_seq_len_block_table_args(decode_meta, False, attn_type)
  532. output[num_prefill_tokens:] = PagedAttention.forward_decode(
  533. decode_query,
  534. key_cache,
  535. value_cache,
  536. block_tables_arg,
  537. seq_lens_arg,
  538. max_seq_len_arg,
  539. self.kv_cache_dtype,
  540. self.num_kv_heads,
  541. self.scale,
  542. self.alibi_slopes,
  543. k_scale,
  544. v_scale,
  545. )
  546. # Reshape the output tensor.
  547. return output.view(-1, self.num_heads * self.head_size)
  548. def _run_memory_efficient_xformers_forward(
  549. self,
  550. query: torch.Tensor,
  551. key: torch.Tensor,
  552. value: torch.Tensor,
  553. attn_metadata: XFormersMetadata,
  554. attn_type: AttentionType = AttentionType.DECODER,
  555. ) -> torch.Tensor:
  556. """Attention for 1D query of multiple prompts. Multiple prompt
  557. tokens are flattened in to `query` input.
  558. See https://facebookresearch.github.io/xformers/components/ops.html
  559. for API spec.
  560. Args:
  561. output: shape = [num_prefill_tokens, num_heads, head_size]
  562. query: shape = [num_prefill_tokens, num_heads, head_size]
  563. key: shape = [num_prefill_tokens, num_kv_heads, head_size]
  564. value: shape = [num_prefill_tokens, num_kv_heads, head_size]
  565. attn_metadata: Metadata for attention.
  566. attn_type: Select attention type, between encoder attention,
  567. decoder self-attention, or encoder/decoder cross-
  568. attention. Defaults to decoder self-attention,
  569. which is the Aphrodite default generally
  570. """
  571. original_query = query
  572. if self.num_kv_heads != self.num_heads:
  573. # GQA/MQA requires the shape [B, M, G, H, K].
  574. # Note that the output also has the same shape (which is different
  575. # from a spec from the doc).
  576. query = query.view(query.shape[0], self.num_kv_heads,
  577. self.num_queries_per_kv, query.shape[-1])
  578. key = key[:, :,
  579. None, :].expand(key.shape[0], self.num_kv_heads,
  580. self.num_queries_per_kv, key.shape[-1])
  581. value = value[:, :,
  582. None, :].expand(value.shape[0], self.num_kv_heads,
  583. self.num_queries_per_kv,
  584. value.shape[-1])
  585. # Set attention bias if not provided. This typically happens at
  586. # the very attention layer of every iteration.
  587. # FIXME: This is a hack.
  588. attn_bias = _get_attn_bias(attn_metadata, attn_type)
  589. if attn_bias is None:
  590. if self.alibi_slopes is None:
  591. if (attn_type == AttentionType.ENCODER_DECODER):
  592. assert attn_metadata.seq_lens is not None
  593. assert attn_metadata.encoder_seq_lens is not None
  594. # Default enc/dec cross-attention mask is non-causal
  595. attn_bias = BlockDiagonalMask.from_seqlens(
  596. attn_metadata.seq_lens, attn_metadata.encoder_seq_lens)
  597. elif attn_type == AttentionType.ENCODER:
  598. assert attn_metadata.encoder_seq_lens is not None
  599. # Default encoder self-attention mask is non-causal
  600. attn_bias = BlockDiagonalMask.from_seqlens(
  601. attn_metadata.encoder_seq_lens)
  602. else:
  603. assert attn_metadata.seq_lens is not None
  604. # Default decoder self-attention mask is causal
  605. attn_bias = BlockDiagonalCausalMask.from_seqlens(
  606. attn_metadata.seq_lens)
  607. if self.sliding_window is not None:
  608. attn_bias = attn_bias.make_local_attention(
  609. self.sliding_window)
  610. attn_bias = [attn_bias]
  611. else:
  612. assert attn_metadata.seq_lens is not None
  613. attn_bias = _make_alibi_bias(self.alibi_slopes,
  614. self.num_kv_heads, query.dtype,
  615. attn_metadata.seq_lens)
  616. _set_attn_bias(attn_metadata, attn_bias, attn_type)
  617. # No alibi slopes.
  618. # TODO: Too many view operations. Let's try to reduce
  619. # them in the future for code readability.
  620. if self.alibi_slopes is None:
  621. # Add the batch dimension.
  622. query = query.unsqueeze(0)
  623. key = key.unsqueeze(0)
  624. value = value.unsqueeze(0)
  625. out = xops.memory_efficient_attention_forward(
  626. query,
  627. key,
  628. value,
  629. attn_bias=attn_bias[0],
  630. p=0.0,
  631. scale=self.scale)
  632. return out.view_as(original_query)
  633. # Attention with alibi slopes.
  634. # FIXME: Because xformers does not support dynamic sequence
  635. # lengths with custom attention bias, we process each prompt one by
  636. # one. This is inefficient, especially when we have many short prompts.
  637. assert attn_metadata.seq_lens is not None
  638. output = torch.empty_like(original_query)
  639. start = 0
  640. for i, seq_len in enumerate(attn_metadata.seq_lens):
  641. end = start + seq_len
  642. out = xops.memory_efficient_attention_forward(
  643. query[None, start:end],
  644. key[None, start:end],
  645. value[None, start:end],
  646. attn_bias=attn_bias[i],
  647. p=0.0,
  648. scale=self.scale)
  649. # TODO: Unnecessary copy. Optimize.
  650. output[start:end].copy_(out.view_as(original_query[start:end]))
  651. start += seq_len
  652. return output
  653. def _make_alibi_bias(
  654. alibi_slopes: torch.Tensor,
  655. num_kv_heads: int,
  656. dtype: torch.dtype,
  657. seq_lens: List[int],
  658. ) -> List[AttentionBias]:
  659. attn_biases: List[AttentionBias] = []
  660. for seq_len in seq_lens:
  661. bias = torch.arange(seq_len, dtype=dtype)
  662. # NOTE: HF uses
  663. # `bias = bias[None, :].repeat(seq_len, 1)`
  664. # here. We find that both biases give the same results, but
  665. # the bias below more accurately follows the original ALiBi
  666. # paper.
  667. # Calculate a matrix where each element represents ith element- jth
  668. # element.
  669. bias = bias[None, :] - bias[:, None]
  670. padded_len = (seq_len + 7) // 8 * 8
  671. num_heads = alibi_slopes.shape[0]
  672. bias = torch.empty(
  673. 1, # batch size
  674. num_heads,
  675. seq_len,
  676. padded_len,
  677. device=alibi_slopes.device,
  678. dtype=dtype,
  679. )[:, :, :, :seq_len].copy_(bias)
  680. bias.mul_(alibi_slopes[:, None, None])
  681. if num_heads != num_kv_heads:
  682. bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
  683. attn_biases.append(LowerTriangularMaskWithTensorBias(bias))
  684. return attn_biases