utils.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929
  1. """Kernel test utils"""
  2. import itertools
  3. import random
  4. from numbers import Number
  5. from typing import Any, List, NamedTuple, Optional, Tuple, Union
  6. import pytest
  7. import torch
  8. from aphrodite.attention import (AttentionBackend, AttentionMetadata,
  9. AttentionType)
  10. from aphrodite.attention.backends.xformers import XFormersBackend
  11. from aphrodite.common.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
  12. make_tensor_with_pad)
  13. class QKVInputs(NamedTuple):
  14. '''
  15. Data structure for representing unpacked attention inputs,
  16. query/key/values and their sequence lengths.
  17. Attributes:
  18. * {query,key,value}: unpacked (batch_size x padded_seq_len x
  19. num_heads x head_size) attention inputs
  20. * q_seq_lens: query sequence lengths list
  21. * kv_seq_lens: shared key/value sequence lengths list
  22. '''
  23. query: torch.Tensor
  24. key: torch.Tensor
  25. value: torch.Tensor
  26. q_seq_lens: List[int]
  27. kv_seq_lens: List[int]
  28. class QKVO(NamedTuple):
  29. '''
  30. Data structure for representing unpacked attention inputs,
  31. alongside unpacked known-correct attention output
  32. Attributes:
  33. * qkv: unpacked (batch_size x padded_seq_len x
  34. num_heads x head_size) attention inputs
  35. * ideal_output: unpacked (batch_size x padded_seq_len x
  36. num_heads x head_size) known-correct attention output
  37. '''
  38. qkv: QKVInputs
  39. ideal_output: torch.Tensor
  40. class PackedQKVInputs(NamedTuple):
  41. '''
  42. Data structure for representing packed attention inputs
  43. Attributes:
  44. * {query,key,value}: packed (number_of_tokens x num_heads
  45. x head_size) attention inputs
  46. * q_start_loc_list: list of query start locations within packed tensor
  47. * kv_start_loc_list: shared list of key/value start locations within
  48. packed tensor
  49. * q_seq_lens: query sequence lengths list
  50. * kv_seq_lens: shared key/value sequence lengths list
  51. '''
  52. query: torch.Tensor
  53. key: torch.Tensor
  54. value: torch.Tensor
  55. q_start_loc_list: Optional[List[int]]
  56. kv_start_loc_list: Optional[List[int]]
  57. q_seq_lens: Optional[List[int]]
  58. kv_seq_lens: Optional[List[int]]
  59. class PackedQKVO(NamedTuple):
  60. '''
  61. Data structure for representing packed attention inputs,
  62. alongside packed known-correct attention output
  63. Attributes:
  64. * packed_qkv: packed (number_of_tokens x num_heads
  65. x head_size) attention inputs
  66. * ideal_output: packed (number_of_tokens x num_heads
  67. x head_size) known-correct attention output
  68. '''
  69. packed_qkv: Optional[PackedQKVInputs]
  70. ideal_output: torch.Tensor
  71. class KVMemoryMap(NamedTuple):
  72. '''
  73. Data structure for encapsulating KV cache memory mapping.
  74. Attributes:
  75. * block_tables: KV cache block tables
  76. * slot_mapping: mapping of sequence offset to physical address
  77. '''
  78. block_tables: torch.Tensor
  79. slot_mapping: torch.Tensor
  80. class PhaseTestParameters(NamedTuple):
  81. '''
  82. Data structure for encapsulating the test parameters
  83. for a given test "phase" (prefill or decode phase) and attention
  84. scenario (encoder, decoder-self, encoder/decoder-cross)
  85. Attributes:
  86. * packed_qkvo: packed (number_of_tokens x num_heads
  87. x head_size) attention inputs & known-correct
  88. output
  89. * kv_mmap: KV cache memory mapping, specific to this test phase &
  90. attention scenario
  91. '''
  92. packed_qkvo: PackedQKVO
  93. kv_mmap: Optional[KVMemoryMap]
  94. def maybe_make_int_tensor(
  95. _list: Optional[List[int]],
  96. device: Union[torch.device, str],
  97. ) -> torch.Tensor:
  98. '''
  99. Convert Python int list to a 1D int torch.Tensor on `device`
  100. Returns:
  101. * If _list is not None: 1D int torch.Tensor on `device`
  102. * None otherwise
  103. '''
  104. return None if _list is None else torch.tensor(
  105. _list, dtype=torch.int, device=device)
  106. def maybe_make_long_tensor(
  107. _list: Optional[List[int]],
  108. device: Union[torch.device, str],
  109. ) -> torch.Tensor:
  110. '''
  111. Convert Python int list to a 1D long torch.Tensor on `device`
  112. Returns:
  113. * If _list is not None: 1D long torch.Tensor on `device`
  114. * None otherwise
  115. '''
  116. return None if _list is None else torch.tensor(
  117. _list, dtype=torch.long, device=device)
  118. def maybe_max(_list: Optional[List]) -> Optional[Number]:
  119. '''
  120. Returns:
  121. * If _list is not None: max(_list)
  122. * None otherwise
  123. '''
  124. return None if _list is None else max(_list)
  125. def make_causal_mask(
  126. q_max_seq_len: int,
  127. kv_max_seq_len: int,
  128. ) -> torch.Tensor:
  129. '''
  130. Create a q_max_seq_len x kv_max_seq_len causal mask
  131. Arguments:
  132. * q_max_seq_len: query max seq len
  133. * kv_max_seq_len: key/value max seq len
  134. Returns:
  135. * 2D tensor, q_max_seq_len x kv_max_seq_len
  136. '''
  137. # Create a matrix where entry (i, j) is True if i >= j
  138. mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), diagonal=1)
  139. # Replace True with float('-inf') and False with 0
  140. mask = mask.masked_fill(mask == 1,
  141. float('-inf')).masked_fill(mask == 0, 0.0)
  142. return mask
  143. def override_backend_env_variable(mpatch: pytest.MonkeyPatch,
  144. backend_name: str) -> None:
  145. '''
  146. Override the environment variable indicating the vLLM backend temporarily,
  147. using pytest monkeypatch to ensure that the env vars get
  148. reset once the test context exits.
  149. Arguments:
  150. * mpatch: pytest monkeypatch instance
  151. * backend_name: attention backend name to force
  152. '''
  153. mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name)
  154. def ref_masked_attention(query: torch.Tensor,
  155. key: torch.Tensor,
  156. value: torch.Tensor,
  157. scale: float,
  158. custom_mask: Optional[torch.Tensor] = None,
  159. q_seq_lens: Optional[List] = None,
  160. kv_seq_lens: Optional[List] = None) -> torch.Tensor:
  161. '''
  162. "Golden" masked attention reference. Supports two types of masking:
  163. * Basic attention mask, utilizing {q,kv}_seq_lens args to mask out
  164. padding elements
  165. * Custom attention mask, which can force an arbitrary mask tensor, i.e.
  166. causal
  167. Arguments:
  168. * query: batch_size x q_padded_seq_len x num_heads x head_size
  169. * key: batch_size x kv_padded_seq_len x num_heads x head_size
  170. * value: batch_size x kv_padded_seq_len x num_heads x head_size
  171. * scale: Attention scale factor
  172. * custom_mask: custom attention mask; good place to inject a causal
  173. attention mask
  174. * q_seq_lens: list of unpadded query seq_lens for each batch index
  175. * kv_seq_lens: list of unpadded key/value seq_lens for each batch index
  176. Returns:
  177. * Attention result, batch_size x q_padded_seq_len x num_heads x head_size
  178. '''
  179. assert q_seq_lens is not None
  180. assert kv_seq_lens is not None
  181. batch_size = query.shape[0]
  182. assert (len(q_seq_lens) == batch_size)
  183. assert (len(kv_seq_lens) == batch_size)
  184. attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float()
  185. # Basic attention mask, derived from seq lens
  186. if (q_seq_lens is not None) or (kv_seq_lens is not None):
  187. attn_mask = torch.zeros_like(attn_weights)
  188. if q_seq_lens is not None:
  189. for bdx, plen in enumerate(q_seq_lens):
  190. attn_mask[bdx, :, plen:, :] = -torch.inf
  191. if kv_seq_lens is not None:
  192. for bdx, plen in enumerate(kv_seq_lens):
  193. attn_mask[bdx, :, :, plen:] = -torch.inf
  194. attn_weights = attn_weights + attn_mask.float()
  195. # Custom attention mask
  196. if custom_mask is not None:
  197. attn_weights = attn_weights + custom_mask.float()
  198. attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
  199. out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value)
  200. return out
  201. def make_qkv(
  202. batch_size: int,
  203. max_q_seq_len: int,
  204. max_kv_seq_len: Optional[int],
  205. num_heads: int,
  206. head_size: int,
  207. device: Union[torch.device, str],
  208. force_kv_seq_lens: Optional[List[int]] = None,
  209. attn_type: AttentionType = AttentionType.ENCODER_DECODER,
  210. force_max_len: bool = False,
  211. ) -> Tuple[QKVInputs, QKVInputs, QKVInputs]:
  212. '''
  213. Construct QKV test tensors for self- and cross-attention.
  214. Generates three query/key/value triplets:
  215. * "Baseline" query/key/value (for input to reference attention function)
  216. * "Prefill" query/key/value (last sequence offset zero'd out, for use as
  217. input to prefill kernel)
  218. * "Decode" query/key/value (only the last sequence offset from baseline,
  219. for use as input to decode kernel)
  220. Each Q/K/V triplet is associated with a list of q seqlens and a list of k/v
  221. seqlens
  222. Arguments:
  223. * batch_size
  224. * max_q_seq_len: max query seq len
  225. * max_kv_seq_len: max key/value seq len
  226. * num_heads
  227. * head_size
  228. * is_encoder_decoder_attn: if True, query seqlen may differ from
  229. key/value seqlen (as is often the case for cross-attention);
  230. o/w, query/key/value seqlens match at each batch index
  231. (max_kv_seq_len is unused)
  232. * force_kv_seq_lens: if not None, overrides kv sequence lengths
  233. * attn_type: encoder, decoder self, or enc/dec cross attention
  234. * force_max_len: if True, all query seqlens are max_q_seq_len; o/w query
  235. seqlens are random in [2,max_q_seq_lens]. Same for key/value seqlens
  236. and max_kv_seq_len, unless forced by is_encoder_decoder_attn=False
  237. * device: CPU or CUDA device
  238. Returns:
  239. * Overall QKVInputs structure (containing full unpacked Q/K/V tensors)
  240. * Prefill QKVInputs structure (containing all but the last sequence offset)
  241. * Decode QKVInputs structure (containing all only the last sequence offset)
  242. '''
  243. if force_max_len:
  244. q_seq_lens = [max_q_seq_len for _ in range(batch_size)]
  245. else:
  246. q_seq_lens = [
  247. random.randint(2, max_q_seq_len) for _ in range(batch_size)
  248. ]
  249. kv_seq_lens = None
  250. if force_kv_seq_lens is not None:
  251. kv_seq_lens = force_kv_seq_lens
  252. elif attn_type != AttentionType.ENCODER_DECODER:
  253. # K,V seq lens match Q for self-attention
  254. kv_seq_lens = q_seq_lens
  255. else:
  256. # K,V seq lens are distinct from Q seq lens & random
  257. assert max_kv_seq_len is not None
  258. if force_max_len:
  259. kv_seq_lens = [max_kv_seq_len] * batch_size
  260. else:
  261. kv_seq_lens = [
  262. random.randint(2, max_kv_seq_len) for _ in range(batch_size)
  263. ]
  264. query = torch.rand(
  265. (batch_size, max_q_seq_len, num_heads, head_size)).to(device)
  266. key = torch.rand(
  267. (batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
  268. value = torch.rand(
  269. (batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
  270. prefill_query = torch.zeros(
  271. (batch_size, max_q_seq_len, num_heads, head_size)).to(device)
  272. prefill_key = torch.zeros(
  273. (batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
  274. prefill_value = torch.zeros(
  275. (batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
  276. decode_query = torch.zeros(
  277. (batch_size, 1, num_heads, head_size)).to(device)
  278. decode_key = torch.zeros((batch_size, 1, num_heads, head_size)).to(device)
  279. decode_value = torch.zeros(
  280. (batch_size, 1, num_heads, head_size)).to(device)
  281. for bdx, (q_seq_len, kv_seq_len) in enumerate(zip(q_seq_lens,
  282. kv_seq_lens)):
  283. query[bdx, q_seq_len:, :, :] = 0
  284. key[bdx, kv_seq_len:, :, :] = 0
  285. value[bdx, kv_seq_len:, :, :] = 0
  286. prefill_query[bdx,
  287. 0:(q_seq_len - 1), :, :] = query[bdx,
  288. 0:(q_seq_len - 1), :, :]
  289. prefill_key[bdx,
  290. 0:(kv_seq_len - 1), :, :] = key[bdx,
  291. 0:(kv_seq_len - 1), :, :]
  292. prefill_value[bdx, 0:(kv_seq_len -
  293. 1), :, :] = value[bdx, 0:(kv_seq_len - 1), :, :]
  294. decode_query[bdx, :, :, :] = query[bdx,
  295. (q_seq_len - 1):q_seq_len, :, :]
  296. decode_key[bdx, :, :, :] = key[bdx, (kv_seq_len - 1):kv_seq_len, :, :]
  297. decode_value[bdx, :, :, :] = value[bdx,
  298. (kv_seq_len - 1):kv_seq_len, :, :]
  299. prefill_q_seq_lens = [plen - 1 for plen in q_seq_lens]
  300. prefill_kv_seq_lens = [plen - 1 for plen in kv_seq_lens]
  301. decode_q_seq_lens = [1 for _ in q_seq_lens]
  302. decode_kv_seq_lens = [1 for _ in kv_seq_lens]
  303. return (
  304. QKVInputs(
  305. query, # Overall QKV inputs
  306. key,
  307. value,
  308. q_seq_lens,
  309. kv_seq_lens),
  310. QKVInputs(
  311. prefill_query, # Prefill subset of QKV sequences
  312. prefill_key,
  313. prefill_value,
  314. prefill_q_seq_lens,
  315. prefill_kv_seq_lens),
  316. QKVInputs(
  317. decode_query, # Decode subset of KV sequences
  318. decode_key,
  319. decode_value,
  320. decode_q_seq_lens,
  321. decode_kv_seq_lens))
  322. def pack_tensor(
  323. unpacked_tensor: torch.Tensor, seq_lens: List[int],
  324. device: Union[torch.device, str]) -> Tuple[torch.Tensor, List[int]]:
  325. '''
  326. Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an
  327. unpadded number_of_tokens x num_heads x head_size tensor, where
  328. number_of_tokens = sum(seq_lens)
  329. Arguments:
  330. * unpacked_tensor: batch_size x padded_seq_len x num_heads x head_size
  331. * seq_lens: list of token counts for each seq
  332. * device: CPU or CUDA device
  333. Returns
  334. * packed_tensor: number_of_tokens x num_heads x head_size
  335. * start_loc_list: start idx of each batch elt in packed_tensor; [0] +
  336. list(itertools.accumulate(seq_lens))
  337. '''
  338. num_tok = sum(seq_lens)
  339. num_heads = unpacked_tensor.shape[-2]
  340. head_size = unpacked_tensor.shape[-1]
  341. start_loc_list = [0] + list(itertools.accumulate(seq_lens))
  342. packed_tensor = torch.zeros((num_tok, num_heads, head_size), device=device)
  343. for bdx, (seq_len, start_loc) in enumerate(zip(seq_lens, start_loc_list)):
  344. packed_tensor[start_loc:(
  345. start_loc + seq_len), :, :] = unpacked_tensor[bdx, :seq_len, :, :]
  346. return packed_tensor, start_loc_list
  347. def pack_qkv(qkv: QKVInputs, device: Union[torch.device,
  348. str]) -> PackedQKVInputs:
  349. '''
  350. Individually pack each of Q, K and V, each with dimensions batch_size x
  351. padded_seq_len x num_heads x head_size, into respective number_of_tokens x
  352. num_heads x head_size tensors.
  353. For Q, number_of_tokens = sum(q_seq_lens).
  354. For K and V, number_of_tokens = sum(kv_seq_lens)
  355. Arguments:
  356. * qkv: Unpacked (batch_size x padded_seq_len x num_heads x head_size)
  357. attention inputs
  358. * device: CPU or CUDA device
  359. Returns
  360. * Packed (number_of_tokens x num_heads x head_size) QKV inputs
  361. derived from unpacked inputs
  362. '''
  363. if qkv.query is None:
  364. packed_query = None
  365. q_start_loc_list = None
  366. else:
  367. packed_query, q_start_loc_list = pack_tensor(qkv.query,
  368. qkv.q_seq_lens,
  369. device=device)
  370. packed_key, kv_start_loc_list = pack_tensor(qkv.key,
  371. qkv.kv_seq_lens,
  372. device=device)
  373. packed_value, _ = pack_tensor(qkv.value, qkv.kv_seq_lens, device=device)
  374. return PackedQKVInputs(
  375. packed_query, packed_key, packed_value, q_start_loc_list,
  376. kv_start_loc_list,
  377. (None if q_start_loc_list is None else qkv.q_seq_lens),
  378. qkv.kv_seq_lens)
  379. def make_backend(backend_name: str) -> AttentionBackend:
  380. '''
  381. Construct the backend instance determined by the backend_name string
  382. argument.
  383. "XFORMERS" -> construct xformers backend
  384. TODO: other backends
  385. Note: at time of writing the Attention wrapper automatically selects
  386. its own backend for Attention.forward(); so the backend instance which
  387. you generate with this function is not meant to be used for *running*
  388. inference, but rather for generating compatible metadata structures
  389. using backend.make_metadata()
  390. Returns:
  391. * Backend instance
  392. '''
  393. if backend_name == STR_XFORMERS_ATTN_VAL:
  394. return XFormersBackend()
  395. raise AssertionError(
  396. f"Unrecognized backend_name {backend_name} for unit test")
  397. def _make_metadata_tensors(
  398. seq_lens: Optional[List[int]], context_lens: Optional[List[int]],
  399. encoder_seq_lens: Optional[List[int]], device: Union[torch.device, str]
  400. ) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[List[int]],
  401. torch.Tensor, Optional[int]]:
  402. '''
  403. Build scalar & tensor values required to build attention metadata structure.
  404. Arguments:
  405. * seq_lens: list of token-counts for each decoder input seq
  406. * context_lens: list of context length values for each seq
  407. * encoder_seq_lens: list of token-counts for each encoder input seq
  408. * device: CPU or CUDA device
  409. Returns:
  410. * seq_lens_tensor: decoder seq_lens list, as tensor
  411. * context_lens_tensor: context_lens list, as tensor
  412. * max_context_len: max(context_lens)
  413. * max_seq_len: max(seq_lens)
  414. * seq_start_loc: start idx of each sequence
  415. * max_encoder_seq_len: encoder seq_lens list, as tensor
  416. '''
  417. seq_lens_tensor = maybe_make_int_tensor(seq_lens, device)
  418. context_lens_tensor = maybe_make_int_tensor(context_lens, device)
  419. max_context_len = maybe_max(context_lens)
  420. max_seq_len = maybe_max(seq_lens)
  421. encoder_seq_lens_tensor = maybe_make_int_tensor(encoder_seq_lens, device)
  422. max_encoder_seq_len = (None if encoder_seq_lens is None else
  423. max(encoder_seq_lens))
  424. seq_start_loc = None
  425. return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len,
  426. seq_start_loc, encoder_seq_lens_tensor, max_encoder_seq_len)
  427. def make_kv_cache(num_blocks: int,
  428. num_heads: int,
  429. head_size: int,
  430. block_size: int,
  431. device: Union[torch.device, str],
  432. default_val: float = 0.0) -> torch.Tensor:
  433. '''
  434. Create a fake KV cache.
  435. Arguments:
  436. * num_blocks: number of blocks in the KV cache
  437. * num_heads: number of attention heads
  438. * head_size: head dimension
  439. * block_size: number of offsets within a block
  440. * device: CPU or CUDA device
  441. * default_val: initialization value for KV cache elements
  442. Returns:
  443. * kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
  444. '''
  445. kv_cache = torch.rand(
  446. (2, num_blocks, block_size * num_heads * head_size)).to(device)
  447. if default_val is not None:
  448. kv_cache[:, :, :] = default_val
  449. return kv_cache
  450. def _num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int:
  451. '''
  452. Compute the minimum number of blocks required to hold num_tokens tokens,
  453. given block_size
  454. '''
  455. return (num_tokens + block_size) // block_size
  456. def make_empty_slot_mapping_tensor(device: Union[torch.device, str]):
  457. return maybe_make_long_tensor([], device)
  458. def make_empty_block_tables_tensor(device: Union[torch.device, str]):
  459. return torch.tensor([], device=device)
  460. def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int],
  461. device: Union[torch.device, str]):
  462. '''
  463. Split a slot mapping into valid prefill- and decode-phase slot mappings.
  464. Context:
  465. * Your goal is to test (1) prefill of N prompts, with prompt-lengths
  466. {K_i \\forall i \\in [0,N)}, followed by (2) decoding of a single token
  467. for all N prompts (N tokens total); the resultant sequence lengths
  468. after decode would be {K_i + 1 for i \\in [0,N)}
  469. * The test you want to do requires (1) having the prefill slot mapping
  470. for all tokens present during prefill, the number of which is
  471. M = \\sum_i{K_i}, and (2) having the decode slot mapping for all N
  472. decoded tokens
  473. This function consumes a single 1D slot mapping, which is the
  474. concatenation of N slot mappings each of length K_i + 1 (corresponding
  475. to the sequence lengths after decode), with a total length of
  476. P = \\sum_i{K_i + 1} = M + N
  477. The prefill-phase slot mapping results from excising the (K_i + 1)-th entry
  478. from each of the N subsequences in the slot mapping (i.e. omitting the
  479. decoded token's mapping.)
  480. The N excised entries are appended to obtain the decode-phase slot mapping
  481. Arguments:
  482. * slot_mapping_list: Length-P 1D slot mapping (as List) reflecting all N
  483. post-decode sequences
  484. * seq_lens: List of N post-decode sequence lengths (K_i + 1 in the
  485. description above)
  486. * device: cuda, cpu, etc.
  487. Returns:
  488. * prefill_slot_mapping: Length-M 1D slot mapping (as Tensor)
  489. reflecting all N prefill prompts
  490. * decode_slot_mapping: Length-N 1D slot mapping (as Tensor) reflecting
  491. all N decoded tokens
  492. '''
  493. prefill_slot_mapping = []
  494. decode_slot_mapping = []
  495. base_idx = 0
  496. for seq_len in seq_lens:
  497. prefill_slot_mapping.extend(slot_mapping_list[base_idx:(base_idx +
  498. seq_len - 1)])
  499. decode_slot_mapping.append(slot_mapping_list[base_idx + seq_len - 1])
  500. base_idx += seq_len
  501. return (maybe_make_long_tensor(prefill_slot_mapping, device),
  502. maybe_make_long_tensor(decode_slot_mapping, device))
  503. def make_block_tables_slot_mapping(
  504. block_size: int,
  505. seq_lens: List[int],
  506. device: Union[torch.device, str],
  507. block_base_addr: int = 0) -> Tuple[torch.Tensor, List[int], int]:
  508. '''
  509. Construct fake block tables & slot mappings.
  510. For a sequence with num_tokens tokens the minimum number
  511. of required KV cache blocks is
  512. num_blocks = (num_tokens + block_size) // block_size
  513. Then the minimum KV cache size in blocks is
  514. total_cache_blocks = sum(num_blocks for all seqs)
  515. Then, the blocktable mapping counts downward from
  516. block_base_addr + total_cache_blocks
  517. to
  518. block_base_addr
  519. The constructed block-tables and slot-mapping are sized to the
  520. lengths of the sequences in their entirety (as reflected by seq_lens),
  521. i.e. the total of prefill prompt tokens + decoded tokens.
  522. Arguments:
  523. * block_size: number of offsets per block
  524. * seq_lens: list of token-counts for each sequence
  525. * block_base_addr: the block table base address
  526. * device: CPU or CUDA device
  527. Return:
  528. * block_tables_tensor: block table for sequence
  529. * slot_mapping_list: slot mapping for sequence
  530. * max_block_idx: the highest block address within this block table
  531. '''
  532. # Provision minimum number of KV cache blocks
  533. num_blocks_list = [
  534. _num_tokens_to_min_blocks(num_tokens, block_size)
  535. for num_tokens in seq_lens
  536. ]
  537. max_block_table_len = max(num_blocks_list)
  538. block_table_pad_tokens = 10
  539. block_tables = []
  540. slot_mapping_list = []
  541. # Compute uppermost address of block table
  542. total_cache_blocks = sum(num_blocks_list)
  543. block_base_idx = block_base_addr + total_cache_blocks
  544. max_block_idx = block_base_idx
  545. for sdx, num_tokens in enumerate(seq_lens):
  546. num_blocks = num_blocks_list[sdx]
  547. block_table = list(
  548. range(block_base_idx, block_base_idx - num_blocks, -1))
  549. for idx in range(num_tokens):
  550. mapping_value = (
  551. idx % block_size) + block_table[idx // block_size] * block_size
  552. slot_mapping_list.append(mapping_value)
  553. block_base_idx -= num_blocks
  554. block_tables.append(block_table)
  555. block_tables_tensor = make_tensor_with_pad(
  556. block_tables,
  557. max_len=max_block_table_len + block_table_pad_tokens,
  558. pad=0,
  559. dtype=torch.int,
  560. device=device,
  561. )
  562. return (block_tables_tensor, slot_mapping_list, max_block_idx)
  563. def make_test_metadata(
  564. attn_backend: AttentionBackend,
  565. is_prompt: bool,
  566. seq_lens: Optional[List[int]],
  567. decoder_test_params: Optional[PhaseTestParameters],
  568. device: Union[torch.device, str],
  569. encoder_test_params: Optional[PhaseTestParameters] = None,
  570. cross_test_params: Optional[PhaseTestParameters] = None
  571. ) -> AttentionMetadata:
  572. '''
  573. Construct fake attention metadata for a given test phase
  574. (prefill-phase or decode-phase).
  575. encoder_test_params and cross_test_params arguments allow encoder
  576. attention and enc/dec cross-attention (respectively) to use distinct
  577. metadata values from decoder self-attention (decoder_test_params.)
  578. if encoder_test_params and cross_test_params are None, the attention
  579. metadata will support decoder-only scenario.
  580. Assumptions:
  581. * No chunked prefill -> a batch is 100% prefill or 100% decode, never both
  582. Arguments:
  583. * attn_backend: Backend for sourcing attention kernels
  584. * is_prompt: prefill if True, o/w decode
  585. * seq_lens: list of token counts for each sequence
  586. * decoder_test_params: decoder self-attention test params;
  587. this function requires
  588. kv_mmap (memory mapping) field
  589. * device: CPU or CUDA device
  590. * encoder_test_params: encoder attention test params;
  591. this function requires encoder query
  592. sequence lengths field. If None,
  593. encoder query sequence lengths are
  594. treated as None
  595. * cross_test_params: enc/dec cross-attention test params;
  596. this function requires kv_mmap field.
  597. If None, KV cache memory map data
  598. structures are treated as None
  599. Return:
  600. * AttentionMetadata structure
  601. '''
  602. # Decoder self-attention memory mapping
  603. # decoder_test_params is None signals encoder-only
  604. # scenario, so kv_mmap is None
  605. kv_mmap = (None
  606. if decoder_test_params is None else decoder_test_params.kv_mmap)
  607. # This function constructs metadata assuming no chunked prefill,
  608. # i.e. 100% prefill tokens or 100% decode tokens
  609. #
  610. # - If is_prompt, num_prefills_or_decodes is the number of prefills
  611. # and num_prefill_or_decode_tokens is the number of prefill tokens
  612. # - If not is_prompt, num_prefills_or_decodes is the number of decodes
  613. # and num_prefill_or_decode_tokens is the number of decode tokens
  614. #
  615. # seq_lens is None signals encoder-only
  616. # scenario, in which case num_prefills_or_decodes and
  617. # num_prefill_or_decode_tokens are unused
  618. num_prefills_or_decodes = (None if seq_lens is None else len(seq_lens))
  619. num_prefill_or_decode_tokens = (None if seq_lens is None else (
  620. sum(seq_lens) if is_prompt else len(seq_lens)))
  621. # Seems for non-prefix-caching scenarios context_lens
  622. # is never needed
  623. context_lens = None
  624. if encoder_test_params is None:
  625. encoder_seq_lens = None
  626. num_encoder_tokens = None
  627. else:
  628. # Encoder/decoder or encoder-only models only:
  629. # * Extract encoder input sequence lengths
  630. assert encoder_test_params.packed_qkvo.packed_qkv is not None
  631. encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens
  632. num_encoder_tokens = (None if encoder_seq_lens is None else
  633. (sum(encoder_seq_lens)))
  634. if cross_test_params is None:
  635. cross_kv_mmap = None
  636. else:
  637. # Encoder/decoder or encoder-only models only:
  638. # * Extract *cross-attention* slot_mapping and block table
  639. # (kv_mmap)
  640. cross_kv_mmap = cross_test_params.kv_mmap
  641. if is_prompt:
  642. # Prefill-phase scenario
  643. num_prefills = num_prefills_or_decodes
  644. num_prefill_tokens = num_prefill_or_decode_tokens
  645. num_decode_tokens = 0
  646. (
  647. seq_lens_tensor,
  648. context_lens_tensor,
  649. _,
  650. _,
  651. _,
  652. encoder_seq_lens_tensor,
  653. max_encoder_seq_len,
  654. ) = _make_metadata_tensors(seq_lens,
  655. context_lens,
  656. encoder_seq_lens,
  657. device=device)
  658. return attn_backend.make_metadata(
  659. num_prefills=num_prefills,
  660. slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
  661. num_prefill_tokens=num_prefill_tokens,
  662. num_decode_tokens=num_decode_tokens,
  663. seq_lens=seq_lens,
  664. seq_lens_tensor=seq_lens_tensor,
  665. max_prefill_seq_len=None if seq_lens is None else max(seq_lens),
  666. max_decode_seq_len=0,
  667. context_lens_tensor=context_lens_tensor,
  668. block_tables=(None if kv_mmap is None else kv_mmap.block_tables),
  669. use_cuda_graph=False,
  670. num_encoder_tokens=num_encoder_tokens,
  671. encoder_seq_lens=encoder_seq_lens,
  672. encoder_seq_lens_tensor=encoder_seq_lens_tensor,
  673. max_encoder_seq_len=max_encoder_seq_len,
  674. cross_slot_mapping=(None if cross_kv_mmap is None else
  675. cross_kv_mmap.slot_mapping),
  676. cross_block_tables=(None if cross_kv_mmap is None else
  677. cross_kv_mmap.block_tables))
  678. else: # not is_prompt
  679. # Decode-phase scenario
  680. assert kv_mmap is not None
  681. assert num_prefill_or_decode_tokens is not None
  682. assert seq_lens is not None
  683. num_prefills = 0
  684. num_prefill_tokens = 0
  685. num_decode_tokens = num_prefill_or_decode_tokens
  686. (
  687. seq_lens_tensor,
  688. context_lens_tensor,
  689. _,
  690. _,
  691. _,
  692. encoder_seq_lens_tensor,
  693. max_encoder_seq_len,
  694. ) = _make_metadata_tensors(seq_lens,
  695. context_lens,
  696. encoder_seq_lens,
  697. device=device)
  698. return attn_backend.make_metadata(
  699. num_prefills=num_prefills,
  700. slot_mapping=kv_mmap.slot_mapping,
  701. num_prefill_tokens=num_prefill_tokens,
  702. num_decode_tokens=num_decode_tokens,
  703. seq_lens=seq_lens,
  704. seq_lens_tensor=seq_lens_tensor,
  705. max_prefill_seq_len=0,
  706. max_decode_seq_len=max(seq_lens),
  707. context_lens_tensor=context_lens_tensor,
  708. block_tables=kv_mmap.block_tables,
  709. use_cuda_graph=False,
  710. num_encoder_tokens=num_encoder_tokens,
  711. encoder_seq_lens=encoder_seq_lens,
  712. encoder_seq_lens_tensor=encoder_seq_lens_tensor,
  713. max_encoder_seq_len=max_encoder_seq_len,
  714. cross_slot_mapping=(None if cross_kv_mmap is None else
  715. cross_kv_mmap.slot_mapping),
  716. cross_block_tables=(None if cross_kv_mmap is None else
  717. cross_kv_mmap.block_tables))
  718. def assert_actual_matches_ideal(test_params: PhaseTestParameters,
  719. output_under_test: torch.Tensor) -> None:
  720. '''
  721. Assert that observed output matches the ideal output
  722. contained in the test parameters data structure.
  723. Arguments:
  724. * test_params: Test parameters including packed ideal output
  725. * output_under_test: actually observed output value
  726. '''
  727. ideal_output = test_params.packed_qkvo.ideal_output
  728. torch.testing.assert_close(ideal_output,
  729. output_under_test.view_as(ideal_output))