import random from collections import defaultdict from types import SimpleNamespace from typing import Dict, List, Set from unittest.mock import MagicMock import pytest import torch from aphrodite.common.sequence import ExecuteModelRequest, SequenceOutput from aphrodite.modeling.layers.sampler import SamplerOutput from aphrodite.modeling.utils import set_random_seed from aphrodite.spec_decode.interfaces import SpeculativeProposals from aphrodite.spec_decode.metrics import (AsyncMetricsCollector, SpecDecodeWorkerMetrics) from aphrodite.spec_decode.multi_step_worker import MultiStepWorker from aphrodite.spec_decode.spec_decode_worker import ( SpecDecodeWorker, split_num_cache_blocks_evenly) from .test_utils import mock_spec_decode_sampler from .utils import create_batch, create_sampler_output_list, mock_worker @pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('batch_size', [1, 2, 32]) @pytest.mark.parametrize("acceptance_sampler_method", ["rejection_sampler", "typical_acceptance_sampler"]) @torch.inference_mode() def test_correctly_calls_draft_model(k: int, batch_size: int, acceptance_sampler_method: str): """Verify SpecDecodeWorker calls the draft worker with correct inputs. Everything else is mocked out. """ draft_worker = mock_worker(cls=MultiStepWorker) target_worker = mock_worker() metrics_collector = MagicMock(spec=AsyncMetricsCollector) worker = SpecDecodeWorker( draft_worker, target_worker, mock_spec_decode_sampler(acceptance_sampler_method), disable_logprobs=False, metrics_collector=metrics_collector) exception_secret = 'artificial stop' draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) seq_group_metadata_list, _, _ = create_batch(batch_size, k) execute_model_req = ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k) with pytest.raises(ValueError, match=exception_secret): worker.execute_model(execute_model_req=execute_model_req) call_args_list = draft_worker.get_spec_proposals.call_args_list assert len(call_args_list) == 1 for args, _ in call_args_list: actual_execute_model_data = args[0] assert actual_execute_model_data == execute_model_req @pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('batch_size', [1, 2, 32]) @pytest.mark.parametrize("acceptance_sampler_method", ["rejection_sampler", "typical_acceptance_sampler"]) @torch.inference_mode() def test_correctly_calls_target_model(k: int, batch_size: int, acceptance_sampler_method: str): """Verify SpecDecodeWorker calls the target model with correct inputs. Everything else is mocked out. """ draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False) target_worker = mock_worker(use_spec=False) metrics_collector = MagicMock(spec=AsyncMetricsCollector) draft_worker.device = 'cuda' target_worker.device = 'cuda' set_random_seed(1) worker = SpecDecodeWorker( draft_worker, target_worker, mock_spec_decode_sampler(acceptance_sampler_method), disable_logprobs=False, metrics_collector=metrics_collector) worker.init_device() vocab_size = 32_000 proposal_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, k), dtype=torch.int64, device='cuda') proposal_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32, device='cuda') proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='cuda') * k seq_group_metadata_list, prompts, prev_output_tokens = create_batch( batch_size, k) draft_worker.get_spec_proposals.return_value = SpeculativeProposals( proposal_token_ids=proposal_token_ids, proposal_probs=proposal_probs, proposal_lens=proposal_lens) exception_secret = 'artificial stop' target_worker.execute_model.side_effect = ValueError(exception_secret) with pytest.raises(ValueError, match=exception_secret): worker.execute_model(execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)) seen_contexts: List[List[int]] = [] call_args_list = target_worker.execute_model.call_args_list assert len(call_args_list) == 1 for _, kwargs in call_args_list: seq_group_metadata_list = kwargs[ "execute_model_req"].seq_group_metadata_list assert len(seq_group_metadata_list) == (k + 1) * batch_size for seq_group_metadata in seq_group_metadata_list: for seq_data in seq_group_metadata.seq_data.values(): seen_contexts.append(seq_data.get_token_ids()) expected_seen_contexts: List[List[int]] = [] for prompt, prev_generated, draft_tokens in zip( prompts, prev_output_tokens, proposal_token_ids.tolist()): for i in range(len(draft_tokens) + 1): expected_seen_contexts.append(prompt + prev_generated + draft_tokens[:i]) seen_contexts.sort() expected_seen_contexts.sort() assert expected_seen_contexts == seen_contexts @pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('batch_size', [1, 2, 32]) @pytest.mark.parametrize("acceptance_sampler_method", ["rejection_sampler", "typical_acceptance_sampler"]) @torch.inference_mode() def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int, acceptance_sampler_method: str): """Verify SpecDecodeWorker calls the rejection sampler with correct inputs. Everything else is mocked out. """ vocab_size = 32_000 draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size, use_spec=False) target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method) metrics_collector = MagicMock(spec=AsyncMetricsCollector) draft_worker.device = 'cuda' target_worker.device = 'cuda' set_random_seed(1) worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, disable_logprobs=False, metrics_collector=metrics_collector) worker.init_device() proposal_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, k), dtype=torch.int64, device='cuda') proposal_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32, device='cuda') proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='cuda') * k seq_group_metadata_list, _, _ = create_batch(batch_size, k) draft_worker.get_spec_proposals.return_value = SpeculativeProposals( proposal_token_ids=proposal_token_ids, proposal_probs=proposal_probs, proposal_lens=proposal_lens) target_token_ids = torch.randint(low=0, high=vocab_size, size=(1, batch_size * (k + 1)), dtype=torch.int64, device='cuda') target_token_probs = torch.rand(1, batch_size * (k + 1), vocab_size, dtype=torch.float32, device='cuda') target_token_logprobs = torch.rand(1, batch_size * (k + 1), vocab_size, dtype=torch.float32, device='cuda') target_output = create_sampler_output_list(target_token_ids, target_token_probs, target_token_logprobs) target_worker.execute_model.return_value = [target_output[0]] exception_secret = 'artificial stop' spec_decode_sampler.side_effect = ValueError(exception_secret) with pytest.raises(ValueError, match=exception_secret): worker.execute_model(execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)) assert len(spec_decode_sampler.call_args_list) == 1 _, kwargs = spec_decode_sampler.call_args_list[0] actual = SimpleNamespace(**kwargs) assert torch.equal(actual.bonus_token_ids, target_token_ids.reshape(batch_size, k + 1)[:, -1:]) assert torch.equal( actual.target_probs, target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1]) assert torch.equal(actual.draft_token_ids, proposal_token_ids) assert torch.equal(actual.draft_probs, proposal_probs) @pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('batch_size', [1, 2, 32]) @pytest.mark.parametrize("acceptance_sampler_method", ["rejection_sampler", "typical_acceptance_sampler"]) @torch.inference_mode() def test_correctly_formats_output(k: int, batch_size: int, acceptance_sampler_method: str): """Verify SpecDecodeWorker formats sampler output correctly. Everything else is mocked out. """ vocab_size = 32_000 draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size, use_spec=False) target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) metrics_collector = MagicMock(spec=AsyncMetricsCollector) draft_worker.device = 'cuda' target_worker.device = 'cuda' set_random_seed(1) spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method) worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, disable_logprobs=False, metrics_collector=metrics_collector) worker.init_device() proposal_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, k), dtype=torch.int64, device='cuda') proposal_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32, device='cuda') proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='cuda') * k seq_group_metadata_list, _, _ = create_batch(batch_size, k) draft_worker.get_spec_proposals.return_value = SpeculativeProposals( proposal_token_ids=proposal_token_ids, proposal_probs=proposal_probs, proposal_lens=proposal_lens) target_token_ids = torch.randint(low=0, high=vocab_size, size=(1, batch_size * (k + 1)), dtype=torch.int64, device='cuda') target_token_probs = torch.rand(1, batch_size * (k + 1), vocab_size, dtype=torch.float32, device='cuda') target_token_logprobs = torch.rand(1, batch_size * (k + 1), vocab_size, dtype=torch.float32, device='cuda') target_output = create_sampler_output_list(target_token_ids, target_token_probs, target_token_logprobs) target_worker.execute_model.return_value = [target_output[0]] spec_decode_sampler_output = torch.randint(low=0, high=vocab_size, size=(batch_size, k + 1), dtype=torch.int64, device='cuda') for i in range(batch_size): minimum_accepted_tokens = 1 spec_decode_sampler_output[i][ -random.randint(minimum_accepted_tokens, k + 1):] = -1 spec_decode_sampler.return_value = spec_decode_sampler_output output = worker.execute_model(execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)) expected_output = create_sampler_output_list( token_ids=spec_decode_sampler_output.transpose(0, 1), probs=[None for _ in range(k + 1)], logprobs=[None for _ in range(k + 1)]) seq_ids = [ next(iter(seq_group_metadata.seq_data.keys())) for seq_group_metadata in seq_group_metadata_list ] actual_output_by_seq: Dict[int, List[SequenceOutput]] = { seq_id: [] for seq_id in seq_ids } expected_output_by_seq: Dict[int, List[SequenceOutput]] = { seq_id: [] for seq_id in seq_ids } for step in output: for seq_group in step: for sample in seq_group.samples: seq_id = sample.parent_seq_id actual_output_by_seq[seq_id].append(sample) for step in expected_output: for seq_group in step: for sample in seq_group.samples: seq_id = sample.parent_seq_id expected_output_by_seq[seq_id].append(sample) all_seen_seq_ids = set( list(actual_output_by_seq.keys()) + list(expected_output_by_seq.keys())) for seq_id in all_seen_seq_ids: actual_by_step = actual_output_by_seq[seq_id] expected_by_step = expected_output_by_seq[seq_id] for i in range(k + 1): if i >= len(actual_by_step): assert expected_by_step[i].output_token == -1 continue assert actual_by_step[i].output_token == expected_by_step[ i].output_token @pytest.mark.parametrize('k', [1, 2]) @pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('returns_metrics', [True, False]) @pytest.mark.parametrize("acceptance_sampler_method", ["rejection_sampler", "typical_acceptance_sampler"]) @torch.inference_mode() def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool, acceptance_sampler_method: str): """Verify SpecDecodeWorker collects metrics. """ vocab_size = 32_000 draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size, use_spec=False) target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method) metrics_collector = MagicMock(spec=AsyncMetricsCollector) draft_worker.device = 'cuda' target_worker.device = 'cuda' set_random_seed(1) worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, disable_logprobs=False, metrics_collector=metrics_collector) worker.init_device() proposal_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, k), dtype=torch.int64, device='cuda') proposal_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32, device='cuda') proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='cuda') * k seq_group_metadata_list, _, _ = create_batch(batch_size, k) draft_worker.get_spec_proposals.return_value = SpeculativeProposals( proposal_token_ids=proposal_token_ids, proposal_probs=proposal_probs, proposal_lens=proposal_lens) target_token_ids = torch.randint(low=0, high=vocab_size, size=(1, batch_size * (k + 1)), dtype=torch.int64, device='cuda') target_token_probs = torch.rand(1, batch_size * (k + 1), vocab_size, dtype=torch.float32, device='cuda') target_token_logprobs = torch.rand(1, batch_size * (k + 1), vocab_size, dtype=torch.float32, device='cuda') target_output = create_sampler_output_list(target_token_ids, target_token_probs, target_token_logprobs) target_worker.execute_model.return_value = [target_output[0]] spec_decode_sampler_output = torch.randint(low=0, high=vocab_size, size=(batch_size, k + 1), dtype=torch.int64, device='cuda') for i in range(batch_size): minimum_accepted_tokens = 1 spec_decode_sampler_output[i][ -random.randint(minimum_accepted_tokens, k + 1):] = -1 spec_decode_sampler.return_value = spec_decode_sampler_output mock_rejsample_metrics = MagicMock( spec=SpecDecodeWorkerMetrics) if returns_metrics else None metrics_collector.maybe_collect_rejsample_metrics.return_value = ( mock_rejsample_metrics) output = worker.execute_model(execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)) assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics call_args_list = ( metrics_collector.maybe_collect_rejsample_metrics.call_args_list) assert len(call_args_list) == 1 args, kwargs = call_args_list[0] assert args[0] == k or kwargs.get('k', -1) == k @pytest.mark.parametrize('k', [0]) @pytest.mark.parametrize('batch_size', [1, 2, 32]) @pytest.mark.parametrize("acceptance_sampler_method", ["rejection_sampler", "typical_acceptance_sampler"]) @torch.inference_mode() def test_k_equals_zero(k: int, batch_size: int, acceptance_sampler_method: str): """Verify that the SpecDecodeWorker calls the draft and target workers when k is zero. This happens during prefill. """ draft_worker = mock_worker(cls=MultiStepWorker) target_worker = mock_worker() metrics_collector = MagicMock(spec=AsyncMetricsCollector) sampler_output = MagicMock(spec=SamplerOutput) sampler_output.hidden_states = None target_worker.execute_model.return_value = [sampler_output] draft_worker.device = 'cuda' target_worker.device = 'cuda' set_random_seed(1) worker = SpecDecodeWorker( proposer_worker=draft_worker, scorer_worker=target_worker, spec_decode_sampler=mock_spec_decode_sampler( acceptance_sampler_method), disable_logprobs=False, metrics_collector=metrics_collector, ) seq_group_metadata_list, _, _ = create_batch(batch_size, k, prev_output_token_len=0) execute_model_req = ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k) out = worker.execute_model(execute_model_req=execute_model_req) assert len(out) == 1, f"expected only one token output when {k=}" assert out[0].sampled_token_probs is None, ( "expect gpu tensor references to be None") assert out[ 0].sampled_token_ids is None, "expect gpu tensor references to be None" draft_worker.execute_model.assert_called_once_with(execute_model_req) target_worker.execute_model.assert_called_once_with(execute_model_req) @pytest.mark.parametrize('k', [0, 5]) @pytest.mark.parametrize('batch_size', [0]) @pytest.mark.parametrize("acceptance_sampler_method", ["rejection_sampler", "typical_acceptance_sampler"]) @torch.inference_mode() def test_empty_input_batch(k: int, batch_size: int, acceptance_sampler_method: str): """Verify that the SpecDecodeWorker calls the draft and target workers when the input batch is empty. This can happen if the engine communicates to the workers information without scheduling a batch. """ draft_worker = mock_worker(cls=MultiStepWorker) target_worker = mock_worker() metrics_collector = MagicMock(spec=AsyncMetricsCollector) sampler_output = MagicMock(spec=SamplerOutput) sampler_output.hidden_states = None target_worker.execute_model.return_value = [sampler_output] draft_worker.device = 'cuda' target_worker.device = 'cuda' set_random_seed(1) worker = SpecDecodeWorker( proposer_worker=draft_worker, scorer_worker=target_worker, spec_decode_sampler=mock_spec_decode_sampler( acceptance_sampler_method), disable_logprobs=False, metrics_collector=metrics_collector, ) seq_group_metadata_list, _, _ = create_batch(batch_size, k, prev_output_token_len=0) execute_model_req = ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k) out = worker.execute_model(execute_model_req=execute_model_req) assert len(out) == 1, f"expected only one token output when {k=}" assert out[0].sampled_token_probs is None, ( "expect gpu tensor references to be None") assert out[ 0].sampled_token_ids is None, "expect gpu tensor references to be None" draft_worker.execute_model.assert_called_once_with(execute_model_req) target_worker.execute_model.assert_called_once_with(execute_model_req) @pytest.mark.parametrize("acceptance_sampler_method", ["rejection_sampler", "typical_acceptance_sampler"]) @pytest.mark.skip_global_cleanup def test_init_device(acceptance_sampler_method: str): """Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as well as other GPU initialization. """ draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False) target_worker = mock_worker(use_spec=False) spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method) metrics_collector = MagicMock(spec=AsyncMetricsCollector) worker = SpecDecodeWorker( proposer_worker=draft_worker, scorer_worker=target_worker, spec_decode_sampler=spec_decode_sampler, disable_logprobs=False, metrics_collector=metrics_collector, ) worker.init_device() draft_worker.init_device.assert_called_once() target_worker.init_device.assert_called_once() metrics_collector.init_gpu_tensors.assert_called_once() spec_decode_sampler.init_gpu_tensors.assert_called_once() @pytest.mark.parametrize("acceptance_sampler_method", ["rejection_sampler", "typical_acceptance_sampler"]) @torch.inference_mode() def test_initialize_cache(acceptance_sampler_method): """Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer workers. """ draft_worker = mock_worker(cls=MultiStepWorker) target_worker = mock_worker() metrics_collector = MagicMock(spec=AsyncMetricsCollector) worker = SpecDecodeWorker(proposer_worker=draft_worker, scorer_worker=target_worker, spec_decode_sampler=mock_spec_decode_sampler( acceptance_sampler_method), metrics_collector=metrics_collector) kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023} worker.initialize_cache(**kwargs) draft_worker.initialize_cache.assert_called_once_with(**kwargs) target_worker.initialize_cache.assert_called_once_with(**kwargs) @pytest.mark.parametrize('available_gpu_blocks', [1, 1024]) @pytest.mark.parametrize('available_cpu_blocks', [500]) @pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096]) @pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096]) @pytest.mark.parametrize("acceptance_sampler_method", ["rejection_sampler", "typical_acceptance_sampler"]) @pytest.mark.skip_global_cleanup def test_determine_num_available_blocks(available_gpu_blocks: int, available_cpu_blocks: int, target_cache_block_size_bytes: int, draft_kv_size_bytes: int, acceptance_sampler_method: str): """Verify SpecDecodeWorker correctly profiles num available GPU blocks. Specifically, it should run profiling in the scorer worker, and then evenly split the blocks between proposer and scorer worker. """ draft_worker = mock_worker(cls=MultiStepWorker) target_worker = mock_worker() metrics_collector = MagicMock(spec=AsyncMetricsCollector) target_worker.determine_num_available_blocks.return_value = ( available_gpu_blocks, available_cpu_blocks) target_worker.get_cache_block_size_bytes.return_value = ( target_cache_block_size_bytes) draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes worker = SpecDecodeWorker( draft_worker, target_worker, mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks() target_worker.determine_num_available_blocks.assert_called_once() assert num_cpu_blocks == available_cpu_blocks assert num_gpu_blocks == split_num_cache_blocks_evenly( target_cache_block_size_bytes, draft_kv_size_bytes, available_gpu_blocks) @pytest.mark.parametrize('available_gpu_blocks', list(range(20)) + [1024, 1024**2]) @pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096, 2 * 2 * 8192]) @pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096]) @pytest.mark.skip_global_cleanup def test_split_num_cache_blocks_evenly(available_gpu_blocks: int, target_cache_block_size_bytes: int, draft_kv_size_bytes: int): """Verify split_num_cache_blocks_evenly does not exceed original memory allocation in bytes. """ num_blocks = split_num_cache_blocks_evenly(target_cache_block_size_bytes, draft_kv_size_bytes, available_gpu_blocks) assert (num_blocks * target_cache_block_size_bytes) + ( num_blocks * draft_kv_size_bytes) <= (available_gpu_blocks * target_cache_block_size_bytes) @torch.inference_mode() def test_populate_seq_ids_with_bonus_tokens(): """ Verify that a call to _create_output_sampler_list correctly updates seq_with_bonus_token_in_last_step. seq_with_bonus_token_in_last_step is an internal data structure in SpecDecodeWorker that tracks the sequence IDs which are assigned bonus tokens by the target model in their last forward pass. This state is maintained only for models relying on the KV cache, such as those using the MultiStepWorker. """ batch_size = 10 k = 5 vocab_size = 10000 num_sequences_with_bonus_tokens = 5 target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) metrics_collector = MagicMock(spec=AsyncMetricsCollector) target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)] target_worker.device = 'cuda' set_random_seed(1) draft_worker = mock_worker(cls=MultiStepWorker) draft_worker.device = 'cuda' # The sequence_ids attached to each sequence in the batch. # The sequence at index i has seq_id assigned_seq_ids[i] assigned_seq_ids = list(range(batch_size)) seq_group_metadata_list, _, _ = create_batch(batch_size, k, seq_ids=assigned_seq_ids, prev_output_token_len=10) target_token_logprobs = torch.rand(batch_size, (k + 1), vocab_size, dtype=torch.float32, device='cuda') accepted_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, (k + 1)), dtype=torch.int64, device='cuda') expected_request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set) for seq_group_metadata in seq_group_metadata_list: for seq_id in seq_group_metadata.seq_data: expected_request_id_seq_ids_mapping[ seq_group_metadata.request_id].add(seq_id) # Generate a random sample of sequence indexes with bonus tokens seq_indexes_with_bonus_tokens = random.sample( range(batch_size), num_sequences_with_bonus_tokens) # Create a mask that is True for indices in seq_indexes_with_bonus_tokens mask = torch.ones(batch_size, dtype=torch.bool, device='cuda') mask[seq_indexes_with_bonus_tokens] = False # Set the last token ID to -1 for all indices not in # seq_indexes_with_bonus_tokens to indicate the lack of bonus token in # those indices. accepted_token_ids[mask, -1:] = -1 worker = SpecDecodeWorker(draft_worker, target_worker, mock_spec_decode_sampler("rejection_sampler"), disable_logprobs=False, metrics_collector=metrics_collector) # Initialize _seq_with_bonus_token_in_last_step with a set of sequence IDs. # This set includes all sequence IDs in the batch as well as an additional # `num_extra_sequence_ids` sequence IDs. Note that the sequence IDs are in # the range [0, batch_size + num_extra_sequence_ids). num_extra_sequence_ids = 10 worker._seq_with_bonus_token_in_last_step = set( range(batch_size + num_extra_sequence_ids)) worker._create_output_sampler_list( seq_group_metadata_list=seq_group_metadata_list, accepted_token_ids=accepted_token_ids, target_logprobs=target_token_logprobs, k=k, stage_times=(0, 0, 0)) # Verify that _seq_with_bonus_token_in_last_step contains the following: # 1. Sequence IDs that were already present in # _seq_with_bonus_token_in_last_step but were not part of the current # batch are retained. # 2. Of the sequence IDs present in the current batch, only those with a # bonus token are retained in _seq_with_bonus_token_in_last_step. # Sequence IDs that are present in the current batch but do not have # bonus tokens are removed from _seq_with_bonus_token_in_last_step. expected_seq_ids_with_bonus_tokens = \ set([assigned_seq_ids[i] for i in seq_indexes_with_bonus_tokens]) additional_sequence_ids = \ set(range(batch_size, batch_size + num_extra_sequence_ids)) assert worker._seq_with_bonus_token_in_last_step == \ expected_seq_ids_with_bonus_tokens.union(additional_sequence_ids) assert worker._request_id_seq_id_mapping == \ expected_request_id_seq_ids_mapping @torch.inference_mode() def test_handle_finished_requests(): """ Test to verify that finished request IDs are appropriately processed to update the internal state of the SpecDecodeWorker. This test initializes the SpecDecodeWorker with mock data, marks certain requests as finished, and ensures that the corresponding sequence IDs are correctly removed from the internal mappings. """ batch_size = 32 k = 3 draft_worker = mock_worker(cls=MultiStepWorker) target_worker = mock_worker() metrics_collector = MagicMock(spec=AsyncMetricsCollector) worker = SpecDecodeWorker(draft_worker, target_worker, mock_spec_decode_sampler("rejection_sampler"), metrics_collector) # Initialize the request_id_seq_id_mapping mapping dict with a few fake # request ids and corresponding sequence ids. worker._request_id_seq_id_mapping = \ {'request-1': {1,2,3}, 'request-2': {4,5,6,7}, 'request-3': {8,9}, 'request-4': {10,11}} # Initialize seq_with_bonus_token_in_last_step with a few fake # sequence ids. worker._seq_with_bonus_token_in_last_step = {1, 4, 5, 8, 9, 10} exception_secret = 'artificial stop' draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) seq_group_metadata_list, _, _ = create_batch(batch_size, k) # Mark requests with ids request-1 and request-3 as finished. execute_model_req = ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k, finished_requests_ids=['request-1', 'request-3']) with pytest.raises(ValueError, match=exception_secret): worker.execute_model(execute_model_req=execute_model_req) # Verify that request-1 and request-3 are removed from # request_id_seq_id_mapping assert worker._request_id_seq_id_mapping == \ {'request-2': {4,5,6,7}, 'request-4': {10,11}} # Verify that all sequence ids corresponding to 'request-1' # and 'request-3' are removed from seq_with_bonus_token_in_last_step. assert worker._seq_with_bonus_token_in_last_step == \ {4,5,10}