浏览代码

fix: test units (#201)

* update async engine tests

* update kernel tests

* add test for mistral

* update logprobs test

* samplers will be this way for now

* move the benchmarks to their own dir

* add regression test

* update conftest

* formatting
AlpinDale 1 年之前
父节点
当前提交
e1f3fd1e02

+ 4 - 12
aphrodite/endpoints/ooba/api_server.py

@@ -2,7 +2,7 @@ import argparse
 import json
 from typing import AsyncGenerator
 
-from fastapi import (BackgroundTasks, Header, FastAPI, HTTPException, Request)
+from fastapi import (BackgroundTasks, FastAPI, HTTPException, Request)
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.responses import JSONResponse, Response, StreamingResponse
 import uvicorn
@@ -32,12 +32,10 @@ app.add_middleware(
 parser = argparse.ArgumentParser()
 parser.add_argument("--host", type=str, default="localhost")
 parser.add_argument("--port", type=int, default=2242)
-parser.add_argument("--api-keys", nargs="*", default=["EMPTY"])
 parser.add_argument("--served-model-name", type=str, default=None)
 parser = AsyncEngineArgs.add_cli_args(parser)
 args = parser.parse_args()
 engine_args = AsyncEngineArgs.from_cli_args(args)
-valid_api_keys = args.api_keys
 if args.served_model_name is not None:
     served_model = args.served_model_name
 else:
@@ -45,8 +43,7 @@ else:
 
 
 @app.post("/api/v1/generate")
-async def generate(
-    request: Request, x_api_key: str = Header(None)) -> Response:
+async def generate(request: Request) -> Response:
     """Generate completion for the request.
 
     The request should be a JSON object with the following fields:
@@ -54,9 +51,6 @@ async def generate(
     - stream: whether to stream the results or not.
     - other fields: the sampling parameters (See `SamplingParams` for details).
     """
-    if x_api_key is None or x_api_key not in valid_api_keys:
-        raise HTTPException(status_code=401,
-                            detail="Unauthorized. Please acquire an API key.")
 
     request_dict = await request.json()
     prompt = request_dict.pop("prompt")
@@ -135,11 +129,9 @@ async def generate(
 
 
 @app.get("/api/v1/model")
-async def get_model_name(x_api_key: str = Header(None)) -> JSONResponse:
+async def get_model_name() -> JSONResponse:
     """Return the model name based on the EngineArgs configuration."""
-    if x_api_key is None or x_api_key not in valid_api_keys:
-        raise HTTPException(status_code=401,
-                            detail="Unauthorized. Please acquire an API key.")
+
     if engine is not None:
         result = {"result": f"aphrodite/{served_model}"}
         return JSONResponse(content=result)

+ 2 - 3
tests/async_engine/api_server_async_aphrodite.py

@@ -1,4 +1,4 @@
-"""API server with some extra logging for testing."""
+"""aphrodite.endpoints.ooba.api_server with some extra logging for testing."""
 import argparse
 from typing import Any, Dict
 
@@ -14,7 +14,6 @@ app = aphrodite.endpoints.ooba.api_server.app
 
 class AsyncAphroditeWithStats(AsyncAphrodite):
 
-    # pylint: disable=redefined-outer-name
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
         self._num_aborts = 0
@@ -36,7 +35,7 @@ def stats() -> Response:
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
     parser.add_argument("--host", type=str, default="localhost")
-    parser.add_argument("--port", type=int, default=8000)
+    parser.add_argument("--port", type=int, default=2242)
     parser = AsyncEngineArgs.add_cli_args(parser)
     args = parser.parse_args()
 

+ 18 - 18
tests/async_engine/test_api_server.py

@@ -8,15 +8,11 @@ import pytest
 import requests
 
 
-def _query_server(prompt: str) -> dict:
-    headers = {
-        "x-api-key": "EMPTY"  # change as needed
-    }
-    response = requests.post("http://localhost:2242/api/v1/generate",
-                             headers=headers,
+def _query_server(prompt: str, max_tokens: int = 5) -> dict:
+    response = requests.post("http://localhost:2242/generate",
                              json={
                                  "prompt": prompt,
-                                 "max_tokens": 100,
+                                 "max_tokens": max_tokens,
                                  "temperature": 0,
                                  "ignore_eos": True
                              })
@@ -24,11 +20,14 @@ def _query_server(prompt: str) -> dict:
     return response.json()
 
 
+def _query_server_long(prompt: str) -> dict:
+    return _query_server(prompt, max_tokens=500)
+
+
 @pytest.fixture
 def api_server():
     script_path = Path(__file__).parent.joinpath(
         "api_server_async_engine.py").absolute()
-    # pylint: disable=consider-using-with
     uvicorn_process = subprocess.Popen([
         sys.executable, "-u",
         str(script_path), "--model", "EleutherAI/pythia-70m-deduped"
@@ -37,7 +36,6 @@ def api_server():
     uvicorn_process.terminate()
 
 
-# pylint: disable=redefined-outer-name, unused-argument
 def test_api_server(api_server):
     """
     Run the API server and test it.
@@ -50,14 +48,14 @@ def test_api_server(api_server):
     """
     with Pool(32) as pool:
         # Wait until the server is ready
-        prompts = ["Hello world"] * 1
+        prompts = ["warm up"] * 1
         result = None
         while not result:
-            # pylint: disable=bare-except
             try:
-                for result in pool.map(_query_server, prompts):
+                for r in pool.map(_query_server, prompts):
+                    result = r
                     break
-            except:
+            except requests.exceptions.ConnectionError:
                 time.sleep(1)
 
         # Actual tests start here
@@ -66,28 +64,30 @@ def test_api_server(api_server):
             assert result
 
         num_aborted_requests = requests.get(
-            "http://localhost:2242/stats").json()["num_aborted_requests"]
+            "http://localhost:8000/stats").json()["num_aborted_requests"]
         assert num_aborted_requests == 0
 
         # Try with 100 prompts
-        prompts = ["Hello world"] * 100
+        prompts = ["test prompt"] * 100
         for result in pool.map(_query_server, prompts):
             assert result
 
+    with Pool(32) as pool:
         # Cancel requests
-        pool.map_async(_query_server, prompts)
+        prompts = ["canceled requests"] * 100
+        pool.map_async(_query_server_long, prompts)
         time.sleep(0.01)
         pool.terminate()
         pool.join()
 
         # check cancellation stats
         num_aborted_requests = requests.get(
-            "http://localhost:2242/stats").json()["num_aborted_requests"]
+            "http://localhost:8000/stats").json()["num_aborted_requests"]
         assert num_aborted_requests > 0
 
     # check that server still runs after cancellations
     with Pool(32) as pool:
         # Try with 100 prompts
-        prompts = ["Hello world"] * 100
+        prompts = ["test prompt after canceled"] * 100
         for result in pool.map(_query_server, prompts):
             assert result

+ 119 - 0
tests/async_engine/test_openai_server.py

@@ -0,0 +1,119 @@
+from argparse import Namespace
+from dataclasses import dataclass
+
+import pytest
+from fastapi.testclient import TestClient
+
+from aphrodite.endpoints.openai.api_server import *
+
+# Define models, templates, and their corresponding expected outputs
+MODEL_TEMPLATE_GENERATON_OUTPUT = [
+    ("EleutherAI/pythia-70m-deduped", None, True,
+     "Hello</s>Hi there!</s>What is the capital of</s>"),
+    ("EleutherAI/pythia-70m-deduped", None, False,
+     "Hello</s>Hi there!</s>What is the capital of</s>"),
+    ("EleutherAI/pythia-70m-deduped", "../../examples/template_chatml.jinja",
+     True, """<|im_start|>user
+Hello<|im_end|>
+<|im_start|>assistant
+Hi there!<|im_end|>
+<|im_start|>user
+What is the capital of<|im_end|>
+<|im_start|>assistant
+"""),
+    ("EleutherAI/pythia-70m-deduped", "../../examples/template_chatml.jinja",
+     False, """<|im_start|>user
+Hello<|im_end|>
+<|im_start|>assistant
+Hi there!<|im_end|>
+<|im_start|>user
+What is the capital of""")
+]
+
+TEST_MESSAGES = [
+    {
+        'role': 'user',
+        'content': 'Hello'
+    },
+    {
+        'role': 'assistant',
+        'content': 'Hi there!'
+    },
+    {
+        'role': 'user',
+        'content': 'What is the capital of'
+    },
+]
+client = TestClient(app)
+
+
+@dataclass
+class MockTokenizer:
+    chat_template = None
+
+
+def test_load_chat_template():
+    # Testing chatml template
+    template = "../../examples/chatml_template.jinja"
+    mock_args = Namespace(chat_template=template)
+    tokenizer = MockTokenizer()
+
+    # Call the function with the mocked args
+    load_chat_template(mock_args, tokenizer)
+
+    template_content = tokenizer.chat_template
+
+    # Test assertions
+    assert template_content is not None
+    # Hard coded value for chatml_template.jinja
+    assert template_content == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
+{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}"""
+
+
+def test_no_load_chat_template():
+    # Testing chatml template
+    template = "../../examples/does_not_exist"
+    mock_args = Namespace(chat_template=template)
+    tokenizer = MockTokenizer()
+
+    # Call the function with the mocked args
+    load_chat_template(mock_args, tokenizer=tokenizer)
+    template_content = tokenizer.chat_template
+
+    # Test assertions
+    assert template_content is not None
+    # Hard coded value for chatml_template.jinja
+    assert template_content == """../../examples/does_not_exist"""
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "model,template,add_generation_prompt,expected_output",
+    MODEL_TEMPLATE_GENERATON_OUTPUT)
+async def test_get_gen_prompt(model, template, add_generation_prompt,
+                              expected_output):
+    # Initialize the tokenizer
+    tokenizer = get_tokenizer(tokenizer_name=model)
+
+    mock_args = Namespace(chat_template=template)
+    load_chat_template(mock_args, tokenizer)
+
+    # Create a mock request object using keyword arguments
+    mock_request = ChatCompletionRequest(
+        model=model,
+        messages=TEST_MESSAGES,
+        add_generation_prompt=add_generation_prompt)
+
+    # Call the function and get the result
+    result = tokenizer.apply_chat_template(
+        conversation=mock_request.messages,
+        tokenize=False,
+        add_generation_prompt=mock_request.add_generation_prompt)
+
+    # Test assertion
+    assert result == expected_output, f"The generated prompt does not match the expected output for model {model} and template {template}"
+
+
+def test_health_endpoint():
+    response = client.get("/health")
+    assert response.status_code == 200

+ 0 - 0
tests/attention.py → tests/benchmarks/attention.py


+ 0 - 0
tests/latency.py → tests/benchmarks/latency.py


+ 0 - 0
tests/serving.py → tests/benchmarks/serving.py


+ 0 - 0
tests/throughput.py → tests/benchmarks/throughput.py


+ 24 - 11
tests/conftest.py

@@ -1,3 +1,4 @@
+import os
 from typing import List, Optional, Tuple
 
 import pytest
@@ -7,21 +8,33 @@ from transformers import AutoModelForCausalLM
 from aphrodite import LLM, SamplingParams
 from aphrodite.transformers_utils.tokenizer import get_tokenizer
 
-_TEST_PROMPTS = [
-    # pylint: disable=line-too-long
-    "Develop a detailed method for integrating a blockchain-based distributed ledger system into a pre-existing finance management application. The focus should be on ensuring security, transparency, and real-time updates of transactions.",
-    "Design an AI-powered predictive analytics engine capable of identifying trends and patterns from unstructured data sets. The engine should be adaptable to different industry requirements such as healthcare, finance, and marketing.",
-    "Construct a comprehensive model for a multi-cloud architecture that can smoothly transition between different cloud platforms (AWS, Google Cloud, Azure) without any interruption in service or loss of data.",
-    "Propose a strategy for integrating Quantum Computing capabilities into existing high-performance computing (HPC) systems. The approach should consider potential challenges and solutions of Quantum-HPC integration.",
-    "Create a robust cybersecurity framework for an Internet of Things (IoT) ecosystem. The framework should be capable of detecting, preventing, and mitigating potential security breaches.",
-    "Develop a scalable high-frequency trading algorithm that uses machine learning to predict and respond to microtrends in financial markets. The algorithm should be capable of processing real-time data and executing trades within milliseconds.",
-    "Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'",
-]
+_TEST_DIR = os.path.dirname(__file__)
+_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
+_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
+
+
+def _read_prompts(filename: str) -> str:
+    prompts = []
+    with open(filename, "r") as f:
+        prompt = f.readline()
+        prompts.append(prompt)
+    return prompts
 
 
 @pytest.fixture
 def example_prompts() -> List[str]:
-    return _TEST_PROMPTS
+    prompts = []
+    for filename in _TEST_PROMPTS:
+        prompts += _read_prompts(filename)
+    return prompts
+
+
+@pytest.fixture
+def example_long_prompts() -> List[str]:
+    prompts = []
+    for filename in _LONG_PROMPTS:
+        prompts += _read_prompts(filename)
+    return prompts
 
 
 _STR_DTYPE_TO_TORCH_DTYPE = {

+ 9 - 8
tests/kernels/conftest.py

@@ -12,6 +12,7 @@ def create_kv_caches(
     head_size: int,
     dtype: torch.dtype,
     seed: int,
+    device: str,
 ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
     torch.random.manual_seed(seed)
     torch.cuda.manual_seed(seed)
@@ -23,19 +24,19 @@ def create_kv_caches(
     for _ in range(num_layers):
         key_cache = torch.empty(size=key_cache_shape,
                                 dtype=dtype,
-                                device='cuda')
+                                device=device)
         key_cache.uniform_(-scale, scale)
         key_caches.append(key_cache)
 
     value_cache_shape = (num_blocks, num_heads, head_size, block_size)
-    values_caches = []
+    value_caches = []
     for _ in range(num_layers):
-        values_cache = torch.empty(size=value_cache_shape,
-                                   dtype=dtype,
-                                   device='cuda')
-        values_cache.uniform_(-scale, scale)
-        values_caches.append(values_cache)
-    return key_caches, values_caches
+        value_cache = torch.empty(size=value_cache_shape,
+                                  dtype=dtype,
+                                  device=device)
+        value_cache.uniform_(-scale, scale)
+        value_caches.append(value_cache)
+    return key_caches, value_caches
 
 
 @pytest.fixture()

+ 25 - 23
tests/kernels/test_activation.py

@@ -1,39 +1,35 @@
-"""Test suite for the activation kernel."""
 import pytest
 import torch
-import torch.nn.functional as F
-from transformers.activations import get_activation
 
-from aphrodite._C import ops as activation_ops
+from aphrodite.modeling.layers.activation import FastGELU, NewGELU, SiluAndMul
 
 DTYPES = [torch.half, torch.bfloat16, torch.float]
-NUM_TOKENS = [7, 38, 2048]
-D = [512, 4096, 5120, 13824]  # arbitrary values for testing
+NUM_TOKENS = [7, 83, 2048]  # Arbitrary values for testing
+D = [512, 4096, 5120, 13824]  # Arbitrary values for testing
 SEEDS = [0]
-
-
-def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
-    x1, x2 = x.chunk(chunks=2, dim=1)
-    return F.silu(x1) * x2
+DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
 
 
 @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
 @pytest.mark.parametrize("d", D)
 @pytest.mark.parametrize("dtype", DTYPES)
 @pytest.mark.parametrize("seed", SEEDS)
+@pytest.mark.parametrize("device", DEVICES)
 @torch.inference_mode()
 def test_silu_and_mul(
     num_tokens: int,
     d: int,
     dtype: torch.dtype,
     seed: int,
+    device: int,
 ) -> None:
     torch.random.manual_seed(seed)
     torch.cuda.manual_seed(seed)
-    x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda")
-    out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
-    activation_ops.silu_and_mul(out, x)
-    ref_out = ref_silu_and_mul(x)
+    gpu_id = f"cuda:{device}"
+    x = torch.randn(num_tokens, 2 * d, dtype=dtype, device=gpu_id)
+    layer = SiluAndMul()
+    out = layer(x)
+    ref_out = layer._forward(x)
     assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
 
 
@@ -41,19 +37,22 @@ def test_silu_and_mul(
 @pytest.mark.parametrize("d", D)
 @pytest.mark.parametrize("dtype", DTYPES)
 @pytest.mark.parametrize("seed", SEEDS)
+@pytest.mark.parametrize("device", DEVICES)
 @torch.inference_mode()
 def test_gelu_new(
     num_tokens: int,
     d: int,
     dtype: torch.dtype,
     seed: int,
+    device: int,
 ) -> None:
     torch.random.manual_seed(seed)
     torch.cuda.manual_seed(seed)
-    x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
-    out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
-    activation_ops.gelu_new(out, x)
-    ref_out = get_activation("gelu_new")(x)
+    gpu_id = f"cuda:{device}"
+    x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id)
+    layer = NewGELU()
+    out = layer(x)
+    ref_out = layer._forward(x)
     assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
 
 
@@ -61,16 +60,19 @@ def test_gelu_new(
 @pytest.mark.parametrize("d", D)
 @pytest.mark.parametrize("dtype", DTYPES)
 @pytest.mark.parametrize("seed", SEEDS)
+@pytest.mark.parametrize("device", DEVICES)
 def test_gelu_fast(
     num_tokens: int,
     d: int,
     dtype: torch.dtype,
     seed: int,
+    device: int,
 ) -> None:
     torch.random.manual_seed(seed)
     torch.cuda.manual_seed(seed)
-    x = torch.rand(num_tokens, d, dtype=dtype, device="cuda")
-    out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
-    activation_ops.gelu_fast(out, x)
-    ref_out = get_activation("gelu_fast")(x)
+    gpu_id = f"cuda:{device}"
+    x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id)
+    layer = FastGELU()
+    out = layer(x)
+    ref_out = layer._forward(x)
     assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)

+ 22 - 20
tests/kernels/test_attention.py

@@ -6,14 +6,14 @@ import torch
 from xformers import ops as xops
 from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
 
-from aphrodite._C import ops as attention_ops
+from aphrodite._C import ops
 from aphrodite.common.utils import get_max_shared_memory_bytes
 
 FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
 # This will change depending on the compute capability.
 # - 512 as a buffer
 MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
-NUM_BLOCKS = 128  # Arbitrary values for testing
+NUM_BLOCKS = 40000  # Arbitrary values for testing
 PARTITION_SIZE = 512
 
 DTYPES = [torch.half, torch.bfloat16, torch.float]
@@ -24,6 +24,7 @@ HEAD_SIZES = [64, 80, 96, 112, 128, 256]
 BLOCK_SIZES = [16, 32]
 USE_ALIBI = [False, True]
 SEEDS = [0]
+DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
 
 
 def ref_masked_attention(
@@ -87,7 +88,7 @@ def ref_single_query_cached_kv_attention(
         alibi_bias = None
         if alibi_slopes is not None:
             # Create the ALiBi bias used in the paged attention kernel.
-            position_ids = torch.arange(context_len, device="cuda").int()
+            position_ids = torch.arange(context_len, device=query.device).int()
             alibi_bias = (position_ids - context_len + 1).float()
             alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
                 1, 1, -1)
@@ -105,6 +106,7 @@ def ref_single_query_cached_kv_attention(
 @pytest.mark.parametrize("block_size", BLOCK_SIZES)
 @pytest.mark.parametrize("dtype", DTYPES)
 @pytest.mark.parametrize("seed", SEEDS)
+@pytest.mark.parametrize("device", DEVICES)
 def test_paged_attention(
     kv_cache_factory,
     version: str,
@@ -115,35 +117,33 @@ def test_paged_attention(
     block_size: int,
     dtype: torch.dtype,
     seed: int,
+    device: int,
 ) -> None:
     random.seed(seed)
     torch.random.manual_seed(seed)
     torch.cuda.manual_seed(seed)
-
+    gpu_id = f"cuda:{device}"
     scale = float(1.0 / (head_size**0.5))
     num_query_heads, num_kv_heads = num_heads
     query = torch.empty(num_seqs,
                         num_query_heads,
                         head_size,
                         dtype=dtype,
-                        device="cuda")
+                        device=gpu_id)
     query.uniform_(-scale, scale)
 
     assert num_query_heads % num_kv_heads == 0
     num_queries_per_kv = num_query_heads // num_kv_heads
-    head_mapping = torch.repeat_interleave(
-        torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
-        num_queries_per_kv)
     alibi_slopes = None
     if use_alibi:
         alibi_slopes = torch.randn(num_query_heads,
                                    dtype=torch.float,
-                                   device="cuda")
+                                   device=gpu_id)
 
     context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
     context_lens[-1] = MAX_SEQ_LEN
     max_context_len = max(context_lens)
-    context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
+    context_lens = torch.tensor(context_lens, dtype=torch.int, device=gpu_id)
 
     # Create the block tables.
     max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
@@ -154,23 +154,23 @@ def test_paged_attention(
             for _ in range(max_num_blocks_per_seq)
         ]
         block_tables.append(block_table)
-    block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
+    block_tables = torch.tensor(block_tables, dtype=torch.int, device=gpu_id)
 
     # Create the KV caches.
     key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
                                                 num_kv_heads, head_size, dtype,
-                                                seed)
+                                                seed, gpu_id)
     key_cache, value_cache = key_caches[0], value_caches[0]
 
     # Call the paged attention kernel.
     output = torch.empty_like(query)
     if version == "v1":
-        attention_ops.paged_attention_v1(
+        ops.paged_attention_v1(
             output,
             query,
             key_cache,
             value_cache,
-            head_mapping,
+            num_kv_heads,
             scale,
             block_tables,
             context_lens,
@@ -194,7 +194,7 @@ def test_paged_attention(
             device=output.device,
         )
         max_logits = torch.empty_like(exp_sums)
-        attention_ops.paged_attention_v2(
+        ops.paged_attention_v2(
             output,
             exp_sums,
             max_logits,
@@ -202,7 +202,7 @@ def test_paged_attention(
             query,
             key_cache,
             value_cache,
-            head_mapping,
+            num_kv_heads,
             scale,
             block_tables,
             context_lens,
@@ -211,7 +211,7 @@ def test_paged_attention(
             alibi_slopes,
         )
     else:
-        assert False, f"Unknown version: {version}"
+        raise AssertionError(f"Unknown version: {version}")
 
     # Run the reference implementation.
     ref_output = torch.empty_like(query)
@@ -252,7 +252,7 @@ def ref_multi_query_kv_attention(
         attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
                                diagonal=1)
         attn_mask = attn_mask * torch.finfo(dtype).min
-        attn_mask = attn_mask.to(dtype=dtype, device="cuda")
+        attn_mask = attn_mask.to(dtype=dtype, device=query.device)
 
         ref_output = ref_masked_attention(
             query[start_idx:end_idx],
@@ -272,6 +272,7 @@ def ref_multi_query_kv_attention(
 @pytest.mark.parametrize("head_size", HEAD_SIZES)
 @pytest.mark.parametrize("dtype", DTYPES)
 @pytest.mark.parametrize("seed", SEEDS)
+@pytest.mark.parametrize("device", DEVICES)
 @torch.inference_mode()
 def test_multi_query_kv_attention(
     num_seqs: int,
@@ -279,11 +280,12 @@ def test_multi_query_kv_attention(
     head_size: int,
     dtype: torch.dtype,
     seed: int,
+    device: int,
 ) -> None:
     random.seed(seed)
     torch.random.manual_seed(seed)
     torch.cuda.manual_seed(seed)
-
+    gpu_id = f"cuda:{device}"
     # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
     # As the xformers library is already tested with its own tests, we can use
     # a smaller MAX_SEQ_LEN here.
@@ -297,7 +299,7 @@ def test_multi_query_kv_attention(
                       num_query_heads + 2 * num_kv_heads,
                       head_size,
                       dtype=dtype,
-                      device="cuda")
+                      device=gpu_id)
     qkv.uniform_(-scale, scale)
     query, key, value = qkv.split(
         [num_query_heads, num_kv_heads, num_kv_heads], dim=1)

+ 29 - 24
tests/kernels/test_cache.py

@@ -6,14 +6,15 @@ import torch
 from aphrodite._C import cache_ops
 
 DTYPES = [torch.half, torch.bfloat16, torch.float]
-NUM_TOKENS = [7, 83, 2048]  # Arbitrary values for testing
-NUM_LAYERS = [5]  # Arbitrary values for testing
+NUM_TOKENS = [83]  # Arbitrary values for testing
+NUM_LAYERS = [1]  # Arbitrary values for testing
 NUM_HEADS = [8]  # Arbitrary values for testing
 HEAD_SIZES = [64, 80, 96, 112, 128, 256]
 BLOCK_SIZES = [8, 16, 32]
-NUM_BLOCKS = [1024]  # Arbitrary values for testing
-NUM_MAPPINGS = [32, 256]  # Arbitrary values for testing
+NUM_BLOCKS = [1024, 36000]  # Arbitrary values for testing
+NUM_MAPPINGS = [256]  # Arbitrary values for testing
 SEEDS = [0]
+DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
 
 
 @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@@ -24,6 +25,7 @@ SEEDS = [0]
 @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
 @pytest.mark.parametrize("dtype", DTYPES)
 @pytest.mark.parametrize("seed", SEEDS)
+@pytest.mark.parametrize("device", DEVICES)
 @torch.inference_mode()
 def test_copy_blocks(
     kv_cache_factory,
@@ -35,43 +37,44 @@ def test_copy_blocks(
     num_blocks: int,
     dtype: torch.dtype,
     seed: int,
+    device: int,
 ) -> None:
     random.seed(seed)
     torch.random.manual_seed(seed)
     torch.cuda.manual_seed(seed)
-
+    gpu_id = f"cuda:{device}"
     # Generate random block mappings where each source block is mapped to two
     # destination blocks.
     assert 2 * num_mappings <= num_blocks
     src_blocks = random.sample(range(num_blocks), num_mappings)
-    remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
-    dst_blocks = random.sample(remaining_blocks, 2 * num_mappings)
-    block_mapping = {}
+    remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
+    dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
+    copy_src = []
+    copy_dst = []
     for i in range(num_mappings):
-        src = src_blocks[i]
-        dst1 = dst_blocks[2 * i]
-        dst2 = dst_blocks[2 * i + 1]
-        block_mapping[src] = [dst1, dst2]
+        copy_src.append(src_blocks[i])
+        copy_dst.append(dst_blocks[2 * i])
+        copy_src.append(src_blocks[i])
+        copy_dst.append(dst_blocks[2 * i + 1])
 
     # Create the KV caches.
     key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
                                                 num_layers, num_heads,
-                                                head_size, dtype, seed)
+                                                head_size, dtype, seed, gpu_id)
 
     # Clone the KV caches.
     cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
     cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
 
     # Call the copy blocks kernel.
-    cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
+    cache_ops.copy_blocks(key_caches, value_caches, copy_src, copy_dst)
 
     # Run the reference implementation.
-    for src, dsts in block_mapping.items():
-        for dst in dsts:
-            for cloned_key_cache in cloned_key_caches:
-                cloned_key_cache[dst] = cloned_key_cache[src]
-            for cloned_value_cache in cloned_value_caches:
-                cloned_value_cache[dst] = cloned_value_cache[src]
+    for src, dst in zip(copy_src, copy_dst):
+        for cloned_key_cache in cloned_key_caches:
+            cloned_key_cache[dst].copy_(cloned_key_cache[src])
+        for cloned_value_cache in cloned_value_caches:
+            cloned_value_cache[dst].copy_(cloned_value_cache[src])
 
     # Compare the results.
     for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
@@ -88,6 +91,7 @@ def test_copy_blocks(
 @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
 @pytest.mark.parametrize("dtype", DTYPES)
 @pytest.mark.parametrize("seed", SEEDS)
+@pytest.mark.parametrize("device", DEVICES)
 @torch.inference_mode()
 def test_reshape_and_cache(
     kv_cache_factory,
@@ -98,28 +102,29 @@ def test_reshape_and_cache(
     num_blocks: int,
     dtype: torch.dtype,
     seed: int,
+    device: int,
 ) -> None:
     random.seed(seed)
     torch.random.manual_seed(seed)
     torch.cuda.manual_seed(seed)
-
+    gpu_id = f"cuda:{device}"
     # Create a random slot mapping.
     num_slots = block_size * num_blocks
     slot_mapping = random.sample(range(num_slots), num_tokens)
-    slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device="cuda")
+    slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=gpu_id)
 
     qkv = torch.randn(num_tokens,
                       3,
                       num_heads,
                       head_size,
                       dtype=dtype,
-                      device="cuda")
+                      device=gpu_id)
     _, key, value = qkv.unbind(dim=1)
 
     # Create the KV caches.
     key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
                                                 num_heads, head_size, dtype,
-                                                seed)
+                                                seed, gpu_id)
     key_cache, value_cache = key_caches[0], value_caches[0]
 
     # Clone the KV caches.

+ 28 - 39
tests/kernels/test_layernorm.py

@@ -1,61 +1,50 @@
 import pytest
 import torch
-import torch.nn as nn
 
-from aphrodite._C import ops as layernorm_ops
+from aphrodite.modeling.layers.layernorm import RMSNorm
 
 DTYPES = [torch.half, torch.bfloat16, torch.float]
-HIDDEN_SIZES = [67, 768, 2048, 5120, 8192]  # Arbitrary values for testing
 NUM_TOKENS = [7, 83, 4096]  # Arbitrary values for testing
+HIDDEN_SIZES = [768, 5120, 8192]  # Arbitrary values for testing
+ADD_RESIDUAL = [False, True]
 SEEDS = [0]
-
-
-class RefRMSNorm(nn.Module):
-
-    def __init__(self, hidden_size, eps=1e-6):
-        super().__init__()
-        weight = torch.empty(hidden_size)
-        weight.normal_(mean=1.0, std=0.1)
-        self.weight = nn.Parameter(weight)
-        self.variance_epsilon = eps
-
-    def forward(self, hidden_states):
-        input_dtype = hidden_states.dtype
-        hidden_states = hidden_states.to(torch.float32)
-        variance = hidden_states.pow(2).mean(-1, keepdim=True)
-        hidden_states = hidden_states * torch.rsqrt(variance +
-                                                    self.variance_epsilon)
-        return self.weight * hidden_states.to(input_dtype)
+DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
 
 
 @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
 @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
+@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
 @pytest.mark.parametrize("dtype", DTYPES)
 @pytest.mark.parametrize("seed", SEEDS)
+@pytest.mark.parametrize("device", DEVICES)
 @torch.inference_mode()
 def test_rms_norm(
     num_tokens: int,
     hidden_size: int,
+    add_residual: bool,
     dtype: torch.dtype,
     seed: int,
+    device: int,
 ) -> None:
     torch.random.manual_seed(seed)
     torch.cuda.manual_seed(seed)
-
-    scale = float(hidden_size**-0.5)
-    x = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda")
-    x.uniform_(-scale, scale)
-    ref = RefRMSNorm(hidden_size).to(dtype).cuda()
-
-    out = torch.empty_like(x)
-    layernorm_ops.rms_norm(
-        out,
-        x,
-        ref.weight.data,
-        ref.variance_epsilon,
-    )
-    ref_out = ref(x)
-
-    print("out: ", out)
-    print("ref_out: ", ref_out)
-    assert torch.allclose(out, ref_out, atol=1e-1, rtol=1e-3)
+    gpu_id = f"cuda:{device}"
+    layer = RMSNorm(hidden_size).to(dtype=dtype, device=gpu_id)
+    layer.weight.data.normal_(mean=1.0, std=0.1)
+    scale = 1 / (2 * hidden_size)
+    x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=gpu_id)
+    x *= scale
+    residual = torch.randn_like(x) * scale if add_residual else None
+
+    # NOTE: The reference implementation should be executed first
+    # because the custom kernel is in-place.
+    ref_out = layer._forward(x, residual)
+    out = layer(x, residual)
+    # NOTE: LayerNorm operators (including RMS) typically have larger
+    # numerical errors than other operators because they involve reductions.
+    # Therefore, we use a larger tolerance.
+    if add_residual:
+        assert torch.allclose(out[0], ref_out[0], atol=1e-2, rtol=1e-2)
+        assert torch.allclose(out[1], ref_out[1], atol=1e-2, rtol=1e-2)
+    else:
+        assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-2)

+ 29 - 136
tests/kernels/test_pos_encoding.py

@@ -1,119 +1,41 @@
-from typing import Optional, Tuple
+from typing import Optional
 
 import pytest
 import torch
-import torch.nn as nn
-import torch.nn.functional as F
 
-from aphrodite._C import pos_encoding_ops
+from aphrodite.modeling.layers.rotary_embedding import get_rope
 
 IS_NEOX_STYLE = [True, False]
 DTYPES = [torch.half, torch.bfloat16, torch.float]
 HEAD_SIZES = [64, 80, 96, 112, 128, 256]
 ROTARY_DIMS = [None, 32]  # None means rotary dim == head size
-NUM_HEADS = [7, 12, 40, 52]  # Arbitrary values for testing
-NUM_TOKENS = [11, 83, 2048]  # Arbitrary values for testing
+NUM_HEADS = [7, 17]  # Arbitrary values for testing
+BATCH_SIZES = [1, 5]  # Arbitrary values for testing
+SEQ_LENS = [11, 8192]  # Arbitrary values for testing
 SEEDS = [0]
-
-
-def rotate_neox(x: torch.Tensor) -> torch.Tensor:
-    x1 = x[..., :x.shape[-1] // 2]
-    x2 = x[..., x.shape[-1] // 2:]
-    return torch.cat((-x2, x1), dim=-1)
-
-
-def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
-    x1 = x[..., ::2]
-    x2 = x[..., 1::2]
-    x = torch.stack((-x2, x1), dim=-1)
-    return x.flatten(-2)
-
-
-def apply_rope(
-    q: torch.Tensor,
-    k: torch.Tensor,
-    cos: torch.Tensor,
-    sin: torch.Tensor,
-    is_neox_style: bool,
-) -> Tuple[torch.Tensor, torch.Tensor]:
-    rotate_fn = rotate_neox if is_neox_style else rotate_gptj
-    q_embed = (q * cos) + (rotate_fn(q) * sin)
-    k_embed = (k * cos) + (rotate_fn(k) * sin)
-    return q_embed, k_embed
-
-
-class RefRotaryEmbedding(nn.Module):
-    """Reference implementation of rotary embedding."""
-
-    def __init__(
-        self,
-        dim: int,
-        is_neox_style: bool,
-        max_position_embeddings: int = 8192,
-        base: int = 10000,
-    ) -> None:
-        super().__init__()
-        self.rotary_dim = dim
-        self.is_neox_style = is_neox_style
-        self.max_position_embeddings = max_position_embeddings
-
-        # Create cos and sin embeddings.
-        inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim))
-        t = torch.arange(max_position_embeddings).float()
-        freqs = torch.einsum("i,j->ij", t, inv_freq.float())
-        if is_neox_style:
-            emb = torch.cat((freqs, freqs), dim=-1)
-        else:
-            emb = torch.repeat_interleave(freqs, 2, -1)
-        cos = emb.cos().to(dtype=inv_freq.dtype)
-        sin = emb.sin().to(dtype=inv_freq.dtype)
-        self.register_buffer("cos_cached", cos, persistent=False)
-        self.register_buffer("sin_cached", sin, persistent=False)
-
-    def forward(
-        self,
-        positions: torch.Tensor,  # [num_tokens]
-        query: torch.Tensor,  # [num_tokens, num_heads, head_size]
-        key: torch.Tensor,  # [num_tokens, num_heads, head_size]
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        query_rot = query[..., :self.rotary_dim]
-        query_pass = query[..., self.rotary_dim:]
-        key_rot = key[..., :self.rotary_dim]
-        key_pass = key[..., self.rotary_dim:]
-
-        query_rot = query_rot.transpose(0, 1)
-        key_rot = key_rot.transpose(0, 1)
-        cos = F.embedding(positions, self.cos_cached)
-        sin = F.embedding(positions, self.sin_cached)
-
-        query_rot, key_rot = apply_rope(query_rot, key_rot, cos, sin,
-                                        self.is_neox_style)
-        query_rot = query_rot.transpose(0, 1).contiguous()
-        key_rot = key_rot.transpose(0, 1).contiguous()
-
-        query = torch.cat((query_rot, query_pass), dim=-1)
-        key = torch.cat((key_rot, key_pass), dim=-1)
-
-        # Output query/key shape: [num_tokens, num_tokens, head_size]
-        return query, key
+DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
 
 
 @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
-@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
+@pytest.mark.parametrize("batch_size", BATCH_SIZES)
+@pytest.mark.parametrize("seq_len", SEQ_LENS)
 @pytest.mark.parametrize("num_heads", NUM_HEADS)
 @pytest.mark.parametrize("head_size", HEAD_SIZES)
 @pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
 @pytest.mark.parametrize("dtype", DTYPES)
 @pytest.mark.parametrize("seed", SEEDS)
+@pytest.mark.parametrize("device", DEVICES)
 @torch.inference_mode()
 def test_rotary_embedding(
     is_neox_style: bool,
-    num_tokens: int,
+    batch_size: int,
+    seq_len: int,
     num_heads: int,
     head_size: int,
     rotary_dim: Optional[int],
     dtype: torch.dtype,
     seed: int,
+    device: int,
     max_position: int = 8192,
     base: int = 10000,
 ) -> None:
@@ -121,55 +43,26 @@ def test_rotary_embedding(
         rotary_dim = head_size
     torch.random.manual_seed(seed)
     torch.cuda.manual_seed(seed)
-
-    positions = torch.randint(0, max_position, (num_tokens, ), device="cuda")
-    query = torch.randn(num_tokens,
+    gpu_id = f"cuda:{device}"
+    if rotary_dim is None:
+        rotary_dim = head_size
+    rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
+    rope = rope.to(dtype=dtype, device=gpu_id)
+
+    positions = torch.randint(0,
+                              max_position, (batch_size, seq_len),
+                              device=gpu_id)
+    query = torch.randn(batch_size,
+                        seq_len,
                         num_heads * head_size,
                         dtype=dtype,
-                        device="cuda")
-    key = torch.randn(num_tokens,
-                      num_heads * head_size,
-                      dtype=dtype,
-                      device="cuda")
-
-    # Create the rotary embedding.
-    inv_freq = 1.0 / (base**(
-        torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
-    t = torch.arange(max_position).float()
-    freqs = torch.einsum("i,j -> ij", t, inv_freq)
-    cos = freqs.cos()
-    sin = freqs.sin()
-    cos_sin_cache = torch.cat((cos, sin), dim=-1)
-    cos_sin_cache = cos_sin_cache.to(dtype=dtype, device="cuda")
-
-    # Run the kernel. The kernel is in-place, so we need to clone the inputs.
-    out_query = query.clone()
-    out_key = key.clone()
-    pos_encoding_ops.rotary_embedding(
-        positions,
-        out_query,
-        out_key,
-        head_size,
-        cos_sin_cache,
-        is_neox_style,
-    )
-
-    # Run the reference implementation.
-    ref_rotary_embedding = RefRotaryEmbedding(
-        dim=rotary_dim,
-        is_neox_style=is_neox_style,
-        max_position_embeddings=max_position,
-        base=base,
-    ).to(dtype=dtype, device="cuda")
-    # pylint: disable=not-callable
-    ref_query, ref_key = ref_rotary_embedding(
-        positions,
-        query.view(num_tokens, num_heads, head_size),
-        key.view(num_tokens, num_heads, head_size),
-    )
-    ref_query = ref_query.view(num_tokens, num_heads * head_size)
-    ref_key = ref_key.view(num_tokens, num_heads * head_size)
+                        device=gpu_id)
+    key = torch.randn_like(query)
 
+    # NOTE: The reference implementation should be executed first
+    # because the custom kernel is in-place.
+    ref_query, ref_key = rope._forward(positions, query, key)
+    out_query, out_key = rope.forward(positions, query, key)
     # Compare the results.
     assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5)
     assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5)

+ 4 - 0
tests/models/test_models.py

@@ -2,6 +2,10 @@ import pytest
 
 MODELS = [
     "EleutherAI/pythia-70m-deduped",
+    "meta-llama/Llama-2-7b-hf",
+    "Deci/DeciLM-7b",
+    "tiiuae/falcon-7b",
+    "microsoft/phi-2",
 ]
 
 

+ 3 - 0
tests/prompts/example.txt

@@ -0,0 +1,3 @@
+Tell me your favorite story.
+Transformers have revolutionized almost all natural language processing (NLP) tasks but suffer from memory and computational complexity that scales quadratically with sequence length. In contrast, recurrent neural networks (RNNs) exhibit linear scaling in memory and computational requirements but struggle to match the same performance as Transformers due to limitations in parallelization and scalability. We propose a novel model architecture, Receptance Weighted Key Value (RWKV), that combines the efficient parallelizable training of Transformers with the efficient inference of RNNs. Our approach leverages a linear attention mechanism and allows us to formulate the model as either a Transformer or an RNN, which parallelizes computations during training and maintains constant computational and memory complexity during inference, leading to the first non-transformer architecture to be scaled to tens of billions of parameters. Our experiments reveal that RWKV performs on par with similarly sized Transformers, suggesting that future work can leverage this architecture to create more efficient models. This work presents a significant step towards reconciling the trade-offs between computational efficiency and model performance in sequence processing tasks.
+トランスフォーマーは、ほぼすべての自然言語処理に革命をもたらしました

+ 7 - 0
tests/prompts/summary.txt

@@ -0,0 +1,7 @@
+Develop a detailed method for integrating a blockchain-based distributed ledger system into a pre-existing finance management application. The focus should be on ensuring security, transparency, and real-time updates of transactions.
+Design an AI-powered predictive analytics engine capable of identifying trends and patterns from unstructured data sets. The engine should be adaptable to different industry requirements such as healthcare, finance, and marketing.
+Construct a comprehensive model for a multi-cloud architecture that can smoothly transition between different cloud platforms (AWS, Google Cloud, Azure) without any interruption in service or loss of data.
+Propose a strategy for integrating Quantum Computing capabilities into existing high-performance computing (HPC) systems. The approach should consider potential challenges and solutions of Quantum-HPC integration.
+Create a robust cybersecurity framework for an Internet of Things (IoT) ecosystem. The framework should be capable of detecting, preventing, and mitigating potential security breaches.
+Develop a scalable high-frequency trading algorithm that uses machine learning to predict and respond to microtrends in financial markets. The algorithm should be capable of processing real-time data and executing trades within milliseconds.
+Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'

+ 4 - 4
tests/samplers/test_logprobs.py

@@ -45,13 +45,13 @@ def test_get_prompt_logprobs(
             for token_id, logprob in aphrodite_prompt_logprob_dict.items():
                 torch.testing.assert_close(logprob,
                                            hf_logprob[0][i][token_id].item(),
-                                           atol=1e-1,
-                                           rtol=1e-1)
+                                           atol=1e-2,
+                                           rtol=1e-2)
         aphrodite_sample_logprobs = aphrodite_result.outputs[0].logprobs
         for i, aphrodite_sample_logprob_dict in enumerate(
                 aphrodite_sample_logprobs):
             for token_id, logprob in aphrodite_sample_logprob_dict.items():
                 torch.testing.assert_close(logprob,
                                            hf_logprob[i][-1][token_id].item(),
-                                           atol=1e-1,
-                                           rtol=1e-1)
+                                           atol=1e-2,
+                                           rtol=1e-2)

+ 126 - 73
tests/samplers/test_samplers.py

@@ -1,14 +1,15 @@
-import pytest
 import random
 from typing import Tuple
 from unittest.mock import patch
 
+import pytest
 import torch
 
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.utils import set_random_seed
-from aphrodite.common.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
-from aphrodite.task_handler.worker import Worker
+from aphrodite.common.sequence import (SamplingParams, SequenceData,
+                                       SequenceGroupMetadata)
+from aphrodite.task_handler.model_runner import ModelRunner
 
 
 class MockLogitsSampler(Sampler):
@@ -19,15 +20,15 @@ class MockLogitsSampler(Sampler):
 
     def forward(self, *args, **kwargs):
         with patch("aphrodite.modeling.layers.sampler._prune_hidden_states",
-                   lambda x, y: x):
-            with patch("aphrodite.modeling.layers.sampler._get_logits",
+                   lambda x, y: x), patch(
+                       "aphrodite.modeling.layers.sampler._get_logits",
                        lambda *args, **kwargs: self.fake_logits):
-                return super().forward(*args, **kwargs)
+            return super().forward(*args, **kwargs)
 
 
 def _prepare_test(
     batch_size: int
-) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, Worker]:
+) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]:
     vocab_size = 32000
     input_tensor = torch.rand((batch_size, 1024),
                               device="cuda",
@@ -37,9 +38,8 @@ def _prepare_test(
                              device=input_tensor.device,
                              dtype=input_tensor.dtype)
     sampler = MockLogitsSampler(32000, fake_logits)
-    worker = Worker(None, None, None)
-    worker.block_size = 16
-    return input_tensor, fake_logits, sampler, worker
+    model_runner = ModelRunner(None, None, None)
+    return input_tensor, fake_logits, sampler, model_runner
 
 
 RANDOM_SEEDS = list(range(128))
@@ -49,27 +49,31 @@ RANDOM_SEEDS = list(range(128))
 def test_sampler_all_greedy(seed: int):
     set_random_seed(seed)
     batch_size = random.randint(1, 256)
-    input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
+    input_tensor, fake_logits, sampler, model_runner = _prepare_test(
+        batch_size)
 
     seq_group_metadata_list = []
+    prompt_lens = []
     for i in range(batch_size):
         seq_group_metadata_list.append(
-            SequenceGroupMetadata(request_id=f"test_{i}",
-                                  is_prompt=True,
-                                  seq_data={0: SequenceData([1, 2, 3])},
-                                  sampling_params=SamplingParams(
-                                      temperature=0, ),
-                                  block_tables={0: [1]},
-                                  persistent_data={}))
-
-    # pylint: disable=protected-access
-    _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
+            SequenceGroupMetadata(
+                request_id=f"test_{i}",
+                is_prompt=True,
+                seq_data={0: SequenceData([1, 2, 3])},
+                sampling_params=SamplingParams(temperature=0, ),
+                block_tables={0: [1]},
+                persistent_data={},
+            ))
+        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
+
+    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
+                                                     prompt_lens)
     sampler_output = sampler(embedding=None,
                              hidden_states=input_tensor,
-                             input_metadata=input_metadata)
+                             sampling_metadata=sampling_metadata)
     expected = torch.argmax(fake_logits, dim=-1)
     for i, sequence_output in enumerate(sampler_output):
-        for nth_output in sequence_output:
+        for nth_output in sequence_output.samples:
             assert nth_output.output_token == expected[i].item()
 
 
@@ -77,30 +81,36 @@ def test_sampler_all_greedy(seed: int):
 def test_sampler_all_random(seed: int):
     set_random_seed(seed)
     batch_size = random.randint(1, 256)
-    input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
+    input_tensor, fake_logits, sampler, model_runner = _prepare_test(
+        batch_size)
 
     for i in range(batch_size):
         fake_logits[i, i] = 1e2
 
     seq_group_metadata_list = []
+    prompt_lens = []
     for i in range(batch_size):
         seq_group_metadata_list.append(
-            SequenceGroupMetadata(request_id=f"test_{i}",
-                                  is_prompt=True,
-                                  seq_data={0: SequenceData([1, 2, 3])},
-                                  sampling_params=SamplingParams(
-                                      temperature=1.0,
-                                      n=random.randint(1, 10),
-                                  ),
-                                  block_tables={0: [1]},
-                                  persistent_data={}))
-    # pylint: disable=protected-access
-    _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
+            SequenceGroupMetadata(
+                request_id=f"test_{i}",
+                is_prompt=True,
+                seq_data={0: SequenceData([1, 2, 3])},
+                sampling_params=SamplingParams(
+                    temperature=1.0,
+                    n=random.randint(1, 10),
+                ),
+                block_tables={0: [1]},
+                persistent_data={},
+            ))
+        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
+
+    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
+                                                     prompt_lens)
     sampler_output = sampler(embedding=None,
                              hidden_states=input_tensor,
-                             input_metadata=input_metadata)
+                             sampling_metadata=sampling_metadata)
     for i, sequence_output in enumerate(sampler_output):
-        for nth_output in sequence_output:
+        for nth_output in sequence_output.samples:
             assert nth_output.output_token == i
 
 
@@ -108,37 +118,47 @@ def test_sampler_all_random(seed: int):
 def test_sampler_all_beam(seed: int):
     set_random_seed(seed)
     batch_size = random.randint(1, 256)
-    # pylint: disable=unused-variable
-    input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
+    input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
 
     seq_group_metadata_list = []
+    prompt_lens = []
     for i in range(batch_size):
         seq_group_metadata_list.append(
-            SequenceGroupMetadata(request_id=f"test_{i}",
-                                  is_prompt=True,
-                                  seq_data={0: SequenceData([1, 2, 3])},
-                                  sampling_params=SamplingParams(
-                                      temperature=0,
-                                      best_of=2,
-                                      use_beam_search=True,
-                                  ),
-                                  block_tables={0: [1]},
-                                  persistent_data={}))
-    # pylint: disable=protected-access
-    _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
+            SequenceGroupMetadata(
+                request_id=f"test_{i}",
+                is_prompt=True,
+                seq_data={0: SequenceData([1, 2, 3])},
+                sampling_params=SamplingParams(
+                    temperature=0,
+                    best_of=2,
+                    use_beam_search=True,
+                ),
+                block_tables={0: [1]},
+                persistent_data={},
+            ))
+        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
+
+    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
+                                                     prompt_lens)
     sampler(embedding=None,
             hidden_states=input_tensor,
-            input_metadata=input_metadata)
+            sampling_metadata=sampling_metadata)
+    # no assertion here as I am not sure how to determine whether
+    # the outputs are expected - in other words, this just tests
+    # whether there are no exceptions in the sampler
+    # when handling an all-beam search case.
 
 
 @pytest.mark.parametrize("seed", RANDOM_SEEDS)
 def test_sampler_mixed(seed: int):
     set_random_seed(seed)
     batch_size = random.randint(1, 256)
-    input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
+    input_tensor, fake_logits, sampler, model_runner = _prepare_test(
+        batch_size)
 
     seq_group_metadata_list = []
     expected_tokens = []
+    prompt_lens = []
     for i in range(batch_size):
         n = 1
         sampling_type = random.randint(0, 2)
@@ -150,39 +170,72 @@ def test_sampler_mixed(seed: int):
                 temperature=random.random() + 0.1,
                 top_p=min(random.random() + 0.1, 1),
                 top_k=random.randint(0, 10) or -1,
-                top_a=min(random.random() + 0.1, 2),
-                tfs=min(random.random() + 0.1, 1),
-                eta_cutoff=random.randint(0, 10) or 0,
-                epsilon_cutoff=random.randint(0, 10) or 0,
-                typical_p=min(random.random() + 0.1, 1),
-                presence_penalty=random.randint(0, 1),
-                frequency_penalty=random.randint(0, 1),
-                repetition_penalty=min(random.random() + 0.1, 1),
                 n=n,
+                presence_penalty=random.randint(0, 1),
             )
         else:
             sampling_params = SamplingParams(temperature=0,
                                              use_beam_search=True,
                                              best_of=2)
-
         for idx in range(n):
             fake_logits[i, i + idx] = 1e2
             expected_tokens.append(i + idx)
         seq_group_metadata_list.append(
-            SequenceGroupMetadata(request_id=f"test_{i}",
-                                  is_prompt=True,
-                                  seq_data={0: SequenceData([1, 2, 3])},
-                                  sampling_params=sampling_params,
-                                  block_tables={0: [1]},
-                                  persistent_data={}))
-
-    # pylint: disable=protected-access
-    _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
+            SequenceGroupMetadata(
+                request_id=f"test_{i}",
+                is_prompt=True,
+                seq_data={0: SequenceData([1, 2, 3])},
+                sampling_params=sampling_params,
+                block_tables={0: [1]},
+                persistent_data={},
+            ))
+        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
+
+    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
+                                                     prompt_lens)
     sampler_output = sampler(embedding=None,
                              hidden_states=input_tensor,
-                             input_metadata=input_metadata)
+                             sampling_metadata=sampling_metadata)
     for i, sequence_output in enumerate(sampler_output):
         if seq_group_metadata_list[i].sampling_params.use_beam_search:
             continue
-        for nth_output in sequence_output:
+        for nth_output in sequence_output.samples:
             assert nth_output.output_token in expected_tokens
+
+
+@pytest.mark.parametrize("seed", RANDOM_SEEDS)
+def test_sampler_logits_processors(seed: int):
+    set_random_seed(seed)
+    batch_size = random.randint(1, 256)
+    input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
+
+    # This sample logits processor gives infinite score to the i-th token,
+    # where i is the length of the input sequence.
+    # We therefore expect the output token sequence to be [0, 1, 2, ...]
+    def pick_ith(token_ids, logits):
+        logits[len(token_ids)] = float("inf")
+        return logits
+
+    seq_group_metadata_list = []
+    prompt_lens = []
+    for i in range(batch_size):
+        seq_group_metadata_list.append(
+            SequenceGroupMetadata(
+                request_id=f"test_{i}",
+                is_prompt=True,
+                seq_data={0: SequenceData([1, 2, 3])},
+                sampling_params=SamplingParams(temperature=0,
+                                               logits_processors=[pick_ith]),
+                block_tables={0: [1]},
+                persistent_data={},
+            ))
+        prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
+
+    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
+                                                     prompt_lens)
+    sampler_output = sampler(embedding=None,
+                             hidden_states=input_tensor,
+                             sampling_metadata=sampling_metadata)
+    for _, sequence_output in enumerate(sampler_output):
+        for idx, nth_output in enumerate(sequence_output.samples):
+            assert nth_output.output_token == idx

+ 26 - 0
tests/test_regression.py

@@ -0,0 +1,26 @@
+"""Containing tests that check for regressions in Aphrodite's behavior.
+
+It should include tests that are reported by users and making sure they
+will never happen again.
+
+"""
+from aphrodite import LLM, SamplingParams
+
+
+def test_duplicated_ignored_sequence_group():
+
+    sampling_params = SamplingParams(temperature=0.01,
+                                     top_p=0.1,
+                                     max_tokens=256)
+    llm = LLM(model="EleutherAI/pythia-70m-deduped",
+              max_num_batched_tokens=4096,
+              tensor_parallel_size=1)
+    prompts = ["This is a short prompt", "This is a very long prompt " * 1000]
+    outputs = llm.generate(prompts, sampling_params=sampling_params)
+
+    assert len(prompts) == len(outputs)
+
+
+if __name__ == "__main__":
+    import pytest
+    pytest.main([__file__])

+ 51 - 0
tests/worker/test_model_runner.py

@@ -0,0 +1,51 @@
+import random
+import torch
+
+from aphrodite.common.sequence import (SamplingParams, SequenceData,
+                                       SequenceGroupMetadata)
+from aphrodite.task_handler.model_runner import ModelRunner
+
+
+def test_prepare_prompt():
+    model_runner = ModelRunner(None, None, None)
+    model_runner.set_block_size(16)
+
+    batch_size = random.randint(1, 256)
+    prompt_lens = []
+    seq_group_metadata_list = []
+    for i in range(batch_size):
+        # make sure all tokens fit into one block
+        prompt_len = i % (model_runner.block_size - 1) + 1
+        prompt_lens.append(prompt_len)
+        seq_data = list(range(prompt_len))
+        seq_group_metadata_list.append(
+            SequenceGroupMetadata(
+                request_id=f"test_{i}",
+                is_prompt=True,
+                seq_data={0: SequenceData(seq_data)},
+                sampling_params=SamplingParams(temperature=0),
+                block_tables={0: [1]},
+                persistent_data={},
+            ))
+
+    expected_selected_token_indices = []
+    selected_token_start_idx = 0
+    max_seq_len = max(prompt_lens)
+    for prompt_len in prompt_lens:
+        expected_selected_token_indices.append(selected_token_start_idx +
+                                               prompt_len - 1)
+        selected_token_start_idx += max_seq_len
+    input_tokens, input_positions, return_prompt_lens = (
+        model_runner._prepare_prompt(seq_group_metadata_list))
+    assert return_prompt_lens == prompt_lens
+    sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
+                                                     prompt_lens)
+    assert input_tokens.shape == (batch_size, max_seq_len)
+    assert input_positions.shape == (batch_size, max_seq_len)
+    torch.testing.assert_close(input_tokens, input_positions)
+
+    actual = sampling_metadata.selected_token_indices
+    expected = torch.tensor(expected_selected_token_indices,
+                            device=actual.device,
+                            dtype=actual.dtype)
+    torch.testing.assert_close(actual, expected)