utils.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960
  1. """Kernel test utils"""
  2. import itertools
  3. import random
  4. from numbers import Number
  5. from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple,
  6. Union)
  7. import pytest
  8. import torch
  9. from aphrodite.attention import (AttentionBackend, AttentionMetadata,
  10. AttentionType)
  11. from aphrodite.attention.backends.xformers import XFormersBackend
  12. from aphrodite.common.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
  13. make_tensor_with_pad)
  14. # For now, disable "test_aot_dispatch_dynamic" since there are some
  15. # bugs related to this test in PyTorch 2.4.
  16. DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
  17. "test_schema",
  18. "test_autograd_registration",
  19. "test_faketensor",
  20. )
  21. ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
  22. "test_schema",
  23. "test_autograd_registration",
  24. "test_faketensor",
  25. "test_aot_dispatch_dynamic",
  26. )
  27. class QKVInputs(NamedTuple):
  28. '''
  29. Data structure for representing unpacked attention inputs,
  30. query/key/values and their sequence lengths.
  31. Attributes:
  32. * {query,key,value}: unpacked (batch_size x padded_seq_len x
  33. num_heads x head_size) attention inputs
  34. * q_seq_lens: query sequence lengths list
  35. * kv_seq_lens: shared key/value sequence lengths list
  36. '''
  37. query: torch.Tensor
  38. key: torch.Tensor
  39. value: torch.Tensor
  40. q_seq_lens: List[int]
  41. kv_seq_lens: List[int]
  42. class QKVO(NamedTuple):
  43. '''
  44. Data structure for representing unpacked attention inputs,
  45. alongside unpacked known-correct attention output
  46. Attributes:
  47. * qkv: unpacked (batch_size x padded_seq_len x
  48. num_heads x head_size) attention inputs
  49. * ideal_output: unpacked (batch_size x padded_seq_len x
  50. num_heads x head_size) known-correct attention output
  51. '''
  52. qkv: QKVInputs
  53. ideal_output: torch.Tensor
  54. class PackedQKVInputs(NamedTuple):
  55. '''
  56. Data structure for representing packed attention inputs
  57. Attributes:
  58. * {query,key,value}: packed (number_of_tokens x num_heads
  59. x head_size) attention inputs
  60. * q_start_loc_list: list of query start locations within packed tensor
  61. * kv_start_loc_list: shared list of key/value start locations within
  62. packed tensor
  63. * q_seq_lens: query sequence lengths list
  64. * kv_seq_lens: shared key/value sequence lengths list
  65. '''
  66. query: torch.Tensor
  67. key: torch.Tensor
  68. value: torch.Tensor
  69. q_start_loc_list: Optional[List[int]]
  70. kv_start_loc_list: Optional[List[int]]
  71. q_seq_lens: Optional[List[int]]
  72. kv_seq_lens: Optional[List[int]]
  73. class PackedQKVO(NamedTuple):
  74. '''
  75. Data structure for representing packed attention inputs,
  76. alongside packed known-correct attention output
  77. Attributes:
  78. * packed_qkv: packed (number_of_tokens x num_heads
  79. x head_size) attention inputs
  80. * ideal_output: packed (number_of_tokens x num_heads
  81. x head_size) known-correct attention output
  82. '''
  83. packed_qkv: Optional[PackedQKVInputs]
  84. ideal_output: torch.Tensor
  85. class KVMemoryMap(NamedTuple):
  86. '''
  87. Data structure for encapsulating KV cache memory mapping.
  88. Attributes:
  89. * block_tables: KV cache block tables
  90. * slot_mapping: mapping of sequence offset to physical address
  91. '''
  92. block_tables: torch.Tensor
  93. slot_mapping: torch.Tensor
  94. class PhaseTestParameters(NamedTuple):
  95. '''
  96. Data structure for encapsulating the test parameters
  97. for a given test "phase" (prefill or decode phase) and attention
  98. scenario (encoder, decoder-self, encoder/decoder-cross)
  99. Attributes:
  100. * packed_qkvo: packed (number_of_tokens x num_heads
  101. x head_size) attention inputs & known-correct
  102. output
  103. * kv_mmap: KV cache memory mapping, specific to this test phase &
  104. attention scenario
  105. '''
  106. packed_qkvo: PackedQKVO
  107. kv_mmap: Optional[KVMemoryMap]
  108. def maybe_make_int_tensor(
  109. _list: Optional[List[int]],
  110. device: Union[torch.device, str],
  111. ) -> torch.Tensor:
  112. '''
  113. Convert Python int list to a 1D int torch.Tensor on `device`
  114. Returns:
  115. * If _list is not None: 1D int torch.Tensor on `device`
  116. * None otherwise
  117. '''
  118. return None if _list is None else torch.tensor(
  119. _list, dtype=torch.int, device=device)
  120. def maybe_make_long_tensor(
  121. _list: Optional[List[int]],
  122. device: Union[torch.device, str],
  123. ) -> torch.Tensor:
  124. '''
  125. Convert Python int list to a 1D long torch.Tensor on `device`
  126. Returns:
  127. * If _list is not None: 1D long torch.Tensor on `device`
  128. * None otherwise
  129. '''
  130. return None if _list is None else torch.tensor(
  131. _list, dtype=torch.long, device=device)
  132. def maybe_max(_list: Optional[List]) -> Optional[Number]:
  133. '''
  134. Returns:
  135. * If _list is not None: max(_list)
  136. * None otherwise
  137. '''
  138. return None if _list is None else max(_list)
  139. def make_causal_mask(
  140. q_max_seq_len: int,
  141. kv_max_seq_len: int,
  142. ) -> torch.Tensor:
  143. '''
  144. Create a q_max_seq_len x kv_max_seq_len causal mask
  145. Arguments:
  146. * q_max_seq_len: query max seq len
  147. * kv_max_seq_len: key/value max seq len
  148. Returns:
  149. * 2D tensor, q_max_seq_len x kv_max_seq_len
  150. '''
  151. # Create a matrix where entry (i, j) is True if i >= j
  152. mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), diagonal=1)
  153. # Replace True with float('-inf') and False with 0
  154. mask = mask.masked_fill(mask == 1,
  155. float('-inf')).masked_fill(mask == 0, 0.0)
  156. return mask
  157. def override_backend_env_variable(mpatch: pytest.MonkeyPatch,
  158. backend_name: str) -> None:
  159. '''
  160. Override the environment variable indicating the vLLM backend temporarily,
  161. using pytest monkeypatch to ensure that the env vars get
  162. reset once the test context exits.
  163. Arguments:
  164. * mpatch: pytest monkeypatch instance
  165. * backend_name: attention backend name to force
  166. '''
  167. mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name)
  168. def ref_masked_attention(query: torch.Tensor,
  169. key: torch.Tensor,
  170. value: torch.Tensor,
  171. scale: float,
  172. custom_mask: Optional[torch.Tensor] = None,
  173. q_seq_lens: Optional[List] = None,
  174. kv_seq_lens: Optional[List] = None) -> torch.Tensor:
  175. '''
  176. "Golden" masked attention reference. Supports two types of masking:
  177. * Basic attention mask, utilizing {q,kv}_seq_lens args to mask out
  178. padding elements
  179. * Custom attention mask, which can force an arbitrary mask tensor, i.e.
  180. causal
  181. Arguments:
  182. * query: batch_size x q_padded_seq_len x num_heads x head_size
  183. * key: batch_size x kv_padded_seq_len x num_heads x head_size
  184. * value: batch_size x kv_padded_seq_len x num_heads x head_size
  185. * scale: Attention scale factor
  186. * custom_mask: custom attention mask; good place to inject a causal
  187. attention mask
  188. * q_seq_lens: list of unpadded query seq_lens for each batch index
  189. * kv_seq_lens: list of unpadded key/value seq_lens for each batch index
  190. Returns:
  191. * Attention result, batch_size x q_padded_seq_len x num_heads x head_size
  192. '''
  193. assert q_seq_lens is not None
  194. assert kv_seq_lens is not None
  195. batch_size = query.shape[0]
  196. assert (len(q_seq_lens) == batch_size)
  197. assert (len(kv_seq_lens) == batch_size)
  198. attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float()
  199. # Basic attention mask, derived from seq lens
  200. if (q_seq_lens is not None) or (kv_seq_lens is not None):
  201. attn_mask = torch.zeros_like(attn_weights)
  202. if q_seq_lens is not None:
  203. for bdx, plen in enumerate(q_seq_lens):
  204. attn_mask[bdx, :, plen:, :] = -torch.inf
  205. if kv_seq_lens is not None:
  206. for bdx, plen in enumerate(kv_seq_lens):
  207. attn_mask[bdx, :, :, plen:] = -torch.inf
  208. attn_weights = attn_weights + attn_mask.float()
  209. # Custom attention mask
  210. if custom_mask is not None:
  211. attn_weights = attn_weights + custom_mask.float()
  212. attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
  213. out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value)
  214. return out
  215. def make_qkv(
  216. batch_size: int,
  217. max_q_seq_len: int,
  218. max_kv_seq_len: Optional[int],
  219. num_heads: int,
  220. head_size: int,
  221. device: Union[torch.device, str],
  222. force_kv_seq_lens: Optional[List[int]] = None,
  223. attn_type: AttentionType = AttentionType.ENCODER_DECODER,
  224. force_max_len: bool = False,
  225. ) -> Tuple[QKVInputs, QKVInputs, QKVInputs]:
  226. '''
  227. Construct QKV test tensors for self- and cross-attention.
  228. Generates three query/key/value triplets:
  229. * "Baseline" query/key/value (for input to reference attention function)
  230. * "Prefill" query/key/value (last sequence offset zero'd out, for use as
  231. input to prefill kernel)
  232. * "Decode" query/key/value (only the last sequence offset from baseline,
  233. for use as input to decode kernel)
  234. Each Q/K/V triplet is associated with a list of q seqlens and a list of k/v
  235. seqlens
  236. Arguments:
  237. * batch_size
  238. * max_q_seq_len: max query seq len
  239. * max_kv_seq_len: max key/value seq len
  240. * num_heads
  241. * head_size
  242. * is_encoder_decoder_attn: if True, query seqlen may differ from
  243. key/value seqlen (as is often the case for cross-attention);
  244. o/w, query/key/value seqlens match at each batch index
  245. (max_kv_seq_len is unused)
  246. * force_kv_seq_lens: if not None, overrides kv sequence lengths
  247. * attn_type: encoder, decoder self, or enc/dec cross attention
  248. * force_max_len: if True, all query seqlens are max_q_seq_len; o/w query
  249. seqlens are random in [2,max_q_seq_lens]. Same for key/value seqlens
  250. and max_kv_seq_len, unless forced by is_encoder_decoder_attn=False
  251. * device: CPU or CUDA device
  252. Returns:
  253. * Overall QKVInputs structure (containing full unpacked Q/K/V tensors)
  254. * Prefill QKVInputs structure (containing all but the last sequence offset)
  255. * Decode QKVInputs structure (containing all only the last sequence offset)
  256. '''
  257. if force_max_len:
  258. q_seq_lens = [max_q_seq_len for _ in range(batch_size)]
  259. else:
  260. q_seq_lens = [
  261. random.randint(2, max_q_seq_len) for _ in range(batch_size)
  262. ]
  263. kv_seq_lens = None
  264. if force_kv_seq_lens is not None:
  265. kv_seq_lens = force_kv_seq_lens
  266. elif attn_type != AttentionType.ENCODER_DECODER:
  267. # K,V seq lens match Q for self-attention
  268. kv_seq_lens = q_seq_lens
  269. else:
  270. # K,V seq lens are distinct from Q seq lens & random
  271. assert max_kv_seq_len is not None
  272. if force_max_len:
  273. kv_seq_lens = [max_kv_seq_len] * batch_size
  274. else:
  275. kv_seq_lens = [
  276. random.randint(2, max_kv_seq_len) for _ in range(batch_size)
  277. ]
  278. query = torch.rand(
  279. (batch_size, max_q_seq_len, num_heads, head_size)).to(device)
  280. key = torch.rand(
  281. (batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
  282. value = torch.rand(
  283. (batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
  284. prefill_query = torch.zeros(
  285. (batch_size, max_q_seq_len, num_heads, head_size)).to(device)
  286. prefill_key = torch.zeros(
  287. (batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
  288. prefill_value = torch.zeros(
  289. (batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
  290. decode_query = torch.zeros(
  291. (batch_size, 1, num_heads, head_size)).to(device)
  292. decode_key = torch.zeros((batch_size, 1, num_heads, head_size)).to(device)
  293. decode_value = torch.zeros(
  294. (batch_size, 1, num_heads, head_size)).to(device)
  295. for bdx, (q_seq_len, kv_seq_len) in enumerate(zip(q_seq_lens,
  296. kv_seq_lens)):
  297. query[bdx, q_seq_len:, :, :] = 0
  298. key[bdx, kv_seq_len:, :, :] = 0
  299. value[bdx, kv_seq_len:, :, :] = 0
  300. prefill_query[bdx,
  301. 0:(q_seq_len - 1), :, :] = query[bdx,
  302. 0:(q_seq_len - 1), :, :]
  303. prefill_key[bdx,
  304. 0:(kv_seq_len - 1), :, :] = key[bdx,
  305. 0:(kv_seq_len - 1), :, :]
  306. prefill_value[bdx, 0:(kv_seq_len -
  307. 1), :, :] = value[bdx, 0:(kv_seq_len - 1), :, :]
  308. decode_query[bdx, :, :, :] = query[bdx,
  309. (q_seq_len - 1):q_seq_len, :, :]
  310. decode_key[bdx, :, :, :] = key[bdx, (kv_seq_len - 1):kv_seq_len, :, :]
  311. decode_value[bdx, :, :, :] = value[bdx,
  312. (kv_seq_len - 1):kv_seq_len, :, :]
  313. prefill_q_seq_lens = [plen - 1 for plen in q_seq_lens]
  314. prefill_kv_seq_lens = [plen - 1 for plen in kv_seq_lens]
  315. decode_q_seq_lens = [1 for _ in q_seq_lens]
  316. decode_kv_seq_lens = [1 for _ in kv_seq_lens]
  317. return (
  318. QKVInputs(
  319. query, # Overall QKV inputs
  320. key,
  321. value,
  322. q_seq_lens,
  323. kv_seq_lens),
  324. QKVInputs(
  325. prefill_query, # Prefill subset of QKV sequences
  326. prefill_key,
  327. prefill_value,
  328. prefill_q_seq_lens,
  329. prefill_kv_seq_lens),
  330. QKVInputs(
  331. decode_query, # Decode subset of KV sequences
  332. decode_key,
  333. decode_value,
  334. decode_q_seq_lens,
  335. decode_kv_seq_lens))
  336. def pack_tensor(
  337. unpacked_tensor: torch.Tensor, seq_lens: List[int],
  338. device: Union[torch.device, str]) -> Tuple[torch.Tensor, List[int]]:
  339. '''
  340. Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an
  341. unpadded number_of_tokens x num_heads x head_size tensor, where
  342. number_of_tokens = sum(seq_lens)
  343. Arguments:
  344. * unpacked_tensor: batch_size x padded_seq_len x num_heads x head_size
  345. * seq_lens: list of token counts for each seq
  346. * device: CPU or CUDA device
  347. Returns
  348. * packed_tensor: number_of_tokens x num_heads x head_size
  349. * start_loc_list: start idx of each batch elt in packed_tensor; [0] +
  350. list(itertools.accumulate(seq_lens))
  351. '''
  352. num_tok = sum(seq_lens)
  353. num_heads = unpacked_tensor.shape[-2]
  354. head_size = unpacked_tensor.shape[-1]
  355. start_loc_list = [0] + list(itertools.accumulate(seq_lens))
  356. packed_tensor = torch.zeros((num_tok, num_heads, head_size), device=device)
  357. for bdx, (seq_len, start_loc) in enumerate(zip(seq_lens, start_loc_list)):
  358. packed_tensor[start_loc:(
  359. start_loc + seq_len), :, :] = unpacked_tensor[bdx, :seq_len, :, :]
  360. return packed_tensor, start_loc_list
  361. def pack_qkv(qkv: QKVInputs, device: Union[torch.device,
  362. str]) -> PackedQKVInputs:
  363. '''
  364. Individually pack each of Q, K and V, each with dimensions batch_size x
  365. padded_seq_len x num_heads x head_size, into respective number_of_tokens x
  366. num_heads x head_size tensors.
  367. For Q, number_of_tokens = sum(q_seq_lens).
  368. For K and V, number_of_tokens = sum(kv_seq_lens)
  369. Arguments:
  370. * qkv: Unpacked (batch_size x padded_seq_len x num_heads x head_size)
  371. attention inputs
  372. * device: CPU or CUDA device
  373. Returns
  374. * Packed (number_of_tokens x num_heads x head_size) QKV inputs
  375. derived from unpacked inputs
  376. '''
  377. if qkv.query is None:
  378. packed_query = None
  379. q_start_loc_list = None
  380. else:
  381. packed_query, q_start_loc_list = pack_tensor(qkv.query,
  382. qkv.q_seq_lens,
  383. device=device)
  384. packed_key, kv_start_loc_list = pack_tensor(qkv.key,
  385. qkv.kv_seq_lens,
  386. device=device)
  387. packed_value, _ = pack_tensor(qkv.value, qkv.kv_seq_lens, device=device)
  388. return PackedQKVInputs(
  389. packed_query, packed_key, packed_value, q_start_loc_list,
  390. kv_start_loc_list,
  391. (None if q_start_loc_list is None else qkv.q_seq_lens),
  392. qkv.kv_seq_lens)
  393. def make_backend(backend_name: str) -> AttentionBackend:
  394. '''
  395. Construct the backend instance determined by the backend_name string
  396. argument.
  397. "XFORMERS" -> construct xformers backend
  398. TODO: other backends
  399. Note: at time of writing the Attention wrapper automatically selects
  400. its own backend for Attention.forward(); so the backend instance which
  401. you generate with this function is not meant to be used for *running*
  402. inference, but rather for generating compatible metadata structures
  403. using backend.make_metadata()
  404. Returns:
  405. * Backend instance
  406. '''
  407. if backend_name == STR_XFORMERS_ATTN_VAL:
  408. return XFormersBackend()
  409. raise AssertionError(
  410. f"Unrecognized backend_name {backend_name} for unit test")
  411. def _make_metadata_tensors(
  412. seq_lens: Optional[List[int]], context_lens: Optional[List[int]],
  413. encoder_seq_lens: Optional[List[int]], device: Union[torch.device, str]
  414. ) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[List[int]],
  415. torch.Tensor, Optional[int]]:
  416. '''
  417. Build scalar & tensor values required to build attention metadata structure.
  418. Arguments:
  419. * seq_lens: list of token-counts for each decoder input seq
  420. * context_lens: list of context length values for each seq
  421. * encoder_seq_lens: list of token-counts for each encoder input seq
  422. * device: CPU or CUDA device
  423. Returns:
  424. * seq_lens_tensor: decoder seq_lens list, as tensor
  425. * context_lens_tensor: context_lens list, as tensor
  426. * max_context_len: max(context_lens)
  427. * max_seq_len: max(seq_lens)
  428. * seq_start_loc: start idx of each sequence
  429. * max_encoder_seq_len: encoder seq_lens list, as tensor
  430. '''
  431. seq_lens_tensor = maybe_make_int_tensor(seq_lens, device)
  432. context_lens_tensor = maybe_make_int_tensor(context_lens, device)
  433. max_context_len = maybe_max(context_lens)
  434. max_seq_len = maybe_max(seq_lens)
  435. encoder_seq_lens_tensor = maybe_make_int_tensor(encoder_seq_lens, device)
  436. max_encoder_seq_len = (None if encoder_seq_lens is None else
  437. max(encoder_seq_lens))
  438. seq_start_loc = None
  439. return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len,
  440. seq_start_loc, encoder_seq_lens_tensor, max_encoder_seq_len)
  441. def make_kv_cache(num_blocks: int,
  442. num_heads: int,
  443. head_size: int,
  444. block_size: int,
  445. device: Union[torch.device, str],
  446. default_val: float = 0.0) -> torch.Tensor:
  447. '''
  448. Create a fake KV cache.
  449. Arguments:
  450. * num_blocks: number of blocks in the KV cache
  451. * num_heads: number of attention heads
  452. * head_size: head dimension
  453. * block_size: number of offsets within a block
  454. * device: CPU or CUDA device
  455. * default_val: initialization value for KV cache elements
  456. Returns:
  457. * kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
  458. '''
  459. kv_cache = torch.rand(
  460. (2, num_blocks, block_size * num_heads * head_size)).to(device)
  461. if default_val is not None:
  462. kv_cache[:, :, :] = default_val
  463. return kv_cache
  464. def _num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int:
  465. '''
  466. Compute the minimum number of blocks required to hold num_tokens tokens,
  467. given block_size
  468. '''
  469. return (num_tokens + block_size) // block_size
  470. def make_empty_slot_mapping_tensor(device: Union[torch.device, str]):
  471. return maybe_make_long_tensor([], device)
  472. def make_empty_block_tables_tensor(device: Union[torch.device, str]):
  473. return torch.tensor([], device=device)
  474. def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int],
  475. device: Union[torch.device, str]):
  476. '''
  477. Split a slot mapping into valid prefill- and decode-phase slot mappings.
  478. Context:
  479. * Your goal is to test (1) prefill of N prompts, with prompt-lengths
  480. {K_i \\forall i \\in [0,N)}, followed by (2) decoding of a single token
  481. for all N prompts (N tokens total); the resultant sequence lengths
  482. after decode would be {K_i + 1 for i \\in [0,N)}
  483. * The test you want to do requires (1) having the prefill slot mapping
  484. for all tokens present during prefill, the number of which is
  485. M = \\sum_i{K_i}, and (2) having the decode slot mapping for all N
  486. decoded tokens
  487. This function consumes a single 1D slot mapping, which is the
  488. concatenation of N slot mappings each of length K_i + 1 (corresponding
  489. to the sequence lengths after decode), with a total length of
  490. P = \\sum_i{K_i + 1} = M + N
  491. The prefill-phase slot mapping results from excising the (K_i + 1)-th entry
  492. from each of the N subsequences in the slot mapping (i.e. omitting the
  493. decoded token's mapping.)
  494. The N excised entries are appended to obtain the decode-phase slot mapping
  495. Arguments:
  496. * slot_mapping_list: Length-P 1D slot mapping (as List) reflecting all N
  497. post-decode sequences
  498. * seq_lens: List of N post-decode sequence lengths (K_i + 1 in the
  499. description above)
  500. * device: cuda, cpu, etc.
  501. Returns:
  502. * prefill_slot_mapping: Length-M 1D slot mapping (as Tensor)
  503. reflecting all N prefill prompts
  504. * decode_slot_mapping: Length-N 1D slot mapping (as Tensor) reflecting
  505. all N decoded tokens
  506. '''
  507. prefill_slot_mapping = []
  508. decode_slot_mapping = []
  509. base_idx = 0
  510. for seq_len in seq_lens:
  511. prefill_slot_mapping.extend(slot_mapping_list[base_idx:(base_idx +
  512. seq_len - 1)])
  513. decode_slot_mapping.append(slot_mapping_list[base_idx + seq_len - 1])
  514. base_idx += seq_len
  515. return (maybe_make_long_tensor(prefill_slot_mapping, device),
  516. maybe_make_long_tensor(decode_slot_mapping, device))
  517. def make_block_tables_slot_mapping(
  518. block_size: int,
  519. seq_lens: List[int],
  520. device: Union[torch.device, str],
  521. block_base_addr: int = 0) -> Tuple[torch.Tensor, List[int], int]:
  522. '''
  523. Construct fake block tables & slot mappings.
  524. For a sequence with num_tokens tokens the minimum number
  525. of required KV cache blocks is
  526. num_blocks = (num_tokens + block_size) // block_size
  527. Then the minimum KV cache size in blocks is
  528. total_cache_blocks = sum(num_blocks for all seqs)
  529. Then, the blocktable mapping counts downward from
  530. block_base_addr + total_cache_blocks
  531. to
  532. block_base_addr
  533. The constructed block-tables and slot-mapping are sized to the
  534. lengths of the sequences in their entirety (as reflected by seq_lens),
  535. i.e. the total of prefill prompt tokens + decoded tokens.
  536. Arguments:
  537. * block_size: number of offsets per block
  538. * seq_lens: list of token-counts for each sequence
  539. * block_base_addr: the block table base address
  540. * device: CPU or CUDA device
  541. Return:
  542. * block_tables_tensor: block table for sequence
  543. * slot_mapping_list: slot mapping for sequence
  544. * max_block_idx: the highest block address within this block table
  545. '''
  546. # Provision minimum number of KV cache blocks
  547. num_blocks_list = [
  548. _num_tokens_to_min_blocks(num_tokens, block_size)
  549. for num_tokens in seq_lens
  550. ]
  551. max_block_table_len = max(num_blocks_list)
  552. block_table_pad_tokens = 10
  553. block_tables = []
  554. slot_mapping_list = []
  555. # Compute uppermost address of block table
  556. total_cache_blocks = sum(num_blocks_list)
  557. block_base_idx = block_base_addr + total_cache_blocks
  558. max_block_idx = block_base_idx
  559. for sdx, num_tokens in enumerate(seq_lens):
  560. num_blocks = num_blocks_list[sdx]
  561. block_table = list(
  562. range(block_base_idx, block_base_idx - num_blocks, -1))
  563. for idx in range(num_tokens):
  564. mapping_value = (
  565. idx % block_size) + block_table[idx // block_size] * block_size
  566. slot_mapping_list.append(mapping_value)
  567. block_base_idx -= num_blocks
  568. block_tables.append(block_table)
  569. block_tables_tensor = make_tensor_with_pad(
  570. block_tables,
  571. max_len=max_block_table_len + block_table_pad_tokens,
  572. pad=0,
  573. dtype=torch.int,
  574. device=device,
  575. )
  576. return (block_tables_tensor, slot_mapping_list, max_block_idx)
  577. def make_test_metadata(
  578. attn_backend: AttentionBackend,
  579. is_prompt: bool,
  580. seq_lens: Optional[List[int]],
  581. decoder_test_params: Optional[PhaseTestParameters],
  582. device: Union[torch.device, str],
  583. encoder_test_params: Optional[PhaseTestParameters] = None,
  584. cross_test_params: Optional[PhaseTestParameters] = None
  585. ) -> AttentionMetadata:
  586. '''
  587. Construct fake attention metadata for a given test phase
  588. (prefill-phase or decode-phase).
  589. encoder_test_params and cross_test_params arguments allow encoder
  590. attention and enc/dec cross-attention (respectively) to use distinct
  591. metadata values from decoder self-attention (decoder_test_params.)
  592. if encoder_test_params and cross_test_params are None, the attention
  593. metadata will support decoder-only scenario.
  594. Assumptions:
  595. * No chunked prefill -> a batch is 100% prefill or 100% decode, never both
  596. Arguments:
  597. * attn_backend: Backend for sourcing attention kernels
  598. * is_prompt: prefill if True, o/w decode
  599. * seq_lens: list of token counts for each sequence
  600. * decoder_test_params: decoder self-attention test params;
  601. this function requires
  602. kv_mmap (memory mapping) field
  603. * device: CPU or CUDA device
  604. * encoder_test_params: encoder attention test params;
  605. this function requires encoder query
  606. sequence lengths field. If None,
  607. encoder query sequence lengths are
  608. treated as None
  609. * cross_test_params: enc/dec cross-attention test params;
  610. this function requires kv_mmap field.
  611. If None, KV cache memory map data
  612. structures are treated as None
  613. Return:
  614. * AttentionMetadata structure
  615. '''
  616. # Decoder self-attention memory mapping
  617. # decoder_test_params is None signals encoder-only
  618. # scenario, so kv_mmap is None
  619. kv_mmap = (None
  620. if decoder_test_params is None else decoder_test_params.kv_mmap)
  621. # This function constructs metadata assuming no chunked prefill,
  622. # i.e. 100% prefill tokens or 100% decode tokens
  623. #
  624. # - If is_prompt, num_prefills_or_decodes is the number of prefills
  625. # and num_prefill_or_decode_tokens is the number of prefill tokens
  626. # - If not is_prompt, num_prefills_or_decodes is the number of decodes
  627. # and num_prefill_or_decode_tokens is the number of decode tokens
  628. #
  629. # seq_lens is None signals encoder-only
  630. # scenario, in which case num_prefills_or_decodes and
  631. # num_prefill_or_decode_tokens are unused
  632. num_prefills_or_decodes = (None if seq_lens is None else len(seq_lens))
  633. num_prefill_or_decode_tokens = (None if seq_lens is None else (
  634. sum(seq_lens) if is_prompt else len(seq_lens)))
  635. # Seems for non-prefix-caching scenarios context_lens
  636. # is never needed
  637. context_lens = None
  638. if encoder_test_params is None:
  639. encoder_seq_lens = None
  640. num_encoder_tokens = None
  641. else:
  642. # Encoder/decoder or encoder-only models only:
  643. # * Extract encoder input sequence lengths
  644. assert encoder_test_params.packed_qkvo.packed_qkv is not None
  645. encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens
  646. num_encoder_tokens = (None if encoder_seq_lens is None else
  647. (sum(encoder_seq_lens)))
  648. if cross_test_params is None:
  649. cross_kv_mmap = None
  650. else:
  651. # Encoder/decoder or encoder-only models only:
  652. # * Extract *cross-attention* slot_mapping and block table
  653. # (kv_mmap)
  654. cross_kv_mmap = cross_test_params.kv_mmap
  655. if is_prompt:
  656. # Prefill-phase scenario
  657. num_prefills = num_prefills_or_decodes
  658. num_prefill_tokens = num_prefill_or_decode_tokens
  659. num_decode_tokens = 0
  660. (
  661. seq_lens_tensor,
  662. context_lens_tensor,
  663. _,
  664. _,
  665. _,
  666. encoder_seq_lens_tensor,
  667. max_encoder_seq_len,
  668. ) = _make_metadata_tensors(seq_lens,
  669. context_lens,
  670. encoder_seq_lens,
  671. device=device)
  672. return attn_backend.make_metadata(
  673. num_prefills=num_prefills,
  674. slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
  675. num_prefill_tokens=num_prefill_tokens,
  676. num_decode_tokens=num_decode_tokens,
  677. seq_lens=seq_lens,
  678. seq_lens_tensor=seq_lens_tensor,
  679. max_prefill_seq_len=None if seq_lens is None else max(seq_lens),
  680. max_decode_seq_len=0,
  681. context_lens_tensor=context_lens_tensor,
  682. block_tables=(None if kv_mmap is None else kv_mmap.block_tables),
  683. use_cuda_graph=False,
  684. num_encoder_tokens=num_encoder_tokens,
  685. encoder_seq_lens=encoder_seq_lens,
  686. encoder_seq_lens_tensor=encoder_seq_lens_tensor,
  687. max_encoder_seq_len=max_encoder_seq_len,
  688. cross_slot_mapping=(None if cross_kv_mmap is None else
  689. cross_kv_mmap.slot_mapping),
  690. cross_block_tables=(None if cross_kv_mmap is None else
  691. cross_kv_mmap.block_tables))
  692. else: # not is_prompt
  693. # Decode-phase scenario
  694. assert kv_mmap is not None
  695. assert num_prefill_or_decode_tokens is not None
  696. assert seq_lens is not None
  697. num_prefills = 0
  698. num_prefill_tokens = 0
  699. num_decode_tokens = num_prefill_or_decode_tokens
  700. (
  701. seq_lens_tensor,
  702. context_lens_tensor,
  703. _,
  704. _,
  705. _,
  706. encoder_seq_lens_tensor,
  707. max_encoder_seq_len,
  708. ) = _make_metadata_tensors(seq_lens,
  709. context_lens,
  710. encoder_seq_lens,
  711. device=device)
  712. return attn_backend.make_metadata(
  713. num_prefills=num_prefills,
  714. slot_mapping=kv_mmap.slot_mapping,
  715. num_prefill_tokens=num_prefill_tokens,
  716. num_decode_tokens=num_decode_tokens,
  717. seq_lens=seq_lens,
  718. seq_lens_tensor=seq_lens_tensor,
  719. max_prefill_seq_len=0,
  720. max_decode_seq_len=max(seq_lens),
  721. context_lens_tensor=context_lens_tensor,
  722. block_tables=kv_mmap.block_tables,
  723. use_cuda_graph=False,
  724. num_encoder_tokens=num_encoder_tokens,
  725. encoder_seq_lens=encoder_seq_lens,
  726. encoder_seq_lens_tensor=encoder_seq_lens_tensor,
  727. max_encoder_seq_len=max_encoder_seq_len,
  728. cross_slot_mapping=(None if cross_kv_mmap is None else
  729. cross_kv_mmap.slot_mapping),
  730. cross_block_tables=(None if cross_kv_mmap is None else
  731. cross_kv_mmap.block_tables))
  732. def assert_actual_matches_ideal(test_params: PhaseTestParameters,
  733. output_under_test: torch.Tensor) -> None:
  734. '''
  735. Assert that observed output matches the ideal output
  736. contained in the test parameters data structure.
  737. Arguments:
  738. * test_params: Test parameters including packed ideal output
  739. * output_under_test: actually observed output value
  740. '''
  741. ideal_output = test_params.packed_qkvo.ideal_output
  742. torch.testing.assert_close(ideal_output,
  743. output_under_test.view_as(ideal_output))
  744. def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
  745. torch._library.custom_ops.CustomOpDef],
  746. args: Tuple[Any, ...],
  747. kwargs: Optional[Dict[str, Any]] = None,
  748. *,
  749. test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
  750. raise_exception: bool = True,
  751. cond: bool = True) -> Dict[str, str]:
  752. return torch.library.opcheck(
  753. op,
  754. args,
  755. kwargs,
  756. test_utils=test_utils,
  757. raise_exception=raise_exception) if cond else {}