test_spec_decode_worker.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821
  1. import random
  2. from collections import defaultdict
  3. from types import SimpleNamespace
  4. from typing import Dict, List, Set
  5. from unittest.mock import MagicMock
  6. import pytest
  7. import torch
  8. from aphrodite.common.sequence import ExecuteModelRequest, SequenceOutput
  9. from aphrodite.modeling.layers.sampler import SamplerOutput
  10. from aphrodite.modeling.utils import set_random_seed
  11. from aphrodite.spec_decode.interfaces import SpeculativeProposals
  12. from aphrodite.spec_decode.metrics import (AsyncMetricsCollector,
  13. SpecDecodeWorkerMetrics)
  14. from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
  15. from aphrodite.spec_decode.spec_decode_worker import (
  16. SpecDecodeWorker, split_num_cache_blocks_evenly)
  17. from .test_utils import mock_spec_decode_sampler
  18. from .utils import create_batch, create_sampler_output_list, mock_worker
  19. @pytest.mark.parametrize('k', [1, 2, 6])
  20. @pytest.mark.parametrize('batch_size', [1, 2, 32])
  21. @pytest.mark.parametrize("acceptance_sampler_method",
  22. ["rejection_sampler", "typical_acceptance_sampler"])
  23. @torch.inference_mode()
  24. def test_correctly_calls_draft_model(k: int, batch_size: int,
  25. acceptance_sampler_method: str):
  26. """Verify SpecDecodeWorker calls the draft worker with correct
  27. inputs. Everything else is mocked out.
  28. """
  29. draft_worker = mock_worker(cls=MultiStepWorker)
  30. target_worker = mock_worker()
  31. metrics_collector = MagicMock(spec=AsyncMetricsCollector)
  32. worker = SpecDecodeWorker(
  33. draft_worker,
  34. target_worker,
  35. mock_spec_decode_sampler(acceptance_sampler_method),
  36. disable_logprobs=False,
  37. metrics_collector=metrics_collector)
  38. exception_secret = 'artificial stop'
  39. draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
  40. seq_group_metadata_list, _, _ = create_batch(batch_size, k)
  41. execute_model_req = ExecuteModelRequest(
  42. seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
  43. with pytest.raises(ValueError, match=exception_secret):
  44. worker.execute_model(execute_model_req=execute_model_req)
  45. call_args_list = draft_worker.get_spec_proposals.call_args_list
  46. assert len(call_args_list) == 1
  47. for args, _ in call_args_list:
  48. actual_execute_model_data = args[0]
  49. assert actual_execute_model_data == execute_model_req
  50. @pytest.mark.parametrize('k', [1, 2, 6])
  51. @pytest.mark.parametrize('batch_size', [1, 2, 32])
  52. @pytest.mark.parametrize("acceptance_sampler_method",
  53. ["rejection_sampler", "typical_acceptance_sampler"])
  54. @torch.inference_mode()
  55. def test_correctly_calls_target_model(k: int, batch_size: int,
  56. acceptance_sampler_method: str):
  57. """Verify SpecDecodeWorker calls the target model with correct
  58. inputs. Everything else is mocked out.
  59. """
  60. draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
  61. target_worker = mock_worker(use_spec=False)
  62. metrics_collector = MagicMock(spec=AsyncMetricsCollector)
  63. draft_worker.device = 'cuda'
  64. target_worker.device = 'cuda'
  65. set_random_seed(1)
  66. worker = SpecDecodeWorker(
  67. draft_worker,
  68. target_worker,
  69. mock_spec_decode_sampler(acceptance_sampler_method),
  70. disable_logprobs=False,
  71. metrics_collector=metrics_collector)
  72. worker.init_device()
  73. vocab_size = 32_000
  74. proposal_token_ids = torch.randint(low=0,
  75. high=vocab_size,
  76. size=(batch_size, k),
  77. dtype=torch.int64,
  78. device='cuda')
  79. proposal_probs = torch.rand(batch_size,
  80. k,
  81. vocab_size,
  82. dtype=torch.float32,
  83. device='cuda')
  84. proposal_lens = torch.ones(batch_size, dtype=torch.int64,
  85. device='cuda') * k
  86. seq_group_metadata_list, prompts, prev_output_tokens = create_batch(
  87. batch_size, k)
  88. draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
  89. proposal_token_ids=proposal_token_ids,
  90. proposal_probs=proposal_probs,
  91. proposal_lens=proposal_lens)
  92. exception_secret = 'artificial stop'
  93. target_worker.execute_model.side_effect = ValueError(exception_secret)
  94. with pytest.raises(ValueError, match=exception_secret):
  95. worker.execute_model(execute_model_req=ExecuteModelRequest(
  96. seq_group_metadata_list=seq_group_metadata_list,
  97. num_lookahead_slots=k))
  98. seen_contexts: List[List[int]] = []
  99. call_args_list = target_worker.execute_model.call_args_list
  100. assert len(call_args_list) == 1
  101. for _, kwargs in call_args_list:
  102. seq_group_metadata_list = kwargs[
  103. "execute_model_req"].seq_group_metadata_list
  104. assert len(seq_group_metadata_list) == (k + 1) * batch_size
  105. for seq_group_metadata in seq_group_metadata_list:
  106. for seq_data in seq_group_metadata.seq_data.values():
  107. seen_contexts.append(seq_data.get_token_ids())
  108. expected_seen_contexts: List[List[int]] = []
  109. for prompt, prev_generated, draft_tokens in zip(
  110. prompts, prev_output_tokens, proposal_token_ids.tolist()):
  111. for i in range(len(draft_tokens) + 1):
  112. expected_seen_contexts.append(prompt + prev_generated +
  113. draft_tokens[:i])
  114. seen_contexts.sort()
  115. expected_seen_contexts.sort()
  116. assert expected_seen_contexts == seen_contexts
  117. @pytest.mark.parametrize('k', [1, 2, 6])
  118. @pytest.mark.parametrize('batch_size', [1, 2, 32])
  119. @pytest.mark.parametrize("acceptance_sampler_method",
  120. ["rejection_sampler", "typical_acceptance_sampler"])
  121. @torch.inference_mode()
  122. def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
  123. acceptance_sampler_method: str):
  124. """Verify SpecDecodeWorker calls the rejection sampler with
  125. correct inputs. Everything else is mocked out.
  126. """
  127. vocab_size = 32_000
  128. draft_worker = mock_worker(cls=MultiStepWorker,
  129. vocab_size=vocab_size,
  130. use_spec=False)
  131. target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
  132. spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
  133. metrics_collector = MagicMock(spec=AsyncMetricsCollector)
  134. draft_worker.device = 'cuda'
  135. target_worker.device = 'cuda'
  136. set_random_seed(1)
  137. worker = SpecDecodeWorker(draft_worker,
  138. target_worker,
  139. spec_decode_sampler,
  140. disable_logprobs=False,
  141. metrics_collector=metrics_collector)
  142. worker.init_device()
  143. proposal_token_ids = torch.randint(low=0,
  144. high=vocab_size,
  145. size=(batch_size, k),
  146. dtype=torch.int64,
  147. device='cuda')
  148. proposal_probs = torch.rand(batch_size,
  149. k,
  150. vocab_size,
  151. dtype=torch.float32,
  152. device='cuda')
  153. proposal_lens = torch.ones(batch_size, dtype=torch.int64,
  154. device='cuda') * k
  155. seq_group_metadata_list, _, _ = create_batch(batch_size, k)
  156. draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
  157. proposal_token_ids=proposal_token_ids,
  158. proposal_probs=proposal_probs,
  159. proposal_lens=proposal_lens)
  160. target_token_ids = torch.randint(low=0,
  161. high=vocab_size,
  162. size=(1, batch_size * (k + 1)),
  163. dtype=torch.int64,
  164. device='cuda')
  165. target_token_probs = torch.rand(1,
  166. batch_size * (k + 1),
  167. vocab_size,
  168. dtype=torch.float32,
  169. device='cuda')
  170. target_token_logprobs = torch.rand(1,
  171. batch_size * (k + 1),
  172. vocab_size,
  173. dtype=torch.float32,
  174. device='cuda')
  175. target_output = create_sampler_output_list(target_token_ids,
  176. target_token_probs,
  177. target_token_logprobs)
  178. target_worker.execute_model.return_value = [target_output[0]]
  179. exception_secret = 'artificial stop'
  180. spec_decode_sampler.side_effect = ValueError(exception_secret)
  181. with pytest.raises(ValueError, match=exception_secret):
  182. worker.execute_model(execute_model_req=ExecuteModelRequest(
  183. seq_group_metadata_list=seq_group_metadata_list,
  184. num_lookahead_slots=k))
  185. assert len(spec_decode_sampler.call_args_list) == 1
  186. _, kwargs = spec_decode_sampler.call_args_list[0]
  187. actual = SimpleNamespace(**kwargs)
  188. assert torch.equal(actual.bonus_token_ids,
  189. target_token_ids.reshape(batch_size, k + 1)[:, -1:])
  190. assert torch.equal(
  191. actual.target_probs,
  192. target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1])
  193. assert torch.equal(actual.draft_token_ids, proposal_token_ids)
  194. assert torch.equal(actual.draft_probs, proposal_probs)
  195. @pytest.mark.parametrize('k', [1, 2, 6])
  196. @pytest.mark.parametrize('batch_size', [1, 2, 32])
  197. @pytest.mark.parametrize("acceptance_sampler_method",
  198. ["rejection_sampler", "typical_acceptance_sampler"])
  199. @torch.inference_mode()
  200. def test_correctly_formats_output(k: int, batch_size: int,
  201. acceptance_sampler_method: str):
  202. """Verify SpecDecodeWorker formats sampler output correctly.
  203. Everything else is mocked out.
  204. """
  205. vocab_size = 32_000
  206. draft_worker = mock_worker(cls=MultiStepWorker,
  207. vocab_size=vocab_size,
  208. use_spec=False)
  209. target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
  210. metrics_collector = MagicMock(spec=AsyncMetricsCollector)
  211. draft_worker.device = 'cuda'
  212. target_worker.device = 'cuda'
  213. set_random_seed(1)
  214. spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
  215. worker = SpecDecodeWorker(draft_worker,
  216. target_worker,
  217. spec_decode_sampler,
  218. disable_logprobs=False,
  219. metrics_collector=metrics_collector)
  220. worker.init_device()
  221. proposal_token_ids = torch.randint(low=0,
  222. high=vocab_size,
  223. size=(batch_size, k),
  224. dtype=torch.int64,
  225. device='cuda')
  226. proposal_probs = torch.rand(batch_size,
  227. k,
  228. vocab_size,
  229. dtype=torch.float32,
  230. device='cuda')
  231. proposal_lens = torch.ones(batch_size, dtype=torch.int64,
  232. device='cuda') * k
  233. seq_group_metadata_list, _, _ = create_batch(batch_size, k)
  234. draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
  235. proposal_token_ids=proposal_token_ids,
  236. proposal_probs=proposal_probs,
  237. proposal_lens=proposal_lens)
  238. target_token_ids = torch.randint(low=0,
  239. high=vocab_size,
  240. size=(1, batch_size * (k + 1)),
  241. dtype=torch.int64,
  242. device='cuda')
  243. target_token_probs = torch.rand(1,
  244. batch_size * (k + 1),
  245. vocab_size,
  246. dtype=torch.float32,
  247. device='cuda')
  248. target_token_logprobs = torch.rand(1,
  249. batch_size * (k + 1),
  250. vocab_size,
  251. dtype=torch.float32,
  252. device='cuda')
  253. target_output = create_sampler_output_list(target_token_ids,
  254. target_token_probs,
  255. target_token_logprobs)
  256. target_worker.execute_model.return_value = [target_output[0]]
  257. spec_decode_sampler_output = torch.randint(low=0,
  258. high=vocab_size,
  259. size=(batch_size, k + 1),
  260. dtype=torch.int64,
  261. device='cuda')
  262. for i in range(batch_size):
  263. minimum_accepted_tokens = 1
  264. spec_decode_sampler_output[i][
  265. -random.randint(minimum_accepted_tokens, k + 1):] = -1
  266. spec_decode_sampler.return_value = spec_decode_sampler_output
  267. output = worker.execute_model(execute_model_req=ExecuteModelRequest(
  268. seq_group_metadata_list=seq_group_metadata_list,
  269. num_lookahead_slots=k))
  270. expected_output = create_sampler_output_list(
  271. token_ids=spec_decode_sampler_output.transpose(0, 1),
  272. probs=[None for _ in range(k + 1)],
  273. logprobs=[None for _ in range(k + 1)])
  274. seq_ids = [
  275. next(iter(seq_group_metadata.seq_data.keys()))
  276. for seq_group_metadata in seq_group_metadata_list
  277. ]
  278. actual_output_by_seq: Dict[int, List[SequenceOutput]] = {
  279. seq_id: []
  280. for seq_id in seq_ids
  281. }
  282. expected_output_by_seq: Dict[int, List[SequenceOutput]] = {
  283. seq_id: []
  284. for seq_id in seq_ids
  285. }
  286. for step in output:
  287. for seq_group in step:
  288. for sample in seq_group.samples:
  289. seq_id = sample.parent_seq_id
  290. actual_output_by_seq[seq_id].append(sample)
  291. for step in expected_output:
  292. for seq_group in step:
  293. for sample in seq_group.samples:
  294. seq_id = sample.parent_seq_id
  295. expected_output_by_seq[seq_id].append(sample)
  296. all_seen_seq_ids = set(
  297. list(actual_output_by_seq.keys()) +
  298. list(expected_output_by_seq.keys()))
  299. for seq_id in all_seen_seq_ids:
  300. actual_by_step = actual_output_by_seq[seq_id]
  301. expected_by_step = expected_output_by_seq[seq_id]
  302. for i in range(k + 1):
  303. if i >= len(actual_by_step):
  304. assert expected_by_step[i].output_token == -1
  305. continue
  306. assert actual_by_step[i].output_token == expected_by_step[
  307. i].output_token
  308. @pytest.mark.parametrize('k', [1, 2])
  309. @pytest.mark.parametrize('batch_size', [1])
  310. @pytest.mark.parametrize('returns_metrics', [True, False])
  311. @pytest.mark.parametrize("acceptance_sampler_method",
  312. ["rejection_sampler", "typical_acceptance_sampler"])
  313. @torch.inference_mode()
  314. def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool,
  315. acceptance_sampler_method: str):
  316. """Verify SpecDecodeWorker collects metrics.
  317. """
  318. vocab_size = 32_000
  319. draft_worker = mock_worker(cls=MultiStepWorker,
  320. vocab_size=vocab_size,
  321. use_spec=False)
  322. target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
  323. spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
  324. metrics_collector = MagicMock(spec=AsyncMetricsCollector)
  325. draft_worker.device = 'cuda'
  326. target_worker.device = 'cuda'
  327. set_random_seed(1)
  328. worker = SpecDecodeWorker(draft_worker,
  329. target_worker,
  330. spec_decode_sampler,
  331. disable_logprobs=False,
  332. metrics_collector=metrics_collector)
  333. worker.init_device()
  334. proposal_token_ids = torch.randint(low=0,
  335. high=vocab_size,
  336. size=(batch_size, k),
  337. dtype=torch.int64,
  338. device='cuda')
  339. proposal_probs = torch.rand(batch_size,
  340. k,
  341. vocab_size,
  342. dtype=torch.float32,
  343. device='cuda')
  344. proposal_lens = torch.ones(batch_size, dtype=torch.int64,
  345. device='cuda') * k
  346. seq_group_metadata_list, _, _ = create_batch(batch_size, k)
  347. draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
  348. proposal_token_ids=proposal_token_ids,
  349. proposal_probs=proposal_probs,
  350. proposal_lens=proposal_lens)
  351. target_token_ids = torch.randint(low=0,
  352. high=vocab_size,
  353. size=(1, batch_size * (k + 1)),
  354. dtype=torch.int64,
  355. device='cuda')
  356. target_token_probs = torch.rand(1,
  357. batch_size * (k + 1),
  358. vocab_size,
  359. dtype=torch.float32,
  360. device='cuda')
  361. target_token_logprobs = torch.rand(1,
  362. batch_size * (k + 1),
  363. vocab_size,
  364. dtype=torch.float32,
  365. device='cuda')
  366. target_output = create_sampler_output_list(target_token_ids,
  367. target_token_probs,
  368. target_token_logprobs)
  369. target_worker.execute_model.return_value = [target_output[0]]
  370. spec_decode_sampler_output = torch.randint(low=0,
  371. high=vocab_size,
  372. size=(batch_size, k + 1),
  373. dtype=torch.int64,
  374. device='cuda')
  375. for i in range(batch_size):
  376. minimum_accepted_tokens = 1
  377. spec_decode_sampler_output[i][
  378. -random.randint(minimum_accepted_tokens, k + 1):] = -1
  379. spec_decode_sampler.return_value = spec_decode_sampler_output
  380. mock_rejsample_metrics = MagicMock(
  381. spec=SpecDecodeWorkerMetrics) if returns_metrics else None
  382. metrics_collector.maybe_collect_rejsample_metrics.return_value = (
  383. mock_rejsample_metrics)
  384. output = worker.execute_model(execute_model_req=ExecuteModelRequest(
  385. seq_group_metadata_list=seq_group_metadata_list,
  386. num_lookahead_slots=k))
  387. assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics
  388. call_args_list = (
  389. metrics_collector.maybe_collect_rejsample_metrics.call_args_list)
  390. assert len(call_args_list) == 1
  391. args, kwargs = call_args_list[0]
  392. assert args[0] == k or kwargs.get('k', -1) == k
  393. @pytest.mark.parametrize('k', [0])
  394. @pytest.mark.parametrize('batch_size', [1, 2, 32])
  395. @pytest.mark.parametrize("acceptance_sampler_method",
  396. ["rejection_sampler", "typical_acceptance_sampler"])
  397. @torch.inference_mode()
  398. def test_k_equals_zero(k: int, batch_size: int,
  399. acceptance_sampler_method: str):
  400. """Verify that the SpecDecodeWorker calls the draft and target workers
  401. when k is zero. This happens during prefill.
  402. """
  403. draft_worker = mock_worker(cls=MultiStepWorker)
  404. target_worker = mock_worker()
  405. metrics_collector = MagicMock(spec=AsyncMetricsCollector)
  406. sampler_output = MagicMock(spec=SamplerOutput)
  407. sampler_output.hidden_states = None
  408. target_worker.execute_model.return_value = [sampler_output]
  409. draft_worker.device = 'cuda'
  410. target_worker.device = 'cuda'
  411. set_random_seed(1)
  412. worker = SpecDecodeWorker(
  413. proposer_worker=draft_worker,
  414. scorer_worker=target_worker,
  415. spec_decode_sampler=mock_spec_decode_sampler(
  416. acceptance_sampler_method),
  417. disable_logprobs=False,
  418. metrics_collector=metrics_collector,
  419. )
  420. seq_group_metadata_list, _, _ = create_batch(batch_size,
  421. k,
  422. prev_output_token_len=0)
  423. execute_model_req = ExecuteModelRequest(
  424. seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
  425. out = worker.execute_model(execute_model_req=execute_model_req)
  426. assert len(out) == 1, f"expected only one token output when {k=}"
  427. assert out[0].sampled_token_probs is None, (
  428. "expect gpu tensor references to be None")
  429. assert out[
  430. 0].sampled_token_ids is None, "expect gpu tensor references to be None"
  431. draft_worker.execute_model.assert_called_once_with(execute_model_req)
  432. target_worker.execute_model.assert_called_once_with(execute_model_req)
  433. @pytest.mark.parametrize('k', [0, 5])
  434. @pytest.mark.parametrize('batch_size', [0])
  435. @pytest.mark.parametrize("acceptance_sampler_method",
  436. ["rejection_sampler", "typical_acceptance_sampler"])
  437. @torch.inference_mode()
  438. def test_empty_input_batch(k: int, batch_size: int,
  439. acceptance_sampler_method: str):
  440. """Verify that the SpecDecodeWorker calls the draft and target workers
  441. when the input batch is empty. This can happen if the engine communicates
  442. to the workers information without scheduling a batch.
  443. """
  444. draft_worker = mock_worker(cls=MultiStepWorker)
  445. target_worker = mock_worker()
  446. metrics_collector = MagicMock(spec=AsyncMetricsCollector)
  447. sampler_output = MagicMock(spec=SamplerOutput)
  448. sampler_output.hidden_states = None
  449. target_worker.execute_model.return_value = [sampler_output]
  450. draft_worker.device = 'cuda'
  451. target_worker.device = 'cuda'
  452. set_random_seed(1)
  453. worker = SpecDecodeWorker(
  454. proposer_worker=draft_worker,
  455. scorer_worker=target_worker,
  456. spec_decode_sampler=mock_spec_decode_sampler(
  457. acceptance_sampler_method),
  458. disable_logprobs=False,
  459. metrics_collector=metrics_collector,
  460. )
  461. seq_group_metadata_list, _, _ = create_batch(batch_size,
  462. k,
  463. prev_output_token_len=0)
  464. execute_model_req = ExecuteModelRequest(
  465. seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
  466. out = worker.execute_model(execute_model_req=execute_model_req)
  467. assert len(out) == 1, f"expected only one token output when {k=}"
  468. assert out[0].sampled_token_probs is None, (
  469. "expect gpu tensor references to be None")
  470. assert out[
  471. 0].sampled_token_ids is None, "expect gpu tensor references to be None"
  472. draft_worker.execute_model.assert_called_once_with(execute_model_req)
  473. target_worker.execute_model.assert_called_once_with(execute_model_req)
  474. @pytest.mark.parametrize("acceptance_sampler_method",
  475. ["rejection_sampler", "typical_acceptance_sampler"])
  476. @pytest.mark.skip_global_cleanup
  477. def test_init_device(acceptance_sampler_method: str):
  478. """Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
  479. well as other GPU initialization.
  480. """
  481. draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
  482. target_worker = mock_worker(use_spec=False)
  483. spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
  484. metrics_collector = MagicMock(spec=AsyncMetricsCollector)
  485. worker = SpecDecodeWorker(
  486. proposer_worker=draft_worker,
  487. scorer_worker=target_worker,
  488. spec_decode_sampler=spec_decode_sampler,
  489. disable_logprobs=False,
  490. metrics_collector=metrics_collector,
  491. )
  492. worker.init_device()
  493. draft_worker.init_device.assert_called_once()
  494. target_worker.init_device.assert_called_once()
  495. metrics_collector.init_gpu_tensors.assert_called_once()
  496. spec_decode_sampler.init_gpu_tensors.assert_called_once()
  497. @pytest.mark.parametrize("acceptance_sampler_method",
  498. ["rejection_sampler", "typical_acceptance_sampler"])
  499. @torch.inference_mode()
  500. def test_initialize_cache(acceptance_sampler_method):
  501. """Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer
  502. workers.
  503. """
  504. draft_worker = mock_worker(cls=MultiStepWorker)
  505. target_worker = mock_worker()
  506. metrics_collector = MagicMock(spec=AsyncMetricsCollector)
  507. worker = SpecDecodeWorker(proposer_worker=draft_worker,
  508. scorer_worker=target_worker,
  509. spec_decode_sampler=mock_spec_decode_sampler(
  510. acceptance_sampler_method),
  511. metrics_collector=metrics_collector)
  512. kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023}
  513. worker.initialize_cache(**kwargs)
  514. draft_worker.initialize_cache.assert_called_once_with(**kwargs)
  515. target_worker.initialize_cache.assert_called_once_with(**kwargs)
  516. @pytest.mark.parametrize('available_gpu_blocks', [1, 1024])
  517. @pytest.mark.parametrize('available_cpu_blocks', [500])
  518. @pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096])
  519. @pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
  520. @pytest.mark.parametrize("acceptance_sampler_method",
  521. ["rejection_sampler", "typical_acceptance_sampler"])
  522. @pytest.mark.skip_global_cleanup
  523. def test_determine_num_available_blocks(available_gpu_blocks: int,
  524. available_cpu_blocks: int,
  525. target_cache_block_size_bytes: int,
  526. draft_kv_size_bytes: int,
  527. acceptance_sampler_method: str):
  528. """Verify SpecDecodeWorker correctly profiles num available GPU blocks.
  529. Specifically, it should run profiling in the scorer worker, and then evenly
  530. split the blocks between proposer and scorer worker.
  531. """
  532. draft_worker = mock_worker(cls=MultiStepWorker)
  533. target_worker = mock_worker()
  534. metrics_collector = MagicMock(spec=AsyncMetricsCollector)
  535. target_worker.determine_num_available_blocks.return_value = (
  536. available_gpu_blocks, available_cpu_blocks)
  537. target_worker.get_cache_block_size_bytes.return_value = (
  538. target_cache_block_size_bytes)
  539. draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes
  540. worker = SpecDecodeWorker(
  541. draft_worker, target_worker,
  542. mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
  543. num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks()
  544. target_worker.determine_num_available_blocks.assert_called_once()
  545. assert num_cpu_blocks == available_cpu_blocks
  546. assert num_gpu_blocks == split_num_cache_blocks_evenly(
  547. target_cache_block_size_bytes, draft_kv_size_bytes,
  548. available_gpu_blocks)
  549. @pytest.mark.parametrize('available_gpu_blocks',
  550. list(range(20)) + [1024, 1024**2])
  551. @pytest.mark.parametrize('target_cache_block_size_bytes',
  552. [2 * 2 * 4096, 2 * 2 * 8192])
  553. @pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
  554. @pytest.mark.skip_global_cleanup
  555. def test_split_num_cache_blocks_evenly(available_gpu_blocks: int,
  556. target_cache_block_size_bytes: int,
  557. draft_kv_size_bytes: int):
  558. """Verify split_num_cache_blocks_evenly does not exceed original memory
  559. allocation in bytes.
  560. """
  561. num_blocks = split_num_cache_blocks_evenly(target_cache_block_size_bytes,
  562. draft_kv_size_bytes,
  563. available_gpu_blocks)
  564. assert (num_blocks * target_cache_block_size_bytes) + (
  565. num_blocks * draft_kv_size_bytes) <= (available_gpu_blocks *
  566. target_cache_block_size_bytes)
  567. @torch.inference_mode()
  568. def test_populate_seq_ids_with_bonus_tokens():
  569. """
  570. Verify that a call to _create_output_sampler_list correctly updates
  571. seq_with_bonus_token_in_last_step.
  572. seq_with_bonus_token_in_last_step is an internal data structure in
  573. SpecDecodeWorker that tracks the sequence IDs which are assigned bonus
  574. tokens by the target model in their last forward pass. This state is
  575. maintained only for models relying on the KV cache, such as those using
  576. the MultiStepWorker.
  577. """
  578. batch_size = 10
  579. k = 5
  580. vocab_size = 10000
  581. num_sequences_with_bonus_tokens = 5
  582. target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
  583. metrics_collector = MagicMock(spec=AsyncMetricsCollector)
  584. target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
  585. target_worker.device = 'cuda'
  586. set_random_seed(1)
  587. draft_worker = mock_worker(cls=MultiStepWorker)
  588. draft_worker.device = 'cuda'
  589. # The sequence_ids attached to each sequence in the batch.
  590. # The sequence at index i has seq_id assigned_seq_ids[i]
  591. assigned_seq_ids = list(range(batch_size))
  592. seq_group_metadata_list, _, _ = create_batch(batch_size,
  593. k,
  594. seq_ids=assigned_seq_ids,
  595. prev_output_token_len=10)
  596. target_token_logprobs = torch.rand(batch_size, (k + 1),
  597. vocab_size,
  598. dtype=torch.float32,
  599. device='cuda')
  600. accepted_token_ids = torch.randint(low=0,
  601. high=vocab_size,
  602. size=(batch_size, (k + 1)),
  603. dtype=torch.int64,
  604. device='cuda')
  605. expected_request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
  606. for seq_group_metadata in seq_group_metadata_list:
  607. for seq_id in seq_group_metadata.seq_data:
  608. expected_request_id_seq_ids_mapping[
  609. seq_group_metadata.request_id].add(seq_id)
  610. # Generate a random sample of sequence indexes with bonus tokens
  611. seq_indexes_with_bonus_tokens = random.sample(
  612. range(batch_size), num_sequences_with_bonus_tokens)
  613. # Create a mask that is True for indices in seq_indexes_with_bonus_tokens
  614. mask = torch.ones(batch_size, dtype=torch.bool, device='cuda')
  615. mask[seq_indexes_with_bonus_tokens] = False
  616. # Set the last token ID to -1 for all indices not in
  617. # seq_indexes_with_bonus_tokens to indicate the lack of bonus token in
  618. # those indices.
  619. accepted_token_ids[mask, -1:] = -1
  620. worker = SpecDecodeWorker(draft_worker,
  621. target_worker,
  622. mock_spec_decode_sampler("rejection_sampler"),
  623. disable_logprobs=False,
  624. metrics_collector=metrics_collector)
  625. # Initialize _seq_with_bonus_token_in_last_step with a set of sequence IDs.
  626. # This set includes all sequence IDs in the batch as well as an additional
  627. # `num_extra_sequence_ids` sequence IDs. Note that the sequence IDs are in
  628. # the range [0, batch_size + num_extra_sequence_ids).
  629. num_extra_sequence_ids = 10
  630. worker._seq_with_bonus_token_in_last_step = set(
  631. range(batch_size + num_extra_sequence_ids))
  632. worker._create_output_sampler_list(
  633. seq_group_metadata_list=seq_group_metadata_list,
  634. accepted_token_ids=accepted_token_ids,
  635. target_logprobs=target_token_logprobs,
  636. k=k,
  637. stage_times=(0, 0, 0))
  638. # Verify that _seq_with_bonus_token_in_last_step contains the following:
  639. # 1. Sequence IDs that were already present in
  640. # _seq_with_bonus_token_in_last_step but were not part of the current
  641. # batch are retained.
  642. # 2. Of the sequence IDs present in the current batch, only those with a
  643. # bonus token are retained in _seq_with_bonus_token_in_last_step.
  644. # Sequence IDs that are present in the current batch but do not have
  645. # bonus tokens are removed from _seq_with_bonus_token_in_last_step.
  646. expected_seq_ids_with_bonus_tokens = \
  647. set([assigned_seq_ids[i] for i in seq_indexes_with_bonus_tokens])
  648. additional_sequence_ids = \
  649. set(range(batch_size, batch_size + num_extra_sequence_ids))
  650. assert worker._seq_with_bonus_token_in_last_step == \
  651. expected_seq_ids_with_bonus_tokens.union(additional_sequence_ids)
  652. assert worker._request_id_seq_id_mapping == \
  653. expected_request_id_seq_ids_mapping
  654. @torch.inference_mode()
  655. def test_handle_finished_requests():
  656. """
  657. Test to verify that finished request IDs are appropriately processed to
  658. update the internal state of the SpecDecodeWorker.
  659. This test initializes the SpecDecodeWorker with mock data, marks certain
  660. requests as finished, and ensures that the corresponding sequence IDs are
  661. correctly removed from the internal mappings.
  662. """
  663. batch_size = 32
  664. k = 3
  665. draft_worker = mock_worker(cls=MultiStepWorker)
  666. target_worker = mock_worker()
  667. metrics_collector = MagicMock(spec=AsyncMetricsCollector)
  668. worker = SpecDecodeWorker(draft_worker, target_worker,
  669. mock_spec_decode_sampler("rejection_sampler"),
  670. metrics_collector)
  671. # Initialize the request_id_seq_id_mapping mapping dict with a few fake
  672. # request ids and corresponding sequence ids.
  673. worker._request_id_seq_id_mapping = \
  674. {'request-1': {1,2,3}, 'request-2': {4,5,6,7},
  675. 'request-3': {8,9}, 'request-4': {10,11}}
  676. # Initialize seq_with_bonus_token_in_last_step with a few fake
  677. # sequence ids.
  678. worker._seq_with_bonus_token_in_last_step = {1, 4, 5, 8, 9, 10}
  679. exception_secret = 'artificial stop'
  680. draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
  681. seq_group_metadata_list, _, _ = create_batch(batch_size, k)
  682. # Mark requests with ids request-1 and request-3 as finished.
  683. execute_model_req = ExecuteModelRequest(
  684. seq_group_metadata_list=seq_group_metadata_list,
  685. num_lookahead_slots=k,
  686. finished_requests_ids=['request-1', 'request-3'])
  687. with pytest.raises(ValueError, match=exception_secret):
  688. worker.execute_model(execute_model_req=execute_model_req)
  689. # Verify that request-1 and request-3 are removed from
  690. # request_id_seq_id_mapping
  691. assert worker._request_id_seq_id_mapping == \
  692. {'request-2': {4,5,6,7}, 'request-4': {10,11}}
  693. # Verify that all sequence ids corresponding to 'request-1'
  694. # and 'request-3' are removed from seq_with_bonus_token_in_last_step.
  695. assert worker._seq_with_bonus_token_in_last_step == \
  696. {4,5,10}