Browse Source

feat: add model and sampler tests

AlpinDale 1 year ago
parent
commit
3738f09886
2 changed files with 215 additions and 0 deletions
  1. 35 0
      tests/models/test_models.py
  2. 180 0
      tests/samplers/test_samplers.py

+ 35 - 0
tests/models/test_models.py

@@ -0,0 +1,35 @@
+import pytest
+
+MODELS = [
+    "EleutherAI/gpt-j-6b",
+    "EleutherAI/gpt-neox-20b",
+    "meta-llama/llama-2-7b-hf",
+    "/mistralai/Mistral-7B-v0.1",
+]
+
+@pytest.mark.parametrize("model", MODELS)
+@pytest.mark.parametrize("dtype", ["half"])
+@pytest.mark.parametrize("max_tokens", [128])
+def test_models(
+    hf_runner,
+    aphrodite_runner,
+    example_prompts,
+    model: str,
+    dtype: str,
+    max_tokens: int,
+) -> None:
+    hf_model = hf_runner(model, dtype=dtype)
+    hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
+    del hf_model
+
+    aphrodite_model = aphrodite_runner(model, dtype=dtype)
+    aphrodite_outputs = aphrodite_model.generate_greedy(example_prompts, max_tokens)
+    del aphrodite_model
+
+    for i in range(len(example_prompts)):
+        hf_output_ids, hf_output_str = hf_outputs[i]
+        aphrodite_output_ids, aphrodite_output_str = aphrodite_outputs[i]
+        assert hf_output_str == aphrodite_output_str, (
+            f"Test{i}:\nHF: {hf_output_str!r}\nAphrodite: {aphrodite_output_str!r}")
+        assert hf_output_ids == aphrodite_output_ids, (
+            f"Test{i}:\nHF: {hf_output_ids}\nAphrodite: {aphrodite_output_ids}")

+ 180 - 0
tests/samplers/test_samplers.py

@@ -0,0 +1,180 @@
+import pytest
+import random
+from typing import Tuple
+from unittest.mock import patch
+
+import torch
+
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.utils import set_random_seed
+from aphrodite.common.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
+from aphrodite.task_handler.worker import Worker
+
+
+class MockLogitsSampler(Sampler):
+
+    def __init__(self, vocab_size: int, fake_logits: torch.Tensor):
+        super().__init__(vocab_size=vocab_size)
+        self.fake_logits = fake_logits
+
+    def forward(self, *args, **kwargs):
+        with patch("aphrodite.modeling.layers.sampler._prune_hidden_states",
+                    lambda x, y: x):
+            with patch("aphrodite.modeling.layers.sampler._get_logits",
+                      lambda *args, **kwargs: self.fake_logits):
+                return super().forward(*args, **kwargs)
+            
+def _prepare_test(
+    batch_size: int
+) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, Worker]:
+    vocab_size = 32000
+    input_tensor = torch.rand((batch_size, 1024),
+                             device='cuda',
+                             dtype=torch.float16)
+    fake_logits = torch.full((batch_size, vocab_size),
+                             1e-2,
+                             device=input_tensor.device,
+                             dtype=input_tensor.dtype)
+    sampler = MockLogitsSampler(32000, fake_logits)
+    worker = Worker(None, None, None)
+    worker.block_size = 16
+    return input_tensor, fake_logits, sampler, worker
+
+RANDOM_SEEDS = list(range(128))
+
+@pytest.mark.parametrize("seed", RANDOM_SEEDS)
+def test_sampler_all_greedy(seed: int):
+    set_random_seed(seed)
+    batch_size = random.randint(1, 256)
+    input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
+
+    seq_group_metadata_list = []
+    for i in range(batch_size):
+        seq_group_metadata_list.append(
+            SequenceGroupMetadata(
+                request_id=f'test_{i}',
+                is_prompt=True,
+                seq_data={0: SequenceData([1, 2, 3])},
+                sampling_params=SamplingParams(temperature=0, ),
+                block_tables={0: [1]},
+            ))
+    
+    _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
+    sampler_output = sampler(embedding=None,
+                            hidden_states=input_tensor,
+                            input_metadata=input_metadata)
+    expected = torch.argmax(fake_logits, dim=-1)
+    for i, sequence_output in enumerate(sampler_output):
+        for nth_output in sequence_output:
+            assert nth_output.output_token == expected[i].item()
+
+@pytest.mark.parametrize("seed", RANDOM_SEEDS)
+def test_sampler_all_random(seed: int):
+    set_random_seed(seed)
+    batch_size = random.randint(1, 256)
+    input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
+
+    for i in range(batch_size):
+        fake_logits[i, i] = 1e2
+    
+    seq_group_metadata_list = []
+    for i in range(batch_size):
+        seq_group_metadata_list.append(
+            SequenceGroupMetadata(
+                request_id=f"test_{i}",
+                is_prompt=True,
+                seq_data={0: SequenceData([1, 2, 3])},
+                sampling_params=SamplingParams(
+                    temperature=1.0,
+                    n=random.randint(1, 10),
+                ),
+                block_tables={0: [1]},
+            ))
+    _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
+    sampler_output = sampler(embedding=None,
+                             hidden_states=input_tensor,
+                             input_metadata=input_metadata)
+    for i, sequence_output in enumerate(sampler_output):
+        for nth_output in sequence_output:
+            assert nth_output.output_token == i
+
+@pytest.mark.parametrize("seed", RANDOM_SEEDS)
+def test_sampler_all_beam(seed: int):
+    set_random_seed(seed)
+    batch_size = random.randint(1, 256)
+    input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
+
+    seq_group_metadata_list = []
+    for i in range(batch_size):
+        seq_group_metadata_list.append(
+            SequenceGroupMetadata(
+                request_id=f"test_{i}",
+                is_prompt=True,
+                seq_data={0: SequenceData([1, 2, 3])},
+                sampling_params=SamplingParams(
+                    temperature=0,
+                    best_of=2,
+                    use_beam_search=True,
+                ),
+                block_tables={0: [1]},
+            ))
+    _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
+    sampler(embedding=None,
+            hidden_states=input_tensor,
+            input_metadata=input_metadata)
+
+@pytest.mark.parametrize("seed", RANDOM_SEEDS)
+def test_sampler_mixed(seed: int):
+    set_random_seed(seed)
+    batch_size = random.randint(1, 256)
+    input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
+
+    seq_group_metadata_list = []
+    expected_tokens = []
+    for i in range(batch_size):
+        n = 1
+        sampling_type = random.randint(0, 2)
+        if sampling_type == 0:
+            sampling_params = SamplingParams(temperature=0)
+        elif sampling_type == 1:
+            n = random.randint(1, 10)
+            sampling_params = SamplingParams(
+                temperature=random.random() + 0.1,
+                top_p=min(random.random() + 0.1, 1),
+                top_k=random.randint(0, 10) or -1,
+                top_a=min(random.random() + 0.1, 2),
+                tfs=min(random.random() + 0.1, 1),
+                eta_cutoff=random.randint(0, 10) or 0,
+                epsilon_cutoff=random.randint(0, 10) or 0,
+                typical_p=min(random.random() + 0.1, 1),
+                presence_penalty=random.randint(0, 1),
+                frequency_penalty=random.randint(0, 1),
+                repetition_penalty=min(random.random() + 0.1, 1),
+                n=n,
+            )
+        else:
+            sampling_params = SamplingParams(temperature=0,
+                                             use_beam_search=True,
+                                             best_of=2)
+        
+        for idx in range(n):
+            fake_logits[i, i + idx] = 1e2
+            expected_tokens.append(i + idx)
+        seq_group_metadata_list.append(
+            SequenceGroupMetadata(
+                request_id=f"test_{i}",
+                is_prompt=True,
+                seq_data={0: SequenceData([1, 2, 3])},
+                sampling_params=sampling_params,
+                block_tables={0: [1]},
+            ))
+        
+    _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
+    sampler_output = sampler(embedding=None,
+                             hidden_states=input_tensor,
+                             input_metadata=input_metadata)
+    for i, sequence_output in enumerate(sampler_output):
+        if seq_group_metadata_list[i].sampling_params.use_beam_search:
+            continue
+        for nth_output in sequence_output:
+            assert nth_output.output_token in expected_tokens