|
@@ -1,14 +1,15 @@
|
|
|
-import pytest
|
|
|
import random
|
|
|
from typing import Tuple
|
|
|
from unittest.mock import patch
|
|
|
|
|
|
+import pytest
|
|
|
import torch
|
|
|
|
|
|
from aphrodite.modeling.layers.sampler import Sampler
|
|
|
from aphrodite.modeling.utils import set_random_seed
|
|
|
-from aphrodite.common.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
|
|
-from aphrodite.task_handler.worker import Worker
|
|
|
+from aphrodite.common.sequence import (SamplingParams, SequenceData,
|
|
|
+ SequenceGroupMetadata)
|
|
|
+from aphrodite.task_handler.model_runner import ModelRunner
|
|
|
|
|
|
|
|
|
class MockLogitsSampler(Sampler):
|
|
@@ -19,15 +20,15 @@ class MockLogitsSampler(Sampler):
|
|
|
|
|
|
def forward(self, *args, **kwargs):
|
|
|
with patch("aphrodite.modeling.layers.sampler._prune_hidden_states",
|
|
|
- lambda x, y: x):
|
|
|
- with patch("aphrodite.modeling.layers.sampler._get_logits",
|
|
|
+ lambda x, y: x), patch(
|
|
|
+ "aphrodite.modeling.layers.sampler._get_logits",
|
|
|
lambda *args, **kwargs: self.fake_logits):
|
|
|
- return super().forward(*args, **kwargs)
|
|
|
+ return super().forward(*args, **kwargs)
|
|
|
|
|
|
|
|
|
def _prepare_test(
|
|
|
batch_size: int
|
|
|
-) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, Worker]:
|
|
|
+) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]:
|
|
|
vocab_size = 32000
|
|
|
input_tensor = torch.rand((batch_size, 1024),
|
|
|
device="cuda",
|
|
@@ -37,9 +38,8 @@ def _prepare_test(
|
|
|
device=input_tensor.device,
|
|
|
dtype=input_tensor.dtype)
|
|
|
sampler = MockLogitsSampler(32000, fake_logits)
|
|
|
- worker = Worker(None, None, None)
|
|
|
- worker.block_size = 16
|
|
|
- return input_tensor, fake_logits, sampler, worker
|
|
|
+ model_runner = ModelRunner(None, None, None)
|
|
|
+ return input_tensor, fake_logits, sampler, model_runner
|
|
|
|
|
|
|
|
|
RANDOM_SEEDS = list(range(128))
|
|
@@ -49,27 +49,31 @@ RANDOM_SEEDS = list(range(128))
|
|
|
def test_sampler_all_greedy(seed: int):
|
|
|
set_random_seed(seed)
|
|
|
batch_size = random.randint(1, 256)
|
|
|
- input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
|
|
|
+ input_tensor, fake_logits, sampler, model_runner = _prepare_test(
|
|
|
+ batch_size)
|
|
|
|
|
|
seq_group_metadata_list = []
|
|
|
+ prompt_lens = []
|
|
|
for i in range(batch_size):
|
|
|
seq_group_metadata_list.append(
|
|
|
- SequenceGroupMetadata(request_id=f"test_{i}",
|
|
|
- is_prompt=True,
|
|
|
- seq_data={0: SequenceData([1, 2, 3])},
|
|
|
- sampling_params=SamplingParams(
|
|
|
- temperature=0, ),
|
|
|
- block_tables={0: [1]},
|
|
|
- persistent_data={}))
|
|
|
-
|
|
|
- # pylint: disable=protected-access
|
|
|
- _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
|
|
+ SequenceGroupMetadata(
|
|
|
+ request_id=f"test_{i}",
|
|
|
+ is_prompt=True,
|
|
|
+ seq_data={0: SequenceData([1, 2, 3])},
|
|
|
+ sampling_params=SamplingParams(temperature=0, ),
|
|
|
+ block_tables={0: [1]},
|
|
|
+ persistent_data={},
|
|
|
+ ))
|
|
|
+ prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
|
|
+
|
|
|
+ sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
|
|
+ prompt_lens)
|
|
|
sampler_output = sampler(embedding=None,
|
|
|
hidden_states=input_tensor,
|
|
|
- input_metadata=input_metadata)
|
|
|
+ sampling_metadata=sampling_metadata)
|
|
|
expected = torch.argmax(fake_logits, dim=-1)
|
|
|
for i, sequence_output in enumerate(sampler_output):
|
|
|
- for nth_output in sequence_output:
|
|
|
+ for nth_output in sequence_output.samples:
|
|
|
assert nth_output.output_token == expected[i].item()
|
|
|
|
|
|
|
|
@@ -77,30 +81,36 @@ def test_sampler_all_greedy(seed: int):
|
|
|
def test_sampler_all_random(seed: int):
|
|
|
set_random_seed(seed)
|
|
|
batch_size = random.randint(1, 256)
|
|
|
- input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
|
|
|
+ input_tensor, fake_logits, sampler, model_runner = _prepare_test(
|
|
|
+ batch_size)
|
|
|
|
|
|
for i in range(batch_size):
|
|
|
fake_logits[i, i] = 1e2
|
|
|
|
|
|
seq_group_metadata_list = []
|
|
|
+ prompt_lens = []
|
|
|
for i in range(batch_size):
|
|
|
seq_group_metadata_list.append(
|
|
|
- SequenceGroupMetadata(request_id=f"test_{i}",
|
|
|
- is_prompt=True,
|
|
|
- seq_data={0: SequenceData([1, 2, 3])},
|
|
|
- sampling_params=SamplingParams(
|
|
|
- temperature=1.0,
|
|
|
- n=random.randint(1, 10),
|
|
|
- ),
|
|
|
- block_tables={0: [1]},
|
|
|
- persistent_data={}))
|
|
|
- # pylint: disable=protected-access
|
|
|
- _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
|
|
+ SequenceGroupMetadata(
|
|
|
+ request_id=f"test_{i}",
|
|
|
+ is_prompt=True,
|
|
|
+ seq_data={0: SequenceData([1, 2, 3])},
|
|
|
+ sampling_params=SamplingParams(
|
|
|
+ temperature=1.0,
|
|
|
+ n=random.randint(1, 10),
|
|
|
+ ),
|
|
|
+ block_tables={0: [1]},
|
|
|
+ persistent_data={},
|
|
|
+ ))
|
|
|
+ prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
|
|
+
|
|
|
+ sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
|
|
+ prompt_lens)
|
|
|
sampler_output = sampler(embedding=None,
|
|
|
hidden_states=input_tensor,
|
|
|
- input_metadata=input_metadata)
|
|
|
+ sampling_metadata=sampling_metadata)
|
|
|
for i, sequence_output in enumerate(sampler_output):
|
|
|
- for nth_output in sequence_output:
|
|
|
+ for nth_output in sequence_output.samples:
|
|
|
assert nth_output.output_token == i
|
|
|
|
|
|
|
|
@@ -108,37 +118,47 @@ def test_sampler_all_random(seed: int):
|
|
|
def test_sampler_all_beam(seed: int):
|
|
|
set_random_seed(seed)
|
|
|
batch_size = random.randint(1, 256)
|
|
|
- # pylint: disable=unused-variable
|
|
|
- input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
|
|
|
+ input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
|
|
|
|
|
|
seq_group_metadata_list = []
|
|
|
+ prompt_lens = []
|
|
|
for i in range(batch_size):
|
|
|
seq_group_metadata_list.append(
|
|
|
- SequenceGroupMetadata(request_id=f"test_{i}",
|
|
|
- is_prompt=True,
|
|
|
- seq_data={0: SequenceData([1, 2, 3])},
|
|
|
- sampling_params=SamplingParams(
|
|
|
- temperature=0,
|
|
|
- best_of=2,
|
|
|
- use_beam_search=True,
|
|
|
- ),
|
|
|
- block_tables={0: [1]},
|
|
|
- persistent_data={}))
|
|
|
- # pylint: disable=protected-access
|
|
|
- _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
|
|
+ SequenceGroupMetadata(
|
|
|
+ request_id=f"test_{i}",
|
|
|
+ is_prompt=True,
|
|
|
+ seq_data={0: SequenceData([1, 2, 3])},
|
|
|
+ sampling_params=SamplingParams(
|
|
|
+ temperature=0,
|
|
|
+ best_of=2,
|
|
|
+ use_beam_search=True,
|
|
|
+ ),
|
|
|
+ block_tables={0: [1]},
|
|
|
+ persistent_data={},
|
|
|
+ ))
|
|
|
+ prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
|
|
+
|
|
|
+ sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
|
|
+ prompt_lens)
|
|
|
sampler(embedding=None,
|
|
|
hidden_states=input_tensor,
|
|
|
- input_metadata=input_metadata)
|
|
|
+ sampling_metadata=sampling_metadata)
|
|
|
+ # no assertion here as I am not sure how to determine whether
|
|
|
+ # the outputs are expected - in other words, this just tests
|
|
|
+ # whether there are no exceptions in the sampler
|
|
|
+ # when handling an all-beam search case.
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
|
|
def test_sampler_mixed(seed: int):
|
|
|
set_random_seed(seed)
|
|
|
batch_size = random.randint(1, 256)
|
|
|
- input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
|
|
|
+ input_tensor, fake_logits, sampler, model_runner = _prepare_test(
|
|
|
+ batch_size)
|
|
|
|
|
|
seq_group_metadata_list = []
|
|
|
expected_tokens = []
|
|
|
+ prompt_lens = []
|
|
|
for i in range(batch_size):
|
|
|
n = 1
|
|
|
sampling_type = random.randint(0, 2)
|
|
@@ -150,39 +170,72 @@ def test_sampler_mixed(seed: int):
|
|
|
temperature=random.random() + 0.1,
|
|
|
top_p=min(random.random() + 0.1, 1),
|
|
|
top_k=random.randint(0, 10) or -1,
|
|
|
- top_a=min(random.random() + 0.1, 2),
|
|
|
- tfs=min(random.random() + 0.1, 1),
|
|
|
- eta_cutoff=random.randint(0, 10) or 0,
|
|
|
- epsilon_cutoff=random.randint(0, 10) or 0,
|
|
|
- typical_p=min(random.random() + 0.1, 1),
|
|
|
- presence_penalty=random.randint(0, 1),
|
|
|
- frequency_penalty=random.randint(0, 1),
|
|
|
- repetition_penalty=min(random.random() + 0.1, 1),
|
|
|
n=n,
|
|
|
+ presence_penalty=random.randint(0, 1),
|
|
|
)
|
|
|
else:
|
|
|
sampling_params = SamplingParams(temperature=0,
|
|
|
use_beam_search=True,
|
|
|
best_of=2)
|
|
|
-
|
|
|
for idx in range(n):
|
|
|
fake_logits[i, i + idx] = 1e2
|
|
|
expected_tokens.append(i + idx)
|
|
|
seq_group_metadata_list.append(
|
|
|
- SequenceGroupMetadata(request_id=f"test_{i}",
|
|
|
- is_prompt=True,
|
|
|
- seq_data={0: SequenceData([1, 2, 3])},
|
|
|
- sampling_params=sampling_params,
|
|
|
- block_tables={0: [1]},
|
|
|
- persistent_data={}))
|
|
|
-
|
|
|
- # pylint: disable=protected-access
|
|
|
- _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
|
|
+ SequenceGroupMetadata(
|
|
|
+ request_id=f"test_{i}",
|
|
|
+ is_prompt=True,
|
|
|
+ seq_data={0: SequenceData([1, 2, 3])},
|
|
|
+ sampling_params=sampling_params,
|
|
|
+ block_tables={0: [1]},
|
|
|
+ persistent_data={},
|
|
|
+ ))
|
|
|
+ prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
|
|
+
|
|
|
+ sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
|
|
+ prompt_lens)
|
|
|
sampler_output = sampler(embedding=None,
|
|
|
hidden_states=input_tensor,
|
|
|
- input_metadata=input_metadata)
|
|
|
+ sampling_metadata=sampling_metadata)
|
|
|
for i, sequence_output in enumerate(sampler_output):
|
|
|
if seq_group_metadata_list[i].sampling_params.use_beam_search:
|
|
|
continue
|
|
|
- for nth_output in sequence_output:
|
|
|
+ for nth_output in sequence_output.samples:
|
|
|
assert nth_output.output_token in expected_tokens
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
|
|
+def test_sampler_logits_processors(seed: int):
|
|
|
+ set_random_seed(seed)
|
|
|
+ batch_size = random.randint(1, 256)
|
|
|
+ input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
|
|
|
+
|
|
|
+ # This sample logits processor gives infinite score to the i-th token,
|
|
|
+ # where i is the length of the input sequence.
|
|
|
+ # We therefore expect the output token sequence to be [0, 1, 2, ...]
|
|
|
+ def pick_ith(token_ids, logits):
|
|
|
+ logits[len(token_ids)] = float("inf")
|
|
|
+ return logits
|
|
|
+
|
|
|
+ seq_group_metadata_list = []
|
|
|
+ prompt_lens = []
|
|
|
+ for i in range(batch_size):
|
|
|
+ seq_group_metadata_list.append(
|
|
|
+ SequenceGroupMetadata(
|
|
|
+ request_id=f"test_{i}",
|
|
|
+ is_prompt=True,
|
|
|
+ seq_data={0: SequenceData([1, 2, 3])},
|
|
|
+ sampling_params=SamplingParams(temperature=0,
|
|
|
+ logits_processors=[pick_ith]),
|
|
|
+ block_tables={0: [1]},
|
|
|
+ persistent_data={},
|
|
|
+ ))
|
|
|
+ prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
|
|
+
|
|
|
+ sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
|
|
+ prompt_lens)
|
|
|
+ sampler_output = sampler(embedding=None,
|
|
|
+ hidden_states=input_tensor,
|
|
|
+ sampling_metadata=sampling_metadata)
|
|
|
+ for _, sequence_output in enumerate(sampler_output):
|
|
|
+ for idx, nth_output in enumerate(sequence_output.samples):
|
|
|
+ assert nth_output.output_token == idx
|