utils.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963
  1. """Kernel test utils"""
  2. import itertools
  3. import random
  4. from numbers import Number
  5. from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple,
  6. Union)
  7. import pytest
  8. import torch
  9. from aphrodite.attention import (AttentionBackend, AttentionMetadata,
  10. AttentionType)
  11. from aphrodite.common.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
  12. make_tensor_with_pad)
  13. # For now, disable "test_aot_dispatch_dynamic" since there are some
  14. # bugs related to this test in PyTorch 2.4.
  15. DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
  16. "test_schema",
  17. "test_autograd_registration",
  18. "test_faketensor",
  19. )
  20. ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
  21. "test_schema",
  22. "test_autograd_registration",
  23. "test_faketensor",
  24. "test_aot_dispatch_dynamic",
  25. )
  26. class QKVInputs(NamedTuple):
  27. '''
  28. Data structure for representing unpacked attention inputs,
  29. query/key/values and their sequence lengths.
  30. Attributes:
  31. * {query,key,value}: unpacked (batch_size x padded_seq_len x
  32. num_heads x head_size) attention inputs
  33. * q_seq_lens: query sequence lengths list
  34. * kv_seq_lens: shared key/value sequence lengths list
  35. '''
  36. query: torch.Tensor
  37. key: torch.Tensor
  38. value: torch.Tensor
  39. q_seq_lens: List[int]
  40. kv_seq_lens: List[int]
  41. class QKVO(NamedTuple):
  42. '''
  43. Data structure for representing unpacked attention inputs,
  44. alongside unpacked known-correct attention output
  45. Attributes:
  46. * qkv: unpacked (batch_size x padded_seq_len x
  47. num_heads x head_size) attention inputs
  48. * ideal_output: unpacked (batch_size x padded_seq_len x
  49. num_heads x head_size) known-correct attention output
  50. '''
  51. qkv: QKVInputs
  52. ideal_output: torch.Tensor
  53. class PackedQKVInputs(NamedTuple):
  54. '''
  55. Data structure for representing packed attention inputs
  56. Attributes:
  57. * {query,key,value}: packed (number_of_tokens x num_heads
  58. x head_size) attention inputs
  59. * q_start_loc_list: list of query start locations within packed tensor
  60. * kv_start_loc_list: shared list of key/value start locations within
  61. packed tensor
  62. * q_seq_lens: query sequence lengths list
  63. * kv_seq_lens: shared key/value sequence lengths list
  64. '''
  65. query: torch.Tensor
  66. key: torch.Tensor
  67. value: torch.Tensor
  68. q_start_loc_list: Optional[List[int]]
  69. kv_start_loc_list: Optional[List[int]]
  70. q_seq_lens: Optional[List[int]]
  71. kv_seq_lens: Optional[List[int]]
  72. class PackedQKVO(NamedTuple):
  73. '''
  74. Data structure for representing packed attention inputs,
  75. alongside packed known-correct attention output
  76. Attributes:
  77. * packed_qkv: packed (number_of_tokens x num_heads
  78. x head_size) attention inputs
  79. * ideal_output: packed (number_of_tokens x num_heads
  80. x head_size) known-correct attention output
  81. '''
  82. packed_qkv: Optional[PackedQKVInputs]
  83. ideal_output: torch.Tensor
  84. class KVMemoryMap(NamedTuple):
  85. '''
  86. Data structure for encapsulating KV cache memory mapping.
  87. Attributes:
  88. * block_tables: KV cache block tables
  89. * slot_mapping: mapping of sequence offset to physical address
  90. '''
  91. block_tables: torch.Tensor
  92. slot_mapping: torch.Tensor
  93. class PhaseTestParameters(NamedTuple):
  94. '''
  95. Data structure for encapsulating the test parameters
  96. for a given test "phase" (prefill or decode phase) and attention
  97. scenario (encoder, decoder-self, encoder/decoder-cross)
  98. Attributes:
  99. * packed_qkvo: packed (number_of_tokens x num_heads
  100. x head_size) attention inputs & known-correct
  101. output
  102. * kv_mmap: KV cache memory mapping, specific to this test phase &
  103. attention scenario
  104. '''
  105. packed_qkvo: PackedQKVO
  106. kv_mmap: Optional[KVMemoryMap]
  107. def maybe_make_int_tensor(
  108. _list: Optional[List[int]],
  109. device: Union[torch.device, str],
  110. ) -> torch.Tensor:
  111. '''
  112. Convert Python int list to a 1D int torch.Tensor on `device`
  113. Returns:
  114. * If _list is not None: 1D int torch.Tensor on `device`
  115. * None otherwise
  116. '''
  117. return None if _list is None else torch.tensor(
  118. _list, dtype=torch.int, device=device)
  119. def maybe_make_long_tensor(
  120. _list: Optional[List[int]],
  121. device: Union[torch.device, str],
  122. ) -> torch.Tensor:
  123. '''
  124. Convert Python int list to a 1D long torch.Tensor on `device`
  125. Returns:
  126. * If _list is not None: 1D long torch.Tensor on `device`
  127. * None otherwise
  128. '''
  129. return None if _list is None else torch.tensor(
  130. _list, dtype=torch.long, device=device)
  131. def maybe_max(_list: Optional[List]) -> Optional[Number]:
  132. '''
  133. Returns:
  134. * If _list is not None: max(_list)
  135. * None otherwise
  136. '''
  137. return None if _list is None else max(_list)
  138. def make_causal_mask(
  139. q_max_seq_len: int,
  140. kv_max_seq_len: int,
  141. ) -> torch.Tensor:
  142. '''
  143. Create a q_max_seq_len x kv_max_seq_len causal mask
  144. Arguments:
  145. * q_max_seq_len: query max seq len
  146. * kv_max_seq_len: key/value max seq len
  147. Returns:
  148. * 2D tensor, q_max_seq_len x kv_max_seq_len
  149. '''
  150. # Create a matrix where entry (i, j) is True if i >= j
  151. mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), diagonal=1)
  152. # Replace True with float('-inf') and False with 0
  153. mask = mask.masked_fill(mask == 1,
  154. float('-inf')).masked_fill(mask == 0, 0.0)
  155. return mask
  156. def override_backend_env_variable(mpatch: pytest.MonkeyPatch,
  157. backend_name: str) -> None:
  158. '''
  159. Override the environment variable indicating the Aphrodite backend
  160. temporarily, using pytest monkeypatch to ensure that the env vars get
  161. reset once the test context exits.
  162. Arguments:
  163. * mpatch: pytest monkeypatch instance
  164. * backend_name: attention backend name to force
  165. '''
  166. mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name)
  167. def ref_masked_attention(query: torch.Tensor,
  168. key: torch.Tensor,
  169. value: torch.Tensor,
  170. scale: float,
  171. custom_mask: Optional[torch.Tensor] = None,
  172. q_seq_lens: Optional[List] = None,
  173. kv_seq_lens: Optional[List] = None) -> torch.Tensor:
  174. '''
  175. "Golden" masked attention reference. Supports two types of masking:
  176. * Basic attention mask, utilizing {q,kv}_seq_lens args to mask out
  177. padding elements
  178. * Custom attention mask, which can force an arbitrary mask tensor, i.e.
  179. causal
  180. Arguments:
  181. * query: batch_size x q_padded_seq_len x num_heads x head_size
  182. * key: batch_size x kv_padded_seq_len x num_heads x head_size
  183. * value: batch_size x kv_padded_seq_len x num_heads x head_size
  184. * scale: Attention scale factor
  185. * custom_mask: custom attention mask; good place to inject a causal
  186. attention mask
  187. * q_seq_lens: list of unpadded query seq_lens for each batch index
  188. * kv_seq_lens: list of unpadded key/value seq_lens for each batch index
  189. Returns:
  190. * Attention result, batch_size x q_padded_seq_len x num_heads x head_size
  191. '''
  192. assert q_seq_lens is not None
  193. assert kv_seq_lens is not None
  194. batch_size = query.shape[0]
  195. assert (len(q_seq_lens) == batch_size)
  196. assert (len(kv_seq_lens) == batch_size)
  197. attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float()
  198. # Basic attention mask, derived from seq lens
  199. if (q_seq_lens is not None) or (kv_seq_lens is not None):
  200. attn_mask = torch.zeros_like(attn_weights)
  201. if q_seq_lens is not None:
  202. for bdx, plen in enumerate(q_seq_lens):
  203. attn_mask[bdx, :, plen:, :] = -torch.inf
  204. if kv_seq_lens is not None:
  205. for bdx, plen in enumerate(kv_seq_lens):
  206. attn_mask[bdx, :, :, plen:] = -torch.inf
  207. attn_weights = attn_weights + attn_mask.float()
  208. # Custom attention mask
  209. if custom_mask is not None:
  210. attn_weights = attn_weights + custom_mask.float()
  211. attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
  212. out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value)
  213. return out
  214. def make_qkv(
  215. batch_size: int,
  216. max_q_seq_len: int,
  217. max_kv_seq_len: Optional[int],
  218. num_heads: int,
  219. head_size: int,
  220. device: Union[torch.device, str],
  221. force_kv_seq_lens: Optional[List[int]] = None,
  222. attn_type: AttentionType = AttentionType.ENCODER_DECODER,
  223. force_max_len: bool = False,
  224. ) -> Tuple[QKVInputs, QKVInputs, QKVInputs]:
  225. '''
  226. Construct QKV test tensors for self- and cross-attention.
  227. Generates three query/key/value triplets:
  228. * "Baseline" query/key/value (for input to reference attention function)
  229. * "Prefill" query/key/value (last sequence offset zero'd out, for use as
  230. input to prefill kernel)
  231. * "Decode" query/key/value (only the last sequence offset from baseline,
  232. for use as input to decode kernel)
  233. Each Q/K/V triplet is associated with a list of q seqlens and a list of k/v
  234. seqlens
  235. Arguments:
  236. * batch_size
  237. * max_q_seq_len: max query seq len
  238. * max_kv_seq_len: max key/value seq len
  239. * num_heads
  240. * head_size
  241. * is_encoder_decoder_attn: if True, query seqlen may differ from
  242. key/value seqlen (as is often the case for cross-attention);
  243. o/w, query/key/value seqlens match at each batch index
  244. (max_kv_seq_len is unused)
  245. * force_kv_seq_lens: if not None, overrides kv sequence lengths
  246. * attn_type: encoder, decoder self, or enc/dec cross attention
  247. * force_max_len: if True, all query seqlens are max_q_seq_len; o/w query
  248. seqlens are random in [2,max_q_seq_lens]. Same for key/value seqlens
  249. and max_kv_seq_len, unless forced by is_encoder_decoder_attn=False
  250. * device: CPU or CUDA device
  251. Returns:
  252. * Overall QKVInputs structure (containing full unpacked Q/K/V tensors)
  253. * Prefill QKVInputs structure (containing all but the last sequence offset)
  254. * Decode QKVInputs structure (containing all only the last sequence offset)
  255. '''
  256. if force_max_len:
  257. q_seq_lens = [max_q_seq_len for _ in range(batch_size)]
  258. else:
  259. q_seq_lens = [
  260. random.randint(2, max_q_seq_len) for _ in range(batch_size)
  261. ]
  262. kv_seq_lens = None
  263. if force_kv_seq_lens is not None:
  264. kv_seq_lens = force_kv_seq_lens
  265. elif attn_type != AttentionType.ENCODER_DECODER:
  266. # K,V seq lens match Q for self-attention
  267. kv_seq_lens = q_seq_lens
  268. else:
  269. # K,V seq lens are distinct from Q seq lens & random
  270. assert max_kv_seq_len is not None
  271. if force_max_len:
  272. kv_seq_lens = [max_kv_seq_len] * batch_size
  273. else:
  274. kv_seq_lens = [
  275. random.randint(2, max_kv_seq_len) for _ in range(batch_size)
  276. ]
  277. query = torch.rand(
  278. (batch_size, max_q_seq_len, num_heads, head_size)).to(device)
  279. key = torch.rand(
  280. (batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
  281. value = torch.rand(
  282. (batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
  283. prefill_query = torch.zeros(
  284. (batch_size, max_q_seq_len, num_heads, head_size)).to(device)
  285. prefill_key = torch.zeros(
  286. (batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
  287. prefill_value = torch.zeros(
  288. (batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
  289. decode_query = torch.zeros(
  290. (batch_size, 1, num_heads, head_size)).to(device)
  291. decode_key = torch.zeros((batch_size, 1, num_heads, head_size)).to(device)
  292. decode_value = torch.zeros(
  293. (batch_size, 1, num_heads, head_size)).to(device)
  294. for bdx, (q_seq_len, kv_seq_len) in enumerate(zip(q_seq_lens,
  295. kv_seq_lens)):
  296. query[bdx, q_seq_len:, :, :] = 0
  297. key[bdx, kv_seq_len:, :, :] = 0
  298. value[bdx, kv_seq_len:, :, :] = 0
  299. prefill_query[bdx,
  300. 0:(q_seq_len - 1), :, :] = query[bdx,
  301. 0:(q_seq_len - 1), :, :]
  302. prefill_key[bdx,
  303. 0:(kv_seq_len - 1), :, :] = key[bdx,
  304. 0:(kv_seq_len - 1), :, :]
  305. prefill_value[bdx, 0:(kv_seq_len -
  306. 1), :, :] = value[bdx, 0:(kv_seq_len - 1), :, :]
  307. decode_query[bdx, :, :, :] = query[bdx,
  308. (q_seq_len - 1):q_seq_len, :, :]
  309. decode_key[bdx, :, :, :] = key[bdx, (kv_seq_len - 1):kv_seq_len, :, :]
  310. decode_value[bdx, :, :, :] = value[bdx,
  311. (kv_seq_len - 1):kv_seq_len, :, :]
  312. prefill_q_seq_lens = [plen - 1 for plen in q_seq_lens]
  313. prefill_kv_seq_lens = [plen - 1 for plen in kv_seq_lens]
  314. decode_q_seq_lens = [1 for _ in q_seq_lens]
  315. decode_kv_seq_lens = [1 for _ in kv_seq_lens]
  316. return (
  317. QKVInputs(
  318. query, # Overall QKV inputs
  319. key,
  320. value,
  321. q_seq_lens,
  322. kv_seq_lens),
  323. QKVInputs(
  324. prefill_query, # Prefill subset of QKV sequences
  325. prefill_key,
  326. prefill_value,
  327. prefill_q_seq_lens,
  328. prefill_kv_seq_lens),
  329. QKVInputs(
  330. decode_query, # Decode subset of KV sequences
  331. decode_key,
  332. decode_value,
  333. decode_q_seq_lens,
  334. decode_kv_seq_lens))
  335. def pack_tensor(
  336. unpacked_tensor: torch.Tensor, seq_lens: List[int],
  337. device: Union[torch.device, str]) -> Tuple[torch.Tensor, List[int]]:
  338. '''
  339. Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an
  340. unpadded number_of_tokens x num_heads x head_size tensor, where
  341. number_of_tokens = sum(seq_lens)
  342. Arguments:
  343. * unpacked_tensor: batch_size x padded_seq_len x num_heads x head_size
  344. * seq_lens: list of token counts for each seq
  345. * device: CPU or CUDA device
  346. Returns
  347. * packed_tensor: number_of_tokens x num_heads x head_size
  348. * start_loc_list: start idx of each batch elt in packed_tensor; [0] +
  349. list(itertools.accumulate(seq_lens))
  350. '''
  351. num_tok = sum(seq_lens)
  352. num_heads = unpacked_tensor.shape[-2]
  353. head_size = unpacked_tensor.shape[-1]
  354. start_loc_list = [0] + list(itertools.accumulate(seq_lens))
  355. packed_tensor = torch.zeros((num_tok, num_heads, head_size), device=device)
  356. for bdx, (seq_len, start_loc) in enumerate(zip(seq_lens, start_loc_list)):
  357. packed_tensor[start_loc:(
  358. start_loc + seq_len), :, :] = unpacked_tensor[bdx, :seq_len, :, :]
  359. return packed_tensor, start_loc_list
  360. def pack_qkv(qkv: QKVInputs, device: Union[torch.device,
  361. str]) -> PackedQKVInputs:
  362. '''
  363. Individually pack each of Q, K and V, each with dimensions batch_size x
  364. padded_seq_len x num_heads x head_size, into respective number_of_tokens x
  365. num_heads x head_size tensors.
  366. For Q, number_of_tokens = sum(q_seq_lens).
  367. For K and V, number_of_tokens = sum(kv_seq_lens)
  368. Arguments:
  369. * qkv: Unpacked (batch_size x padded_seq_len x num_heads x head_size)
  370. attention inputs
  371. * device: CPU or CUDA device
  372. Returns
  373. * Packed (number_of_tokens x num_heads x head_size) QKV inputs
  374. derived from unpacked inputs
  375. '''
  376. if qkv.query is None:
  377. packed_query = None
  378. q_start_loc_list = None
  379. else:
  380. packed_query, q_start_loc_list = pack_tensor(qkv.query,
  381. qkv.q_seq_lens,
  382. device=device)
  383. packed_key, kv_start_loc_list = pack_tensor(qkv.key,
  384. qkv.kv_seq_lens,
  385. device=device)
  386. packed_value, _ = pack_tensor(qkv.value, qkv.kv_seq_lens, device=device)
  387. return PackedQKVInputs(
  388. packed_query, packed_key, packed_value, q_start_loc_list,
  389. kv_start_loc_list,
  390. (None if q_start_loc_list is None else qkv.q_seq_lens),
  391. qkv.kv_seq_lens)
  392. def make_backend(backend_name: str) -> AttentionBackend:
  393. '''
  394. Construct the backend instance determined by the backend_name string
  395. argument.
  396. "XFORMERS" -> construct xformers backend
  397. TODO: other backends
  398. Note: at time of writing the Attention wrapper automatically selects
  399. its own backend for Attention.forward(); so the backend instance which
  400. you generate with this function is not meant to be used for *running*
  401. inference, but rather for generating compatible metadata structures
  402. using backend.make_metadata()
  403. Returns:
  404. * Backend instance
  405. '''
  406. if backend_name == STR_XFORMERS_ATTN_VAL:
  407. # NOTE: xFormers backend cannot be imported for CPU and AMD GPUs.
  408. from aphrodite.attention.backends.xformers import XFormersBackend
  409. return XFormersBackend()
  410. raise AssertionError(
  411. f"Unrecognized backend_name {backend_name} for unit test")
  412. def _make_metadata_tensors(
  413. seq_lens: Optional[List[int]], context_lens: Optional[List[int]],
  414. encoder_seq_lens: Optional[List[int]], device: Union[torch.device, str]
  415. ) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[List[int]],
  416. torch.Tensor, Optional[int]]:
  417. '''
  418. Build scalar & tensor values required to build attention metadata structure.
  419. Arguments:
  420. * seq_lens: list of token-counts for each decoder input seq
  421. * context_lens: list of context length values for each seq
  422. * encoder_seq_lens: list of token-counts for each encoder input seq
  423. * device: CPU or CUDA device
  424. Returns:
  425. * seq_lens_tensor: decoder seq_lens list, as tensor
  426. * context_lens_tensor: context_lens list, as tensor
  427. * max_context_len: max(context_lens)
  428. * max_seq_len: max(seq_lens)
  429. * seq_start_loc: start idx of each sequence
  430. * max_encoder_seq_len: encoder seq_lens list, as tensor
  431. '''
  432. seq_lens_tensor = maybe_make_int_tensor(seq_lens, device)
  433. context_lens_tensor = maybe_make_int_tensor(context_lens, device)
  434. max_context_len = maybe_max(context_lens)
  435. max_seq_len = maybe_max(seq_lens)
  436. encoder_seq_lens_tensor = maybe_make_int_tensor(encoder_seq_lens, device)
  437. max_encoder_seq_len = (None if encoder_seq_lens is None else
  438. max(encoder_seq_lens))
  439. seq_start_loc = None
  440. return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len,
  441. seq_start_loc, encoder_seq_lens_tensor, max_encoder_seq_len)
  442. def make_kv_cache(num_blocks: int,
  443. num_heads: int,
  444. head_size: int,
  445. block_size: int,
  446. device: Union[torch.device, str],
  447. default_val: float = 0.0) -> torch.Tensor:
  448. '''
  449. Create a fake KV cache.
  450. Arguments:
  451. * num_blocks: number of blocks in the KV cache
  452. * num_heads: number of attention heads
  453. * head_size: head dimension
  454. * block_size: number of offsets within a block
  455. * device: CPU or CUDA device
  456. * default_val: initialization value for KV cache elements
  457. Returns:
  458. * kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
  459. '''
  460. kv_cache = torch.rand(
  461. (2, num_blocks, block_size * num_heads * head_size)).to(device)
  462. if default_val is not None:
  463. kv_cache[:, :, :] = default_val
  464. return kv_cache
  465. def _num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int:
  466. '''
  467. Compute the minimum number of blocks required to hold num_tokens tokens,
  468. given block_size
  469. '''
  470. return (num_tokens + block_size) // block_size
  471. def make_empty_slot_mapping_tensor(device: Union[torch.device, str]):
  472. return maybe_make_long_tensor([], device)
  473. def make_empty_block_tables_tensor(device: Union[torch.device, str]):
  474. return torch.tensor([], device=device)
  475. def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int],
  476. device: Union[torch.device, str]):
  477. '''
  478. Split a slot mapping into valid prefill- and decode-phase slot mappings.
  479. Context:
  480. * Your goal is to test (1) prefill of N prompts, with prompt-lengths
  481. {K_i \\forall i \\in [0,N)}, followed by (2) decoding of a single token
  482. for all N prompts (N tokens total); the resultant sequence lengths
  483. after decode would be {K_i + 1 for i \\in [0,N)}
  484. * The test you want to do requires (1) having the prefill slot mapping
  485. for all tokens present during prefill, the number of which is
  486. M = \\sum_i{K_i}, and (2) having the decode slot mapping for all N
  487. decoded tokens
  488. This function consumes a single 1D slot mapping, which is the
  489. concatenation of N slot mappings each of length K_i + 1 (corresponding
  490. to the sequence lengths after decode), with a total length of
  491. P = \\sum_i{K_i + 1} = M + N
  492. The prefill-phase slot mapping results from excising the (K_i + 1)-th entry
  493. from each of the N subsequences in the slot mapping (i.e. omitting the
  494. decoded token's mapping.)
  495. The N excised entries are appended to obtain the decode-phase slot mapping
  496. Arguments:
  497. * slot_mapping_list: Length-P 1D slot mapping (as List) reflecting all N
  498. post-decode sequences
  499. * seq_lens: List of N post-decode sequence lengths (K_i + 1 in the
  500. description above)
  501. * device: cuda, cpu, etc.
  502. Returns:
  503. * prefill_slot_mapping: Length-M 1D slot mapping (as Tensor)
  504. reflecting all N prefill prompts
  505. * decode_slot_mapping: Length-N 1D slot mapping (as Tensor) reflecting
  506. all N decoded tokens
  507. '''
  508. prefill_slot_mapping = []
  509. decode_slot_mapping = []
  510. base_idx = 0
  511. for seq_len in seq_lens:
  512. prefill_slot_mapping.extend(slot_mapping_list[base_idx:(base_idx +
  513. seq_len - 1)])
  514. decode_slot_mapping.append(slot_mapping_list[base_idx + seq_len - 1])
  515. base_idx += seq_len
  516. return (maybe_make_long_tensor(prefill_slot_mapping, device),
  517. maybe_make_long_tensor(decode_slot_mapping, device))
  518. def make_block_tables_slot_mapping(
  519. block_size: int,
  520. seq_lens: List[int],
  521. device: Union[torch.device, str],
  522. block_base_addr: int = 0) -> Tuple[torch.Tensor, List[int], int]:
  523. '''
  524. Construct fake block tables & slot mappings.
  525. For a sequence with num_tokens tokens the minimum number
  526. of required KV cache blocks is
  527. num_blocks = (num_tokens + block_size) // block_size
  528. Then the minimum KV cache size in blocks is
  529. total_cache_blocks = sum(num_blocks for all seqs)
  530. Then, the blocktable mapping counts downward from
  531. block_base_addr + total_cache_blocks
  532. to
  533. block_base_addr
  534. The constructed block-tables and slot-mapping are sized to the
  535. lengths of the sequences in their entirety (as reflected by seq_lens),
  536. i.e. the total of prefill prompt tokens + decoded tokens.
  537. Arguments:
  538. * block_size: number of offsets per block
  539. * seq_lens: list of token-counts for each sequence
  540. * block_base_addr: the block table base address
  541. * device: CPU or CUDA device
  542. Return:
  543. * block_tables_tensor: block table for sequence
  544. * slot_mapping_list: slot mapping for sequence
  545. * max_block_idx: the highest block address within this block table
  546. '''
  547. # Provision minimum number of KV cache blocks
  548. num_blocks_list = [
  549. _num_tokens_to_min_blocks(num_tokens, block_size)
  550. for num_tokens in seq_lens
  551. ]
  552. max_block_table_len = max(num_blocks_list)
  553. block_table_pad_tokens = 10
  554. block_tables = []
  555. slot_mapping_list = []
  556. # Compute uppermost address of block table
  557. total_cache_blocks = sum(num_blocks_list)
  558. block_base_idx = block_base_addr + total_cache_blocks
  559. max_block_idx = block_base_idx
  560. for sdx, num_tokens in enumerate(seq_lens):
  561. num_blocks = num_blocks_list[sdx]
  562. block_table = list(
  563. range(block_base_idx, block_base_idx - num_blocks, -1))
  564. for idx in range(num_tokens):
  565. mapping_value = (
  566. idx % block_size) + block_table[idx // block_size] * block_size
  567. slot_mapping_list.append(mapping_value)
  568. block_base_idx -= num_blocks
  569. block_tables.append(block_table)
  570. block_tables_tensor = make_tensor_with_pad(
  571. block_tables,
  572. max_len=max_block_table_len + block_table_pad_tokens,
  573. pad=0,
  574. dtype=torch.int,
  575. device=device,
  576. )
  577. return (block_tables_tensor, slot_mapping_list, max_block_idx)
  578. def make_test_metadata(
  579. attn_backend: AttentionBackend,
  580. is_prompt: bool,
  581. seq_lens: Optional[List[int]],
  582. decoder_test_params: Optional[PhaseTestParameters],
  583. device: Union[torch.device, str],
  584. encoder_test_params: Optional[PhaseTestParameters] = None,
  585. cross_test_params: Optional[PhaseTestParameters] = None
  586. ) -> AttentionMetadata:
  587. '''
  588. Construct fake attention metadata for a given test phase
  589. (prefill-phase or decode-phase).
  590. encoder_test_params and cross_test_params arguments allow encoder
  591. attention and enc/dec cross-attention (respectively) to use distinct
  592. metadata values from decoder self-attention (decoder_test_params.)
  593. if encoder_test_params and cross_test_params are None, the attention
  594. metadata will support decoder-only scenario.
  595. Assumptions:
  596. * No chunked prefill -> a batch is 100% prefill or 100% decode, never both
  597. Arguments:
  598. * attn_backend: Backend for sourcing attention kernels
  599. * is_prompt: prefill if True, o/w decode
  600. * seq_lens: list of token counts for each sequence
  601. * decoder_test_params: decoder self-attention test params;
  602. this function requires
  603. kv_mmap (memory mapping) field
  604. * device: CPU or CUDA device
  605. * encoder_test_params: encoder attention test params;
  606. this function requires encoder query
  607. sequence lengths field. If None,
  608. encoder query sequence lengths are
  609. treated as None
  610. * cross_test_params: enc/dec cross-attention test params;
  611. this function requires kv_mmap field.
  612. If None, KV cache memory map data
  613. structures are treated as None
  614. Return:
  615. * AttentionMetadata structure
  616. '''
  617. # Decoder self-attention memory mapping
  618. # decoder_test_params is None signals encoder-only
  619. # scenario, so kv_mmap is None
  620. kv_mmap = (None
  621. if decoder_test_params is None else decoder_test_params.kv_mmap)
  622. # This function constructs metadata assuming no chunked prefill,
  623. # i.e. 100% prefill tokens or 100% decode tokens
  624. #
  625. # - If is_prompt, num_prefills_or_decodes is the number of prefills
  626. # and num_prefill_or_decode_tokens is the number of prefill tokens
  627. # - If not is_prompt, num_prefills_or_decodes is the number of decodes
  628. # and num_prefill_or_decode_tokens is the number of decode tokens
  629. #
  630. # seq_lens is None signals encoder-only
  631. # scenario, in which case num_prefills_or_decodes and
  632. # num_prefill_or_decode_tokens are unused
  633. num_prefills_or_decodes = (None if seq_lens is None else len(seq_lens))
  634. num_prefill_or_decode_tokens = (None if seq_lens is None else (
  635. sum(seq_lens) if is_prompt else len(seq_lens)))
  636. # Seems for non-prefix-caching scenarios context_lens
  637. # is never needed
  638. context_lens = None
  639. if encoder_test_params is None:
  640. encoder_seq_lens = None
  641. num_encoder_tokens = None
  642. else:
  643. # Encoder/decoder or encoder-only models only:
  644. # * Extract encoder input sequence lengths
  645. assert encoder_test_params.packed_qkvo.packed_qkv is not None
  646. encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens
  647. num_encoder_tokens = (None if encoder_seq_lens is None else
  648. (sum(encoder_seq_lens)))
  649. if cross_test_params is None:
  650. cross_kv_mmap = None
  651. else:
  652. # Encoder/decoder or encoder-only models only:
  653. # * Extract *cross-attention* slot_mapping and block table
  654. # (kv_mmap)
  655. cross_kv_mmap = cross_test_params.kv_mmap
  656. if is_prompt:
  657. # Prefill-phase scenario
  658. num_prefills = num_prefills_or_decodes
  659. num_prefill_tokens = num_prefill_or_decode_tokens
  660. num_decode_tokens = 0
  661. (
  662. seq_lens_tensor,
  663. context_lens_tensor,
  664. _,
  665. _,
  666. _,
  667. encoder_seq_lens_tensor,
  668. max_encoder_seq_len,
  669. ) = _make_metadata_tensors(seq_lens,
  670. context_lens,
  671. encoder_seq_lens,
  672. device=device)
  673. return attn_backend.make_metadata(
  674. num_prefills=num_prefills,
  675. slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
  676. num_prefill_tokens=num_prefill_tokens,
  677. num_decode_tokens=num_decode_tokens,
  678. seq_lens=seq_lens,
  679. seq_lens_tensor=seq_lens_tensor,
  680. max_prefill_seq_len=None if seq_lens is None else max(seq_lens),
  681. max_decode_seq_len=0,
  682. context_lens_tensor=context_lens_tensor,
  683. block_tables=(None if kv_mmap is None else kv_mmap.block_tables),
  684. use_cuda_graph=False,
  685. num_encoder_tokens=num_encoder_tokens,
  686. encoder_seq_lens=encoder_seq_lens,
  687. encoder_seq_lens_tensor=encoder_seq_lens_tensor,
  688. max_encoder_seq_len=max_encoder_seq_len,
  689. cross_slot_mapping=(None if cross_kv_mmap is None else
  690. cross_kv_mmap.slot_mapping),
  691. cross_block_tables=(None if cross_kv_mmap is None else
  692. cross_kv_mmap.block_tables))
  693. else: # not is_prompt
  694. # Decode-phase scenario
  695. assert kv_mmap is not None
  696. assert num_prefill_or_decode_tokens is not None
  697. assert seq_lens is not None
  698. num_prefills = 0
  699. num_prefill_tokens = 0
  700. num_decode_tokens = num_prefill_or_decode_tokens
  701. (
  702. seq_lens_tensor,
  703. context_lens_tensor,
  704. _,
  705. _,
  706. _,
  707. encoder_seq_lens_tensor,
  708. max_encoder_seq_len,
  709. ) = _make_metadata_tensors(seq_lens,
  710. context_lens,
  711. encoder_seq_lens,
  712. device=device)
  713. return attn_backend.make_metadata(
  714. num_prefills=num_prefills,
  715. slot_mapping=kv_mmap.slot_mapping,
  716. num_prefill_tokens=num_prefill_tokens,
  717. num_decode_tokens=num_decode_tokens,
  718. seq_lens=seq_lens,
  719. seq_lens_tensor=seq_lens_tensor,
  720. max_prefill_seq_len=0,
  721. max_decode_seq_len=max(seq_lens),
  722. context_lens_tensor=context_lens_tensor,
  723. block_tables=kv_mmap.block_tables,
  724. use_cuda_graph=False,
  725. num_encoder_tokens=num_encoder_tokens,
  726. encoder_seq_lens=encoder_seq_lens,
  727. encoder_seq_lens_tensor=encoder_seq_lens_tensor,
  728. max_encoder_seq_len=max_encoder_seq_len,
  729. cross_slot_mapping=(None if cross_kv_mmap is None else
  730. cross_kv_mmap.slot_mapping),
  731. cross_block_tables=(None if cross_kv_mmap is None else
  732. cross_kv_mmap.block_tables))
  733. def assert_actual_matches_ideal(test_params: PhaseTestParameters,
  734. output_under_test: torch.Tensor) -> None:
  735. '''
  736. Assert that observed output matches the ideal output
  737. contained in the test parameters data structure.
  738. Arguments:
  739. * test_params: Test parameters including packed ideal output
  740. * output_under_test: actually observed output value
  741. '''
  742. ideal_output = test_params.packed_qkvo.ideal_output
  743. torch.testing.assert_close(ideal_output,
  744. output_under_test.view_as(ideal_output))
  745. def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
  746. torch._library.custom_ops.CustomOpDef],
  747. args: Tuple[Any, ...],
  748. kwargs: Optional[Dict[str, Any]] = None,
  749. *,
  750. test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
  751. raise_exception: bool = True,
  752. cond: bool = True) -> Dict[str, str]:
  753. return torch.library.opcheck(
  754. op,
  755. args,
  756. kwargs,
  757. test_utils=test_utils,
  758. raise_exception=raise_exception) if cond else {}