""" Tests: * E2E test of Encoder attention + Decoder self-attention + Encoder/decoder cross-attention (collectively "encoder/decoder attention") """ from typing import NamedTuple, Optional import pytest import torch from aphrodite.attention import (Attention, AttentionBackend, AttentionMetadata, AttentionType) from aphrodite.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP from aphrodite.attention.selector import ( _Backend, global_force_attn_backend_context_manager) from aphrodite.common.utils import is_hip from tests.kernels.utils import * # List of support backends for encoder/decoder models LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS] HEAD_SIZES = [64, 256] NUM_HEADS = [1, 16] BATCH_SIZES = [1, 16] BLOCK_SIZES = [16] CUDA_DEVICE = "cuda:0" MAX_DEC_SEQ_LENS = [128] MAX_ENC_SEQ_LENS = [128] # Narrow teest-cases for unsupported-scenario # tests HEAD_SIZES_FOR_UNSUPP = [HEAD_SIZES[0]] class TestPoint(NamedTuple): """ Encapsulates the attributes which define a single invocation of the test_e2e_enc_dec_attn() test Attributes: num_heads: The number of heads in the model. head_size: Head dimension backend_name: Name of the backend framework used. batch_size: Number of samples per batch. block_size: Size of each block of data processed. max_dec_seq_len: Maximum sequence length for the decoder. max_enc_seq_len: Maximum sequence length for the encoder. num_blocks: Number of blocks in the model. """ num_heads: int head_size: int backend_name: str batch_size: int block_size: int max_dec_seq_len: int max_enc_seq_len: int num_blocks: int class TestResources(NamedTuple): ''' Encapsulates key components for performing an encoder/decoder attention test Note that (1) attn automatically selects an attention backend based on platform info & a set of canned heuristics (2) attn_backend is thus *not the same backend instance* used by attn, but rather it is intended to be a *different instance* of the *same backend class*; it is assumed that the user of TestResources will leverage attn_backend for the purpose of constructing backend-compatible attention metadata instances Attributes: * scale: 1/sqrt(d) scale factor for attn * attn_backend: implementatino of abstraction attention interface using a particular kernel library i.e. XFormers * attn: Attention layer instance * kv_cache: shared key/value cache for all attention ''' scale: float attn_backend: AttentionBackend attn: Attention kv_cache: torch.Tensor def _make_test_resources(test_pt: TestPoint, ) -> TestResources: ''' Build key components for performing encoder/decoder attention test. Note that (1) The Attention instance constructed here, automatically selects an attention backend class based on platform info & a set of canned heuristics, so (2) The attention backend instance constructed here is thus *not the same backend instance* used by attn, but rather it is intended to be a *different instance* of the *same backend class*; therefore, (3) This function requires that test_pt.backend_name matches the backend class that Attention will automatically select when it is constructed. Arguments: * test_pt: TestPoint data structure; this function relies on the following fields: num_heads, head_size, num_blocks, block_size, backend_name Returns: * TestResources data structure. ''' scale = float(1.0 / (test_pt.head_size**0.5)) attn_backend = make_backend(test_pt.backend_name) attn = Attention( test_pt.num_heads, test_pt.head_size, scale=scale, ) if test_pt.num_blocks is None or test_pt.num_heads is None: # Caller does not require a KV cache return TestResources(scale, attn_backend, attn, None) # Construct KV cache kv_cache = make_kv_cache(test_pt.num_blocks, test_pt.num_heads, test_pt.head_size, test_pt.block_size, device=CUDA_DEVICE) return TestResources(scale, attn_backend, attn, kv_cache) def _encoder_attn_setup( test_pt: TestPoint, test_rsrcs: TestResources, ) -> PhaseTestParameters: ''' Set up test vectors & data structures for encoder attention test. A triplet of synthetic query/key/value tensors are constructed. Given this is an encoder attention test, the key & value sequences will have the same length as the corresponding queries. The query/key/value tensors are passed to an ideal reference self-attention implementation to generate an ideal output tensor. Encoder inference does not populate the KV cache, therefore no KV cache memory mapping is constructed Arguments: * test_pt: TestPoint data structure; this function relies on the following fields: batch_size, num_heads, head_size, block_size, max_q_seq_len * test_rsrcs: TestResources data structure; this function relies on the scale field Returns: * PhaseTestParameters data structure comprising (1) packed query/key/value tensors, (2) the ideal output of attention computed using a naive implementation, and (3) KVCache field set to None ''' ( num_heads, head_size, _, batch_size, _, _, max_q_seq_len, _, ) = test_pt scale = test_rsrcs.scale max_kv_seq_len = max_q_seq_len # Make test tensors qkv_in, _, _ = make_qkv(batch_size, max_q_seq_len, max_kv_seq_len, num_heads, head_size, attn_type=AttentionType.ENCODER, device=CUDA_DEVICE) # Compute correct answer using naive non-causal attention # implementation ideal_output = ref_masked_attention(qkv_in.query, qkv_in.key, qkv_in.value, scale=scale, q_seq_lens=qkv_in.q_seq_lens, kv_seq_lens=qkv_in.kv_seq_lens) packed_ideal_output, _ = pack_tensor(ideal_output, qkv_in.q_seq_lens, device=CUDA_DEVICE) packed_qkv = pack_qkv(qkv_in, device=CUDA_DEVICE) return PhaseTestParameters( PackedQKVO(packed_qkv, packed_ideal_output), None # No KV cache ) def _decoder_attn_setup( test_pt: TestPoint, test_rsrcs: TestResources, block_base_addr: int = 0, ) -> Tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]: ''' Set up test vectors & data structures for self-attention test. A triplet of synthetic query/key/value tensors are constructed ("baseline" query/key/value). Given this is a self-attention test, the key & value sequences will have the same length as the corresponding queries. "Prefill" query/key/value tensors are derived by masking out the last value in each baseline query/key/value. These tensors are used to test prefill & populate KV cache for a subsequent decode test. "Decode" query/key/value tensors are derived by extracting *only* the last value from each baseline query/key/value (i.e. complement of the prefill tensors.) These tensors are used to test decode, conditional on the kv cache being populated during the prefill test. The baseline query/key/value tensors are passed to an ideal reference self-attention implementation to generate a "Baseline" ideal output tensor. This tensor is split into the "Prefill" ideal output tensor (all but the last element of each output sequence) and the "Decode" ideal output tensor (*only* the last element of each output sequence); the "Prefill" and "Decode" ideal output tensors can be used to validate the prefill and decode test results, respectively. This function also constructs the self-attention KV cache memory mapping (slot mapping and block table), ensuring that the block table starts at block_base_addr Arguments: * test_pt: TestPoint data structure; this function relies on the following fields: batch_size, num_heads, head_size, block_size, max_q_seq_len * test_rsrcs: TestResources data structure; this function relies on the scale field * block_base_addr: decoder self-attention block-table base address Returns: * qkv: Unpacked (batch_size x padded_seq_len x num_heads x head_size) query/key/value tensors * Prefill-phase decoder self-attention PhaseTestParameters data structure, including (1) packed (number_of_tokens x num_heads x head_size) query/key/value tensors along with (2) ideal attention output computed using a naive implementation, and (3) memory-mapping data structures appropriate for prefill phase. * Decode-phase decoder self-attention PhaseTestParameters data structure, including (1) packed (number_of_tokens x num_heads x head_size) query/key/value tensors along with (2) ideal attention output computed using a naive implementation, and (3) memory-mapping data structures appropriate for decode phase. * max_block_idx: max physical address in decoder self-attention block-table (intended to be used as the base address for the encoder/ decoder cross-attention block-table, which is not constructed in this function) ''' ( num_heads, head_size, _, batch_size, block_size, max_q_seq_len, _, _, ) = test_pt scale = test_rsrcs.scale max_kv_seq_len = max_q_seq_len # Build test tensors ( qkv, prefill_qkv, decode_qkv, ) = make_qkv(batch_size, max_q_seq_len, max_kv_seq_len, num_heads, head_size, attn_type=AttentionType.DECODER, device=CUDA_DEVICE) # Compute correct answer using naive attention implementation # with causal attention mask causal_mask = make_causal_mask(max_q_seq_len, max_kv_seq_len).to(CUDA_DEVICE) ideal_output = ref_masked_attention(qkv.query, qkv.key, qkv.value, scale=scale, custom_mask=causal_mask, q_seq_lens=qkv.q_seq_lens, kv_seq_lens=qkv.kv_seq_lens) # Split out the prefill- & decode-phase ideal answers & pack them prefill_ideal_output = torch.zeros_like(ideal_output) decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) for bdx, prefill_q_seq_len in enumerate(prefill_qkv.q_seq_lens): prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[ bdx, :prefill_q_seq_len] decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_seq_len:( prefill_q_seq_len + 1)] prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, prefill_qkv.q_seq_lens, device=CUDA_DEVICE) decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, [1 for _ in range(batch_size)], device=CUDA_DEVICE) # Build prefill- & decode-phase data structures # for decoder self-attention. Block tables and # slot mapping must be in a format compatible # with KV caching & attention kernels # # Prefill-phase: # # * Empty block-tables tensor # * Slot-mapping with entries for prompt tokens # # Decode-phase: # * Block-tables tensor with minimum number of blocks # required by total num. tokens in the entirety of all sequences # (including both prefill & decode) # * Slot-mapping with entries for tokens that will be decoded in the # current decode iteration # # Note: the format described above is simply mirroring what ModelRunner # produces prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE) ( decode_block_tables, slot_mapping_list, max_block_idx, ) = make_block_tables_slot_mapping(block_size, qkv.q_seq_lens, device=CUDA_DEVICE, block_base_addr=block_base_addr) ( prefill_slot_mapping, decode_slot_mapping, ) = split_slot_mapping(slot_mapping_list, qkv.q_seq_lens, device=CUDA_DEVICE) prefill_pckd_qkv = pack_qkv(prefill_qkv, device=CUDA_DEVICE) decode_pckd_qkv = pack_qkv(decode_qkv, device=CUDA_DEVICE) return ( qkv, PhaseTestParameters( # Prefill test params PackedQKVO(prefill_pckd_qkv, prefill_packed_ideal_output), KVMemoryMap(prefill_block_tables, prefill_slot_mapping)), PhaseTestParameters( # Decode test params PackedQKVO(decode_pckd_qkv, decode_packed_ideal_output), KVMemoryMap(decode_block_tables, decode_slot_mapping)), max_block_idx) def _enc_dec_cross_attn_setup_reuses_query( decoder_qkv: QKVInputs, encoder_test_params: PhaseTestParameters, prefill_decoder_phase_test_params: PhaseTestParameters, test_pt: TestPoint, test_rsrcs: TestResources, block_base_addr: int = 0, ) -> Tuple[PhaseTestParameters, PhaseTestParameters]: ''' Set up test vectors & data structures for cross-attention test. A triplet of synthetic cross-attention key/value tensors are constructed ("baseline" key/value). Given this is a cross-attention test, we assume query tensors were already synthesized for a prior self-attention test and will be reused for cross-attention. The key & value sequences generated here may have a different length than the corresponding queries (as is often the case for cross-attention between decoder and encoder sequences.) Cross attention key & value tensors do not grow during autoregressive inference; thus this function obtains a single key/value pair suitable for both prefill and decode. The "baseline" query tensor is received as an argument. The "baseline" query/key/value tensors are passed to an ideal reference cross-attention implementation to generate a "baseline" ideal output tensor. This tensor is split into the "Prefill" ideal output tensor (all but the last element of each output sequence) and the "Decode" ideal output tensor (*only* the last element of each output sequence); the "Prefill" and "Decode" ideal output tensors can be used to validate the prefill and decode test results, respectively. This function also constructs the cross-attention KV cache memory mapping (slot mapping and block table), ensuring that the block table starts at block_base_addr. Arguments: * decoder_qkv: pre-existing unpacked (batch_size x padded_seq_len x num_heads x head_size) decoder self-attention inputs; this function relies on the query and q_seq_lens fields * encoder_test_params: PhaseTestParameters data structure which was used for encoder inference; KV cache field is not used by this function * prefill_decoder_phase_test_params: PhaseTestParameters data structure used for prefill-phase decoder self-attention; all fields including KV cache required * test_pt: TestPoint data structure; this function relies on the following fields: batch_size, num_heads, head_size, block_size, max_q_seq_len * test_rsrcs: TestResources data structure; this function relies on the scale field * block_base_addr: decoder self-attention block-table base address Returns: * Prefill-phase encoder/decoder cross-attention PhaseTestParameters data structure, including (1) packed (number_of_tokens x num_heads x head_size) query/key/value tensors along with (2) ideal attention output computed using a naive implementation, and (3) memory-mapping data structures appropriate for prefill phase. * Decode-phase encoder/decoder cross-attention PhaseTestParameters data structure, including (1) packed (number_of_tokens x num_heads x head_size) query/key/value tensors along with (2) ideal attention output computed using a naive implementation, and (3) memory-mapping data structures appropriate for decode phase. ''' assert encoder_test_params.packed_qkvo.packed_qkv is not None assert prefill_decoder_phase_test_params.packed_qkvo.packed_qkv is not None ( num_heads, head_size, _, batch_size, block_size, max_decoder_seq_len, max_encoder_seq_len, _, ) = test_pt scale = test_rsrcs.scale decoder_query = decoder_qkv.query decoder_seq_lens = decoder_qkv.q_seq_lens encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens prefill_q_seq_lens = ( prefill_decoder_phase_test_params.packed_qkvo.packed_qkv.q_seq_lens) assert prefill_q_seq_lens is not None ( cross_kv, _, _, ) = make_qkv(batch_size, max_decoder_seq_len, max_encoder_seq_len, num_heads, head_size, force_kv_seq_lens=encoder_seq_lens, attn_type=AttentionType.ENCODER_DECODER, device=CUDA_DEVICE) ideal_output = ref_masked_attention(decoder_query, cross_kv.key, cross_kv.value, scale=scale, q_seq_lens=decoder_seq_lens, kv_seq_lens=cross_kv.kv_seq_lens) prefill_ideal_output = torch.zeros_like(ideal_output) decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) for bdx, prefill_q_seq_len in enumerate(prefill_q_seq_lens): prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[ bdx, :prefill_q_seq_len] decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_seq_len:( prefill_q_seq_len + 1)] prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, prefill_q_seq_lens, device=CUDA_DEVICE) decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, [1 for _ in range(batch_size)], device=CUDA_DEVICE) # Build prefill- & decode-phase data structures # for encoder/decoder cross-attention. Block tables and # slot mapping must be in a format compatible # with KV caching & attention kernels # # Whereas decoder self-attention extracts relationships between # equal-length Q/K/V sequences, which mutually grow in length # with each decoded token, cross-attention relates the Q sequence # - which grows with each new decoded token - to fixed-length # K and V sequences derived from the encoder hidden states. # # Prefill-phase: # # * Empty block-tables tensor # * Slot-mapping with as many entries as there are tokens in the encoder # prompt. # # Decode-phase: # * Block-tables tensor with minimum number of blocks to # accommodate K & V tensors which are equal in lnegth # to the encoder prompt length # * Empty slot-mapping tensor (since K & V are fixed in size, # new decoded tokens are not KV-cached and require no slot- # mapping) # # Note: the format above is simply an extension of what ModelRunner # produces for decoder-only models prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE) decode_slot_mapping = make_empty_slot_mapping_tensor(device=CUDA_DEVICE) ( decode_block_tables, prefill_slot_mapping_list, _, ) = make_block_tables_slot_mapping(block_size, cross_kv.kv_seq_lens, block_base_addr=block_base_addr, device=CUDA_DEVICE) prefill_slot_mapping = maybe_make_long_tensor(prefill_slot_mapping_list, device=CUDA_DEVICE) # Packed key/value (query is already provided) packed_cross_kv = pack_qkv(cross_kv, device=CUDA_DEVICE) return ( PhaseTestParameters( # Prefill-phase test params PackedQKVO(packed_cross_kv, prefill_packed_ideal_output), KVMemoryMap(prefill_block_tables, prefill_slot_mapping)), PhaseTestParameters( # Decode-phase test params PackedQKVO(None, decode_packed_ideal_output), KVMemoryMap(decode_block_tables, decode_slot_mapping))) def _run_encoder_attention_test( attn: Attention, encoder_test_params: PhaseTestParameters, attn_metadata: AttentionMetadata, ) -> torch.Tensor: ''' Run encoder attention. attn.forward() is passed attn_type=AttentionType.ENCODER in order to configure the kernel invocation for encoder attention Requires attn_metadata.num_decode_tokens == 0 (There is no encoder execution in the decode-phase) Arguments: * attn: Attention wrapper instance * encoder_test_params: encoder PhaseTestParameters data structure; this function relies on the packed (number_of_tokens x num_heads x head_size) query/key/value fields * attn_metadata: attention metadata for encoder/decoder-self attention Returns: * Attention.forward() applied to packed {query,key,value} and & attn_metadata ''' assert attn_metadata.num_decode_tokens == 0 attn_type = AttentionType.ENCODER packed_qkv = encoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, None, attn_metadata, attn_type=attn_type) def _run_decoder_self_attention_test( test_rsrcs: TestResources, decoder_test_params: PhaseTestParameters, attn_metadata: AttentionMetadata, ) -> torch.Tensor: ''' Run decoder self-attention test. attn.forward() is passed attn_type=AttentionType.DECODER in order to configure the kernel invocation for decoder self-attention. Arguments: * test_rsrcs: TestResources instance; this function relies on the kv_cache and attn (Attention wrapper instance) fields * decoder_test_params: decoder PhaseTestParameters data structure; this function relies on the packed (number_of_tokens x num_heads x head_size) query/key/value fields * attn_metadata: attention metadata for decoder-self attention (contains KV cache memory-mapping) Returns: * Attention.forward() applied to packed_{query,key,value}, kv_cache & attn_metadata ''' attn_type = AttentionType.DECODER attn = test_rsrcs.attn kv_cache = test_rsrcs.kv_cache packed_qkv = decoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, kv_cache, attn_metadata, attn_type=attn_type) def _run_encoder_decoder_cross_attention_test( test_rsrcs: TestResources, decoder_test_params: PhaseTestParameters, cross_test_params: Optional[PhaseTestParameters], attn_metadata: AttentionMetadata, ) -> torch.Tensor: ''' Run encoder/decoder cross-attention test. Via PhaseTestParameters data structures, consumes the same query utilized for decoder self-attention, plus a key/value specific to cross-attention. if cross_test_params is None or cross_test_params.packed_qkvo.packed_qkv is None, this reflects that in decode-phase cross attention there is no growth in the key and value tensors. attn.forward() is passed attn_type=AttentionType.ENCODER_DECODER in order to configure the kernel invocation for encoder/decoder cross- attention. Arguments: * test_rsrcs: TestResources instance; this function relies on the kv_cache and attn (Attention wrapper instance) fields * decoder_test_params: decoder PhaseTestParameters data structure; this function relies on the packed (number_of_tokens x num_heads x head_size) query field * cross_test_params: encoder/decoder PhaseTestParameters data structure; this function relies on the packed (number_of_tokens x num_heads x head_size) key/value fields * attn_metadata: attention metadata for encoder/decoder-self attention Returns: * Attention.forward() applied to packed_{query,key,value}, kv_cache & attn_metadata ''' assert decoder_test_params.packed_qkvo.packed_qkv is not None attn_type = AttentionType.ENCODER_DECODER attn = test_rsrcs.attn kv_cache = test_rsrcs.kv_cache if cross_test_params is None: key = None value = None else: cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key) value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value) return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query, key, value, kv_cache, attn_metadata, attn_type=attn_type) @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) @pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) def test_encoder_only( num_heads: int, head_size: int, attn_backend: _Backend, batch_size: int, block_size: int, max_dec_seq_len: int, max_enc_seq_len: int, ): ''' End-to-end encoder-only attention test: * Construct fake test vectors for (1) encoder attention * Construct (1) attention metadata structure with prefill-phase encoder attention, and (2) an analogous attention metadata structure but for decode-phase * Test & validate encoder attention against ideal output No KV cache is required for encoder-only attention. Note on ROCm/HIP: currently encoder/decoder models are not supported on AMD GPUs, therefore this test simply is skipped if is_hip(). This test globally forces an override of the usual backend auto-selection process, forcing the specific backend-under-test to be utilized. Arguments: * num_heads * head_size, * attn_backend: The attention backend to employ for testing * batch_size * block_size: KV cache block size * max_dec_seq_len: max length of decoder input sequences * max_enc_seq_len: max length of encoder input sequences ''' # Force Attention wrapper backend with global_force_attn_backend_context_manager(attn_backend): # Note: KV cache size of 4096 is arbitrary & chosen intentionally # to be more than necessary, since exceeding the kv cache size # is not part of this test test_pt = TestPoint(num_heads, head_size, attn_backend.name, batch_size, block_size, max_dec_seq_len, max_enc_seq_len, 4096) # Attention scale factor, attention backend instance, attention wrapper # instance, KV cache init test_rsrcs = _make_test_resources(test_pt) # Construct encoder attention test params (only used # during prefill) enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs) # Shared prefill metadata structure prephase_attn_metadata: AttentionMetadata = make_test_metadata( test_rsrcs.attn_backend, True, None, decoder_test_params=None, encoder_test_params=enc_test_params, cross_test_params=None, device=CUDA_DEVICE) # PREFILL: encoder attention enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test( test_rsrcs.attn, enc_test_params, prephase_attn_metadata)) # - Is encoder attention result correct? assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) @pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) def test_e2e_enc_dec_attn( num_heads: int, head_size: int, attn_backend: _Backend, batch_size: int, block_size: int, max_dec_seq_len: int, max_enc_seq_len: int, ) -> None: ''' End-to-end encoder/decoder test: * Construct fake test vectors for (1) encoder attention, (2) decoder self-attention, and (3) encoder/decoder cross-attention * Construct (1) attention metadata structure with self- and cross-attention attributes for prefill-phase, and (2) an analogous attention metadata structure but for decode-phase * Test attention steps in the following order * Encoder attention * Prefill self-attention * Prefill cross-attention * Decode self-attention * Decode cross-attention * Besides being reflective of realistic use-cases, this order would exacerbate any accidental overlap in the self-/cross-attention block tables, which one hopes to avoid * Validate output correctness against ideal reference attention implementation Block tables are constructed such that cross-attention KV cache is in a higher, non-intersecting address-space than self-attention KV cache. Self- and cross-attention share the same query tensor but not the K/V tensors. Self-attention K/Vs must have the same seq len as Q while cross-attention K/Vs are allowed to differ in seq len, as is often the case for cross-attention. This test globally forces an override of the usual backend auto-selection process, forcing the specific backend-under-test to be utilized. Note on ROCm/HIP: currently encoder/decoder models are not supported on AMD GPUs, therefore this test simply is skipped if is_hip(). Note on metadata: there is a single attention metadata structure shared by all prefill-phase attention operations (encoder, decoder, enc/dec cross), and a single one shared by all decode-phase attention operations (decoder & enc/dec cross.) This is intended to reflect the behavior of EncoderDecoderModelRunner, which constructs a single attention metadata structure for each prefill or decode run. A realistic scenario would rely on the attention backend to utilize the appropriate attention metadata fields according to the value of attn_metadata.attention_type. Thus, this test is organized so as to confirm that the backend-under-test can handle a shared prefill attention metadata structure & a shared decode\ attention metadata structure. Arguments: * num_heads * head_size, * attn_backend: The attention backend to employ for testing * batch_size * block_size: KV cache block size * max_dec_seq_len: max length of decoder input sequences * max_enc_seq_len: max length of encoder input sequences ''' # Force Attention wrapper backend with global_force_attn_backend_context_manager(attn_backend): # Note: KV cache size of 4096 is arbitrary & chosen intentionally # to be more than necessary, since exceeding the kv cache size # is not part of this test test_pt = TestPoint(num_heads, head_size, attn_backend.name, batch_size, block_size, max_dec_seq_len, max_enc_seq_len, 4096) # Attention scale factor, attention backend instance, attention wrapper # instance, KV cache init test_rsrcs = _make_test_resources(test_pt) # Construct encoder attention test params (only used # during prefill) enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs) # Construct Decoder self-attention prefill-phase & decode-phase # test params, including query/key/value tensors, decoder self-attention # memory-mapping. cross_block_base_addr is the uppermost address in the # decoder self-attention block-table, i.e. a base address which the # encoder/decoder cross-attention block-table may build downward toward. ( dec_qkv, prephase_dec_test_params, decphase_dec_test_params, cross_block_base_addr, ) = _decoder_attn_setup(test_pt, test_rsrcs) # Construct encoder/decoder cross-attention prefill-phase # & decode-phase test params, including key/value tensors, # cross-attention memory-mapping ( prephase_cross_test_params, decphase_cross_test_params, ) = _enc_dec_cross_attn_setup_reuses_query( dec_qkv, enc_test_params, prephase_dec_test_params, test_pt, test_rsrcs, block_base_addr=cross_block_base_addr) # Shared prefill metadata structure assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None prephase_attn_metadata: AttentionMetadata = make_test_metadata( test_rsrcs.attn_backend, True, prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens, decoder_test_params=prephase_dec_test_params, encoder_test_params=enc_test_params, cross_test_params=prephase_cross_test_params, device=CUDA_DEVICE) # PREFILL: encoder attention enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn, enc_test_params, prephase_attn_metadata) # - Is encoder attention result correct? assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) # PREFILL: decoder self-attention test prephase_dec_pckd_act_out = _run_decoder_self_attention_test( test_rsrcs, prephase_dec_test_params, prephase_attn_metadata) # - Is prefill decoder self-attention correct? assert_actual_matches_ideal(prephase_dec_test_params, prephase_dec_pckd_act_out) # PREFILL: encoder/decoder cross-attention test prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( test_rsrcs, prephase_dec_test_params, prephase_cross_test_params, prephase_attn_metadata) # - Is prefill encoder/decoder cross-attention correct? assert_actual_matches_ideal(prephase_cross_test_params, prephase_cross_pckd_act_out) # DECODE: build decode-phase attention metadata decphase_attn_metadata: AttentionMetadata = make_test_metadata( test_rsrcs.attn_backend, False, dec_qkv.q_seq_lens, decoder_test_params=decphase_dec_test_params, encoder_test_params=enc_test_params, cross_test_params=decphase_cross_test_params, device=CUDA_DEVICE) # DECODE: decoder self-attention test decphase_dec_pckd_act_out = _run_decoder_self_attention_test( test_rsrcs, decphase_dec_test_params, decphase_attn_metadata) # - Is decode-phase decoder self-attention correct? assert_actual_matches_ideal(decphase_dec_test_params, decphase_dec_pckd_act_out) # DECODE: encoder/decoder cross-attention test decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata) # - Is decode-phase encoder/decoder cross-attention correct? assert_actual_matches_ideal(decphase_cross_test_params, decphase_cross_pckd_act_out)