123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272 |
- import random
- from unittest.mock import MagicMock
- import pytest
- from transformers import PreTrainedTokenizer
- from aphrodite.common.sampling_params import SamplingParams
- from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
- SequenceOutput, SequenceStatus)
- from aphrodite.common.utils import Counter
- from aphrodite.engine.output_processor.multi_step import (
- MultiStepOutputProcessor)
- from aphrodite.engine.output_processor.stop_checker import StopChecker
- from aphrodite.processing.scheduler import Scheduler
- from aphrodite.transformers_utils.detokenizer import Detokenizer
- from ...core.utils import create_seq_group
- @pytest.mark.parametrize("seq_output_len", [128])
- @pytest.mark.parametrize("num_new_tokens", [1, 12])
- @pytest.mark.skip_global_cleanup
- def test_appends_token_ids(num_new_tokens: int, seq_output_len: int):
- """Verify multi-step decoding appends token ids correctly.
- We append token ids and verify all the token ids were appended correctly.
- Note that ignore_eos=True.
- """
- detokenizer = MagicMock(spec=Detokenizer)
- scheduler = MagicMock(spec=Scheduler)
- stop_checker = MagicMock(spec=StopChecker)
- seq_counter = Counter()
- output_processor = MultiStepOutputProcessor(
- detokenizer=detokenizer,
- scheduler=[scheduler],
- seq_counter=seq_counter,
- get_tokenizer_for_seq=lambda _: mock_tokenizer(),
- stop_checker=stop_checker,
- )
- seq_group = create_seq_group(
- seq_prompt_len=1024,
- seq_output_lens=[seq_output_len],
- sampling_params=SamplingParams(max_tokens=seq_output_len +
- num_new_tokens,
- ignore_eos=True),
- )
- seq = seq_group.get_seqs()[0]
- seq.status = SequenceStatus.RUNNING
- new_token_ids = list(range(num_new_tokens))
- outputs = [
- CompletionSequenceGroupOutput(
- samples=[
- SequenceOutput(
- parent_seq_id=seq.seq_id,
- output_token=output_token,
- logprobs={output_token: Logprob(0.0)},
- )
- ],
- prompt_logprobs=None,
- ) for output_token in new_token_ids
- ]
- assert seq.get_token_ids()[-len(new_token_ids):] != new_token_ids
- output_processor.process_outputs(seq_group, outputs)
- assert seq.get_token_ids()[-len(new_token_ids):] == new_token_ids
- @pytest.mark.parametrize("seq_prompt_len", [1024])
- @pytest.mark.parametrize("seq_output_len", [128])
- @pytest.mark.parametrize("num_new_tokens", [5, 6, 7, 8])
- @pytest.mark.parametrize("max_tokens", [128 + 3])
- @pytest.mark.skip_global_cleanup
- def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int,
- seq_output_len: int, max_tokens: int):
- """Verify tokens after max_tokens are dropped and not appended to the
- sequence.
- """
- detokenizer = MagicMock(spec=Detokenizer)
- scheduler = MagicMock(spec=Scheduler)
- stop_checker = MagicMock(spec=StopChecker)
- seq_counter = Counter()
- output_processor = MultiStepOutputProcessor(
- detokenizer=detokenizer,
- scheduler=[scheduler],
- seq_counter=seq_counter,
- get_tokenizer_for_seq=lambda _: mock_tokenizer(),
- stop_checker=stop_checker,
- )
- seq_group = create_seq_group(
- seq_prompt_len=seq_prompt_len,
- seq_output_lens=[seq_output_len],
- sampling_params=SamplingParams(max_tokens=max_tokens, ),
- )
- seq = seq_group.get_seqs()[0]
- seq.status = SequenceStatus.RUNNING
- new_token_ids = list(range(num_new_tokens))
- outputs = [
- CompletionSequenceGroupOutput(
- samples=[
- SequenceOutput(
- parent_seq_id=seq.seq_id,
- output_token=output_token,
- logprobs={output_token: Logprob(0.0)},
- )
- ],
- prompt_logprobs=None,
- ) for output_token in new_token_ids
- ]
- assert seq.get_len() == seq_prompt_len + seq_output_len
- output_processor.process_outputs(seq_group, outputs)
- # Expect the processed sequence to not go over max tokens in len.
- assert seq.get_len() == seq_prompt_len + max_tokens
- # Expect the correct tokens were appended.
- expected_appended_tokens = new_token_ids[:max_tokens - seq_output_len]
- assert seq.get_token_ids(
- )[-len(expected_appended_tokens):] == expected_appended_tokens
- @pytest.mark.parametrize("seq_prompt_len", [1024])
- @pytest.mark.parametrize("seq_output_len", [128])
- @pytest.mark.parametrize("num_new_tokens", [12])
- @pytest.mark.parametrize("seed", list(range(6)))
- @pytest.mark.skip_global_cleanup
- def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
- seq_output_len: int, seed: int):
- """Verify the eos token id is included in the sequence, but subsequent
- tokens are dropped (not appended to sequence).
- """
- random.seed(seed)
- detokenizer = MagicMock(spec=Detokenizer)
- scheduler = MagicMock(spec=Scheduler)
- stop_checker = MagicMock(spec=StopChecker)
- seq_counter = Counter()
- eos_token_id = 100
- output_processor = MultiStepOutputProcessor(
- detokenizer=detokenizer,
- scheduler=[scheduler],
- seq_counter=seq_counter,
- get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
- stop_checker=stop_checker,
- )
- seq_group = create_seq_group(
- seq_prompt_len=seq_prompt_len,
- seq_output_lens=[seq_output_len],
- sampling_params=SamplingParams(
- # Ensure enough space.
- max_tokens=seq_output_len + num_new_tokens, ),
- )
- seq = seq_group.get_seqs()[0]
- seq.status = SequenceStatus.RUNNING
- new_token_ids = list(range(num_new_tokens))
- assert eos_token_id not in new_token_ids
- eos_index = random.randint(0, len(new_token_ids) - 1)
- new_token_ids[eos_index] = eos_token_id
- outputs = [
- CompletionSequenceGroupOutput(
- samples=[
- SequenceOutput(
- parent_seq_id=seq.seq_id,
- output_token=output_token,
- logprobs={output_token: Logprob(0.0)},
- )
- ],
- prompt_logprobs=None,
- ) for output_token in new_token_ids
- ]
- assert seq.get_len() == seq_prompt_len + seq_output_len
- output_processor.process_outputs(seq_group, outputs)
- # Expect the processed sequence to not go beyond provided eos.
- assert seq.get_len() == seq_prompt_len + seq_output_len + (eos_index + 1)
- # Expect the correct tokens were appended.
- expected_appended_tokens = new_token_ids[:eos_index + 1]
- assert seq.get_token_ids(
- )[-len(expected_appended_tokens):] == expected_appended_tokens
- @pytest.mark.parametrize("seq_prompt_len", [1024])
- @pytest.mark.parametrize("seq_output_len", [128])
- @pytest.mark.parametrize("num_new_tokens", [12])
- @pytest.mark.parametrize("seed", list(range(6)))
- @pytest.mark.skip_global_cleanup
- def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
- seq_output_len: int, seed: int):
- """When sampling parameters dictate that we should ignore the eos token id,
- ensure all token ids are appended even if the eos token id is emitted.
- """
- random.seed(seed)
- detokenizer = MagicMock(spec=Detokenizer)
- scheduler = MagicMock(spec=Scheduler)
- stop_checker = MagicMock(spec=StopChecker)
- seq_counter = Counter()
- eos_token_id = 100
- output_processor = MultiStepOutputProcessor(
- detokenizer=detokenizer,
- scheduler=[scheduler],
- seq_counter=seq_counter,
- get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
- stop_checker=stop_checker,
- )
- seq_group = create_seq_group(
- seq_prompt_len=seq_prompt_len,
- seq_output_lens=[seq_output_len],
- sampling_params=SamplingParams(
- # Ensure enough space.
- max_tokens=seq_output_len + num_new_tokens,
- ignore_eos=True,
- ),
- )
- seq = seq_group.get_seqs()[0]
- seq.status = SequenceStatus.RUNNING
- new_token_ids = list(range(num_new_tokens))
- assert eos_token_id not in new_token_ids
- eos_index = random.randint(0, len(new_token_ids) - 1)
- new_token_ids[eos_index] = eos_token_id
- outputs = [
- CompletionSequenceGroupOutput(
- samples=[
- SequenceOutput(
- parent_seq_id=seq.seq_id,
- output_token=output_token,
- logprobs={output_token: Logprob(0.0)},
- )
- ],
- prompt_logprobs=None,
- ) for output_token in new_token_ids
- ]
- assert seq.get_len() == seq_prompt_len + seq_output_len
- output_processor.process_outputs(seq_group, outputs)
- # Expect the processed sequence to go beyond eos.
- assert seq.get_len() == seq_prompt_len + seq_output_len + num_new_tokens
- # Expect the correct tokens were appended.
- expected_appended_tokens = new_token_ids[:seq_output_len + num_new_tokens -
- seq_output_len]
- assert seq.get_token_ids(
- )[-len(expected_appended_tokens):] == expected_appended_tokens
- def mock_tokenizer(eos_token_id=1000):
- tokenizer = MagicMock(spec=PreTrainedTokenizer)
- tokenizer.eos_token_id = eos_token_id
- return tokenizer
|