123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963 |
- """Kernel test utils"""
- import itertools
- import random
- from numbers import Number
- from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple,
- Union)
- import pytest
- import torch
- from aphrodite.attention import (AttentionBackend, AttentionMetadata,
- AttentionType)
- from aphrodite.common.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
- make_tensor_with_pad)
- # For now, disable "test_aot_dispatch_dynamic" since there are some
- # bugs related to this test in PyTorch 2.4.
- DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
- "test_schema",
- "test_autograd_registration",
- "test_faketensor",
- )
- ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
- "test_schema",
- "test_autograd_registration",
- "test_faketensor",
- "test_aot_dispatch_dynamic",
- )
- class QKVInputs(NamedTuple):
- '''
- Data structure for representing unpacked attention inputs,
- query/key/values and their sequence lengths.
- Attributes:
- * {query,key,value}: unpacked (batch_size x padded_seq_len x
- num_heads x head_size) attention inputs
- * q_seq_lens: query sequence lengths list
- * kv_seq_lens: shared key/value sequence lengths list
- '''
- query: torch.Tensor
- key: torch.Tensor
- value: torch.Tensor
- q_seq_lens: List[int]
- kv_seq_lens: List[int]
- class QKVO(NamedTuple):
- '''
- Data structure for representing unpacked attention inputs,
- alongside unpacked known-correct attention output
- Attributes:
- * qkv: unpacked (batch_size x padded_seq_len x
- num_heads x head_size) attention inputs
- * ideal_output: unpacked (batch_size x padded_seq_len x
- num_heads x head_size) known-correct attention output
- '''
- qkv: QKVInputs
- ideal_output: torch.Tensor
- class PackedQKVInputs(NamedTuple):
- '''
- Data structure for representing packed attention inputs
- Attributes:
- * {query,key,value}: packed (number_of_tokens x num_heads
- x head_size) attention inputs
- * q_start_loc_list: list of query start locations within packed tensor
- * kv_start_loc_list: shared list of key/value start locations within
- packed tensor
- * q_seq_lens: query sequence lengths list
- * kv_seq_lens: shared key/value sequence lengths list
- '''
- query: torch.Tensor
- key: torch.Tensor
- value: torch.Tensor
- q_start_loc_list: Optional[List[int]]
- kv_start_loc_list: Optional[List[int]]
- q_seq_lens: Optional[List[int]]
- kv_seq_lens: Optional[List[int]]
- class PackedQKVO(NamedTuple):
- '''
- Data structure for representing packed attention inputs,
- alongside packed known-correct attention output
- Attributes:
- * packed_qkv: packed (number_of_tokens x num_heads
- x head_size) attention inputs
- * ideal_output: packed (number_of_tokens x num_heads
- x head_size) known-correct attention output
- '''
- packed_qkv: Optional[PackedQKVInputs]
- ideal_output: torch.Tensor
- class KVMemoryMap(NamedTuple):
- '''
- Data structure for encapsulating KV cache memory mapping.
- Attributes:
- * block_tables: KV cache block tables
- * slot_mapping: mapping of sequence offset to physical address
- '''
- block_tables: torch.Tensor
- slot_mapping: torch.Tensor
- class PhaseTestParameters(NamedTuple):
- '''
- Data structure for encapsulating the test parameters
- for a given test "phase" (prefill or decode phase) and attention
- scenario (encoder, decoder-self, encoder/decoder-cross)
- Attributes:
- * packed_qkvo: packed (number_of_tokens x num_heads
- x head_size) attention inputs & known-correct
- output
- * kv_mmap: KV cache memory mapping, specific to this test phase &
- attention scenario
- '''
- packed_qkvo: PackedQKVO
- kv_mmap: Optional[KVMemoryMap]
- def maybe_make_int_tensor(
- _list: Optional[List[int]],
- device: Union[torch.device, str],
- ) -> torch.Tensor:
- '''
- Convert Python int list to a 1D int torch.Tensor on `device`
- Returns:
- * If _list is not None: 1D int torch.Tensor on `device`
- * None otherwise
- '''
- return None if _list is None else torch.tensor(
- _list, dtype=torch.int, device=device)
- def maybe_make_long_tensor(
- _list: Optional[List[int]],
- device: Union[torch.device, str],
- ) -> torch.Tensor:
- '''
- Convert Python int list to a 1D long torch.Tensor on `device`
- Returns:
- * If _list is not None: 1D long torch.Tensor on `device`
- * None otherwise
- '''
- return None if _list is None else torch.tensor(
- _list, dtype=torch.long, device=device)
- def maybe_max(_list: Optional[List]) -> Optional[Number]:
- '''
- Returns:
- * If _list is not None: max(_list)
- * None otherwise
- '''
- return None if _list is None else max(_list)
- def make_causal_mask(
- q_max_seq_len: int,
- kv_max_seq_len: int,
- ) -> torch.Tensor:
- '''
- Create a q_max_seq_len x kv_max_seq_len causal mask
- Arguments:
-
- * q_max_seq_len: query max seq len
- * kv_max_seq_len: key/value max seq len
- Returns:
- * 2D tensor, q_max_seq_len x kv_max_seq_len
- '''
- # Create a matrix where entry (i, j) is True if i >= j
- mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), diagonal=1)
- # Replace True with float('-inf') and False with 0
- mask = mask.masked_fill(mask == 1,
- float('-inf')).masked_fill(mask == 0, 0.0)
- return mask
- def override_backend_env_variable(mpatch: pytest.MonkeyPatch,
- backend_name: str) -> None:
- '''
- Override the environment variable indicating the Aphrodite backend
- temporarily, using pytest monkeypatch to ensure that the env vars get
- reset once the test context exits.
- Arguments:
- * mpatch: pytest monkeypatch instance
- * backend_name: attention backend name to force
- '''
- mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name)
- def ref_masked_attention(query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- scale: float,
- custom_mask: Optional[torch.Tensor] = None,
- q_seq_lens: Optional[List] = None,
- kv_seq_lens: Optional[List] = None) -> torch.Tensor:
- '''
- "Golden" masked attention reference. Supports two types of masking:
- * Basic attention mask, utilizing {q,kv}_seq_lens args to mask out
- padding elements
- * Custom attention mask, which can force an arbitrary mask tensor, i.e.
- causal
- Arguments:
- * query: batch_size x q_padded_seq_len x num_heads x head_size
- * key: batch_size x kv_padded_seq_len x num_heads x head_size
- * value: batch_size x kv_padded_seq_len x num_heads x head_size
- * scale: Attention scale factor
- * custom_mask: custom attention mask; good place to inject a causal
- attention mask
- * q_seq_lens: list of unpadded query seq_lens for each batch index
- * kv_seq_lens: list of unpadded key/value seq_lens for each batch index
- Returns:
- * Attention result, batch_size x q_padded_seq_len x num_heads x head_size
- '''
- assert q_seq_lens is not None
- assert kv_seq_lens is not None
- batch_size = query.shape[0]
- assert (len(q_seq_lens) == batch_size)
- assert (len(kv_seq_lens) == batch_size)
- attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float()
- # Basic attention mask, derived from seq lens
- if (q_seq_lens is not None) or (kv_seq_lens is not None):
- attn_mask = torch.zeros_like(attn_weights)
- if q_seq_lens is not None:
- for bdx, plen in enumerate(q_seq_lens):
- attn_mask[bdx, :, plen:, :] = -torch.inf
- if kv_seq_lens is not None:
- for bdx, plen in enumerate(kv_seq_lens):
- attn_mask[bdx, :, :, plen:] = -torch.inf
- attn_weights = attn_weights + attn_mask.float()
- # Custom attention mask
- if custom_mask is not None:
- attn_weights = attn_weights + custom_mask.float()
- attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
- out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value)
- return out
- def make_qkv(
- batch_size: int,
- max_q_seq_len: int,
- max_kv_seq_len: Optional[int],
- num_heads: int,
- head_size: int,
- device: Union[torch.device, str],
- force_kv_seq_lens: Optional[List[int]] = None,
- attn_type: AttentionType = AttentionType.ENCODER_DECODER,
- force_max_len: bool = False,
- ) -> Tuple[QKVInputs, QKVInputs, QKVInputs]:
- '''
- Construct QKV test tensors for self- and cross-attention.
- Generates three query/key/value triplets:
- * "Baseline" query/key/value (for input to reference attention function)
- * "Prefill" query/key/value (last sequence offset zero'd out, for use as
- input to prefill kernel)
- * "Decode" query/key/value (only the last sequence offset from baseline,
- for use as input to decode kernel)
- Each Q/K/V triplet is associated with a list of q seqlens and a list of k/v
- seqlens
- Arguments:
- * batch_size
- * max_q_seq_len: max query seq len
- * max_kv_seq_len: max key/value seq len
- * num_heads
- * head_size
- * is_encoder_decoder_attn: if True, query seqlen may differ from
- key/value seqlen (as is often the case for cross-attention);
- o/w, query/key/value seqlens match at each batch index
- (max_kv_seq_len is unused)
- * force_kv_seq_lens: if not None, overrides kv sequence lengths
- * attn_type: encoder, decoder self, or enc/dec cross attention
- * force_max_len: if True, all query seqlens are max_q_seq_len; o/w query
- seqlens are random in [2,max_q_seq_lens]. Same for key/value seqlens
- and max_kv_seq_len, unless forced by is_encoder_decoder_attn=False
- * device: CPU or CUDA device
- Returns:
- * Overall QKVInputs structure (containing full unpacked Q/K/V tensors)
- * Prefill QKVInputs structure (containing all but the last sequence offset)
- * Decode QKVInputs structure (containing all only the last sequence offset)
- '''
- if force_max_len:
- q_seq_lens = [max_q_seq_len for _ in range(batch_size)]
- else:
- q_seq_lens = [
- random.randint(2, max_q_seq_len) for _ in range(batch_size)
- ]
- kv_seq_lens = None
- if force_kv_seq_lens is not None:
- kv_seq_lens = force_kv_seq_lens
- elif attn_type != AttentionType.ENCODER_DECODER:
- # K,V seq lens match Q for self-attention
- kv_seq_lens = q_seq_lens
- else:
- # K,V seq lens are distinct from Q seq lens & random
- assert max_kv_seq_len is not None
- if force_max_len:
- kv_seq_lens = [max_kv_seq_len] * batch_size
- else:
- kv_seq_lens = [
- random.randint(2, max_kv_seq_len) for _ in range(batch_size)
- ]
- query = torch.rand(
- (batch_size, max_q_seq_len, num_heads, head_size)).to(device)
- key = torch.rand(
- (batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
- value = torch.rand(
- (batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
- prefill_query = torch.zeros(
- (batch_size, max_q_seq_len, num_heads, head_size)).to(device)
- prefill_key = torch.zeros(
- (batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
- prefill_value = torch.zeros(
- (batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
- decode_query = torch.zeros(
- (batch_size, 1, num_heads, head_size)).to(device)
- decode_key = torch.zeros((batch_size, 1, num_heads, head_size)).to(device)
- decode_value = torch.zeros(
- (batch_size, 1, num_heads, head_size)).to(device)
- for bdx, (q_seq_len, kv_seq_len) in enumerate(zip(q_seq_lens,
- kv_seq_lens)):
- query[bdx, q_seq_len:, :, :] = 0
- key[bdx, kv_seq_len:, :, :] = 0
- value[bdx, kv_seq_len:, :, :] = 0
- prefill_query[bdx,
- 0:(q_seq_len - 1), :, :] = query[bdx,
- 0:(q_seq_len - 1), :, :]
- prefill_key[bdx,
- 0:(kv_seq_len - 1), :, :] = key[bdx,
- 0:(kv_seq_len - 1), :, :]
- prefill_value[bdx, 0:(kv_seq_len -
- 1), :, :] = value[bdx, 0:(kv_seq_len - 1), :, :]
- decode_query[bdx, :, :, :] = query[bdx,
- (q_seq_len - 1):q_seq_len, :, :]
- decode_key[bdx, :, :, :] = key[bdx, (kv_seq_len - 1):kv_seq_len, :, :]
- decode_value[bdx, :, :, :] = value[bdx,
- (kv_seq_len - 1):kv_seq_len, :, :]
- prefill_q_seq_lens = [plen - 1 for plen in q_seq_lens]
- prefill_kv_seq_lens = [plen - 1 for plen in kv_seq_lens]
- decode_q_seq_lens = [1 for _ in q_seq_lens]
- decode_kv_seq_lens = [1 for _ in kv_seq_lens]
- return (
- QKVInputs(
- query, # Overall QKV inputs
- key,
- value,
- q_seq_lens,
- kv_seq_lens),
- QKVInputs(
- prefill_query, # Prefill subset of QKV sequences
- prefill_key,
- prefill_value,
- prefill_q_seq_lens,
- prefill_kv_seq_lens),
- QKVInputs(
- decode_query, # Decode subset of KV sequences
- decode_key,
- decode_value,
- decode_q_seq_lens,
- decode_kv_seq_lens))
- def pack_tensor(
- unpacked_tensor: torch.Tensor, seq_lens: List[int],
- device: Union[torch.device, str]) -> Tuple[torch.Tensor, List[int]]:
- '''
- Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an
- unpadded number_of_tokens x num_heads x head_size tensor, where
- number_of_tokens = sum(seq_lens)
- Arguments:
- * unpacked_tensor: batch_size x padded_seq_len x num_heads x head_size
- * seq_lens: list of token counts for each seq
- * device: CPU or CUDA device
- Returns
- * packed_tensor: number_of_tokens x num_heads x head_size
- * start_loc_list: start idx of each batch elt in packed_tensor; [0] +
- list(itertools.accumulate(seq_lens))
- '''
- num_tok = sum(seq_lens)
- num_heads = unpacked_tensor.shape[-2]
- head_size = unpacked_tensor.shape[-1]
- start_loc_list = [0] + list(itertools.accumulate(seq_lens))
- packed_tensor = torch.zeros((num_tok, num_heads, head_size), device=device)
- for bdx, (seq_len, start_loc) in enumerate(zip(seq_lens, start_loc_list)):
- packed_tensor[start_loc:(
- start_loc + seq_len), :, :] = unpacked_tensor[bdx, :seq_len, :, :]
- return packed_tensor, start_loc_list
- def pack_qkv(qkv: QKVInputs, device: Union[torch.device,
- str]) -> PackedQKVInputs:
- '''
- Individually pack each of Q, K and V, each with dimensions batch_size x
- padded_seq_len x num_heads x head_size, into respective number_of_tokens x
- num_heads x head_size tensors.
-
- For Q, number_of_tokens = sum(q_seq_lens).
- For K and V, number_of_tokens = sum(kv_seq_lens)
- Arguments:
- * qkv: Unpacked (batch_size x padded_seq_len x num_heads x head_size)
- attention inputs
- * device: CPU or CUDA device
- Returns
- * Packed (number_of_tokens x num_heads x head_size) QKV inputs
- derived from unpacked inputs
- '''
- if qkv.query is None:
- packed_query = None
- q_start_loc_list = None
- else:
- packed_query, q_start_loc_list = pack_tensor(qkv.query,
- qkv.q_seq_lens,
- device=device)
- packed_key, kv_start_loc_list = pack_tensor(qkv.key,
- qkv.kv_seq_lens,
- device=device)
- packed_value, _ = pack_tensor(qkv.value, qkv.kv_seq_lens, device=device)
- return PackedQKVInputs(
- packed_query, packed_key, packed_value, q_start_loc_list,
- kv_start_loc_list,
- (None if q_start_loc_list is None else qkv.q_seq_lens),
- qkv.kv_seq_lens)
- def make_backend(backend_name: str) -> AttentionBackend:
- '''
- Construct the backend instance determined by the backend_name string
- argument.
- "XFORMERS" -> construct xformers backend
- TODO: other backends
- Note: at time of writing the Attention wrapper automatically selects
- its own backend for Attention.forward(); so the backend instance which
- you generate with this function is not meant to be used for *running*
- inference, but rather for generating compatible metadata structures
- using backend.make_metadata()
- Returns:
- * Backend instance
- '''
- if backend_name == STR_XFORMERS_ATTN_VAL:
- # NOTE: xFormers backend cannot be imported for CPU and AMD GPUs.
- from aphrodite.attention.backends.xformers import XFormersBackend
- return XFormersBackend()
- raise AssertionError(
- f"Unrecognized backend_name {backend_name} for unit test")
- def _make_metadata_tensors(
- seq_lens: Optional[List[int]], context_lens: Optional[List[int]],
- encoder_seq_lens: Optional[List[int]], device: Union[torch.device, str]
- ) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[List[int]],
- torch.Tensor, Optional[int]]:
- '''
- Build scalar & tensor values required to build attention metadata structure.
- Arguments:
- * seq_lens: list of token-counts for each decoder input seq
- * context_lens: list of context length values for each seq
- * encoder_seq_lens: list of token-counts for each encoder input seq
- * device: CPU or CUDA device
- Returns:
- * seq_lens_tensor: decoder seq_lens list, as tensor
- * context_lens_tensor: context_lens list, as tensor
- * max_context_len: max(context_lens)
- * max_seq_len: max(seq_lens)
- * seq_start_loc: start idx of each sequence
- * max_encoder_seq_len: encoder seq_lens list, as tensor
- '''
- seq_lens_tensor = maybe_make_int_tensor(seq_lens, device)
- context_lens_tensor = maybe_make_int_tensor(context_lens, device)
- max_context_len = maybe_max(context_lens)
- max_seq_len = maybe_max(seq_lens)
- encoder_seq_lens_tensor = maybe_make_int_tensor(encoder_seq_lens, device)
- max_encoder_seq_len = (None if encoder_seq_lens is None else
- max(encoder_seq_lens))
- seq_start_loc = None
- return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len,
- seq_start_loc, encoder_seq_lens_tensor, max_encoder_seq_len)
- def make_kv_cache(num_blocks: int,
- num_heads: int,
- head_size: int,
- block_size: int,
- device: Union[torch.device, str],
- default_val: float = 0.0) -> torch.Tensor:
- '''
- Create a fake KV cache.
- Arguments:
- * num_blocks: number of blocks in the KV cache
- * num_heads: number of attention heads
- * head_size: head dimension
- * block_size: number of offsets within a block
- * device: CPU or CUDA device
- * default_val: initialization value for KV cache elements
- Returns:
- * kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
- '''
- kv_cache = torch.rand(
- (2, num_blocks, block_size * num_heads * head_size)).to(device)
- if default_val is not None:
- kv_cache[:, :, :] = default_val
- return kv_cache
- def _num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int:
- '''
- Compute the minimum number of blocks required to hold num_tokens tokens,
- given block_size
- '''
- return (num_tokens + block_size) // block_size
- def make_empty_slot_mapping_tensor(device: Union[torch.device, str]):
- return maybe_make_long_tensor([], device)
- def make_empty_block_tables_tensor(device: Union[torch.device, str]):
- return torch.tensor([], device=device)
- def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int],
- device: Union[torch.device, str]):
- '''
- Split a slot mapping into valid prefill- and decode-phase slot mappings.
- Context:
- * Your goal is to test (1) prefill of N prompts, with prompt-lengths
- {K_i \\forall i \\in [0,N)}, followed by (2) decoding of a single token
- for all N prompts (N tokens total); the resultant sequence lengths
- after decode would be {K_i + 1 for i \\in [0,N)}
- * The test you want to do requires (1) having the prefill slot mapping
- for all tokens present during prefill, the number of which is
- M = \\sum_i{K_i}, and (2) having the decode slot mapping for all N
- decoded tokens
-
- This function consumes a single 1D slot mapping, which is the
- concatenation of N slot mappings each of length K_i + 1 (corresponding
- to the sequence lengths after decode), with a total length of
- P = \\sum_i{K_i + 1} = M + N
- The prefill-phase slot mapping results from excising the (K_i + 1)-th entry
- from each of the N subsequences in the slot mapping (i.e. omitting the
- decoded token's mapping.)
- The N excised entries are appended to obtain the decode-phase slot mapping
- Arguments:
- * slot_mapping_list: Length-P 1D slot mapping (as List) reflecting all N
- post-decode sequences
- * seq_lens: List of N post-decode sequence lengths (K_i + 1 in the
- description above)
- * device: cuda, cpu, etc.
- Returns:
- * prefill_slot_mapping: Length-M 1D slot mapping (as Tensor)
- reflecting all N prefill prompts
- * decode_slot_mapping: Length-N 1D slot mapping (as Tensor) reflecting
- all N decoded tokens
- '''
- prefill_slot_mapping = []
- decode_slot_mapping = []
- base_idx = 0
- for seq_len in seq_lens:
- prefill_slot_mapping.extend(slot_mapping_list[base_idx:(base_idx +
- seq_len - 1)])
- decode_slot_mapping.append(slot_mapping_list[base_idx + seq_len - 1])
- base_idx += seq_len
- return (maybe_make_long_tensor(prefill_slot_mapping, device),
- maybe_make_long_tensor(decode_slot_mapping, device))
- def make_block_tables_slot_mapping(
- block_size: int,
- seq_lens: List[int],
- device: Union[torch.device, str],
- block_base_addr: int = 0) -> Tuple[torch.Tensor, List[int], int]:
- '''
- Construct fake block tables & slot mappings.
- For a sequence with num_tokens tokens the minimum number
- of required KV cache blocks is
- num_blocks = (num_tokens + block_size) // block_size
- Then the minimum KV cache size in blocks is
- total_cache_blocks = sum(num_blocks for all seqs)
- Then, the blocktable mapping counts downward from
- block_base_addr + total_cache_blocks
- to
- block_base_addr
-
- The constructed block-tables and slot-mapping are sized to the
- lengths of the sequences in their entirety (as reflected by seq_lens),
- i.e. the total of prefill prompt tokens + decoded tokens.
- Arguments:
- * block_size: number of offsets per block
- * seq_lens: list of token-counts for each sequence
- * block_base_addr: the block table base address
- * device: CPU or CUDA device
- Return:
- * block_tables_tensor: block table for sequence
- * slot_mapping_list: slot mapping for sequence
- * max_block_idx: the highest block address within this block table
- '''
- # Provision minimum number of KV cache blocks
- num_blocks_list = [
- _num_tokens_to_min_blocks(num_tokens, block_size)
- for num_tokens in seq_lens
- ]
- max_block_table_len = max(num_blocks_list)
- block_table_pad_tokens = 10
- block_tables = []
- slot_mapping_list = []
- # Compute uppermost address of block table
- total_cache_blocks = sum(num_blocks_list)
- block_base_idx = block_base_addr + total_cache_blocks
- max_block_idx = block_base_idx
- for sdx, num_tokens in enumerate(seq_lens):
- num_blocks = num_blocks_list[sdx]
- block_table = list(
- range(block_base_idx, block_base_idx - num_blocks, -1))
- for idx in range(num_tokens):
- mapping_value = (
- idx % block_size) + block_table[idx // block_size] * block_size
- slot_mapping_list.append(mapping_value)
- block_base_idx -= num_blocks
- block_tables.append(block_table)
- block_tables_tensor = make_tensor_with_pad(
- block_tables,
- max_len=max_block_table_len + block_table_pad_tokens,
- pad=0,
- dtype=torch.int,
- device=device,
- )
- return (block_tables_tensor, slot_mapping_list, max_block_idx)
- def make_test_metadata(
- attn_backend: AttentionBackend,
- is_prompt: bool,
- seq_lens: Optional[List[int]],
- decoder_test_params: Optional[PhaseTestParameters],
- device: Union[torch.device, str],
- encoder_test_params: Optional[PhaseTestParameters] = None,
- cross_test_params: Optional[PhaseTestParameters] = None
- ) -> AttentionMetadata:
- '''
- Construct fake attention metadata for a given test phase
- (prefill-phase or decode-phase).
- encoder_test_params and cross_test_params arguments allow encoder
- attention and enc/dec cross-attention (respectively) to use distinct
- metadata values from decoder self-attention (decoder_test_params.)
-
- if encoder_test_params and cross_test_params are None, the attention
- metadata will support decoder-only scenario.
- Assumptions:
- * No chunked prefill -> a batch is 100% prefill or 100% decode, never both
- Arguments:
- * attn_backend: Backend for sourcing attention kernels
- * is_prompt: prefill if True, o/w decode
- * seq_lens: list of token counts for each sequence
- * decoder_test_params: decoder self-attention test params;
- this function requires
- kv_mmap (memory mapping) field
- * device: CPU or CUDA device
- * encoder_test_params: encoder attention test params;
- this function requires encoder query
- sequence lengths field. If None,
- encoder query sequence lengths are
- treated as None
- * cross_test_params: enc/dec cross-attention test params;
- this function requires kv_mmap field.
- If None, KV cache memory map data
- structures are treated as None
- Return:
- * AttentionMetadata structure
- '''
- # Decoder self-attention memory mapping
- # decoder_test_params is None signals encoder-only
- # scenario, so kv_mmap is None
- kv_mmap = (None
- if decoder_test_params is None else decoder_test_params.kv_mmap)
- # This function constructs metadata assuming no chunked prefill,
- # i.e. 100% prefill tokens or 100% decode tokens
- #
- # - If is_prompt, num_prefills_or_decodes is the number of prefills
- # and num_prefill_or_decode_tokens is the number of prefill tokens
- # - If not is_prompt, num_prefills_or_decodes is the number of decodes
- # and num_prefill_or_decode_tokens is the number of decode tokens
- #
- # seq_lens is None signals encoder-only
- # scenario, in which case num_prefills_or_decodes and
- # num_prefill_or_decode_tokens are unused
- num_prefills_or_decodes = (None if seq_lens is None else len(seq_lens))
- num_prefill_or_decode_tokens = (None if seq_lens is None else (
- sum(seq_lens) if is_prompt else len(seq_lens)))
- # Seems for non-prefix-caching scenarios context_lens
- # is never needed
- context_lens = None
- if encoder_test_params is None:
- encoder_seq_lens = None
- num_encoder_tokens = None
- else:
- # Encoder/decoder or encoder-only models only:
- # * Extract encoder input sequence lengths
- assert encoder_test_params.packed_qkvo.packed_qkv is not None
- encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens
- num_encoder_tokens = (None if encoder_seq_lens is None else
- (sum(encoder_seq_lens)))
- if cross_test_params is None:
- cross_kv_mmap = None
- else:
- # Encoder/decoder or encoder-only models only:
- # * Extract *cross-attention* slot_mapping and block table
- # (kv_mmap)
- cross_kv_mmap = cross_test_params.kv_mmap
- if is_prompt:
- # Prefill-phase scenario
- num_prefills = num_prefills_or_decodes
- num_prefill_tokens = num_prefill_or_decode_tokens
- num_decode_tokens = 0
- (
- seq_lens_tensor,
- context_lens_tensor,
- _,
- _,
- _,
- encoder_seq_lens_tensor,
- max_encoder_seq_len,
- ) = _make_metadata_tensors(seq_lens,
- context_lens,
- encoder_seq_lens,
- device=device)
- return attn_backend.make_metadata(
- num_prefills=num_prefills,
- slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
- num_prefill_tokens=num_prefill_tokens,
- num_decode_tokens=num_decode_tokens,
- seq_lens=seq_lens,
- seq_lens_tensor=seq_lens_tensor,
- max_prefill_seq_len=None if seq_lens is None else max(seq_lens),
- max_decode_seq_len=0,
- context_lens_tensor=context_lens_tensor,
- block_tables=(None if kv_mmap is None else kv_mmap.block_tables),
- use_cuda_graph=False,
- num_encoder_tokens=num_encoder_tokens,
- encoder_seq_lens=encoder_seq_lens,
- encoder_seq_lens_tensor=encoder_seq_lens_tensor,
- max_encoder_seq_len=max_encoder_seq_len,
- cross_slot_mapping=(None if cross_kv_mmap is None else
- cross_kv_mmap.slot_mapping),
- cross_block_tables=(None if cross_kv_mmap is None else
- cross_kv_mmap.block_tables))
- else: # not is_prompt
- # Decode-phase scenario
- assert kv_mmap is not None
- assert num_prefill_or_decode_tokens is not None
- assert seq_lens is not None
- num_prefills = 0
- num_prefill_tokens = 0
- num_decode_tokens = num_prefill_or_decode_tokens
- (
- seq_lens_tensor,
- context_lens_tensor,
- _,
- _,
- _,
- encoder_seq_lens_tensor,
- max_encoder_seq_len,
- ) = _make_metadata_tensors(seq_lens,
- context_lens,
- encoder_seq_lens,
- device=device)
- return attn_backend.make_metadata(
- num_prefills=num_prefills,
- slot_mapping=kv_mmap.slot_mapping,
- num_prefill_tokens=num_prefill_tokens,
- num_decode_tokens=num_decode_tokens,
- seq_lens=seq_lens,
- seq_lens_tensor=seq_lens_tensor,
- max_prefill_seq_len=0,
- max_decode_seq_len=max(seq_lens),
- context_lens_tensor=context_lens_tensor,
- block_tables=kv_mmap.block_tables,
- use_cuda_graph=False,
- num_encoder_tokens=num_encoder_tokens,
- encoder_seq_lens=encoder_seq_lens,
- encoder_seq_lens_tensor=encoder_seq_lens_tensor,
- max_encoder_seq_len=max_encoder_seq_len,
- cross_slot_mapping=(None if cross_kv_mmap is None else
- cross_kv_mmap.slot_mapping),
- cross_block_tables=(None if cross_kv_mmap is None else
- cross_kv_mmap.block_tables))
- def assert_actual_matches_ideal(test_params: PhaseTestParameters,
- output_under_test: torch.Tensor) -> None:
- '''
- Assert that observed output matches the ideal output
- contained in the test parameters data structure.
- Arguments:
- * test_params: Test parameters including packed ideal output
- * output_under_test: actually observed output value
- '''
- ideal_output = test_params.packed_qkvo.ideal_output
- torch.testing.assert_close(ideal_output,
- output_under_test.view_as(ideal_output))
- def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
- torch._library.custom_ops.CustomOpDef],
- args: Tuple[Any, ...],
- kwargs: Optional[Dict[str, Any]] = None,
- *,
- test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
- raise_exception: bool = True,
- cond: bool = True) -> Dict[str, str]:
- return torch.library.opcheck(
- op,
- args,
- kwargs,
- test_utils=test_utils,
- raise_exception=raise_exception) if cond else {}
|