test_encoder_decoder_model_runner.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648
  1. import itertools
  2. from array import array
  3. from typing import List
  4. import pytest
  5. import torch
  6. from aphrodite.common.sequence import (SamplingParams, SequenceData,
  7. SequenceGroupMetadata)
  8. from aphrodite.common.utils import is_cpu, make_tensor_with_pad
  9. from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
  10. from aphrodite.engine.args_tools import EngineArgs
  11. from aphrodite.worker.enc_dec_model_runner import EncoderDecoderModelRunner
  12. from aphrodite.worker.model_runner import _get_graph_batch_size
  13. BATCH_SIZES = [1, 4, 16, 64, 256]
  14. def _create_model_runner(
  15. model: str, *args, **kwargs
  16. ) -> EncoderDecoderModelRunner:
  17. engine_args = EngineArgs(model, *args, **kwargs)
  18. engine_config = engine_args.create_engine_config()
  19. model_runner = EncoderDecoderModelRunner(
  20. model_config=engine_config.model_config,
  21. parallel_config=engine_config.parallel_config,
  22. scheduler_config=engine_config.scheduler_config,
  23. device_config=engine_config.device_config,
  24. cache_config=engine_config.cache_config,
  25. load_config=engine_config.load_config,
  26. lora_config=engine_config.lora_config,
  27. prompt_adapter_config=engine_config.prompt_adapter_config,
  28. is_driver_worker=True,
  29. )
  30. return model_runner
  31. @pytest.mark.skipif(
  32. condition=is_cpu(),
  33. reason="CPU backend is currently "
  34. "unsupported for encoder/ "
  35. "decoder models",
  36. )
  37. def test_empty_seq_group():
  38. """Verify prepare prompt and decode returns empty output
  39. for empty seq group list"""
  40. model_runner = _create_model_runner(
  41. "facebook/bart-base",
  42. seed=0,
  43. dtype="float16",
  44. max_num_batched_tokens=100000,
  45. max_num_seqs=100000,
  46. enable_chunked_prefill=False,
  47. enforce_eager=True,
  48. )
  49. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  50. model_input = model_runner._prepare_model_input_tensors(
  51. seq_group_metadata_list
  52. )
  53. (
  54. input_tokens,
  55. input_positions,
  56. encoder_input_tokens,
  57. encoder_input_positions,
  58. attn_metadata,
  59. return_seq_lens,
  60. ) = (
  61. model_input.input_tokens,
  62. model_input.input_positions,
  63. model_input.encoder_input_tokens,
  64. model_input.encoder_input_positions,
  65. model_input.attn_metadata,
  66. model_input.seq_lens,
  67. )
  68. assert input_tokens is None
  69. assert input_positions is None
  70. assert encoder_input_tokens is None
  71. assert encoder_input_positions is None
  72. assert attn_metadata is None
  73. assert return_seq_lens is None
  74. @pytest.mark.skipif(
  75. condition=is_cpu(),
  76. reason="CPU backend is currently "
  77. "unsupported for encoder/ "
  78. "decoder models",
  79. )
  80. @pytest.mark.parametrize("batch_size", BATCH_SIZES)
  81. def test_prepare_prompt(batch_size):
  82. """
  83. Test the ability of the encoder/decoder model runner subclass to
  84. produce prefill-phase model inputs & attention metadata.
  85. Test behavior:
  86. * Instantiate BART base model & enc/dec model runner
  87. * Construct sequence-group metadata for dummy prompts
  88. * Test that encoder attention, decoder self-attention,
  89. and encoder/decoder cross-attention inputs are correct
  90. Arguments:
  91. * batch_size
  92. * backend_name: The attention backend under test
  93. * enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
  94. """
  95. model_runner = _create_model_runner(
  96. "facebook/bart-base",
  97. seed=0,
  98. dtype="float16",
  99. max_num_batched_tokens=100000,
  100. max_num_seqs=100000,
  101. enable_chunked_prefill=False,
  102. enforce_eager=True,
  103. )
  104. seq_lens: List[int] = []
  105. encoder_seq_lens: List[int] = []
  106. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  107. block_tables = {0: [1]}
  108. cross_block_table = [2]
  109. for i in range(batch_size):
  110. # make sure all tokens fit into one block
  111. seq_len = i % (model_runner.block_size - 1) + 1
  112. seq_lens.append(seq_len)
  113. seq_data = SequenceData(
  114. array(APHRODITE_TOKEN_ID_ARRAY_TYPE, range(seq_len))
  115. )
  116. encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
  117. encoder_seq_lens.append(encoder_seq_len)
  118. encoder_seq_data = SequenceData(
  119. array(APHRODITE_TOKEN_ID_ARRAY_TYPE, range(encoder_seq_len))
  120. )
  121. seq_group_metadata = SequenceGroupMetadata(
  122. request_id=f"test_{i}",
  123. is_prompt=True,
  124. seq_data={0: seq_data},
  125. sampling_params=SamplingParams(temperature=0),
  126. block_tables=block_tables,
  127. encoder_seq_data=encoder_seq_data,
  128. cross_block_table=cross_block_table,
  129. )
  130. assert seq_group_metadata.token_chunk_size == seq_data.get_len()
  131. seq_group_metadata_list.append(seq_group_metadata)
  132. # Build
  133. # * Decoder model inputs
  134. # * Decoder self-attention KV caching data structures
  135. # * Encoder model inputs
  136. # * Encoder/decoder cross-attention KV caching data structures
  137. model_input = model_runner.prepare_model_input(seq_group_metadata_list)
  138. input_tokens = model_input.input_tokens
  139. input_positions = model_input.input_positions
  140. attn_metadata = model_input.attn_metadata
  141. return_seq_lens = model_input.seq_lens
  142. slot_mapping = attn_metadata.slot_mapping
  143. encoder_input_tokens = model_input.encoder_input_tokens
  144. encoder_input_positions = model_input.encoder_input_positions
  145. cross_slot_mapping = attn_metadata.cross_slot_mapping
  146. assert return_seq_lens == seq_lens
  147. assert len(slot_mapping) == len(input_tokens)
  148. assert len(cross_slot_mapping) == len(encoder_input_tokens)
  149. # Verify input metadata is correct for prompts.
  150. # - Decoder attention metadata
  151. device = model_runner.device
  152. assert attn_metadata.num_prefills > 0
  153. assert attn_metadata.num_decode_tokens == 0
  154. assert torch.equal(
  155. attn_metadata.seq_lens_tensor,
  156. torch.tensor(seq_lens, device=device, dtype=torch.int),
  157. )
  158. assert attn_metadata.seq_lens == seq_lens
  159. assert attn_metadata.max_prefill_seq_len == max(seq_lens)
  160. assert attn_metadata.max_decode_seq_len == 0
  161. # - Encoder attention metadata
  162. assert attn_metadata.encoder_seq_lens == encoder_seq_lens
  163. assert torch.equal(
  164. attn_metadata.encoder_seq_lens_tensor,
  165. torch.tensor(encoder_seq_lens, device=device, dtype=torch.int),
  166. )
  167. assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens)
  168. assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens)
  169. # Test decoder subquery start locs.
  170. start_idx = 0
  171. start_loc = [start_idx]
  172. for seq_len in seq_lens:
  173. start_idx += seq_len
  174. start_loc.append(start_idx)
  175. assert torch.equal(
  176. attn_metadata.query_start_loc,
  177. torch.tensor(start_loc, dtype=torch.int32, device=device),
  178. )
  179. # Test decoder seq start locs & context lengths
  180. assert torch.equal(
  181. attn_metadata.seq_start_loc,
  182. torch.tensor(start_loc, dtype=torch.int32, device=device),
  183. )
  184. assert torch.equal(
  185. attn_metadata.context_lens_tensor,
  186. torch.zeros(
  187. attn_metadata.context_lens_tensor.shape[0],
  188. dtype=torch.int,
  189. device=device,
  190. ),
  191. )
  192. # Verify block tables are correct for prompts
  193. # - Decoder self-attention
  194. expected = torch.tensor(
  195. [[] for _ in range(len(seq_group_metadata_list))],
  196. dtype=torch.int32,
  197. device=model_runner.device,
  198. )
  199. assert torch.equal(
  200. attn_metadata.block_tables,
  201. expected,
  202. )
  203. # - Encoder/decoder cross-attention
  204. assert torch.equal(
  205. attn_metadata.cross_block_tables,
  206. expected,
  207. )
  208. # Cuda graph should not be used for prefill.
  209. assert attn_metadata.use_cuda_graph is False
  210. # Verify the lengths of input tokens & positions
  211. # - Decoder
  212. assert len(input_tokens) == sum(seq_lens)
  213. assert len(input_positions) == sum(seq_lens)
  214. # -- An indirect check that model_input.input_tokens
  215. # and model_input.input_positions are correct -
  216. # by design of the test, the input tokens are
  217. # equal to the input position values, so if
  218. # the model_input data structure has the correct
  219. # values then these two should be equal
  220. assert torch.equal(
  221. input_tokens,
  222. input_positions,
  223. )
  224. # - Encoder
  225. assert len(encoder_input_tokens) == sum(encoder_seq_lens)
  226. # -- An indirect check that model_input.encoder_input_tokens
  227. # and model_input.encoder_input_positions are correct -
  228. # by design of the test, the input tokens are
  229. # equal to the input position values, so if
  230. # the model_input data structure has the correct
  231. # values then these two should be equal
  232. assert torch.equal(
  233. encoder_input_tokens,
  234. encoder_input_positions,
  235. )
  236. # Test that Aphrodite sampling infrastructure chooses the correct
  237. # sequence positions at which to sample (i.e. the end of
  238. # each sequence) in the prefill phase
  239. expected_selected_token_indices = []
  240. selected_token_start_idx = 0
  241. for seq_len in seq_lens:
  242. # Compute the index offset of the final token in each
  243. # prompt (recall that the prompts are concatenated)
  244. expected_selected_token_indices.append(
  245. selected_token_start_idx + seq_len - 1
  246. )
  247. selected_token_start_idx += seq_len
  248. sampling_metadata = model_input.sampling_metadata
  249. actual = sampling_metadata.selected_token_indices
  250. expected = torch.tensor(
  251. expected_selected_token_indices,
  252. device=actual.device,
  253. dtype=actual.dtype,
  254. )
  255. assert torch.equal(actual, expected)
  256. @pytest.mark.skipif(
  257. condition=is_cpu(),
  258. reason="CPU backend is currently "
  259. "unsupported for encoder/ "
  260. "decoder models",
  261. )
  262. @pytest.mark.parametrize("batch_size", BATCH_SIZES)
  263. def test_prepare_decode(batch_size):
  264. """
  265. Test the ability of the encoder/decoder model runner subclass to
  266. produce decode-phase model inputs & attention metadata.
  267. Test behavior:
  268. * Instantiate BART base model & enc/dec model runner
  269. * Construct sequence-group metadata for dummy prompts
  270. * Test that encoder attention, decoder self-attention,
  271. and encoder/decoder cross-attention inputs are correct
  272. Arguments:
  273. * batch_size
  274. * backend_name: The attention backend under test
  275. * enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
  276. """
  277. model_runner = _create_model_runner(
  278. "facebook/bart-base",
  279. seed=0,
  280. dtype="float16",
  281. max_num_batched_tokens=100000,
  282. max_num_seqs=100000,
  283. enable_chunked_prefill=False,
  284. enforce_eager=True,
  285. )
  286. seq_lens: List[int] = []
  287. encoder_seq_lens: List[int] = []
  288. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  289. block_tables = {0: [1]}
  290. cross_block_table = [2]
  291. for i in range(batch_size):
  292. # make sure all tokens fit into one block
  293. seq_len = i % (model_runner.block_size - 1) + 1
  294. seq_lens.append(seq_len)
  295. seq_data = SequenceData(
  296. array(APHRODITE_TOKEN_ID_ARRAY_TYPE, (range(seq_len)))
  297. )
  298. encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
  299. encoder_seq_lens.append(encoder_seq_len)
  300. encoder_seq_data = SequenceData(
  301. array(APHRODITE_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len)))
  302. )
  303. seq_group_metadata = SequenceGroupMetadata(
  304. request_id=f"test_{i}",
  305. is_prompt=False,
  306. seq_data={0: seq_data},
  307. sampling_params=SamplingParams(temperature=0),
  308. block_tables=block_tables,
  309. encoder_seq_data=encoder_seq_data,
  310. cross_block_table=cross_block_table,
  311. )
  312. assert seq_group_metadata.token_chunk_size == 1
  313. seq_group_metadata_list.append(seq_group_metadata)
  314. # Build
  315. # * Decoder model inputs
  316. # * Decoder self-attention KV caching data structures
  317. # * Encoder model inputs
  318. # * Encoder/decoder cross-attention KV caching data structures
  319. model_input = model_runner.prepare_model_input(seq_group_metadata_list)
  320. input_tokens = model_input.input_tokens
  321. input_positions = model_input.input_positions
  322. attn_metadata = model_input.attn_metadata
  323. return_seq_lens = model_input.seq_lens
  324. slot_mapping = attn_metadata.slot_mapping
  325. encoder_input_tokens = model_input.encoder_input_tokens
  326. encoder_input_positions = model_input.encoder_input_positions
  327. cross_slot_mapping = attn_metadata.cross_slot_mapping
  328. assert return_seq_lens == seq_lens
  329. assert len(slot_mapping) == len(input_tokens)
  330. assert len(cross_slot_mapping) == len(encoder_input_tokens)
  331. # Verify input metadata is correct for decode phase.
  332. # - Decoder attention metadata
  333. device = model_runner.device
  334. assert attn_metadata.num_prefills == 0
  335. assert attn_metadata.num_decode_tokens > 0
  336. assert torch.equal(
  337. attn_metadata.seq_lens_tensor,
  338. torch.tensor(seq_lens, device=device, dtype=torch.int),
  339. )
  340. assert attn_metadata.seq_lens == seq_lens
  341. assert attn_metadata.max_prefill_seq_len == 0
  342. assert attn_metadata.max_decode_seq_len == max(seq_lens)
  343. # - Encoder attention metadata
  344. assert attn_metadata.encoder_seq_lens == encoder_seq_lens
  345. assert torch.equal(
  346. attn_metadata.encoder_seq_lens_tensor,
  347. torch.tensor(encoder_seq_lens, device=device, dtype=torch.int),
  348. )
  349. assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens)
  350. assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens)
  351. # Test decoder subquery start locs.
  352. start_idx = 0
  353. start_loc = [start_idx]
  354. for seq_len in seq_lens:
  355. start_idx += 1
  356. start_loc.append(start_idx)
  357. assert torch.equal(
  358. attn_metadata.query_start_loc,
  359. torch.tensor(start_loc, dtype=torch.int32, device=device),
  360. )
  361. # Test decoder seq start locs. Note that for normal prefill it is
  362. # equivalent to query_start_loc.
  363. start_idx = 0
  364. seq_start_loc = [start_idx]
  365. for seq_len in seq_lens:
  366. start_idx += seq_len
  367. seq_start_loc.append(start_idx)
  368. # Test seq_start_loc and context lengths
  369. assert torch.equal(
  370. attn_metadata.seq_start_loc,
  371. torch.tensor(seq_start_loc, dtype=torch.int32, device=device),
  372. )
  373. assert torch.equal(
  374. attn_metadata.context_lens_tensor,
  375. torch.tensor(
  376. [seq_len - 1 for seq_len in seq_lens],
  377. dtype=torch.int,
  378. device=device,
  379. ),
  380. )
  381. # Verify block tables are correct for prompts
  382. # - Decoder self-attention
  383. expected = torch.tensor(
  384. [block_tables[0] for _ in range(len(seq_group_metadata_list))],
  385. dtype=torch.int32,
  386. device=model_runner.device,
  387. )
  388. assert torch.equal(
  389. attn_metadata.block_tables,
  390. expected,
  391. )
  392. # - Encoder/decoder cross-attention
  393. expected = torch.tensor(
  394. [cross_block_table for _ in range(len(seq_group_metadata_list))],
  395. dtype=torch.int32,
  396. device=model_runner.device,
  397. )
  398. assert torch.equal(
  399. attn_metadata.cross_block_tables,
  400. expected,
  401. )
  402. # Model runner's CUDAGraph setting should be propagated to attention
  403. # metadata.
  404. assert attn_metadata.use_cuda_graph is False
  405. # Verify the lengths of input tokens & positions
  406. # - Decoder
  407. assert len(input_tokens) == len(seq_lens)
  408. assert len(input_positions) == len(seq_lens)
  409. # -- An indirect check that model_input.input_tokens
  410. # and model_input.input_positions are correct -
  411. # by design of the test, the input tokens are
  412. # equal to the input position values, so if
  413. # the model_input data structure has the correct
  414. # values then these two should be equal
  415. assert torch.equal(
  416. input_tokens,
  417. input_positions,
  418. )
  419. # - Encoder
  420. assert len(encoder_input_tokens) == 0
  421. assert len(encoder_input_tokens) == 0
  422. # -- An indirect check that model_input.encoder_input_tokens
  423. # and model_input.encoder_input_positions are correct -
  424. # by design of the test, the input tokens are
  425. # equal to the input position values, so if
  426. # the model_input data structure has the correct
  427. # values then these two should be equal
  428. assert torch.equal(
  429. encoder_input_tokens,
  430. encoder_input_positions,
  431. )
  432. # Test that Aphrodite sampling infrastructure chooses the correct
  433. # sequence positions at which to sample (i.e. the end of
  434. # each sequence) in the decode phase
  435. expected_selected_token_indices = []
  436. selected_token_start_idx = 0
  437. for seq_len in seq_lens:
  438. # Compute the index offset of the final token in each
  439. # sequence's decoded outputs; since a single token is
  440. # decoded per iteration per sequence, then the length
  441. # of the decoded tokens for a given sequence is 1 and
  442. # the final index offset into a given sequence's
  443. # generated tokens is 0 (i.e. the expected sampling index
  444. # for a given sequence is just `selected_token_start_idx`)
  445. expected_selected_token_indices.append(selected_token_start_idx)
  446. selected_token_start_idx += 1
  447. sampling_metadata = model_input.sampling_metadata
  448. actual = sampling_metadata.selected_token_indices
  449. expected = torch.tensor(
  450. expected_selected_token_indices,
  451. device=actual.device,
  452. dtype=actual.dtype,
  453. )
  454. assert torch.equal(actual, expected)
  455. @pytest.mark.parametrize("batch_size", list(range(1, 257)))
  456. def test_prepare_decode_cuda_graph(batch_size):
  457. """
  458. Tests that for encoder-decoder models with CUDA Graph capture and replay
  459. enabled, the tensors used during the decode phase are correctly padded
  460. for varying input batch sizes.
  461. """
  462. model_runner = _create_model_runner(
  463. "facebook/bart-base",
  464. seed=0,
  465. dtype="float16",
  466. max_num_batched_tokens=100000,
  467. max_num_seqs=100000,
  468. enable_chunked_prefill=False,
  469. enforce_eager=False,
  470. )
  471. seq_lens: List[int] = []
  472. encoder_seq_lens: List[int] = []
  473. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  474. block_tables = {0: [1]}
  475. cross_block_table = [2]
  476. for i in range(batch_size):
  477. # make sure all tokens fit into one block
  478. seq_len = i % (model_runner.block_size - 1) + 1
  479. seq_lens.append(seq_len)
  480. seq_data = SequenceData(
  481. array(APHRODITE_TOKEN_ID_ARRAY_TYPE, (range(seq_len)))
  482. )
  483. encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
  484. encoder_seq_lens.append(encoder_seq_len)
  485. encoder_seq_data = SequenceData(
  486. array(APHRODITE_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len)))
  487. )
  488. seq_group_metadata = SequenceGroupMetadata(
  489. request_id=f"test_{i}",
  490. is_prompt=False,
  491. seq_data={0: seq_data},
  492. sampling_params=SamplingParams(temperature=0),
  493. block_tables=block_tables,
  494. encoder_seq_data=encoder_seq_data,
  495. cross_block_table=cross_block_table,
  496. )
  497. assert seq_group_metadata.token_chunk_size == 1
  498. seq_group_metadata_list.append(seq_group_metadata)
  499. model_input = model_runner.prepare_model_input(seq_group_metadata_list)
  500. input_tokens = model_input.input_tokens
  501. input_positions = model_input.input_positions
  502. attn_metadata = model_input.attn_metadata
  503. return_seq_lens = model_input.seq_lens
  504. slot_mapping = attn_metadata.slot_mapping
  505. encoder_input_tokens = model_input.encoder_input_tokens
  506. encoder_input_positions = model_input.encoder_input_positions
  507. cross_slot_mapping = attn_metadata.cross_slot_mapping
  508. # With CUDA Graph capture and replay enabled, the decoder and encoder
  509. # input sequences will be padded. Create the expected padded tensors
  510. # accordingly.
  511. graph_batch_size = _get_graph_batch_size(batch_size)
  512. cuda_graph_pad_size = graph_batch_size - batch_size
  513. padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size))
  514. padded_encoder_seq_lens = encoder_seq_lens + list(
  515. itertools.repeat(1, cuda_graph_pad_size)
  516. )
  517. assert return_seq_lens == padded_seq_lens
  518. assert len(slot_mapping) == len(input_tokens)
  519. assert len(cross_slot_mapping) == len(encoder_input_tokens)
  520. # Verify attention metadata
  521. device = model_runner.device
  522. assert attn_metadata.num_prefills == 0
  523. assert attn_metadata.num_decode_tokens > 0
  524. assert torch.equal(
  525. attn_metadata.seq_lens_tensor,
  526. torch.tensor(padded_seq_lens, device=device, dtype=torch.int),
  527. )
  528. assert attn_metadata.seq_lens == padded_seq_lens
  529. assert attn_metadata.max_prefill_seq_len == 0
  530. assert attn_metadata.max_decode_seq_len == max(seq_lens)
  531. # - Encoder attention metadata
  532. assert attn_metadata.encoder_seq_lens == padded_encoder_seq_lens
  533. assert torch.equal(
  534. attn_metadata.encoder_seq_lens_tensor,
  535. torch.tensor(padded_encoder_seq_lens, device=device, dtype=torch.int),
  536. )
  537. assert attn_metadata.max_encoder_seq_len == max(padded_encoder_seq_lens)
  538. assert attn_metadata.num_encoder_tokens == sum(padded_encoder_seq_lens)
  539. # Verify block tables are correct for prompts
  540. # - Decoder self-attention. Pad the block tables as expected.
  541. expected = [block_tables[0] for _ in range(batch_size)]
  542. expected.extend([[] for _ in range(cuda_graph_pad_size)])
  543. expected = make_tensor_with_pad(
  544. expected,
  545. max_len=64,
  546. pad=0,
  547. dtype=torch.int32,
  548. device=model_runner.device,
  549. )
  550. assert torch.equal(
  551. attn_metadata.block_tables,
  552. expected,
  553. )
  554. # - Encoder/decoder cross-attention. Pad the cross-attention block tables
  555. # as expected.
  556. expected = [cross_block_table for _ in range(len(seq_group_metadata_list))]
  557. expected.extend([[] for _ in range(cuda_graph_pad_size)])
  558. expected = make_tensor_with_pad(
  559. expected,
  560. max_len=64,
  561. pad=0,
  562. dtype=torch.int32,
  563. device=model_runner.device,
  564. )
  565. assert torch.equal(
  566. attn_metadata.cross_block_tables,
  567. expected,
  568. )
  569. # Model runner's CUDAGraph setting should be propagated to attention
  570. # metadata.
  571. assert attn_metadata.use_cuda_graph is True
  572. # Verify the lengths of input tokens & positions
  573. # - Decoder
  574. assert len(input_tokens) == len(padded_seq_lens)
  575. assert len(input_positions) == len(padded_seq_lens)
  576. # -- An indirect check that model_input.input_tokens
  577. # and model_input.input_positions are correct -
  578. # by design of the test, the input tokens are
  579. # equal to the input position values, so if
  580. # the model_input data structure has the correct
  581. # values then these two should be equal
  582. assert torch.equal(
  583. input_tokens,
  584. input_positions,
  585. )
  586. # - Encoder
  587. assert len(encoder_input_tokens) == 0
  588. assert len(encoder_input_tokens) == 0
  589. # -- An indirect check that model_input.encoder_input_tokens
  590. # and model_input.encoder_input_positions are correct -
  591. # by design of the test, the input tokens are
  592. # equal to the input position values, so if
  593. # the model_input data structure has the correct
  594. # values then these two should be equal
  595. assert torch.equal(
  596. encoder_input_tokens,
  597. encoder_input_positions,
  598. )