test_utils.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. from unittest.mock import MagicMock
  2. import pytest
  3. import torch
  4. from aphrodite.common.sequence import SequenceGroupMetadata, get_all_seq_ids
  5. from aphrodite.modeling.layers.rejection_sampler import RejectionSampler
  6. from aphrodite.modeling.layers.typical_acceptance_sampler import (
  7. TypicalAcceptanceSampler)
  8. from aphrodite.spec_decode.util import split_batch_by_proposal_len
  9. def test_get_all_seq_ids():
  10. """Verify get_all_seq_ids extracts all seq ids.
  11. """
  12. expected_seq_ids = list(range(10)) + list(range(100, 110))
  13. seq_group_metadata_list = [
  14. SequenceGroupMetadata(
  15. request_id=str(seq_id),
  16. is_prompt=True,
  17. seq_data={
  18. seq_id: MagicMock(),
  19. },
  20. sampling_params=MagicMock(),
  21. block_tables={
  22. seq_id: MagicMock(),
  23. },
  24. lora_request=None,
  25. ) for seq_id in expected_seq_ids
  26. ]
  27. actual_seq_ids = get_all_seq_ids(seq_group_metadata_list)
  28. assert actual_seq_ids == expected_seq_ids
  29. @pytest.fixture
  30. def fake_sequence_group_metadata():
  31. seq_ids = list(range(3))
  32. return [
  33. SequenceGroupMetadata(
  34. request_id=str(i),
  35. is_prompt=True,
  36. seq_data={
  37. i: MagicMock(),
  38. },
  39. sampling_params=MagicMock(),
  40. block_tables={
  41. i: MagicMock(),
  42. },
  43. lora_request=None,
  44. ) for i in seq_ids
  45. ]
  46. def test_filter_zero_length_proposals(fake_sequence_group_metadata):
  47. proposal_lens = [0, 1, 0]
  48. _, (filtered_groups,
  49. indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
  50. proposal_lens)
  51. expected_groups = [
  52. fake_sequence_group_metadata[0], fake_sequence_group_metadata[2]
  53. ]
  54. expected_indices = [0, 2]
  55. assert filtered_groups == expected_groups
  56. assert indices == expected_indices
  57. def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
  58. proposal_lens = [0, 1, 2]
  59. (filtered_groups,
  60. indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
  61. proposal_lens)
  62. expected_groups = [
  63. fake_sequence_group_metadata[1], fake_sequence_group_metadata[2]
  64. ]
  65. expected_indices = [1, 2]
  66. assert filtered_groups == expected_groups
  67. assert indices == expected_indices
  68. def test_empty_inputs():
  69. _, (filtered_groups, indices) = split_batch_by_proposal_len([], [])
  70. assert filtered_groups == []
  71. assert indices == []
  72. def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
  73. proposal_lens = [0, 0, 0]
  74. (filtered_groups,
  75. indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
  76. proposal_lens)
  77. assert filtered_groups == []
  78. assert indices == []
  79. def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
  80. proposal_lens = [1, 1, 1]
  81. _, (filtered_groups,
  82. indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
  83. proposal_lens)
  84. assert filtered_groups == []
  85. assert indices == []
  86. def mock_spec_decode_sampler(acceptance_sampler_method):
  87. """
  88. Returns either a RejectionSampler or TypicalAcceptanceSampler
  89. object depending on whether acceptance_sampler_method is
  90. 'rejection_sampler' or 'typical_acceptance_sampler' respectively.
  91. """
  92. if acceptance_sampler_method == "rejection_sampler":
  93. sampler = MagicMock(spec=RejectionSampler)
  94. sampler.token_id_dtype = torch.int64
  95. return sampler
  96. elif acceptance_sampler_method == "typical_acceptance_sampler":
  97. sampler = MagicMock(spec=TypicalAcceptanceSampler)
  98. sampler.token_id_dtype = torch.int64
  99. return sampler
  100. else:
  101. raise ValueError(f"Invalid sampler name {acceptance_sampler_method}")