123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821 |
- import random
- from collections import defaultdict
- from types import SimpleNamespace
- from typing import Dict, List, Set
- from unittest.mock import MagicMock
- import pytest
- import torch
- from aphrodite.common.sequence import ExecuteModelRequest, SequenceOutput
- from aphrodite.modeling.layers.sampler import SamplerOutput
- from aphrodite.modeling.utils import set_random_seed
- from aphrodite.spec_decode.interfaces import SpeculativeProposals
- from aphrodite.spec_decode.metrics import (AsyncMetricsCollector,
- SpecDecodeWorkerMetrics)
- from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
- from aphrodite.spec_decode.spec_decode_worker import (
- SpecDecodeWorker, split_num_cache_blocks_evenly)
- from .test_utils import mock_spec_decode_sampler
- from .utils import create_batch, create_sampler_output_list, mock_worker
- @pytest.mark.parametrize('k', [1, 2, 6])
- @pytest.mark.parametrize('batch_size', [1, 2, 32])
- @pytest.mark.parametrize("acceptance_sampler_method",
- ["rejection_sampler", "typical_acceptance_sampler"])
- @torch.inference_mode()
- def test_correctly_calls_draft_model(k: int, batch_size: int,
- acceptance_sampler_method: str):
- """Verify SpecDecodeWorker calls the draft worker with correct
- inputs. Everything else is mocked out.
- """
- draft_worker = mock_worker(cls=MultiStepWorker)
- target_worker = mock_worker()
- metrics_collector = MagicMock(spec=AsyncMetricsCollector)
- worker = SpecDecodeWorker(
- draft_worker,
- target_worker,
- mock_spec_decode_sampler(acceptance_sampler_method),
- disable_logprobs=False,
- metrics_collector=metrics_collector)
- exception_secret = 'artificial stop'
- draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
- seq_group_metadata_list, _, _ = create_batch(batch_size, k)
- execute_model_req = ExecuteModelRequest(
- seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
- with pytest.raises(ValueError, match=exception_secret):
- worker.execute_model(execute_model_req=execute_model_req)
- call_args_list = draft_worker.get_spec_proposals.call_args_list
- assert len(call_args_list) == 1
- for args, _ in call_args_list:
- actual_execute_model_data = args[0]
- assert actual_execute_model_data == execute_model_req
- @pytest.mark.parametrize('k', [1, 2, 6])
- @pytest.mark.parametrize('batch_size', [1, 2, 32])
- @pytest.mark.parametrize("acceptance_sampler_method",
- ["rejection_sampler", "typical_acceptance_sampler"])
- @torch.inference_mode()
- def test_correctly_calls_target_model(k: int, batch_size: int,
- acceptance_sampler_method: str):
- """Verify SpecDecodeWorker calls the target model with correct
- inputs. Everything else is mocked out.
- """
- draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
- target_worker = mock_worker(use_spec=False)
- metrics_collector = MagicMock(spec=AsyncMetricsCollector)
- draft_worker.device = 'cuda'
- target_worker.device = 'cuda'
- set_random_seed(1)
- worker = SpecDecodeWorker(
- draft_worker,
- target_worker,
- mock_spec_decode_sampler(acceptance_sampler_method),
- disable_logprobs=False,
- metrics_collector=metrics_collector)
- worker.init_device()
- vocab_size = 32_000
- proposal_token_ids = torch.randint(low=0,
- high=vocab_size,
- size=(batch_size, k),
- dtype=torch.int64,
- device='cuda')
- proposal_probs = torch.rand(batch_size,
- k,
- vocab_size,
- dtype=torch.float32,
- device='cuda')
- proposal_lens = torch.ones(batch_size, dtype=torch.int64,
- device='cuda') * k
- seq_group_metadata_list, prompts, prev_output_tokens = create_batch(
- batch_size, k)
- draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
- proposal_token_ids=proposal_token_ids,
- proposal_probs=proposal_probs,
- proposal_lens=proposal_lens)
- exception_secret = 'artificial stop'
- target_worker.execute_model.side_effect = ValueError(exception_secret)
- with pytest.raises(ValueError, match=exception_secret):
- worker.execute_model(execute_model_req=ExecuteModelRequest(
- seq_group_metadata_list=seq_group_metadata_list,
- num_lookahead_slots=k))
- seen_contexts: List[List[int]] = []
- call_args_list = target_worker.execute_model.call_args_list
- assert len(call_args_list) == 1
- for _, kwargs in call_args_list:
- seq_group_metadata_list = kwargs[
- "execute_model_req"].seq_group_metadata_list
- assert len(seq_group_metadata_list) == (k + 1) * batch_size
- for seq_group_metadata in seq_group_metadata_list:
- for seq_data in seq_group_metadata.seq_data.values():
- seen_contexts.append(seq_data.get_token_ids())
- expected_seen_contexts: List[List[int]] = []
- for prompt, prev_generated, draft_tokens in zip(
- prompts, prev_output_tokens, proposal_token_ids.tolist()):
- for i in range(len(draft_tokens) + 1):
- expected_seen_contexts.append(prompt + prev_generated +
- draft_tokens[:i])
- seen_contexts.sort()
- expected_seen_contexts.sort()
- assert expected_seen_contexts == seen_contexts
- @pytest.mark.parametrize('k', [1, 2, 6])
- @pytest.mark.parametrize('batch_size', [1, 2, 32])
- @pytest.mark.parametrize("acceptance_sampler_method",
- ["rejection_sampler", "typical_acceptance_sampler"])
- @torch.inference_mode()
- def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
- acceptance_sampler_method: str):
- """Verify SpecDecodeWorker calls the rejection sampler with
- correct inputs. Everything else is mocked out.
- """
- vocab_size = 32_000
- draft_worker = mock_worker(cls=MultiStepWorker,
- vocab_size=vocab_size,
- use_spec=False)
- target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
- spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
- metrics_collector = MagicMock(spec=AsyncMetricsCollector)
- draft_worker.device = 'cuda'
- target_worker.device = 'cuda'
- set_random_seed(1)
- worker = SpecDecodeWorker(draft_worker,
- target_worker,
- spec_decode_sampler,
- disable_logprobs=False,
- metrics_collector=metrics_collector)
- worker.init_device()
- proposal_token_ids = torch.randint(low=0,
- high=vocab_size,
- size=(batch_size, k),
- dtype=torch.int64,
- device='cuda')
- proposal_probs = torch.rand(batch_size,
- k,
- vocab_size,
- dtype=torch.float32,
- device='cuda')
- proposal_lens = torch.ones(batch_size, dtype=torch.int64,
- device='cuda') * k
- seq_group_metadata_list, _, _ = create_batch(batch_size, k)
- draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
- proposal_token_ids=proposal_token_ids,
- proposal_probs=proposal_probs,
- proposal_lens=proposal_lens)
- target_token_ids = torch.randint(low=0,
- high=vocab_size,
- size=(1, batch_size * (k + 1)),
- dtype=torch.int64,
- device='cuda')
- target_token_probs = torch.rand(1,
- batch_size * (k + 1),
- vocab_size,
- dtype=torch.float32,
- device='cuda')
- target_token_logprobs = torch.rand(1,
- batch_size * (k + 1),
- vocab_size,
- dtype=torch.float32,
- device='cuda')
- target_output = create_sampler_output_list(target_token_ids,
- target_token_probs,
- target_token_logprobs)
- target_worker.execute_model.return_value = [target_output[0]]
- exception_secret = 'artificial stop'
- spec_decode_sampler.side_effect = ValueError(exception_secret)
- with pytest.raises(ValueError, match=exception_secret):
- worker.execute_model(execute_model_req=ExecuteModelRequest(
- seq_group_metadata_list=seq_group_metadata_list,
- num_lookahead_slots=k))
- assert len(spec_decode_sampler.call_args_list) == 1
- _, kwargs = spec_decode_sampler.call_args_list[0]
- actual = SimpleNamespace(**kwargs)
- assert torch.equal(actual.bonus_token_ids,
- target_token_ids.reshape(batch_size, k + 1)[:, -1:])
- assert torch.equal(
- actual.target_probs,
- target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1])
- assert torch.equal(actual.draft_token_ids, proposal_token_ids)
- assert torch.equal(actual.draft_probs, proposal_probs)
- @pytest.mark.parametrize('k', [1, 2, 6])
- @pytest.mark.parametrize('batch_size', [1, 2, 32])
- @pytest.mark.parametrize("acceptance_sampler_method",
- ["rejection_sampler", "typical_acceptance_sampler"])
- @torch.inference_mode()
- def test_correctly_formats_output(k: int, batch_size: int,
- acceptance_sampler_method: str):
- """Verify SpecDecodeWorker formats sampler output correctly.
- Everything else is mocked out.
- """
- vocab_size = 32_000
- draft_worker = mock_worker(cls=MultiStepWorker,
- vocab_size=vocab_size,
- use_spec=False)
- target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
- metrics_collector = MagicMock(spec=AsyncMetricsCollector)
- draft_worker.device = 'cuda'
- target_worker.device = 'cuda'
- set_random_seed(1)
- spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
- worker = SpecDecodeWorker(draft_worker,
- target_worker,
- spec_decode_sampler,
- disable_logprobs=False,
- metrics_collector=metrics_collector)
- worker.init_device()
- proposal_token_ids = torch.randint(low=0,
- high=vocab_size,
- size=(batch_size, k),
- dtype=torch.int64,
- device='cuda')
- proposal_probs = torch.rand(batch_size,
- k,
- vocab_size,
- dtype=torch.float32,
- device='cuda')
- proposal_lens = torch.ones(batch_size, dtype=torch.int64,
- device='cuda') * k
- seq_group_metadata_list, _, _ = create_batch(batch_size, k)
- draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
- proposal_token_ids=proposal_token_ids,
- proposal_probs=proposal_probs,
- proposal_lens=proposal_lens)
- target_token_ids = torch.randint(low=0,
- high=vocab_size,
- size=(1, batch_size * (k + 1)),
- dtype=torch.int64,
- device='cuda')
- target_token_probs = torch.rand(1,
- batch_size * (k + 1),
- vocab_size,
- dtype=torch.float32,
- device='cuda')
- target_token_logprobs = torch.rand(1,
- batch_size * (k + 1),
- vocab_size,
- dtype=torch.float32,
- device='cuda')
- target_output = create_sampler_output_list(target_token_ids,
- target_token_probs,
- target_token_logprobs)
- target_worker.execute_model.return_value = [target_output[0]]
- spec_decode_sampler_output = torch.randint(low=0,
- high=vocab_size,
- size=(batch_size, k + 1),
- dtype=torch.int64,
- device='cuda')
- for i in range(batch_size):
- minimum_accepted_tokens = 1
- spec_decode_sampler_output[i][
- -random.randint(minimum_accepted_tokens, k + 1):] = -1
- spec_decode_sampler.return_value = spec_decode_sampler_output
- output = worker.execute_model(execute_model_req=ExecuteModelRequest(
- seq_group_metadata_list=seq_group_metadata_list,
- num_lookahead_slots=k))
- expected_output = create_sampler_output_list(
- token_ids=spec_decode_sampler_output.transpose(0, 1),
- probs=[None for _ in range(k + 1)],
- logprobs=[None for _ in range(k + 1)])
- seq_ids = [
- next(iter(seq_group_metadata.seq_data.keys()))
- for seq_group_metadata in seq_group_metadata_list
- ]
- actual_output_by_seq: Dict[int, List[SequenceOutput]] = {
- seq_id: []
- for seq_id in seq_ids
- }
- expected_output_by_seq: Dict[int, List[SequenceOutput]] = {
- seq_id: []
- for seq_id in seq_ids
- }
- for step in output:
- for seq_group in step:
- for sample in seq_group.samples:
- seq_id = sample.parent_seq_id
- actual_output_by_seq[seq_id].append(sample)
- for step in expected_output:
- for seq_group in step:
- for sample in seq_group.samples:
- seq_id = sample.parent_seq_id
- expected_output_by_seq[seq_id].append(sample)
- all_seen_seq_ids = set(
- list(actual_output_by_seq.keys()) +
- list(expected_output_by_seq.keys()))
- for seq_id in all_seen_seq_ids:
- actual_by_step = actual_output_by_seq[seq_id]
- expected_by_step = expected_output_by_seq[seq_id]
- for i in range(k + 1):
- if i >= len(actual_by_step):
- assert expected_by_step[i].output_token == -1
- continue
- assert actual_by_step[i].output_token == expected_by_step[
- i].output_token
- @pytest.mark.parametrize('k', [1, 2])
- @pytest.mark.parametrize('batch_size', [1])
- @pytest.mark.parametrize('returns_metrics', [True, False])
- @pytest.mark.parametrize("acceptance_sampler_method",
- ["rejection_sampler", "typical_acceptance_sampler"])
- @torch.inference_mode()
- def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool,
- acceptance_sampler_method: str):
- """Verify SpecDecodeWorker collects metrics.
- """
- vocab_size = 32_000
- draft_worker = mock_worker(cls=MultiStepWorker,
- vocab_size=vocab_size,
- use_spec=False)
- target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
- spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
- metrics_collector = MagicMock(spec=AsyncMetricsCollector)
- draft_worker.device = 'cuda'
- target_worker.device = 'cuda'
- set_random_seed(1)
- worker = SpecDecodeWorker(draft_worker,
- target_worker,
- spec_decode_sampler,
- disable_logprobs=False,
- metrics_collector=metrics_collector)
- worker.init_device()
- proposal_token_ids = torch.randint(low=0,
- high=vocab_size,
- size=(batch_size, k),
- dtype=torch.int64,
- device='cuda')
- proposal_probs = torch.rand(batch_size,
- k,
- vocab_size,
- dtype=torch.float32,
- device='cuda')
- proposal_lens = torch.ones(batch_size, dtype=torch.int64,
- device='cuda') * k
- seq_group_metadata_list, _, _ = create_batch(batch_size, k)
- draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
- proposal_token_ids=proposal_token_ids,
- proposal_probs=proposal_probs,
- proposal_lens=proposal_lens)
- target_token_ids = torch.randint(low=0,
- high=vocab_size,
- size=(1, batch_size * (k + 1)),
- dtype=torch.int64,
- device='cuda')
- target_token_probs = torch.rand(1,
- batch_size * (k + 1),
- vocab_size,
- dtype=torch.float32,
- device='cuda')
- target_token_logprobs = torch.rand(1,
- batch_size * (k + 1),
- vocab_size,
- dtype=torch.float32,
- device='cuda')
- target_output = create_sampler_output_list(target_token_ids,
- target_token_probs,
- target_token_logprobs)
- target_worker.execute_model.return_value = [target_output[0]]
- spec_decode_sampler_output = torch.randint(low=0,
- high=vocab_size,
- size=(batch_size, k + 1),
- dtype=torch.int64,
- device='cuda')
- for i in range(batch_size):
- minimum_accepted_tokens = 1
- spec_decode_sampler_output[i][
- -random.randint(minimum_accepted_tokens, k + 1):] = -1
- spec_decode_sampler.return_value = spec_decode_sampler_output
- mock_rejsample_metrics = MagicMock(
- spec=SpecDecodeWorkerMetrics) if returns_metrics else None
- metrics_collector.maybe_collect_rejsample_metrics.return_value = (
- mock_rejsample_metrics)
- output = worker.execute_model(execute_model_req=ExecuteModelRequest(
- seq_group_metadata_list=seq_group_metadata_list,
- num_lookahead_slots=k))
- assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics
- call_args_list = (
- metrics_collector.maybe_collect_rejsample_metrics.call_args_list)
- assert len(call_args_list) == 1
- args, kwargs = call_args_list[0]
- assert args[0] == k or kwargs.get('k', -1) == k
- @pytest.mark.parametrize('k', [0])
- @pytest.mark.parametrize('batch_size', [1, 2, 32])
- @pytest.mark.parametrize("acceptance_sampler_method",
- ["rejection_sampler", "typical_acceptance_sampler"])
- @torch.inference_mode()
- def test_k_equals_zero(k: int, batch_size: int,
- acceptance_sampler_method: str):
- """Verify that the SpecDecodeWorker calls the draft and target workers
- when k is zero. This happens during prefill.
- """
- draft_worker = mock_worker(cls=MultiStepWorker)
- target_worker = mock_worker()
- metrics_collector = MagicMock(spec=AsyncMetricsCollector)
- sampler_output = MagicMock(spec=SamplerOutput)
- sampler_output.hidden_states = None
- target_worker.execute_model.return_value = [sampler_output]
- draft_worker.device = 'cuda'
- target_worker.device = 'cuda'
- set_random_seed(1)
- worker = SpecDecodeWorker(
- proposer_worker=draft_worker,
- scorer_worker=target_worker,
- spec_decode_sampler=mock_spec_decode_sampler(
- acceptance_sampler_method),
- disable_logprobs=False,
- metrics_collector=metrics_collector,
- )
- seq_group_metadata_list, _, _ = create_batch(batch_size,
- k,
- prev_output_token_len=0)
- execute_model_req = ExecuteModelRequest(
- seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
- out = worker.execute_model(execute_model_req=execute_model_req)
- assert len(out) == 1, f"expected only one token output when {k=}"
- assert out[0].sampled_token_probs is None, (
- "expect gpu tensor references to be None")
- assert out[
- 0].sampled_token_ids is None, "expect gpu tensor references to be None"
- draft_worker.execute_model.assert_called_once_with(execute_model_req)
- target_worker.execute_model.assert_called_once_with(execute_model_req)
- @pytest.mark.parametrize('k', [0, 5])
- @pytest.mark.parametrize('batch_size', [0])
- @pytest.mark.parametrize("acceptance_sampler_method",
- ["rejection_sampler", "typical_acceptance_sampler"])
- @torch.inference_mode()
- def test_empty_input_batch(k: int, batch_size: int,
- acceptance_sampler_method: str):
- """Verify that the SpecDecodeWorker calls the draft and target workers
- when the input batch is empty. This can happen if the engine communicates
- to the workers information without scheduling a batch.
- """
- draft_worker = mock_worker(cls=MultiStepWorker)
- target_worker = mock_worker()
- metrics_collector = MagicMock(spec=AsyncMetricsCollector)
- sampler_output = MagicMock(spec=SamplerOutput)
- sampler_output.hidden_states = None
- target_worker.execute_model.return_value = [sampler_output]
- draft_worker.device = 'cuda'
- target_worker.device = 'cuda'
- set_random_seed(1)
- worker = SpecDecodeWorker(
- proposer_worker=draft_worker,
- scorer_worker=target_worker,
- spec_decode_sampler=mock_spec_decode_sampler(
- acceptance_sampler_method),
- disable_logprobs=False,
- metrics_collector=metrics_collector,
- )
- seq_group_metadata_list, _, _ = create_batch(batch_size,
- k,
- prev_output_token_len=0)
- execute_model_req = ExecuteModelRequest(
- seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
- out = worker.execute_model(execute_model_req=execute_model_req)
- assert len(out) == 1, f"expected only one token output when {k=}"
- assert out[0].sampled_token_probs is None, (
- "expect gpu tensor references to be None")
- assert out[
- 0].sampled_token_ids is None, "expect gpu tensor references to be None"
- draft_worker.execute_model.assert_called_once_with(execute_model_req)
- target_worker.execute_model.assert_called_once_with(execute_model_req)
- @pytest.mark.parametrize("acceptance_sampler_method",
- ["rejection_sampler", "typical_acceptance_sampler"])
- @pytest.mark.skip_global_cleanup
- def test_init_device(acceptance_sampler_method: str):
- """Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
- well as other GPU initialization.
- """
- draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
- target_worker = mock_worker(use_spec=False)
- spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
- metrics_collector = MagicMock(spec=AsyncMetricsCollector)
- worker = SpecDecodeWorker(
- proposer_worker=draft_worker,
- scorer_worker=target_worker,
- spec_decode_sampler=spec_decode_sampler,
- disable_logprobs=False,
- metrics_collector=metrics_collector,
- )
- worker.init_device()
- draft_worker.init_device.assert_called_once()
- target_worker.init_device.assert_called_once()
- metrics_collector.init_gpu_tensors.assert_called_once()
- spec_decode_sampler.init_gpu_tensors.assert_called_once()
- @pytest.mark.parametrize("acceptance_sampler_method",
- ["rejection_sampler", "typical_acceptance_sampler"])
- @torch.inference_mode()
- def test_initialize_cache(acceptance_sampler_method):
- """Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer
- workers.
- """
- draft_worker = mock_worker(cls=MultiStepWorker)
- target_worker = mock_worker()
- metrics_collector = MagicMock(spec=AsyncMetricsCollector)
- worker = SpecDecodeWorker(proposer_worker=draft_worker,
- scorer_worker=target_worker,
- spec_decode_sampler=mock_spec_decode_sampler(
- acceptance_sampler_method),
- metrics_collector=metrics_collector)
- kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023}
- worker.initialize_cache(**kwargs)
- draft_worker.initialize_cache.assert_called_once_with(**kwargs)
- target_worker.initialize_cache.assert_called_once_with(**kwargs)
- @pytest.mark.parametrize('available_gpu_blocks', [1, 1024])
- @pytest.mark.parametrize('available_cpu_blocks', [500])
- @pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096])
- @pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
- @pytest.mark.parametrize("acceptance_sampler_method",
- ["rejection_sampler", "typical_acceptance_sampler"])
- @pytest.mark.skip_global_cleanup
- def test_determine_num_available_blocks(available_gpu_blocks: int,
- available_cpu_blocks: int,
- target_cache_block_size_bytes: int,
- draft_kv_size_bytes: int,
- acceptance_sampler_method: str):
- """Verify SpecDecodeWorker correctly profiles num available GPU blocks.
- Specifically, it should run profiling in the scorer worker, and then evenly
- split the blocks between proposer and scorer worker.
- """
- draft_worker = mock_worker(cls=MultiStepWorker)
- target_worker = mock_worker()
- metrics_collector = MagicMock(spec=AsyncMetricsCollector)
- target_worker.determine_num_available_blocks.return_value = (
- available_gpu_blocks, available_cpu_blocks)
- target_worker.get_cache_block_size_bytes.return_value = (
- target_cache_block_size_bytes)
- draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes
- worker = SpecDecodeWorker(
- draft_worker, target_worker,
- mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
- num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks()
- target_worker.determine_num_available_blocks.assert_called_once()
- assert num_cpu_blocks == available_cpu_blocks
- assert num_gpu_blocks == split_num_cache_blocks_evenly(
- target_cache_block_size_bytes, draft_kv_size_bytes,
- available_gpu_blocks)
- @pytest.mark.parametrize('available_gpu_blocks',
- list(range(20)) + [1024, 1024**2])
- @pytest.mark.parametrize('target_cache_block_size_bytes',
- [2 * 2 * 4096, 2 * 2 * 8192])
- @pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
- @pytest.mark.skip_global_cleanup
- def test_split_num_cache_blocks_evenly(available_gpu_blocks: int,
- target_cache_block_size_bytes: int,
- draft_kv_size_bytes: int):
- """Verify split_num_cache_blocks_evenly does not exceed original memory
- allocation in bytes.
- """
- num_blocks = split_num_cache_blocks_evenly(target_cache_block_size_bytes,
- draft_kv_size_bytes,
- available_gpu_blocks)
- assert (num_blocks * target_cache_block_size_bytes) + (
- num_blocks * draft_kv_size_bytes) <= (available_gpu_blocks *
- target_cache_block_size_bytes)
- @torch.inference_mode()
- def test_populate_seq_ids_with_bonus_tokens():
- """
- Verify that a call to _create_output_sampler_list correctly updates
- seq_with_bonus_token_in_last_step.
- seq_with_bonus_token_in_last_step is an internal data structure in
- SpecDecodeWorker that tracks the sequence IDs which are assigned bonus
- tokens by the target model in their last forward pass. This state is
- maintained only for models relying on the KV cache, such as those using
- the MultiStepWorker.
- """
- batch_size = 10
- k = 5
- vocab_size = 10000
- num_sequences_with_bonus_tokens = 5
- target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
- metrics_collector = MagicMock(spec=AsyncMetricsCollector)
- target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
- target_worker.device = 'cuda'
- set_random_seed(1)
- draft_worker = mock_worker(cls=MultiStepWorker)
- draft_worker.device = 'cuda'
- # The sequence_ids attached to each sequence in the batch.
- # The sequence at index i has seq_id assigned_seq_ids[i]
- assigned_seq_ids = list(range(batch_size))
- seq_group_metadata_list, _, _ = create_batch(batch_size,
- k,
- seq_ids=assigned_seq_ids,
- prev_output_token_len=10)
- target_token_logprobs = torch.rand(batch_size, (k + 1),
- vocab_size,
- dtype=torch.float32,
- device='cuda')
- accepted_token_ids = torch.randint(low=0,
- high=vocab_size,
- size=(batch_size, (k + 1)),
- dtype=torch.int64,
- device='cuda')
- expected_request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
- for seq_group_metadata in seq_group_metadata_list:
- for seq_id in seq_group_metadata.seq_data:
- expected_request_id_seq_ids_mapping[
- seq_group_metadata.request_id].add(seq_id)
- # Generate a random sample of sequence indexes with bonus tokens
- seq_indexes_with_bonus_tokens = random.sample(
- range(batch_size), num_sequences_with_bonus_tokens)
- # Create a mask that is True for indices in seq_indexes_with_bonus_tokens
- mask = torch.ones(batch_size, dtype=torch.bool, device='cuda')
- mask[seq_indexes_with_bonus_tokens] = False
- # Set the last token ID to -1 for all indices not in
- # seq_indexes_with_bonus_tokens to indicate the lack of bonus token in
- # those indices.
- accepted_token_ids[mask, -1:] = -1
- worker = SpecDecodeWorker(draft_worker,
- target_worker,
- mock_spec_decode_sampler("rejection_sampler"),
- disable_logprobs=False,
- metrics_collector=metrics_collector)
- # Initialize _seq_with_bonus_token_in_last_step with a set of sequence IDs.
- # This set includes all sequence IDs in the batch as well as an additional
- # `num_extra_sequence_ids` sequence IDs. Note that the sequence IDs are in
- # the range [0, batch_size + num_extra_sequence_ids).
- num_extra_sequence_ids = 10
- worker._seq_with_bonus_token_in_last_step = set(
- range(batch_size + num_extra_sequence_ids))
- worker._create_output_sampler_list(
- seq_group_metadata_list=seq_group_metadata_list,
- accepted_token_ids=accepted_token_ids,
- target_logprobs=target_token_logprobs,
- k=k,
- stage_times=(0, 0, 0))
- # Verify that _seq_with_bonus_token_in_last_step contains the following:
- # 1. Sequence IDs that were already present in
- # _seq_with_bonus_token_in_last_step but were not part of the current
- # batch are retained.
- # 2. Of the sequence IDs present in the current batch, only those with a
- # bonus token are retained in _seq_with_bonus_token_in_last_step.
- # Sequence IDs that are present in the current batch but do not have
- # bonus tokens are removed from _seq_with_bonus_token_in_last_step.
- expected_seq_ids_with_bonus_tokens = \
- set([assigned_seq_ids[i] for i in seq_indexes_with_bonus_tokens])
- additional_sequence_ids = \
- set(range(batch_size, batch_size + num_extra_sequence_ids))
- assert worker._seq_with_bonus_token_in_last_step == \
- expected_seq_ids_with_bonus_tokens.union(additional_sequence_ids)
- assert worker._request_id_seq_id_mapping == \
- expected_request_id_seq_ids_mapping
- @torch.inference_mode()
- def test_handle_finished_requests():
- """
- Test to verify that finished request IDs are appropriately processed to
- update the internal state of the SpecDecodeWorker.
- This test initializes the SpecDecodeWorker with mock data, marks certain
- requests as finished, and ensures that the corresponding sequence IDs are
- correctly removed from the internal mappings.
- """
- batch_size = 32
- k = 3
- draft_worker = mock_worker(cls=MultiStepWorker)
- target_worker = mock_worker()
- metrics_collector = MagicMock(spec=AsyncMetricsCollector)
- worker = SpecDecodeWorker(draft_worker, target_worker,
- mock_spec_decode_sampler("rejection_sampler"),
- metrics_collector)
- # Initialize the request_id_seq_id_mapping mapping dict with a few fake
- # request ids and corresponding sequence ids.
- worker._request_id_seq_id_mapping = \
- {'request-1': {1,2,3}, 'request-2': {4,5,6,7},
- 'request-3': {8,9}, 'request-4': {10,11}}
- # Initialize seq_with_bonus_token_in_last_step with a few fake
- # sequence ids.
- worker._seq_with_bonus_token_in_last_step = {1, 4, 5, 8, 9, 10}
- exception_secret = 'artificial stop'
- draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
- seq_group_metadata_list, _, _ = create_batch(batch_size, k)
- # Mark requests with ids request-1 and request-3 as finished.
- execute_model_req = ExecuteModelRequest(
- seq_group_metadata_list=seq_group_metadata_list,
- num_lookahead_slots=k,
- finished_requests_ids=['request-1', 'request-3'])
- with pytest.raises(ValueError, match=exception_secret):
- worker.execute_model(execute_model_req=execute_model_req)
- # Verify that request-1 and request-3 are removed from
- # request_id_seq_id_mapping
- assert worker._request_id_seq_id_mapping == \
- {'request-2': {4,5,6,7}, 'request-4': {10,11}}
- # Verify that all sequence ids corresponding to 'request-1'
- # and 'request-3' are removed from seq_with_bonus_token_in_last_step.
- assert worker._seq_with_bonus_token_in_last_step == \
- {4,5,10}
|