test_encoder_decoder_model_runner.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653
  1. import itertools
  2. from typing import List
  3. import pytest
  4. import torch
  5. from aphrodite.common.sequence import (SamplingParams, SequenceData,
  6. SequenceGroupMetadata)
  7. from aphrodite.common.utils import is_cpu, make_tensor_with_pad
  8. from aphrodite.engine.args_tools import EngineArgs
  9. from aphrodite.worker.enc_dec_model_runner import EncoderDecoderModelRunner
  10. from aphrodite.worker.model_runner import _get_graph_batch_size
  11. BATCH_SIZES = [1, 4, 16, 64, 256]
  12. def _create_model_runner(model: str, *args,
  13. **kwargs) -> EncoderDecoderModelRunner:
  14. engine_args = EngineArgs(model, *args, **kwargs)
  15. engine_config = engine_args.create_engine_config()
  16. model_runner = EncoderDecoderModelRunner(
  17. model_config=engine_config.model_config,
  18. parallel_config=engine_config.parallel_config,
  19. scheduler_config=engine_config.scheduler_config,
  20. device_config=engine_config.device_config,
  21. cache_config=engine_config.cache_config,
  22. load_config=engine_config.load_config,
  23. lora_config=engine_config.lora_config,
  24. prompt_adapter_config=engine_config.prompt_adapter_config,
  25. is_driver_worker=True,
  26. )
  27. return model_runner
  28. @pytest.mark.skipif(condition=is_cpu(),
  29. reason="CPU backend is currently "
  30. "unsupported for encoder/ "
  31. "decoder models")
  32. def test_empty_seq_group():
  33. """Verify prepare prompt and decode returns empty output
  34. for empty seq group list"""
  35. model_runner = _create_model_runner(
  36. "facebook/bart-base",
  37. seed=0,
  38. dtype="float16",
  39. max_num_batched_tokens=100000,
  40. max_num_seqs=100000,
  41. enable_chunked_prefill=False,
  42. enforce_eager=True,
  43. )
  44. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  45. model_input = model_runner._prepare_model_input_tensors(
  46. seq_group_metadata_list)
  47. (
  48. input_tokens,
  49. input_positions,
  50. encoder_input_tokens,
  51. encoder_input_positions,
  52. attn_metadata,
  53. return_seq_lens,
  54. ) = (
  55. model_input.input_tokens,
  56. model_input.input_positions,
  57. model_input.encoder_input_tokens,
  58. model_input.encoder_input_positions,
  59. model_input.attn_metadata,
  60. model_input.seq_lens,
  61. )
  62. assert input_tokens is None
  63. assert input_positions is None
  64. assert encoder_input_tokens is None
  65. assert encoder_input_positions is None
  66. assert attn_metadata is None
  67. assert return_seq_lens is None
  68. @pytest.mark.skipif(condition=is_cpu(),
  69. reason="CPU backend is currently "
  70. "unsupported for encoder/ "
  71. "decoder models")
  72. @pytest.mark.parametrize("batch_size", BATCH_SIZES)
  73. def test_prepare_prompt(batch_size):
  74. '''
  75. Test the ability of the encoder/decoder model runner subclass to
  76. produce prefill-phase model inputs & attention metadata.
  77. Test behavior:
  78. * Instantiate BART base model & enc/dec model runner
  79. * Construct sequence-group metadata for dummy prompts
  80. * Test that encoder attention, decoder self-attention,
  81. and encoder/decoder cross-attention inputs are correct
  82. Arguments:
  83. * batch_size
  84. * backend_name: The attention backend under test
  85. * enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
  86. '''
  87. model_runner = _create_model_runner(
  88. "facebook/bart-base",
  89. seed=0,
  90. dtype="float16",
  91. max_num_batched_tokens=100000,
  92. max_num_seqs=100000,
  93. enable_chunked_prefill=False,
  94. enforce_eager=True,
  95. )
  96. seq_lens: List[int] = []
  97. encoder_seq_lens: List[int] = []
  98. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  99. block_tables = {0: [1]}
  100. cross_block_table = [2]
  101. for i in range(batch_size):
  102. # make sure all tokens fit into one block
  103. seq_len = i % (model_runner.block_size - 1) + 1
  104. seq_lens.append(seq_len)
  105. seq_data = SequenceData.from_seqs(range(seq_len))
  106. encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
  107. encoder_seq_lens.append(encoder_seq_len)
  108. encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len))
  109. seq_group_metadata = SequenceGroupMetadata(
  110. request_id=f"test_{i}",
  111. is_prompt=True,
  112. seq_data={0: seq_data},
  113. sampling_params=SamplingParams(temperature=0),
  114. block_tables=block_tables,
  115. encoder_seq_data=encoder_seq_data,
  116. cross_block_table=cross_block_table,
  117. )
  118. assert seq_group_metadata.token_chunk_size == seq_data.get_len()
  119. seq_group_metadata_list.append(seq_group_metadata)
  120. # Build
  121. # * Decoder model inputs
  122. # * Decoder self-attention KV caching data structures
  123. # * Encoder model inputs
  124. # * Encoder/decoder cross-attention KV caching data structures
  125. model_input = model_runner.prepare_model_input(seq_group_metadata_list)
  126. input_tokens = model_input.input_tokens
  127. input_positions = model_input.input_positions
  128. attn_metadata = model_input.attn_metadata
  129. return_seq_lens = model_input.seq_lens
  130. slot_mapping = attn_metadata.slot_mapping
  131. encoder_input_tokens = model_input.encoder_input_tokens
  132. encoder_input_positions = model_input.encoder_input_positions
  133. cross_slot_mapping = attn_metadata.cross_slot_mapping
  134. assert return_seq_lens == seq_lens
  135. assert len(slot_mapping) == len(input_tokens)
  136. assert len(cross_slot_mapping) == len(encoder_input_tokens)
  137. # Verify input metadata is correct for prompts.
  138. # - Decoder attention metadata
  139. device = model_runner.device
  140. assert attn_metadata.num_prefills > 0
  141. assert attn_metadata.num_decode_tokens == 0
  142. assert torch.equal(attn_metadata.seq_lens_tensor,
  143. torch.tensor(seq_lens, device=device, dtype=torch.int))
  144. assert attn_metadata.seq_lens == seq_lens
  145. assert attn_metadata.max_prefill_seq_len == max(seq_lens)
  146. assert attn_metadata.max_decode_seq_len == 0
  147. # - Encoder attention metadata
  148. assert attn_metadata.encoder_seq_lens == encoder_seq_lens
  149. assert torch.equal(
  150. attn_metadata.encoder_seq_lens_tensor,
  151. torch.tensor(encoder_seq_lens, device=device, dtype=torch.int))
  152. assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens)
  153. assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens)
  154. # Test decoder subquery start locs.
  155. start_idx = 0
  156. start_loc = [start_idx]
  157. for seq_len in seq_lens:
  158. start_idx += seq_len
  159. start_loc.append(start_idx)
  160. assert torch.equal(
  161. attn_metadata.query_start_loc,
  162. torch.tensor(start_loc, dtype=torch.int32, device=device),
  163. )
  164. # Test decoder seq start locs & context lengths
  165. assert torch.equal(
  166. attn_metadata.seq_start_loc,
  167. torch.tensor(start_loc, dtype=torch.int32, device=device),
  168. )
  169. assert torch.equal(
  170. attn_metadata.context_lens_tensor,
  171. torch.zeros(attn_metadata.context_lens_tensor.shape[0],
  172. dtype=torch.int,
  173. device=device),
  174. )
  175. # Verify block tables are correct for prompts
  176. # - Decoder self-attention
  177. expected = torch.tensor(
  178. [[] for _ in range(len(seq_group_metadata_list))],
  179. dtype=torch.int32,
  180. device=model_runner.device,
  181. )
  182. assert torch.equal(
  183. attn_metadata.block_tables,
  184. expected,
  185. )
  186. # - Encoder/decoder cross-attention
  187. assert torch.equal(
  188. attn_metadata.cross_block_tables,
  189. expected,
  190. )
  191. # Cuda graph should not be used for prefill.
  192. assert attn_metadata.use_cuda_graph is False
  193. # Verify the lengths of input tokens & positions
  194. # - Decoder
  195. assert len(input_tokens) == sum(seq_lens)
  196. assert len(input_positions) == sum(seq_lens)
  197. # -- An indirect check that model_input.input_tokens
  198. # and model_input.input_positions are correct -
  199. # by design of the test, the input tokens are
  200. # equal to the input position values, so if
  201. # the model_input data structure has the correct
  202. # values then these two should be equal
  203. assert torch.equal(
  204. input_tokens,
  205. input_positions,
  206. )
  207. # - Encoder
  208. assert len(encoder_input_tokens) == sum(encoder_seq_lens)
  209. # -- An indirect check that model_input.encoder_input_tokens
  210. # and model_input.encoder_input_positions are correct -
  211. # by design of the test, the input tokens are
  212. # equal to the input position values, so if
  213. # the model_input data structure has the correct
  214. # values then these two should be equal
  215. assert torch.equal(
  216. encoder_input_tokens,
  217. encoder_input_positions,
  218. )
  219. # Test that Aphrodite sampling infrastructure chooses the correct
  220. # sequence positions at which to sample (i.e. the end of
  221. # each sequence) in the prefill phase
  222. expected_selected_token_indices = []
  223. selected_token_start_idx = 0
  224. for seq_len in seq_lens:
  225. # Compute the index offset of the final token in each
  226. # prompt (recall that the prompts are concatenated)
  227. expected_selected_token_indices.append(selected_token_start_idx +
  228. seq_len - 1)
  229. selected_token_start_idx += seq_len
  230. sampling_metadata = model_input.sampling_metadata
  231. actual = sampling_metadata.selected_token_indices
  232. expected = torch.tensor(
  233. expected_selected_token_indices,
  234. device=actual.device,
  235. dtype=actual.dtype,
  236. )
  237. assert torch.equal(actual, expected)
  238. @pytest.mark.skipif(condition=is_cpu(),
  239. reason="CPU backend is currently "
  240. "unsupported for encoder/ "
  241. "decoder models")
  242. @pytest.mark.parametrize("batch_size", BATCH_SIZES)
  243. @pytest.mark.parametrize("multiple_seqs_per_seq_group", [True, False])
  244. def test_prepare_decode(batch_size, multiple_seqs_per_seq_group):
  245. '''
  246. Test the ability of the encoder/decoder model runner subclass to
  247. produce decode-phase model inputs & attention metadata.
  248. Test behavior:
  249. * Instantiate BART base model & enc/dec model runner
  250. * Construct sequence-group metadata for dummy prompts
  251. * Test that encoder attention, decoder self-attention,
  252. and encoder/decoder cross-attention inputs are correct
  253. Arguments:
  254. * batch_size
  255. * multiple_seqs_per_seq_group
  256. * backend_name: The attention backend under test
  257. * enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
  258. '''
  259. model_runner = _create_model_runner(
  260. "facebook/bart-base",
  261. seed=0,
  262. dtype="float16",
  263. max_num_batched_tokens=100000,
  264. max_num_seqs=100000,
  265. enable_chunked_prefill=False,
  266. enforce_eager=True,
  267. )
  268. seq_lens: List[int] = []
  269. encoder_seq_lens: List[int] = []
  270. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  271. block_tables = {
  272. 0: [1],
  273. 1: [3]
  274. } if multiple_seqs_per_seq_group else {
  275. 0: [1]
  276. }
  277. cross_block_table = [2]
  278. for i in range(batch_size):
  279. # make sure all tokens fit into one block
  280. seq_len = i % (model_runner.block_size - 1) + 1
  281. seq_data = SequenceData.from_seqs(range(seq_len))
  282. encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
  283. encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len))
  284. seq_group_metadata = SequenceGroupMetadata(
  285. request_id=f"test_{i}",
  286. is_prompt=False,
  287. seq_data={
  288. 0: seq_data,
  289. 1: seq_data
  290. } if multiple_seqs_per_seq_group else {0: seq_data},
  291. sampling_params=SamplingParams(temperature=0),
  292. block_tables=block_tables,
  293. encoder_seq_data=encoder_seq_data,
  294. cross_block_table=cross_block_table,
  295. )
  296. assert seq_group_metadata.token_chunk_size == 1
  297. seq_group_metadata_list.append(seq_group_metadata)
  298. seq_lens.extend(
  299. [seq_len for _ in range(len(seq_group_metadata.seq_data))])
  300. encoder_seq_lens.extend(
  301. [encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))])
  302. # Build
  303. # * Decoder model inputs
  304. # * Decoder self-attention KV caching data structures
  305. # * Encoder model inputs
  306. # * Encoder/decoder cross-attention KV caching data structures
  307. model_input = model_runner.prepare_model_input(seq_group_metadata_list)
  308. input_tokens = model_input.input_tokens
  309. input_positions = model_input.input_positions
  310. attn_metadata = model_input.attn_metadata
  311. return_seq_lens = model_input.seq_lens
  312. slot_mapping = attn_metadata.slot_mapping
  313. encoder_input_tokens = model_input.encoder_input_tokens
  314. encoder_input_positions = model_input.encoder_input_positions
  315. cross_slot_mapping = attn_metadata.cross_slot_mapping
  316. assert return_seq_lens == seq_lens
  317. assert len(slot_mapping) == len(input_tokens)
  318. assert len(cross_slot_mapping) == len(encoder_input_tokens)
  319. # Verify input metadata is correct for decode phase.
  320. # - Decoder attention metadata
  321. device = model_runner.device
  322. assert attn_metadata.num_prefills == 0
  323. assert attn_metadata.num_decode_tokens > 0
  324. assert torch.equal(attn_metadata.seq_lens_tensor,
  325. torch.tensor(seq_lens, device=device, dtype=torch.int))
  326. assert attn_metadata.seq_lens == seq_lens
  327. assert attn_metadata.max_prefill_seq_len == 0
  328. assert attn_metadata.max_decode_seq_len == max(seq_lens)
  329. # - Encoder attention metadata
  330. assert attn_metadata.encoder_seq_lens == encoder_seq_lens
  331. assert torch.equal(
  332. attn_metadata.encoder_seq_lens_tensor,
  333. torch.tensor(encoder_seq_lens, device=device, dtype=torch.int))
  334. assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens)
  335. assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens)
  336. # Test decoder subquery start locs.
  337. start_idx = 0
  338. start_loc = [start_idx]
  339. for seq_len in seq_lens:
  340. start_idx += 1
  341. start_loc.append(start_idx)
  342. assert torch.equal(
  343. attn_metadata.query_start_loc,
  344. torch.tensor(start_loc, dtype=torch.int32, device=device),
  345. )
  346. # Test decoder seq start locs. Note that for normal prefill it is
  347. # equivalent to query_start_loc.
  348. start_idx = 0
  349. seq_start_loc = [start_idx]
  350. for seq_len in seq_lens:
  351. start_idx += seq_len
  352. seq_start_loc.append(start_idx)
  353. # Test seq_start_loc and context lengths
  354. assert torch.equal(
  355. attn_metadata.seq_start_loc,
  356. torch.tensor(seq_start_loc, dtype=torch.int32, device=device),
  357. )
  358. assert torch.equal(
  359. attn_metadata.context_lens_tensor,
  360. torch.tensor([seq_len - 1 for seq_len in seq_lens],
  361. dtype=torch.int,
  362. device=device))
  363. # Verify block tables are correct for prompts
  364. # - Decoder self-attention
  365. flattened_block_tables = [
  366. block_table for block_table in block_tables.values()
  367. ]
  368. expected = torch.tensor(flattened_block_tables *
  369. len(seq_group_metadata_list),
  370. dtype=torch.int32,
  371. device=model_runner.device)
  372. assert torch.equal(
  373. attn_metadata.block_tables,
  374. expected,
  375. )
  376. # - Encoder/decoder cross-attention
  377. expected = torch.tensor([
  378. cross_block_table for seq_group_metadata in seq_group_metadata_list
  379. for _ in range(len(seq_group_metadata.seq_data))
  380. ],
  381. dtype=torch.int32,
  382. device=model_runner.device)
  383. assert torch.equal(
  384. attn_metadata.cross_block_tables,
  385. expected,
  386. )
  387. # Model runner's CUDAGraph setting should be propagated to attention
  388. # metadata.
  389. assert attn_metadata.use_cuda_graph is False
  390. # Verify the lengths of input tokens & positions
  391. # - Decoder
  392. assert len(input_tokens) == len(seq_lens)
  393. assert len(input_positions) == len(seq_lens)
  394. # -- An indirect check that model_input.input_tokens
  395. # and model_input.input_positions are correct -
  396. # by design of the test, the input tokens are
  397. # equal to the input position values, so if
  398. # the model_input data structure has the correct
  399. # values then these two should be equal
  400. assert torch.equal(
  401. input_tokens,
  402. input_positions,
  403. )
  404. # - Encoder
  405. assert len(encoder_input_tokens) == 0
  406. assert len(encoder_input_tokens) == 0
  407. # -- An indirect check that model_input.encoder_input_tokens
  408. # and model_input.encoder_input_positions are correct -
  409. # by design of the test, the input tokens are
  410. # equal to the input position values, so if
  411. # the model_input data structure has the correct
  412. # values then these two should be equal
  413. assert torch.equal(
  414. encoder_input_tokens,
  415. encoder_input_positions,
  416. )
  417. # Test that Aphrodite sampling infrastructure chooses the correct
  418. # sequence positions at which to sample (i.e. the end of
  419. # each sequence) in the decode phase
  420. expected_selected_token_indices = []
  421. for selected_token_start_idx, seq_len in enumerate(seq_lens):
  422. # Compute the index offset of the final token in each
  423. # sequence's decoded outputs; since a single token is
  424. # decoded per iteration per sequence, then the length
  425. # of the decoded tokens for a given sequence is 1 and
  426. # the final index offset into a given sequence's
  427. # generated tokens is 0 (i.e. the expected sampling index
  428. # for a given sequence is just `selected_token_start_idx`)
  429. expected_selected_token_indices.append(selected_token_start_idx)
  430. sampling_metadata = model_input.sampling_metadata
  431. actual = sampling_metadata.selected_token_indices
  432. expected = torch.tensor(
  433. expected_selected_token_indices,
  434. device=actual.device,
  435. dtype=actual.dtype,
  436. )
  437. assert torch.equal(actual, expected)
  438. @pytest.mark.parametrize("batch_size", list(range(1, 257)))
  439. @pytest.mark.parametrize("multiple_seqs_per_seq_group", [True, False])
  440. def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
  441. """
  442. Tests that for encoder-decoder models with CUDA Graph capture and replay
  443. enabled, the tensors used during the decode phase are correctly padded
  444. for varying input batch sizes.
  445. """
  446. model_runner = _create_model_runner(
  447. "facebook/bart-base",
  448. seed=0,
  449. dtype="float16",
  450. max_num_batched_tokens=100000,
  451. max_num_seqs=100000,
  452. enable_chunked_prefill=False,
  453. enforce_eager=False,
  454. )
  455. block_tables = {
  456. 0: [1],
  457. 1: [3]
  458. } if multiple_seqs_per_seq_group else {
  459. 0: [1]
  460. }
  461. seq_lens: List[int] = []
  462. encoder_seq_lens: List[int] = []
  463. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  464. cross_block_table = [2]
  465. expanded_batch_size = 0
  466. for i in range(batch_size):
  467. # make sure all tokens fit into one block
  468. seq_len = i % (model_runner.block_size - 1) + 1
  469. seq_data = SequenceData.from_seqs(range(seq_len))
  470. encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
  471. encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len))
  472. seq_group_metadata = SequenceGroupMetadata(
  473. request_id=f"test_{i}",
  474. is_prompt=False,
  475. seq_data={
  476. 0: seq_data,
  477. 1: seq_data
  478. } if multiple_seqs_per_seq_group else {0: seq_data},
  479. sampling_params=SamplingParams(temperature=0),
  480. block_tables=block_tables,
  481. encoder_seq_data=encoder_seq_data,
  482. cross_block_table=cross_block_table,
  483. )
  484. assert seq_group_metadata.token_chunk_size == 1
  485. seq_lens.extend(
  486. [seq_len for _ in range(len(seq_group_metadata.seq_data))])
  487. encoder_seq_lens.extend(
  488. [encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))])
  489. expanded_batch_size = expanded_batch_size + len(
  490. seq_group_metadata.seq_data)
  491. seq_group_metadata_list.append(seq_group_metadata)
  492. model_input = model_runner.prepare_model_input(seq_group_metadata_list)
  493. input_tokens = model_input.input_tokens
  494. input_positions = model_input.input_positions
  495. attn_metadata = model_input.attn_metadata
  496. return_seq_lens = model_input.seq_lens
  497. slot_mapping = attn_metadata.slot_mapping
  498. encoder_input_tokens = model_input.encoder_input_tokens
  499. encoder_input_positions = model_input.encoder_input_positions
  500. cross_slot_mapping = attn_metadata.cross_slot_mapping
  501. # With CUDA Graph capture and replay enabled, the decoder and encoder
  502. # input sequences will be padded. Create the expected padded tensors
  503. # accordingly.
  504. graph_batch_size = _get_graph_batch_size(expanded_batch_size)
  505. cuda_graph_pad_size = graph_batch_size - expanded_batch_size
  506. padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size))
  507. padded_encoder_seq_lens = encoder_seq_lens + list(
  508. itertools.repeat(1, cuda_graph_pad_size))
  509. assert return_seq_lens == padded_seq_lens
  510. assert len(slot_mapping) == len(input_tokens)
  511. assert len(cross_slot_mapping) == len(encoder_input_tokens)
  512. # Verify attention metadata
  513. device = model_runner.device
  514. assert attn_metadata.num_prefills == 0
  515. assert attn_metadata.num_decode_tokens > 0
  516. assert torch.equal(
  517. attn_metadata.seq_lens_tensor,
  518. torch.tensor(padded_seq_lens, device=device, dtype=torch.int))
  519. assert attn_metadata.seq_lens == padded_seq_lens
  520. assert attn_metadata.max_prefill_seq_len == 0
  521. assert attn_metadata.max_decode_seq_len == max(seq_lens)
  522. # - Encoder attention metadata
  523. assert attn_metadata.encoder_seq_lens == padded_encoder_seq_lens
  524. assert torch.equal(
  525. attn_metadata.encoder_seq_lens_tensor,
  526. torch.tensor(padded_encoder_seq_lens, device=device, dtype=torch.int))
  527. assert attn_metadata.max_encoder_seq_len == max(padded_encoder_seq_lens)
  528. assert attn_metadata.num_encoder_tokens == sum(padded_encoder_seq_lens)
  529. # Verify block tables are correct for prompts
  530. # - Decoder self-attention. Pad the block tables as expected.
  531. flattened_block_tables = [
  532. block_table for _ in range(len(seq_group_metadata_list))
  533. for block_table in block_tables.values()
  534. ]
  535. flattened_block_tables.extend([[] for _ in range(cuda_graph_pad_size)])
  536. expected = make_tensor_with_pad(
  537. flattened_block_tables,
  538. max_len=64,
  539. pad=0,
  540. dtype=torch.int32,
  541. device=model_runner.device,
  542. )
  543. assert torch.equal(
  544. attn_metadata.block_tables,
  545. expected,
  546. )
  547. # - Encoder/decoder cross-attention. Pad the cross-attention block tables
  548. # as expected.
  549. expected = [
  550. cross_block_table for seq_group_metadata in seq_group_metadata_list
  551. for _ in range(len(seq_group_metadata.seq_data))
  552. ]
  553. expected.extend([[] for _ in range(cuda_graph_pad_size)])
  554. expected = make_tensor_with_pad(
  555. expected,
  556. max_len=64,
  557. pad=0,
  558. dtype=torch.int32,
  559. device=model_runner.device,
  560. )
  561. assert torch.equal(
  562. attn_metadata.cross_block_tables,
  563. expected,
  564. )
  565. # Model runner's CUDAGraph setting should be propagated to attention
  566. # metadata.
  567. assert attn_metadata.use_cuda_graph is True
  568. # Verify the lengths of input tokens & positions
  569. # - Decoder
  570. assert len(input_tokens) == len(padded_seq_lens)
  571. assert len(input_positions) == len(padded_seq_lens)
  572. # -- An indirect check that model_input.input_tokens
  573. # and model_input.input_positions are correct -
  574. # by design of the test, the input tokens are
  575. # equal to the input position values, so if
  576. # the model_input data structure has the correct
  577. # values then these two should be equal
  578. assert torch.equal(
  579. input_tokens,
  580. input_positions,
  581. )
  582. # - Encoder
  583. assert len(encoder_input_tokens) == 0
  584. assert len(encoder_input_tokens) == 0
  585. # -- An indirect check that model_input.encoder_input_tokens
  586. # and model_input.encoder_input_positions are correct -
  587. # by design of the test, the input tokens are
  588. # equal to the input position values, so if
  589. # the model_input data structure has the correct
  590. # values then these two should be equal
  591. assert torch.equal(
  592. encoder_input_tokens,
  593. encoder_input_positions,
  594. )