123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128 |
- from unittest.mock import MagicMock
- import pytest
- import torch
- from aphrodite.common.sequence import SequenceGroupMetadata, get_all_seq_ids
- from aphrodite.modeling.layers.rejection_sampler import RejectionSampler
- from aphrodite.modeling.layers.typical_acceptance_sampler import (
- TypicalAcceptanceSampler)
- from aphrodite.spec_decode.util import split_batch_by_proposal_len
- def test_get_all_seq_ids():
- """Verify get_all_seq_ids extracts all seq ids.
- """
- expected_seq_ids = list(range(10)) + list(range(100, 110))
- seq_group_metadata_list = [
- SequenceGroupMetadata(
- request_id=str(seq_id),
- is_prompt=True,
- seq_data={
- seq_id: MagicMock(),
- },
- sampling_params=MagicMock(),
- block_tables={
- seq_id: MagicMock(),
- },
- lora_request=None,
- ) for seq_id in expected_seq_ids
- ]
- actual_seq_ids = get_all_seq_ids(seq_group_metadata_list)
- assert actual_seq_ids == expected_seq_ids
- @pytest.fixture
- def fake_sequence_group_metadata():
- seq_ids = list(range(3))
- return [
- SequenceGroupMetadata(
- request_id=str(i),
- is_prompt=True,
- seq_data={
- i: MagicMock(),
- },
- sampling_params=MagicMock(),
- block_tables={
- i: MagicMock(),
- },
- lora_request=None,
- ) for i in seq_ids
- ]
- def test_filter_zero_length_proposals(fake_sequence_group_metadata):
- proposal_lens = [0, 1, 0]
- _, (filtered_groups,
- indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
- proposal_lens)
- expected_groups = [
- fake_sequence_group_metadata[0], fake_sequence_group_metadata[2]
- ]
- expected_indices = [0, 2]
- assert filtered_groups == expected_groups
- assert indices == expected_indices
- def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
- proposal_lens = [0, 1, 2]
- (filtered_groups,
- indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
- proposal_lens)
- expected_groups = [
- fake_sequence_group_metadata[1], fake_sequence_group_metadata[2]
- ]
- expected_indices = [1, 2]
- assert filtered_groups == expected_groups
- assert indices == expected_indices
- def test_empty_inputs():
- _, (filtered_groups, indices) = split_batch_by_proposal_len([], [])
- assert filtered_groups == []
- assert indices == []
- def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
- proposal_lens = [0, 0, 0]
- (filtered_groups,
- indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
- proposal_lens)
- assert filtered_groups == []
- assert indices == []
- def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
- proposal_lens = [1, 1, 1]
- _, (filtered_groups,
- indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
- proposal_lens)
- assert filtered_groups == []
- assert indices == []
- def mock_spec_decode_sampler(acceptance_sampler_method):
- """
- Returns either a RejectionSampler or TypicalAcceptanceSampler
- object depending on whether acceptance_sampler_method is
- 'rejection_sampler' or 'typical_acceptance_sampler' respectively.
- """
- if acceptance_sampler_method == "rejection_sampler":
- sampler = MagicMock(spec=RejectionSampler)
- sampler.token_id_dtype = torch.int64
- return sampler
- elif acceptance_sampler_method == "typical_acceptance_sampler":
- sampler = MagicMock(spec=TypicalAcceptanceSampler)
- sampler.token_id_dtype = torch.int64
- return sampler
- else:
- raise ValueError(f"Invalid sampler name {acceptance_sampler_method}")
|