test_utils.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. from unittest.mock import MagicMock
  2. import pytest
  3. import torch
  4. from aphrodite.common.sequence import SequenceGroupMetadata, get_all_seq_ids
  5. from aphrodite.modeling.layers.rejection_sampler import RejectionSampler
  6. from aphrodite.modeling.layers.sampler import _get_ranks
  7. from aphrodite.modeling.layers.typical_acceptance_sampler import (
  8. TypicalAcceptanceSampler)
  9. from aphrodite.spec_decode.util import (get_sampled_token_logprobs,
  10. split_batch_by_proposal_len)
  11. def test_get_all_seq_ids():
  12. """Verify get_all_seq_ids extracts all seq ids.
  13. """
  14. expected_seq_ids = list(range(10)) + list(range(100, 110))
  15. seq_group_metadata_list = [
  16. SequenceGroupMetadata(
  17. request_id=str(seq_id),
  18. is_prompt=True,
  19. seq_data={
  20. seq_id: MagicMock(),
  21. },
  22. sampling_params=MagicMock(),
  23. block_tables={
  24. seq_id: MagicMock(),
  25. },
  26. lora_request=None,
  27. ) for seq_id in expected_seq_ids
  28. ]
  29. actual_seq_ids = get_all_seq_ids(seq_group_metadata_list)
  30. assert actual_seq_ids == expected_seq_ids
  31. @pytest.fixture
  32. def fake_sequence_group_metadata():
  33. seq_ids = list(range(3))
  34. return [
  35. SequenceGroupMetadata(
  36. request_id=str(i),
  37. is_prompt=True,
  38. seq_data={
  39. i: MagicMock(),
  40. },
  41. sampling_params=MagicMock(),
  42. block_tables={
  43. i: MagicMock(),
  44. },
  45. lora_request=None,
  46. ) for i in seq_ids
  47. ]
  48. def test_filter_zero_length_proposals(fake_sequence_group_metadata):
  49. proposal_lens = [0, 1, 0]
  50. _, (filtered_groups,
  51. indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
  52. proposal_lens)
  53. expected_groups = [
  54. fake_sequence_group_metadata[0], fake_sequence_group_metadata[2]
  55. ]
  56. expected_indices = [0, 2]
  57. assert filtered_groups == expected_groups
  58. assert indices == expected_indices
  59. def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
  60. proposal_lens = [0, 1, 2]
  61. (filtered_groups,
  62. indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
  63. proposal_lens)
  64. expected_groups = [
  65. fake_sequence_group_metadata[1], fake_sequence_group_metadata[2]
  66. ]
  67. expected_indices = [1, 2]
  68. assert filtered_groups == expected_groups
  69. assert indices == expected_indices
  70. def test_empty_inputs():
  71. _, (filtered_groups, indices) = split_batch_by_proposal_len([], [])
  72. assert filtered_groups == []
  73. assert indices == []
  74. def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
  75. proposal_lens = [0, 0, 0]
  76. (filtered_groups,
  77. indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
  78. proposal_lens)
  79. assert filtered_groups == []
  80. assert indices == []
  81. def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
  82. proposal_lens = [1, 1, 1]
  83. _, (filtered_groups,
  84. indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
  85. proposal_lens)
  86. assert filtered_groups == []
  87. assert indices == []
  88. def mock_spec_decode_sampler(acceptance_sampler_method):
  89. """
  90. Returns either a RejectionSampler or TypicalAcceptanceSampler
  91. object depending on whether acceptance_sampler_method is
  92. 'rejection_sampler' or 'typical_acceptance_sampler' respectively.
  93. """
  94. if acceptance_sampler_method == "rejection_sampler":
  95. sampler = MagicMock(spec=RejectionSampler)
  96. sampler.token_id_dtype = torch.int64
  97. return sampler
  98. elif acceptance_sampler_method == "typical_acceptance_sampler":
  99. sampler = MagicMock(spec=TypicalAcceptanceSampler)
  100. sampler.token_id_dtype = torch.int64
  101. return sampler
  102. else:
  103. raise ValueError(f"Invalid sampler name {acceptance_sampler_method}")
  104. def test_get_sampled_token_logprobs():
  105. """Verify get_sampled_token_logprobs returns consistent rankings
  106. with regular get_ranks when probabilities match exactly.
  107. """
  108. logprob_tensor = torch.tensor(
  109. [[[-.1, -.1]] * 2]) # shape (num_steps, batch_size, vocab_size)
  110. sampled_token_tensor = torch.tensor([[1,
  111. 0]]) # shape (num_steps, batch_size)
  112. ranks_spec_dec, _ = get_sampled_token_logprobs(logprob_tensor,
  113. sampled_token_tensor)
  114. ranks_regular = _get_ranks(logprob_tensor.reshape((2, -1)),
  115. sampled_token_tensor.reshape(-1))
  116. assert torch.equal(ranks_spec_dec.reshape(-1), ranks_regular)