test_model_runner.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. from array import array
  2. from typing import List
  3. import pytest
  4. import torch
  5. from aphrodite.common.sequence import (SamplingParams, SequenceData,
  6. SequenceGroupMetadata)
  7. from aphrodite.common.utils import get_open_port
  8. from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
  9. from aphrodite.distributed.parallel_state import (
  10. ensure_model_parallel_initialized, init_distributed_environment)
  11. from aphrodite.engine.args_tools import EngineArgs
  12. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  13. from aphrodite.task_handler.model_runner import (ModelRunner,
  14. _get_graph_batch_size)
  15. def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
  16. engine_args = EngineArgs(model, *args, **kwargs)
  17. engine_config = engine_args.create_engine_config()
  18. model_runner = ModelRunner(
  19. model_config=engine_config.model_config,
  20. parallel_config=engine_config.parallel_config,
  21. scheduler_config=engine_config.scheduler_config,
  22. device_config=engine_config.device_config,
  23. cache_config=engine_config.cache_config,
  24. load_config=engine_config.load_config,
  25. lora_config=engine_config.lora_config,
  26. prompt_adapter_config=engine_config.prompt_adapter_config,
  27. is_driver_worker=True,
  28. )
  29. return model_runner
  30. @pytest.mark.parametrize("batch_size", list(range(1, 257)))
  31. def test_prepare_prompt(batch_size):
  32. model_runner = _create_model_runner(
  33. "facebook/opt-125m",
  34. max_num_batched_tokens=100000,
  35. max_num_seqs=100000,
  36. enable_chunked_prefill=False,
  37. )
  38. seq_lens: List[int] = []
  39. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  40. block_tables = {0: [1]}
  41. for i in range(batch_size):
  42. # make sure all tokens fit into one block
  43. seq_len = i % (model_runner.block_size - 1) + 1
  44. seq_lens.append(seq_len)
  45. seq_data = SequenceData(array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  46. range(seq_len)))
  47. seq_group_metadata = SequenceGroupMetadata(
  48. request_id=f"test_{i}",
  49. is_prompt=True,
  50. seq_data={0: seq_data},
  51. sampling_params=SamplingParams(temperature=0),
  52. block_tables=block_tables,
  53. )
  54. assert seq_group_metadata.token_chunk_size == seq_data.get_len()
  55. seq_group_metadata_list.append(seq_group_metadata)
  56. expected_selected_token_indices = []
  57. selected_token_start_idx = 0
  58. for seq_len in seq_lens:
  59. expected_selected_token_indices.append(selected_token_start_idx +
  60. seq_len - 1)
  61. selected_token_start_idx += seq_len
  62. model_input = model_runner._prepare_model_input_tensors(
  63. seq_group_metadata_list)
  64. input_tokens = model_input.input_tokens
  65. input_positions = model_input.input_positions
  66. attn_metadata = model_input.attn_metadata
  67. return_seq_lens = model_input.seq_lens
  68. slot_mapping = attn_metadata.slot_mapping
  69. assert return_seq_lens == seq_lens
  70. assert len(slot_mapping) == len(input_tokens)
  71. # Verify input metadata is correct for prompts.
  72. device = model_runner.device
  73. assert attn_metadata.num_prefills > 0
  74. assert attn_metadata.num_decode_tokens == 0
  75. torch.testing.assert_close(
  76. attn_metadata.seq_lens_tensor,
  77. torch.tensor(seq_lens, device=device, dtype=torch.int))
  78. assert attn_metadata.seq_lens == seq_lens
  79. assert attn_metadata.max_prefill_seq_len == max(seq_lens)
  80. assert attn_metadata.max_decode_seq_len == 0
  81. # Test subquery start locs.
  82. start_idx = 0
  83. start_loc = [start_idx]
  84. for seq_len in seq_lens:
  85. start_idx += seq_len
  86. start_loc.append(start_idx)
  87. torch.testing.assert_close(
  88. attn_metadata.query_start_loc,
  89. torch.tensor(start_loc, dtype=torch.int32, device=device))
  90. # Test seq start locs. Note that for normal prefill it is
  91. # equivalent to query_start_loc.
  92. start_idx = 0
  93. seq_start_loc = [start_idx]
  94. for seq_len in seq_lens:
  95. start_idx += seq_len
  96. seq_start_loc.append(start_idx)
  97. torch.testing.assert_close(
  98. attn_metadata.seq_start_loc,
  99. torch.tensor(start_loc, dtype=torch.int32, device=device))
  100. torch.testing.assert_close(
  101. attn_metadata.context_lens_tensor,
  102. torch.zeros(attn_metadata.context_lens_tensor.shape[0],
  103. dtype=torch.int,
  104. device=device))
  105. expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))],
  106. dtype=torch.int32,
  107. device=model_runner.device)
  108. torch.testing.assert_close(attn_metadata.block_tables, expected)
  109. # Cuda graph should not be used for prerill.
  110. assert attn_metadata.use_cuda_graph is False
  111. assert len(input_tokens) == sum(seq_lens)
  112. assert len(input_positions) == sum(seq_lens)
  113. torch.testing.assert_close(input_tokens, input_positions)
  114. sampling_metadata = SamplingMetadata.prepare(
  115. seq_group_metadata_list,
  116. seq_lens,
  117. query_lens=seq_lens,
  118. device=model_runner.device,
  119. pin_memory=model_runner.pin_memory)
  120. assert len(input_tokens) == sum(seq_lens)
  121. assert len(input_positions) == sum(seq_lens)
  122. actual = sampling_metadata.selected_token_indices
  123. expected = torch.tensor(expected_selected_token_indices,
  124. device=actual.device,
  125. dtype=actual.dtype)
  126. torch.testing.assert_close(actual, expected)
  127. torch.allclose(input_tokens, input_positions)
  128. actual = sampling_metadata.selected_token_indices
  129. expected = torch.tensor(expected_selected_token_indices,
  130. device=actual.device,
  131. dtype=actual.dtype)
  132. torch.testing.assert_close(actual, expected)
  133. @pytest.mark.parametrize("batch_size", list(range(1, 257)))
  134. def test_prepare_decode_cuda_graph(batch_size):
  135. model_runner = _create_model_runner(
  136. "facebook/opt-125m",
  137. seed=0,
  138. dtype="float16",
  139. enforce_eager=False,
  140. max_num_batched_tokens=100000,
  141. max_num_seqs=100000,
  142. enable_chunked_prefill=False,
  143. )
  144. context_lens: List[int] = []
  145. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  146. # Assume each seq group finishes prefill.
  147. for i in range(batch_size):
  148. # make sure all tokens fit into one block
  149. context_len = i % (model_runner.block_size - 1) + 1
  150. context_lens.append(context_len)
  151. seq_data = SequenceData(
  152. array(APHRODITE_TOKEN_ID_ARRAY_TYPE, range(context_len)))
  153. seq_data.update_num_computed_tokens(context_len)
  154. # Append one token ID since prefill is finished.
  155. seq_data.append_token_id(1, 0)
  156. seq_group_metadata = SequenceGroupMetadata(
  157. request_id=f"test_{i}",
  158. is_prompt=False,
  159. seq_data={0: seq_data},
  160. sampling_params=SamplingParams(temperature=0),
  161. block_tables={0: [1]},
  162. )
  163. assert seq_group_metadata.token_chunk_size == 1
  164. seq_group_metadata_list.append(seq_group_metadata)
  165. model_input = model_runner._prepare_model_input_tensors(
  166. seq_group_metadata_list)
  167. input_tokens, input_positions, attn_metadata, slot_mapping = (
  168. model_input.input_tokens, model_input.input_positions,
  169. model_input.attn_metadata, model_input.attn_metadata.slot_mapping)
  170. assert len(slot_mapping) == len(input_tokens)
  171. expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
  172. # Verify input metadata is correct for prompts.
  173. device = model_runner.device
  174. assert attn_metadata.num_prefills == 0
  175. assert attn_metadata.num_prefill_tokens == 0
  176. seq_lens = [context_len + 1 for context_len in context_lens]
  177. # seq_lens are padded to expected_bs
  178. for _ in range(expected_bs - len(seq_lens)):
  179. seq_lens.append(1)
  180. assert attn_metadata.seq_lens == seq_lens
  181. assert attn_metadata.num_decode_tokens == len(seq_lens)
  182. start_idx = 0
  183. start_loc = [start_idx]
  184. for _ in context_lens:
  185. # decode has only 1 token for query.
  186. start_idx += 1
  187. start_loc.append(start_idx)
  188. torch.testing.assert_close(
  189. attn_metadata.query_start_loc,
  190. torch.tensor(start_loc, dtype=torch.int32, device=device))
  191. start_idx = 0
  192. seq_start_loc = [start_idx]
  193. for seq_len in seq_lens:
  194. start_idx += seq_len
  195. seq_start_loc.append(start_idx)
  196. torch.testing.assert_close(
  197. attn_metadata.seq_start_loc,
  198. torch.tensor(seq_start_loc, dtype=torch.int32, device=device))
  199. torch.testing.assert_close(
  200. attn_metadata.context_lens_tensor,
  201. torch.tensor(context_lens, dtype=torch.int, device=device))
  202. assert attn_metadata.max_decode_seq_len == max(seq_lens)
  203. torch.testing.assert_close(
  204. attn_metadata.seq_lens_tensor[:len(seq_lens)],
  205. torch.tensor(seq_lens, dtype=torch.int, device=device))
  206. # block table's first index corresponds to each batch, meaning in
  207. # decoding it is each token.
  208. assert attn_metadata.block_tables.shape[0] == len(input_tokens)
  209. # Block table's second dim correspondsd to each token's block number.
  210. # It is padded up to
  211. assert attn_metadata.block_tables.shape[1] == (
  212. model_runner.get_max_block_per_batch())
  213. assert attn_metadata.use_cuda_graph is True
  214. assert len(input_tokens) == expected_bs
  215. assert len(input_positions) == expected_bs
  216. torch.allclose(input_tokens, input_positions)
  217. # Verify Sampling
  218. expected_selected_token_indices = []
  219. selected_token_start_idx = 0
  220. for _ in context_lens:
  221. expected_selected_token_indices.append(selected_token_start_idx)
  222. selected_token_start_idx += 1
  223. sampling_metadata = SamplingMetadata.prepare(
  224. seq_group_metadata_list,
  225. seq_lens,
  226. # query lens is all 1 for decode.
  227. query_lens=[1 for _ in range(len(context_lens))],
  228. device=model_runner.device,
  229. pin_memory=model_runner.pin_memory)
  230. actual = sampling_metadata.selected_token_indices
  231. expected = torch.tensor(expected_selected_token_indices,
  232. device=actual.device,
  233. dtype=actual.dtype)
  234. torch.testing.assert_close(actual, expected)
  235. def test_empty_seq_group():
  236. """Verify prepare prompt and decode returns empty output."""
  237. model_runner = _create_model_runner(
  238. "facebook/opt-125m",
  239. seed=0,
  240. dtype="float16",
  241. enforce_eager=False,
  242. )
  243. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  244. model_input = model_runner._prepare_model_input_tensors(
  245. seq_group_metadata_list)
  246. input_tokens, input_positions, attn_metadata = (
  247. model_input.input_tokens,
  248. model_input.input_positions,
  249. model_input.attn_metadata,
  250. )
  251. assert input_tokens is None
  252. assert input_positions is None
  253. assert attn_metadata is None
  254. model_input = model_runner._prepare_model_input_tensors(
  255. seq_group_metadata_list)
  256. (input_tokens, input_positions, attn_metadata, return_seq_lens) = (
  257. model_input.input_tokens,
  258. model_input.input_positions,
  259. model_input.attn_metadata,
  260. model_input.seq_lens,
  261. )
  262. assert input_tokens is None
  263. assert input_positions is None
  264. assert attn_metadata is None
  265. assert return_seq_lens is None
  266. @pytest.fixture
  267. def distributed_init():
  268. init_distributed_environment(
  269. world_size=1,
  270. rank=0,
  271. distributed_init_method=f"tcp://127.0.0.1:{get_open_port()}",
  272. local_rank=0)
  273. ensure_model_parallel_initialized(1, 1)
  274. @pytest.mark.parametrize("batch_size", list(range(2, 128)))
  275. @pytest.mark.parametrize("enforce_eager", [True, False])
  276. def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
  277. model_runner = _create_model_runner(
  278. "facebook/opt-125m",
  279. seed=0,
  280. dtype="float16",
  281. enforce_eager=enforce_eager,
  282. max_num_batched_tokens=100000,
  283. max_num_seqs=100000,
  284. enable_chunked_prefill=True,
  285. )
  286. # Add prefill requests.
  287. seq_lens: List[int] = []
  288. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  289. prefill_metadata_list: List[SequenceGroupMetadata] = []
  290. decode_metadata_list: List[SequenceGroupMetadata] = []
  291. block_tables = {0: [1]}
  292. prefill_batch_size = batch_size // 2
  293. decode_batch_size = batch_size - prefill_batch_size
  294. for i in range(prefill_batch_size):
  295. # make sure all tokens fit into one block
  296. seq_len = i % (model_runner.block_size - 1) + 1
  297. seq_lens.append(seq_len)
  298. seq_data = SequenceData(array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  299. range(seq_len)))
  300. seq_group_metadata = SequenceGroupMetadata(
  301. request_id=f"test_{i}",
  302. is_prompt=True,
  303. seq_data={0: seq_data},
  304. sampling_params=SamplingParams(temperature=0),
  305. block_tables=block_tables,
  306. )
  307. assert seq_group_metadata.token_chunk_size == seq_data.get_len()
  308. seq_group_metadata_list.append(seq_group_metadata)
  309. prefill_metadata_list.append(seq_group_metadata)
  310. # Add decode requests
  311. for i in range(prefill_batch_size, batch_size):
  312. # make sure all tokens fit into one block
  313. context_len = i % (model_runner.block_size - 1) + 1
  314. prompt_toks = array(APHRODITE_TOKEN_ID_ARRAY_TYPE, range(context_len))
  315. seq_data = SequenceData(prompt_toks)
  316. seq_data.append_token_id(1, 0)
  317. seq_data.update_num_computed_tokens(context_len)
  318. seq_group_metadata = SequenceGroupMetadata(
  319. request_id=f"test_{i}",
  320. is_prompt=False,
  321. seq_data={0: seq_data},
  322. sampling_params=SamplingParams(temperature=0),
  323. block_tables={0: [1]},
  324. )
  325. assert seq_group_metadata.token_chunk_size == 1
  326. seq_group_metadata_list.append(seq_group_metadata)
  327. decode_metadata_list.append(seq_group_metadata)
  328. model_input = model_runner.prepare_model_input(seq_group_metadata_list)
  329. (input_tokens, input_positions, attn_metadata) = (
  330. model_input.input_tokens,
  331. model_input.input_positions,
  332. model_input.attn_metadata,
  333. )
  334. prefill_meta_actual = attn_metadata.prefill_metadata
  335. decode_meta_actual = attn_metadata.decode_metadata
  336. assert len(attn_metadata.slot_mapping) == len(input_tokens)
  337. assert len(input_positions) == len(input_tokens)
  338. assert attn_metadata.num_prefills == prefill_batch_size
  339. assert attn_metadata.num_decode_tokens == decode_batch_size
  340. assert attn_metadata.num_prefill_tokens == sum(seq_lens)
  341. # Verify attn metadata is consistent. We don't need to test individual
  342. # values here because they are tested above.
  343. attn_metadata = model_runner._prepare_model_input_tensors(
  344. seq_group_metadata_list).attn_metadata
  345. for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata),
  346. vars(prefill_meta_actual)):
  347. assert attr_expected[1] == attr_actual[1]
  348. for attr_expected, attr_actual in zip(vars(attn_metadata.decode_metadata),
  349. vars(decode_meta_actual)):
  350. assert attr_expected[1] == attr_actual[1]