test_dynamic_spec_decode.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. from unittest.mock import MagicMock, patch
  2. import pytest
  3. import torch
  4. from aphrodite.common.sequence import ExecuteModelRequest
  5. from aphrodite.spec_decode.metrics import AsyncMetricsCollector
  6. from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
  7. from aphrodite.spec_decode.spec_decode_worker import SpecDecodeWorker
  8. from aphrodite.spec_decode.top1_proposer import Top1Proposer
  9. from .test_utils import mock_spec_decode_sampler
  10. from .utils import create_batch, mock_worker
  11. @pytest.mark.parametrize('queue_size', [4])
  12. @pytest.mark.parametrize('batch_size', [1])
  13. @pytest.mark.parametrize('k', [1])
  14. @pytest.mark.parametrize("acceptance_sampler_method",
  15. ["rejection_sampler", "typical_acceptance_sampler"])
  16. @torch.inference_mode()
  17. def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int,
  18. acceptance_sampler_method: str):
  19. """Verify that speculative tokens are disabled when the batch size
  20. exceeds the threshold.
  21. """
  22. disable_by_batch_size = 3
  23. draft_worker = mock_worker(cls=MultiStepWorker)
  24. target_worker = mock_worker()
  25. metrics_collector = MagicMock(spec=AsyncMetricsCollector)
  26. worker = SpecDecodeWorker(proposer_worker=draft_worker,
  27. scorer_worker=target_worker,
  28. spec_decode_sampler=mock_spec_decode_sampler(
  29. acceptance_sampler_method),
  30. disable_logprobs=False,
  31. metrics_collector=metrics_collector,
  32. disable_by_batch_size=disable_by_batch_size)
  33. exception_secret = 'artificial stop'
  34. draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
  35. seq_group_metadata_list, _, _ = create_batch(batch_size, k)
  36. execute_model_req = ExecuteModelRequest(
  37. seq_group_metadata_list=seq_group_metadata_list,
  38. num_lookahead_slots=k,
  39. running_queue_size=queue_size)
  40. if queue_size > disable_by_batch_size:
  41. with patch.object(worker,
  42. '_run_no_spec',
  43. side_effect=ValueError(exception_secret)), \
  44. pytest.raises(ValueError, match=exception_secret):
  45. worker.execute_model(execute_model_req=execute_model_req)
  46. # When the batch size is larger than the threshold,
  47. # we expect no speculative tokens (0).
  48. expected_num_spec_tokens = None if queue_size < disable_by_batch_size else 0
  49. assert seq_group_metadata_list[
  50. 0].num_speculative_tokens == expected_num_spec_tokens
  51. draft_worker.sampler_output.side_effect = ValueError(exception_secret)
  52. proposer = Top1Proposer(
  53. worker=draft_worker,
  54. device='cpu', # not used
  55. vocab_size=100, # not used
  56. # Must be long enough to avoid being skipped due to length.
  57. max_proposal_len=1024,
  58. )
  59. if queue_size < disable_by_batch_size:
  60. # Should raise exception when executing the mocked draft model.
  61. with pytest.raises(ValueError, match=exception_secret):
  62. proposer.get_spec_proposals(
  63. execute_model_req=ExecuteModelRequest(
  64. seq_group_metadata_list=seq_group_metadata_list,
  65. num_lookahead_slots=k),
  66. seq_ids_with_bonus_token_in_last_step=set())
  67. else:
  68. # Should not execute the draft model because spec decode is disabled
  69. # for all requests. Accordingly, the proposal length should be 0.
  70. proposals = proposer.get_spec_proposals(
  71. execute_model_req=ExecuteModelRequest(
  72. seq_group_metadata_list=seq_group_metadata_list,
  73. num_lookahead_slots=k),
  74. seq_ids_with_bonus_token_in_last_step=set())
  75. assert proposals.proposal_lens.tolist() == [0] * batch_size