test_utils.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. from unittest.mock import MagicMock
  2. import pytest
  3. import torch
  4. from aphrodite.common.sequence import SequenceGroupMetadata, get_all_seq_ids
  5. from aphrodite.modeling.layers.rejection_sampler import RejectionSampler
  6. from aphrodite.modeling.layers.typical_acceptance_sampler import (
  7. TypicalAcceptanceSampler)
  8. from aphrodite.spec_decode.util import split_batch_by_proposal_len
  9. def test_get_all_seq_ids():
  10. """Verify get_all_seq_ids extracts all seq ids.
  11. """
  12. expected_seq_ids = list(range(10)) + list(range(100, 110))
  13. seq_group_metadata_list = [
  14. SequenceGroupMetadata(
  15. request_id=str(seq_id),
  16. is_prompt=True,
  17. seq_data={
  18. seq_id: MagicMock(),
  19. },
  20. sampling_params=MagicMock(),
  21. block_tables={
  22. seq_id: MagicMock(),
  23. },
  24. lora_request=None,
  25. ) for seq_id in expected_seq_ids
  26. ]
  27. actual_seq_ids = get_all_seq_ids(seq_group_metadata_list)
  28. assert actual_seq_ids == expected_seq_ids
  29. @pytest.fixture
  30. def fake_sequence_group_metadata():
  31. seq_ids = list(range(3))
  32. return [
  33. SequenceGroupMetadata(
  34. request_id=str(i),
  35. is_prompt=True,
  36. seq_data={
  37. i: MagicMock(),
  38. },
  39. sampling_params=MagicMock(),
  40. block_tables={
  41. i: MagicMock(),
  42. },
  43. lora_request=None,
  44. ) for i in seq_ids
  45. ]
  46. def test_filter_zero_length_proposals(fake_sequence_group_metadata):
  47. proposal_lens = [0, 1, 0]
  48. filtered_groups, indices = split_batch_by_proposal_len(
  49. fake_sequence_group_metadata,
  50. proposal_lens,
  51. select_proposal_len_zero=True)
  52. expected_groups = [
  53. fake_sequence_group_metadata[0], fake_sequence_group_metadata[2]
  54. ]
  55. expected_indices = [0, 2]
  56. assert filtered_groups == expected_groups
  57. assert indices == expected_indices
  58. def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
  59. proposal_lens = [0, 1, 2]
  60. filtered_groups, indices = split_batch_by_proposal_len(
  61. fake_sequence_group_metadata,
  62. proposal_lens,
  63. select_proposal_len_zero=False)
  64. expected_groups = [
  65. fake_sequence_group_metadata[1], fake_sequence_group_metadata[2]
  66. ]
  67. expected_indices = [1, 2]
  68. assert filtered_groups == expected_groups
  69. assert indices == expected_indices
  70. def test_empty_inputs():
  71. filtered_groups, indices = split_batch_by_proposal_len(
  72. [], [], select_proposal_len_zero=True)
  73. assert filtered_groups == []
  74. assert indices == []
  75. def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
  76. proposal_lens = [0, 0, 0]
  77. filtered_groups, indices = split_batch_by_proposal_len(
  78. fake_sequence_group_metadata,
  79. proposal_lens,
  80. select_proposal_len_zero=False)
  81. assert filtered_groups == []
  82. assert indices == []
  83. def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
  84. proposal_lens = [1, 1, 1]
  85. filtered_groups, indices = split_batch_by_proposal_len(
  86. fake_sequence_group_metadata,
  87. proposal_lens,
  88. select_proposal_len_zero=True)
  89. assert filtered_groups == []
  90. assert indices == []
  91. def mock_spec_decode_sampler(acceptance_sampler_method):
  92. """
  93. Returns either a RejectionSampler or TypicalAcceptanceSampler
  94. object depending on whether acceptance_sampler_method is
  95. 'rejection_sampler' or 'typical_acceptance_sampler' respectively.
  96. """
  97. if acceptance_sampler_method == "rejection_sampler":
  98. sampler = MagicMock(spec=RejectionSampler)
  99. sampler.token_id_dtype = torch.int64
  100. return sampler
  101. elif acceptance_sampler_method == "typical_acceptance_sampler":
  102. sampler = MagicMock(spec=TypicalAcceptanceSampler)
  103. sampler.token_id_dtype = torch.int64
  104. return sampler
  105. else:
  106. raise ValueError(f"Invalid sampler name {acceptance_sampler_method}")