test_encoder_decoder_attn.py 38 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001
  1. """
  2. Tests:
  3. * E2E test of Encoder attention + Decoder self-attention +
  4. Encoder/decoder cross-attention (collectively
  5. "encoder/decoder attention")
  6. """
  7. from typing import NamedTuple, Optional
  8. import pytest
  9. import torch
  10. from aphrodite.attention import (Attention, AttentionBackend,
  11. AttentionMetadata, AttentionType)
  12. from aphrodite.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
  13. from aphrodite.attention.selector import (
  14. _Backend, global_force_attn_backend_context_manager)
  15. from aphrodite.common.utils import is_hip
  16. from tests.kernels.utils import *
  17. # List of support backends for encoder/decoder models
  18. LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS]
  19. HEAD_SIZES = [64, 256]
  20. NUM_HEADS = [1, 16]
  21. BATCH_SIZES = [1, 16]
  22. BLOCK_SIZES = [16]
  23. CUDA_DEVICE = "cuda:0"
  24. MAX_DEC_SEQ_LENS = [128]
  25. MAX_ENC_SEQ_LENS = [128]
  26. # Narrow teest-cases for unsupported-scenario
  27. # tests
  28. HEAD_SIZES_FOR_UNSUPP = [HEAD_SIZES[0]]
  29. class TestPoint(NamedTuple):
  30. """
  31. Encapsulates the attributes which define a single invocation
  32. of the test_e2e_enc_dec_attn() test
  33. Attributes:
  34. num_heads: The number of heads in the model.
  35. head_size: Head dimension
  36. backend_name: Name of the backend framework used.
  37. batch_size: Number of samples per batch.
  38. block_size: Size of each block of data processed.
  39. max_dec_seq_len: Maximum sequence length for the decoder.
  40. max_enc_seq_len: Maximum sequence length for the encoder.
  41. num_blocks: Number of blocks in the model.
  42. """
  43. num_heads: int
  44. head_size: int
  45. backend_name: str
  46. batch_size: int
  47. block_size: int
  48. max_dec_seq_len: int
  49. max_enc_seq_len: int
  50. num_blocks: int
  51. class TestResources(NamedTuple):
  52. '''
  53. Encapsulates key components for performing an
  54. encoder/decoder attention test
  55. Note that
  56. (1) attn automatically selects an attention backend
  57. based on platform info & a set of canned
  58. heuristics
  59. (2) attn_backend is thus *not the same backend
  60. instance* used by attn, but rather it is
  61. intended to be a
  62. *different instance* of the *same backend class*;
  63. it is assumed that the user of TestResources
  64. will leverage attn_backend for the purpose of
  65. constructing backend-compatible attention
  66. metadata instances
  67. Attributes:
  68. * scale: 1/sqrt(d) scale factor for attn
  69. * attn_backend: implementatino of abstraction
  70. attention interface using
  71. a particular kernel library
  72. i.e. XFormers
  73. * attn: Attention layer instance
  74. * kv_cache: shared key/value cache for all attention
  75. '''
  76. scale: float
  77. attn_backend: AttentionBackend
  78. attn: Attention
  79. kv_cache: torch.Tensor
  80. def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
  81. '''
  82. Build key components for performing encoder/decoder attention test.
  83. Note that
  84. (1) The Attention instance constructed here, automatically selects
  85. an attention backend class based on platform info & a set of canned
  86. heuristics, so
  87. (2) The attention backend instance constructed here is thus *not
  88. the same backend instance* used by attn, but rather it is
  89. intended to be a *different instance* of the *same backend class*;
  90. therefore,
  91. (3) This function requires that test_pt.backend_name matches the backend
  92. class that Attention will automatically select when it is constructed.
  93. Arguments:
  94. * test_pt: TestPoint data structure; this function relies on the
  95. following fields: num_heads, head_size, num_blocks,
  96. block_size, backend_name
  97. Returns:
  98. * TestResources data structure.
  99. '''
  100. scale = float(1.0 / (test_pt.head_size**0.5))
  101. attn_backend = make_backend(test_pt.backend_name)
  102. attn = Attention(
  103. test_pt.num_heads,
  104. test_pt.head_size,
  105. scale=scale,
  106. )
  107. if test_pt.num_blocks is None or test_pt.num_heads is None:
  108. # Caller does not require a KV cache
  109. return TestResources(scale, attn_backend, attn, None)
  110. # Construct KV cache
  111. kv_cache = make_kv_cache(test_pt.num_blocks,
  112. test_pt.num_heads,
  113. test_pt.head_size,
  114. test_pt.block_size,
  115. device=CUDA_DEVICE)
  116. return TestResources(scale, attn_backend, attn, kv_cache)
  117. def _encoder_attn_setup(
  118. test_pt: TestPoint,
  119. test_rsrcs: TestResources,
  120. ) -> PhaseTestParameters:
  121. '''
  122. Set up test vectors & data structures for encoder attention test.
  123. A triplet of synthetic query/key/value tensors are constructed.
  124. Given this is an encoder attention test, the key & value
  125. sequences will have the same length as the corresponding queries.
  126. The query/key/value tensors are passed to an ideal reference
  127. self-attention implementation to generate an ideal output tensor.
  128. Encoder inference does not populate the KV cache, therefore
  129. no KV cache memory mapping is constructed
  130. Arguments:
  131. * test_pt: TestPoint data structure; this function relies on the
  132. following fields: batch_size, num_heads, head_size,
  133. block_size, max_q_seq_len
  134. * test_rsrcs: TestResources data structure; this function relies on the
  135. scale field
  136. Returns:
  137. * PhaseTestParameters data structure comprising (1) packed query/key/value
  138. tensors, (2) the ideal output of attention computed using a naive
  139. implementation, and (3) KVCache field set to None
  140. '''
  141. (
  142. num_heads,
  143. head_size,
  144. _,
  145. batch_size,
  146. _,
  147. _,
  148. max_q_seq_len,
  149. _,
  150. ) = test_pt
  151. scale = test_rsrcs.scale
  152. max_kv_seq_len = max_q_seq_len
  153. # Make test tensors
  154. qkv_in, _, _ = make_qkv(batch_size,
  155. max_q_seq_len,
  156. max_kv_seq_len,
  157. num_heads,
  158. head_size,
  159. attn_type=AttentionType.ENCODER,
  160. device=CUDA_DEVICE)
  161. # Compute correct answer using naive non-causal attention
  162. # implementation
  163. ideal_output = ref_masked_attention(qkv_in.query,
  164. qkv_in.key,
  165. qkv_in.value,
  166. scale=scale,
  167. q_seq_lens=qkv_in.q_seq_lens,
  168. kv_seq_lens=qkv_in.kv_seq_lens)
  169. packed_ideal_output, _ = pack_tensor(ideal_output,
  170. qkv_in.q_seq_lens,
  171. device=CUDA_DEVICE)
  172. packed_qkv = pack_qkv(qkv_in, device=CUDA_DEVICE)
  173. return PhaseTestParameters(
  174. PackedQKVO(packed_qkv, packed_ideal_output),
  175. None # No KV cache
  176. )
  177. def _decoder_attn_setup(
  178. test_pt: TestPoint,
  179. test_rsrcs: TestResources,
  180. block_base_addr: int = 0,
  181. ) -> Tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]:
  182. '''
  183. Set up test vectors & data structures for self-attention test.
  184. A triplet of synthetic query/key/value tensors are constructed ("baseline"
  185. query/key/value). Given this is a self-attention test, the key & value
  186. sequences will have the same length as the corresponding queries.
  187. "Prefill" query/key/value tensors are derived by masking out the last value
  188. in each baseline query/key/value. These tensors are used to test prefill &
  189. populate KV cache for a subsequent decode test.
  190. "Decode" query/key/value tensors are derived by extracting *only* the last
  191. value from each baseline query/key/value (i.e. complement of the prefill
  192. tensors.) These tensors are used to test decode, conditional on the kv cache
  193. being populated during the prefill test.
  194. The baseline query/key/value tensors are passed to an ideal reference
  195. self-attention implementation to generate a "Baseline" ideal output tensor.
  196. This tensor is split into the "Prefill" ideal output tensor (all but the
  197. last element of each output sequence) and the "Decode" ideal output tensor
  198. (*only* the last element of each output sequence); the "Prefill" and
  199. "Decode" ideal output tensors can be used to validate the prefill and decode
  200. test results, respectively.
  201. This function also constructs the self-attention KV cache memory mapping
  202. (slot mapping and block table), ensuring that the block table starts at
  203. block_base_addr
  204. Arguments:
  205. * test_pt: TestPoint data structure; this function relies on the
  206. following fields: batch_size, num_heads, head_size,
  207. block_size, max_q_seq_len
  208. * test_rsrcs: TestResources data structure; this function relies on the
  209. scale field
  210. * block_base_addr: decoder self-attention block-table base address
  211. Returns:
  212. * qkv: Unpacked (batch_size x padded_seq_len x num_heads x
  213. head_size) query/key/value tensors
  214. * Prefill-phase decoder self-attention PhaseTestParameters data structure,
  215. including (1) packed (number_of_tokens x num_heads x head_size)
  216. query/key/value tensors along with (2) ideal attention output
  217. computed using a naive implementation, and (3) memory-mapping data
  218. structures appropriate for prefill phase.
  219. * Decode-phase decoder self-attention PhaseTestParameters data structure,
  220. including (1) packed (number_of_tokens x num_heads x head_size)
  221. query/key/value tensors along with (2) ideal attention output
  222. computed using a naive implementation, and (3) memory-mapping data
  223. structures appropriate for decode phase.
  224. * max_block_idx: max physical address in decoder self-attention block-table
  225. (intended to be used as the base address for the encoder/
  226. decoder cross-attention block-table, which is not
  227. constructed in this function)
  228. '''
  229. (
  230. num_heads,
  231. head_size,
  232. _,
  233. batch_size,
  234. block_size,
  235. max_q_seq_len,
  236. _,
  237. _,
  238. ) = test_pt
  239. scale = test_rsrcs.scale
  240. max_kv_seq_len = max_q_seq_len
  241. # Build test tensors
  242. (
  243. qkv,
  244. prefill_qkv,
  245. decode_qkv,
  246. ) = make_qkv(batch_size,
  247. max_q_seq_len,
  248. max_kv_seq_len,
  249. num_heads,
  250. head_size,
  251. attn_type=AttentionType.DECODER,
  252. device=CUDA_DEVICE)
  253. # Compute correct answer using naive attention implementation
  254. # with causal attention mask
  255. causal_mask = make_causal_mask(max_q_seq_len,
  256. max_kv_seq_len).to(CUDA_DEVICE)
  257. ideal_output = ref_masked_attention(qkv.query,
  258. qkv.key,
  259. qkv.value,
  260. scale=scale,
  261. custom_mask=causal_mask,
  262. q_seq_lens=qkv.q_seq_lens,
  263. kv_seq_lens=qkv.kv_seq_lens)
  264. # Split out the prefill- & decode-phase ideal answers & pack them
  265. prefill_ideal_output = torch.zeros_like(ideal_output)
  266. decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1])
  267. for bdx, prefill_q_seq_len in enumerate(prefill_qkv.q_seq_lens):
  268. prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[
  269. bdx, :prefill_q_seq_len]
  270. decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_seq_len:(
  271. prefill_q_seq_len + 1)]
  272. prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output,
  273. prefill_qkv.q_seq_lens,
  274. device=CUDA_DEVICE)
  275. decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output,
  276. [1 for _ in range(batch_size)],
  277. device=CUDA_DEVICE)
  278. # Build prefill- & decode-phase data structures
  279. # for decoder self-attention. Block tables and
  280. # slot mapping must be in a format compatible
  281. # with KV caching & attention kernels
  282. #
  283. # Prefill-phase:
  284. #
  285. # * Empty block-tables tensor
  286. # * Slot-mapping with entries for prompt tokens
  287. #
  288. # Decode-phase:
  289. # * Block-tables tensor with minimum number of blocks
  290. # required by total num. tokens in the entirety of all sequences
  291. # (including both prefill & decode)
  292. # * Slot-mapping with entries for tokens that will be decoded in the
  293. # current decode iteration
  294. #
  295. # Note: the format described above is simply mirroring what ModelRunner
  296. # produces
  297. prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE)
  298. (
  299. decode_block_tables,
  300. slot_mapping_list,
  301. max_block_idx,
  302. ) = make_block_tables_slot_mapping(block_size,
  303. qkv.q_seq_lens,
  304. device=CUDA_DEVICE,
  305. block_base_addr=block_base_addr)
  306. (
  307. prefill_slot_mapping,
  308. decode_slot_mapping,
  309. ) = split_slot_mapping(slot_mapping_list,
  310. qkv.q_seq_lens,
  311. device=CUDA_DEVICE)
  312. prefill_pckd_qkv = pack_qkv(prefill_qkv, device=CUDA_DEVICE)
  313. decode_pckd_qkv = pack_qkv(decode_qkv, device=CUDA_DEVICE)
  314. return (
  315. qkv,
  316. PhaseTestParameters( # Prefill test params
  317. PackedQKVO(prefill_pckd_qkv, prefill_packed_ideal_output),
  318. KVMemoryMap(prefill_block_tables, prefill_slot_mapping)),
  319. PhaseTestParameters( # Decode test params
  320. PackedQKVO(decode_pckd_qkv, decode_packed_ideal_output),
  321. KVMemoryMap(decode_block_tables, decode_slot_mapping)),
  322. max_block_idx)
  323. def _enc_dec_cross_attn_setup_reuses_query(
  324. decoder_qkv: QKVInputs,
  325. encoder_test_params: PhaseTestParameters,
  326. prefill_decoder_phase_test_params: PhaseTestParameters,
  327. test_pt: TestPoint,
  328. test_rsrcs: TestResources,
  329. block_base_addr: int = 0,
  330. ) -> Tuple[PhaseTestParameters, PhaseTestParameters]:
  331. '''
  332. Set up test vectors & data structures for cross-attention test.
  333. A triplet of synthetic cross-attention key/value tensors are constructed
  334. ("baseline" key/value). Given this is a cross-attention test, we assume
  335. query tensors were already synthesized for a prior self-attention test and
  336. will be reused for cross-attention. The key & value sequences generated here
  337. may have a different length than the corresponding queries (as is often
  338. the case for cross-attention between decoder and encoder sequences.)
  339. Cross attention key & value tensors do not grow during autoregressive
  340. inference; thus this function obtains a single key/value pair suitable for
  341. both prefill and decode.
  342. The "baseline" query tensor is received as an argument. The "baseline"
  343. query/key/value tensors are passed to an ideal reference cross-attention
  344. implementation to generate a "baseline" ideal output tensor. This tensor is
  345. split into the "Prefill" ideal output tensor (all but the last element of
  346. each output sequence) and the "Decode" ideal output tensor (*only* the last
  347. element of each output sequence); the "Prefill" and "Decode" ideal output
  348. tensors can be used to validate the prefill and decode test results,
  349. respectively.
  350. This function also constructs the cross-attention KV cache memory mapping
  351. (slot mapping and block table), ensuring that the block table starts at
  352. block_base_addr.
  353. Arguments:
  354. * decoder_qkv: pre-existing unpacked (batch_size x padded_seq_len x
  355. num_heads x head_size) decoder self-attention inputs;
  356. this function relies on the query and q_seq_lens
  357. fields
  358. * encoder_test_params: PhaseTestParameters data structure which was
  359. used for encoder inference; KV cache field
  360. is not used by this function
  361. * prefill_decoder_phase_test_params: PhaseTestParameters data structure
  362. used for prefill-phase decoder
  363. self-attention; all fields
  364. including KV cache required
  365. * test_pt: TestPoint data structure; this function relies on the
  366. following fields: batch_size, num_heads, head_size,
  367. block_size, max_q_seq_len
  368. * test_rsrcs: TestResources data structure; this function relies on the
  369. scale field
  370. * block_base_addr: decoder self-attention block-table base address
  371. Returns:
  372. * Prefill-phase encoder/decoder cross-attention PhaseTestParameters data
  373. structure, including (1) packed
  374. (number_of_tokens x num_heads x head_size) query/key/value tensors
  375. along with (2) ideal attention output computed using a
  376. naive implementation, and (3) memory-mapping data structures appropriate
  377. for prefill phase.
  378. * Decode-phase encoder/decoder cross-attention PhaseTestParameters data
  379. structure, including (1) packed
  380. (number_of_tokens x num_heads x head_size) query/key/value tensors
  381. along with (2) ideal attention output computed using a
  382. naive implementation, and (3) memory-mapping data structures appropriate
  383. for decode phase.
  384. '''
  385. assert encoder_test_params.packed_qkvo.packed_qkv is not None
  386. assert prefill_decoder_phase_test_params.packed_qkvo.packed_qkv is not None
  387. (
  388. num_heads,
  389. head_size,
  390. _,
  391. batch_size,
  392. block_size,
  393. max_decoder_seq_len,
  394. max_encoder_seq_len,
  395. _,
  396. ) = test_pt
  397. scale = test_rsrcs.scale
  398. decoder_query = decoder_qkv.query
  399. decoder_seq_lens = decoder_qkv.q_seq_lens
  400. encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens
  401. prefill_q_seq_lens = (
  402. prefill_decoder_phase_test_params.packed_qkvo.packed_qkv.q_seq_lens)
  403. assert prefill_q_seq_lens is not None
  404. (
  405. cross_kv,
  406. _,
  407. _,
  408. ) = make_qkv(batch_size,
  409. max_decoder_seq_len,
  410. max_encoder_seq_len,
  411. num_heads,
  412. head_size,
  413. force_kv_seq_lens=encoder_seq_lens,
  414. attn_type=AttentionType.ENCODER_DECODER,
  415. device=CUDA_DEVICE)
  416. ideal_output = ref_masked_attention(decoder_query,
  417. cross_kv.key,
  418. cross_kv.value,
  419. scale=scale,
  420. q_seq_lens=decoder_seq_lens,
  421. kv_seq_lens=cross_kv.kv_seq_lens)
  422. prefill_ideal_output = torch.zeros_like(ideal_output)
  423. decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1])
  424. for bdx, prefill_q_seq_len in enumerate(prefill_q_seq_lens):
  425. prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[
  426. bdx, :prefill_q_seq_len]
  427. decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_seq_len:(
  428. prefill_q_seq_len + 1)]
  429. prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output,
  430. prefill_q_seq_lens,
  431. device=CUDA_DEVICE)
  432. decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output,
  433. [1 for _ in range(batch_size)],
  434. device=CUDA_DEVICE)
  435. # Build prefill- & decode-phase data structures
  436. # for encoder/decoder cross-attention. Block tables and
  437. # slot mapping must be in a format compatible
  438. # with KV caching & attention kernels
  439. #
  440. # Whereas decoder self-attention extracts relationships between
  441. # equal-length Q/K/V sequences, which mutually grow in length
  442. # with each decoded token, cross-attention relates the Q sequence
  443. # - which grows with each new decoded token - to fixed-length
  444. # K and V sequences derived from the encoder hidden states.
  445. #
  446. # Prefill-phase:
  447. #
  448. # * Empty block-tables tensor
  449. # * Slot-mapping with as many entries as there are tokens in the encoder
  450. # prompt.
  451. #
  452. # Decode-phase:
  453. # * Block-tables tensor with minimum number of blocks to
  454. # accommodate K & V tensors which are equal in lnegth
  455. # to the encoder prompt length
  456. # * Empty slot-mapping tensor (since K & V are fixed in size,
  457. # new decoded tokens are not KV-cached and require no slot-
  458. # mapping)
  459. #
  460. # Note: the format above is simply an extension of what ModelRunner
  461. # produces for decoder-only models
  462. prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE)
  463. decode_slot_mapping = make_empty_slot_mapping_tensor(device=CUDA_DEVICE)
  464. (
  465. decode_block_tables,
  466. prefill_slot_mapping_list,
  467. _,
  468. ) = make_block_tables_slot_mapping(block_size,
  469. cross_kv.kv_seq_lens,
  470. block_base_addr=block_base_addr,
  471. device=CUDA_DEVICE)
  472. prefill_slot_mapping = maybe_make_long_tensor(prefill_slot_mapping_list,
  473. device=CUDA_DEVICE)
  474. # Packed key/value (query is already provided)
  475. packed_cross_kv = pack_qkv(cross_kv, device=CUDA_DEVICE)
  476. return (
  477. PhaseTestParameters( # Prefill-phase test params
  478. PackedQKVO(packed_cross_kv, prefill_packed_ideal_output),
  479. KVMemoryMap(prefill_block_tables, prefill_slot_mapping)),
  480. PhaseTestParameters( # Decode-phase test params
  481. PackedQKVO(None, decode_packed_ideal_output),
  482. KVMemoryMap(decode_block_tables, decode_slot_mapping)))
  483. def _run_encoder_attention_test(
  484. attn: Attention,
  485. encoder_test_params: PhaseTestParameters,
  486. attn_metadata: AttentionMetadata,
  487. ) -> torch.Tensor:
  488. '''
  489. Run encoder attention.
  490. attn.forward() is passed attn_type=AttentionType.ENCODER in order
  491. to configure the kernel invocation for encoder attention
  492. Requires attn_metadata.num_decode_tokens == 0
  493. (There is no encoder execution in the decode-phase)
  494. Arguments:
  495. * attn: Attention wrapper instance
  496. * encoder_test_params: encoder PhaseTestParameters data structure;
  497. this function relies on the packed
  498. (number_of_tokens x num_heads x head_size)
  499. query/key/value fields
  500. * attn_metadata: attention metadata for encoder/decoder-self attention
  501. Returns:
  502. * Attention.forward() applied to packed {query,key,value} and
  503. & attn_metadata
  504. '''
  505. assert attn_metadata.num_decode_tokens == 0
  506. attn_type = AttentionType.ENCODER
  507. packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
  508. assert packed_qkv is not None
  509. return attn.forward(packed_qkv.query,
  510. packed_qkv.key,
  511. packed_qkv.value,
  512. None,
  513. attn_metadata,
  514. attn_type=attn_type)
  515. def _run_decoder_self_attention_test(
  516. test_rsrcs: TestResources,
  517. decoder_test_params: PhaseTestParameters,
  518. attn_metadata: AttentionMetadata,
  519. ) -> torch.Tensor:
  520. '''
  521. Run decoder self-attention test.
  522. attn.forward() is passed attn_type=AttentionType.DECODER
  523. in order to configure the kernel invocation for decoder self-attention.
  524. Arguments:
  525. * test_rsrcs: TestResources instance; this function relies on the kv_cache
  526. and attn (Attention wrapper instance) fields
  527. * decoder_test_params: decoder PhaseTestParameters data structure;
  528. this function relies on the packed
  529. (number_of_tokens x num_heads x head_size)
  530. query/key/value fields
  531. * attn_metadata: attention metadata for decoder-self attention
  532. (contains KV cache memory-mapping)
  533. Returns:
  534. * Attention.forward() applied to packed_{query,key,value}, kv_cache
  535. & attn_metadata
  536. '''
  537. attn_type = AttentionType.DECODER
  538. attn = test_rsrcs.attn
  539. kv_cache = test_rsrcs.kv_cache
  540. packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
  541. assert packed_qkv is not None
  542. return attn.forward(packed_qkv.query,
  543. packed_qkv.key,
  544. packed_qkv.value,
  545. kv_cache,
  546. attn_metadata,
  547. attn_type=attn_type)
  548. def _run_encoder_decoder_cross_attention_test(
  549. test_rsrcs: TestResources,
  550. decoder_test_params: PhaseTestParameters,
  551. cross_test_params: Optional[PhaseTestParameters],
  552. attn_metadata: AttentionMetadata,
  553. ) -> torch.Tensor:
  554. '''
  555. Run encoder/decoder cross-attention test.
  556. Via PhaseTestParameters data structures, consumes the same query utilized
  557. for decoder self-attention, plus a key/value specific to cross-attention.
  558. if cross_test_params is None or cross_test_params.packed_qkvo.packed_qkv
  559. is None, this reflects that in decode-phase cross attention there
  560. is no growth in the key and value tensors.
  561. attn.forward() is passed attn_type=AttentionType.ENCODER_DECODER
  562. in order to configure the kernel invocation for encoder/decoder cross-
  563. attention.
  564. Arguments:
  565. * test_rsrcs: TestResources instance; this function relies on the kv_cache
  566. and attn (Attention wrapper instance) fields
  567. * decoder_test_params: decoder PhaseTestParameters data structure;
  568. this function relies on the packed
  569. (number_of_tokens x num_heads x head_size)
  570. query field
  571. * cross_test_params: encoder/decoder PhaseTestParameters data structure;
  572. this function relies on the packed
  573. (number_of_tokens x num_heads x head_size)
  574. key/value fields
  575. * attn_metadata: attention metadata for encoder/decoder-self attention
  576. Returns:
  577. * Attention.forward() applied to packed_{query,key,value}, kv_cache
  578. & attn_metadata
  579. '''
  580. assert decoder_test_params.packed_qkvo.packed_qkv is not None
  581. attn_type = AttentionType.ENCODER_DECODER
  582. attn = test_rsrcs.attn
  583. kv_cache = test_rsrcs.kv_cache
  584. if cross_test_params is None:
  585. key = None
  586. value = None
  587. else:
  588. cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv
  589. key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key)
  590. value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value)
  591. return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query,
  592. key,
  593. value,
  594. kv_cache,
  595. attn_metadata,
  596. attn_type=attn_type)
  597. @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
  598. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  599. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  600. @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
  601. @pytest.mark.parametrize("batch_size", BATCH_SIZES)
  602. @pytest.mark.parametrize("block_size", BLOCK_SIZES)
  603. @pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS)
  604. @pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS)
  605. def test_encoder_only(
  606. num_heads: int,
  607. head_size: int,
  608. attn_backend: _Backend,
  609. batch_size: int,
  610. block_size: int,
  611. max_dec_seq_len: int,
  612. max_enc_seq_len: int,
  613. ):
  614. '''
  615. End-to-end encoder-only attention test:
  616. * Construct fake test vectors for (1) encoder attention
  617. * Construct (1) attention metadata structure with prefill-phase
  618. encoder attention, and (2) an analogous attention metadata
  619. structure but for decode-phase
  620. * Test & validate encoder attention against ideal output
  621. No KV cache is required for encoder-only attention.
  622. Note on ROCm/HIP: currently encoder/decoder models are not supported on
  623. AMD GPUs, therefore this test simply is skipped if is_hip().
  624. This test globally forces an override of the usual backend
  625. auto-selection process, forcing the specific backend-under-test
  626. to be utilized.
  627. Arguments:
  628. * num_heads
  629. * head_size,
  630. * attn_backend: The attention backend to employ for testing
  631. * batch_size
  632. * block_size: KV cache block size
  633. * max_dec_seq_len: max length of decoder input sequences
  634. * max_enc_seq_len: max length of encoder input sequences
  635. '''
  636. # Force Attention wrapper backend
  637. with global_force_attn_backend_context_manager(attn_backend):
  638. # Note: KV cache size of 4096 is arbitrary & chosen intentionally
  639. # to be more than necessary, since exceeding the kv cache size
  640. # is not part of this test
  641. test_pt = TestPoint(num_heads, head_size, attn_backend.name,
  642. batch_size, block_size, max_dec_seq_len,
  643. max_enc_seq_len, 4096)
  644. # Attention scale factor, attention backend instance, attention wrapper
  645. # instance, KV cache init
  646. test_rsrcs = _make_test_resources(test_pt)
  647. # Construct encoder attention test params (only used
  648. # during prefill)
  649. enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
  650. # Shared prefill metadata structure
  651. prephase_attn_metadata: AttentionMetadata = make_test_metadata(
  652. test_rsrcs.attn_backend,
  653. True,
  654. None,
  655. decoder_test_params=None,
  656. encoder_test_params=enc_test_params,
  657. cross_test_params=None,
  658. device=CUDA_DEVICE)
  659. # PREFILL: encoder attention
  660. enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test(
  661. test_rsrcs.attn, enc_test_params, prephase_attn_metadata))
  662. # - Is encoder attention result correct?
  663. assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
  664. @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
  665. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  666. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  667. @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
  668. @pytest.mark.parametrize("batch_size", BATCH_SIZES)
  669. @pytest.mark.parametrize("block_size", BLOCK_SIZES)
  670. @pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS)
  671. @pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS)
  672. def test_e2e_enc_dec_attn(
  673. num_heads: int,
  674. head_size: int,
  675. attn_backend: _Backend,
  676. batch_size: int,
  677. block_size: int,
  678. max_dec_seq_len: int,
  679. max_enc_seq_len: int,
  680. ) -> None:
  681. '''
  682. End-to-end encoder/decoder test:
  683. * Construct fake test vectors for (1) encoder attention,
  684. (2) decoder self-attention, and (3) encoder/decoder cross-attention
  685. * Construct (1) attention metadata structure with self- and cross-attention
  686. attributes for prefill-phase, and (2) an analogous attention metadata
  687. structure but for decode-phase
  688. * Test attention steps in the following order
  689. * Encoder attention
  690. * Prefill self-attention
  691. * Prefill cross-attention
  692. * Decode self-attention
  693. * Decode cross-attention
  694. * Besides being reflective of realistic use-cases, this order would
  695. exacerbate any accidental overlap in the self-/cross-attention
  696. block tables, which one hopes to avoid
  697. * Validate output correctness against ideal reference attention
  698. implementation
  699. Block tables are constructed such that cross-attention KV cache is in a
  700. higher, non-intersecting address-space than self-attention KV cache.
  701. Self- and cross-attention share the same query tensor but not the K/V
  702. tensors. Self-attention K/Vs must have the same seq len as Q while
  703. cross-attention K/Vs are allowed to differ in seq len, as is often the case
  704. for cross-attention.
  705. This test globally forces an override of the usual backend
  706. auto-selection process, forcing the specific backend-under-test
  707. to be utilized.
  708. Note on ROCm/HIP: currently encoder/decoder models are not supported on
  709. AMD GPUs, therefore this test simply is skipped if is_hip().
  710. Note on metadata: there is a single attention metadata structure shared by
  711. all prefill-phase attention operations (encoder, decoder, enc/dec cross),
  712. and a single one shared by all decode-phase attention operations
  713. (decoder & enc/dec cross.) This is intended to reflect the behavior
  714. of EncoderDecoderModelRunner, which constructs a single attention metadata
  715. structure for each prefill or decode run. A realistic scenario would rely
  716. on the attention backend to utilize the appropriate attention metadata
  717. fields according to the value of attn_metadata.attention_type. Thus,
  718. this test is organized so as to confirm that the backend-under-test can
  719. handle a shared prefill attention metadata structure & a shared decode\
  720. attention metadata structure.
  721. Arguments:
  722. * num_heads
  723. * head_size,
  724. * attn_backend: The attention backend to employ for testing
  725. * batch_size
  726. * block_size: KV cache block size
  727. * max_dec_seq_len: max length of decoder input sequences
  728. * max_enc_seq_len: max length of encoder input sequences
  729. '''
  730. # Force Attention wrapper backend
  731. with global_force_attn_backend_context_manager(attn_backend):
  732. # Note: KV cache size of 4096 is arbitrary & chosen intentionally
  733. # to be more than necessary, since exceeding the kv cache size
  734. # is not part of this test
  735. test_pt = TestPoint(num_heads, head_size, attn_backend.name,
  736. batch_size, block_size, max_dec_seq_len,
  737. max_enc_seq_len, 4096)
  738. # Attention scale factor, attention backend instance, attention wrapper
  739. # instance, KV cache init
  740. test_rsrcs = _make_test_resources(test_pt)
  741. # Construct encoder attention test params (only used
  742. # during prefill)
  743. enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
  744. # Construct Decoder self-attention prefill-phase & decode-phase
  745. # test params, including query/key/value tensors, decoder self-attention
  746. # memory-mapping. cross_block_base_addr is the uppermost address in the
  747. # decoder self-attention block-table, i.e. a base address which the
  748. # encoder/decoder cross-attention block-table may build downward toward.
  749. (
  750. dec_qkv,
  751. prephase_dec_test_params,
  752. decphase_dec_test_params,
  753. cross_block_base_addr,
  754. ) = _decoder_attn_setup(test_pt, test_rsrcs)
  755. # Construct encoder/decoder cross-attention prefill-phase
  756. # & decode-phase test params, including key/value tensors,
  757. # cross-attention memory-mapping
  758. (
  759. prephase_cross_test_params,
  760. decphase_cross_test_params,
  761. ) = _enc_dec_cross_attn_setup_reuses_query(
  762. dec_qkv,
  763. enc_test_params,
  764. prephase_dec_test_params,
  765. test_pt,
  766. test_rsrcs,
  767. block_base_addr=cross_block_base_addr)
  768. # Shared prefill metadata structure
  769. assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None
  770. prephase_attn_metadata: AttentionMetadata = make_test_metadata(
  771. test_rsrcs.attn_backend,
  772. True,
  773. prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens,
  774. decoder_test_params=prephase_dec_test_params,
  775. encoder_test_params=enc_test_params,
  776. cross_test_params=prephase_cross_test_params,
  777. device=CUDA_DEVICE)
  778. # PREFILL: encoder attention
  779. enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
  780. enc_test_params,
  781. prephase_attn_metadata)
  782. # - Is encoder attention result correct?
  783. assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
  784. # PREFILL: decoder self-attention test
  785. prephase_dec_pckd_act_out = _run_decoder_self_attention_test(
  786. test_rsrcs, prephase_dec_test_params, prephase_attn_metadata)
  787. # - Is prefill decoder self-attention correct?
  788. assert_actual_matches_ideal(prephase_dec_test_params,
  789. prephase_dec_pckd_act_out)
  790. # PREFILL: encoder/decoder cross-attention test
  791. prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
  792. test_rsrcs, prephase_dec_test_params, prephase_cross_test_params,
  793. prephase_attn_metadata)
  794. # - Is prefill encoder/decoder cross-attention correct?
  795. assert_actual_matches_ideal(prephase_cross_test_params,
  796. prephase_cross_pckd_act_out)
  797. # DECODE: build decode-phase attention metadata
  798. decphase_attn_metadata: AttentionMetadata = make_test_metadata(
  799. test_rsrcs.attn_backend,
  800. False,
  801. dec_qkv.q_seq_lens,
  802. decoder_test_params=decphase_dec_test_params,
  803. encoder_test_params=enc_test_params,
  804. cross_test_params=decphase_cross_test_params,
  805. device=CUDA_DEVICE)
  806. # DECODE: decoder self-attention test
  807. decphase_dec_pckd_act_out = _run_decoder_self_attention_test(
  808. test_rsrcs, decphase_dec_test_params, decphase_attn_metadata)
  809. # - Is decode-phase decoder self-attention correct?
  810. assert_actual_matches_ideal(decphase_dec_test_params,
  811. decphase_dec_pckd_act_out)
  812. # DECODE: encoder/decoder cross-attention test
  813. decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
  814. test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata)
  815. # - Is decode-phase encoder/decoder cross-attention correct?
  816. assert_actual_matches_ideal(decphase_cross_test_params,
  817. decphase_cross_pckd_act_out)