1
0
Эх сурвалжийг харах

ci: update & overhaul test units (#769)

* wip

* wip

* wip

* wip

* wip

* wip

* le kernel tests

* formatting

* lora tests

* metrics tests

* model loader tests

* model tests

* multimodal tests

* formatting

* add plugin test

* add prefix caching tests

* add prompt adapter tests

* add quantization tests

* update sampler tests

* add spec decode tests

* add tensorizer test

* update tensorizer test

* update tokenizer tests

* add weight loading test

* add worker tests
AlpinDale 3 сар өмнө
parent
commit
c6c91edab7
100 өөрчлөгдсөн 13411 нэмэгдсэн , 903 устгасан
  1. 0 0
      aphrodite/assets/__init__.py
  2. 53 0
      aphrodite/assets/base.py
  3. 30 0
      aphrodite/assets/image.py
  4. 2 2
      aphrodite/common/config.py
  5. 167 0
      aphrodite/connections.py
  6. 154 0
      aphrodite/endpoints/api_server.py
  7. 2 2
      pyproject.toml
  8. 0 0
      tests/async_aphrodite/__init__.py
  9. 51 0
      tests/async_aphrodite/api_server_async_aphrodite.py
  10. 116 0
      tests/async_aphrodite/test_api_server.py
  11. 150 0
      tests/async_aphrodite/test_async_aphrodite.py
  12. 101 0
      tests/async_aphrodite/test_chat_template.py
  13. 109 0
      tests/async_aphrodite/test_openapi_server_ray.py
  14. 68 0
      tests/async_aphrodite/test_request_tracker.py
  15. 0 0
      tests/basic_correctness/__init__.py
  16. 67 0
      tests/basic_correctness/test_basic_correctness.py
  17. 153 0
      tests/basic_correctness/test_chunked_prefill.py
  18. 6 0
      tests/basic_correctness/test_cpu_offload.py
  19. 271 0
      tests/basic_correctness/test_preemption.py
  20. 20 0
      tests/compile/test_full_graph.py
  21. 584 60
      tests/conftest.py
  22. 0 0
      tests/core/__init__.py
  23. 0 0
      tests/core/block/__init__.py
  24. 12 0
      tests/core/block/conftest.py
  25. 0 0
      tests/core/block/e2e/__init__.py
  26. 68 0
      tests/core/block/e2e/conftest.py
  27. 558 0
      tests/core/block/e2e/test_correctness.py
  28. 163 0
      tests/core/block/e2e/test_correctness_sliding_window.py
  29. 445 0
      tests/core/block/test_block_manager_v2.py
  30. 577 0
      tests/core/block/test_block_table.py
  31. 42 0
      tests/core/block/test_common.py
  32. 94 0
      tests/core/block/test_cpu_gpu_block_allocator.py
  33. 145 0
      tests/core/block/test_naive_block.py
  34. 708 0
      tests/core/block/test_prefix_caching_block.py
  35. 598 0
      tests/core/test_block_manager.py
  36. 583 0
      tests/core/test_chunked_prefill_scheduler.py
  37. 852 0
      tests/core/test_scheduler.py
  38. 99 0
      tests/core/test_scheduler_encoder_decoder.py
  39. 211 0
      tests/core/utils.py
  40. 0 0
      tests/distributed/__init__.py
  41. 81 0
      tests/distributed/test_basic_distributed_correctness.py
  42. 102 0
      tests/distributed/test_basic_distributed_correctness_enc_dec.py
  43. 70 0
      tests/distributed/test_chunked_prefill_distributed.py
  44. 200 0
      tests/distributed/test_comm_ops.py
  45. 115 0
      tests/distributed/test_custom_all_reduce.py
  46. 6 0
      tests/distributed/test_distributed_oot.py
  47. 57 0
      tests/distributed/test_multimodal_broadcast.py
  48. 92 0
      tests/distributed/test_pipeline_parallel.py
  49. 34 0
      tests/distributed/test_pipeline_partition.py
  50. 30 0
      tests/distributed/test_pp_cudagraph.py
  51. 243 0
      tests/distributed/test_pynccl.py
  52. 13 0
      tests/distributed/test_same_node.py
  53. 35 28
      tests/distributed/test_shm_broadcast.py
  54. 35 0
      tests/distributed/test_utils.py
  55. 0 0
      tests/endpoints/__init__.py
  56. 89 0
      tests/endpoints/conftest.py
  57. 0 0
      tests/endpoints/llm/__init__.py
  58. 142 0
      tests/endpoints/llm/test_encode.py
  59. 161 0
      tests/endpoints/llm/test_generate.py
  60. 67 0
      tests/endpoints/llm/test_generate_multiple_loras.py
  61. 142 0
      tests/endpoints/llm/test_guided_generate.py
  62. 0 0
      tests/endpoints/openai/__init__.py
  63. 355 0
      tests/endpoints/openai/test_audio.py
  64. 52 0
      tests/endpoints/openai/test_basic.py
  65. 842 0
      tests/endpoints/openai/test_chat.py
  66. 832 0
      tests/endpoints/openai/test_completion.py
  67. 136 0
      tests/endpoints/openai/test_embedding.py
  68. 50 0
      tests/endpoints/openai/test_encoder_decoder.py
  69. 72 0
      tests/endpoints/openai/test_guided_processors.py
  70. 179 0
      tests/endpoints/openai/test_metrics.py
  71. 60 0
      tests/endpoints/openai/test_models.py
  72. 38 0
      tests/endpoints/openai/test_mp_api_server.py
  73. 42 0
      tests/endpoints/openai/test_oot_registration.py
  74. 83 0
      tests/endpoints/openai/test_return_tokens_as_ids.py
  75. 102 0
      tests/endpoints/openai/test_run_batch.py
  76. 82 0
      tests/endpoints/openai/test_serving_chat.py
  77. 47 0
      tests/endpoints/openai/test_shutdown.py
  78. 152 0
      tests/endpoints/openai/test_tokenization.py
  79. 261 0
      tests/endpoints/openai/test_vision.py
  80. 0 30
      tests/endpoints/test_llm_generate.py
  81. 0 614
      tests/endpoints/test_openai_server.py
  82. 0 75
      tests/endpoints/test_outlines.py
  83. 0 0
      tests/engine/__init__.py
  84. 0 0
      tests/engine/output_processor/__init__.py
  85. 272 0
      tests/engine/output_processor/test_multi_step.py
  86. 85 0
      tests/engine/output_processor/test_stop_checker.py
  87. 24 0
      tests/engine/test_args.py
  88. 34 0
      tests/engine/test_computed_prefix_block.py
  89. 34 0
      tests/engine/test_computed_prefix_blocks.py
  90. 90 0
      tests/engine/test_custom_executor.py
  91. 32 0
      tests/engine/test_detokenization.py
  92. 0 55
      tests/engine/test_detokenize.py
  93. 177 0
      tests/engine/test_multiproc_workers.py
  94. 23 0
      tests/engine/test_skip_tokenizer_init.py
  95. 62 0
      tests/engine/test_stop_reason.py
  96. 112 0
      tests/engine/test_stop_string.py
  97. 0 0
      tests/kernels/__init__.py
  98. 18 0
      tests/kernels/allclose_default.py
  99. 7 37
      tests/kernels/conftest.py
  100. 83 0
      tests/kernels/quant_utils.py

+ 0 - 0
aphrodite/assets/__init__.py


+ 53 - 0
aphrodite/assets/base.py

@@ -0,0 +1,53 @@
+"""Assets for testing. vLLM conveniently has a bucket of public assets
+we can use."""
+import os
+from functools import lru_cache
+from pathlib import Path
+from typing import Optional
+
+from aphrodite.connections import global_http_connection
+
+
+def get_default_cache_root():
+    return os.getenv(
+        "XDG_CACHE_HOME",
+        os.path.join(os.path.expanduser("~"), ".cache"),
+    )
+
+vLLM_S3_BUCKET_URL = "https://vllm-public-assets.s3.us-west-2.amazonaws.com"
+APHRODITE_ASSETS_CACHE = os.path.expanduser(
+    os.getenv(
+        "APHRODITE_ASSETS_CACHE",
+        os.path.join(get_default_cache_root(), "aphrodite", "assets"),
+    ))
+APHRODITE_IMAGE_FETCH_TIMEOUT = int(os.getenv("APHRODITE_IMAGE_FETCH_TIMEOUT",
+                                              5))
+
+def get_cache_dir() -> Path:
+    """Get the path to the cache for storing downloaded assets."""
+    path = Path(APHRODITE_ASSETS_CACHE)
+    path.mkdir(parents=True, exist_ok=True)
+
+    return path
+
+
+@lru_cache
+def get_vllm_public_assets(filename: str,
+                           s3_prefix: Optional[str] = None) -> Path:
+    """
+    Download an asset file from ``s3://vllm-public-assets``
+    and return the path to the downloaded file.
+    """
+    asset_directory = get_cache_dir() / "vllm_public_assets"
+    asset_directory.mkdir(parents=True, exist_ok=True)
+
+    asset_path = asset_directory / filename
+    if not asset_path.exists():
+        if s3_prefix is not None:
+            filename = s3_prefix + "/" + filename
+        global_http_connection.download_file(
+            f"{vLLM_S3_BUCKET_URL}/{filename}",
+            asset_path,
+            timeout=APHRODITE_IMAGE_FETCH_TIMEOUT)
+
+    return asset_path

+ 30 - 0
aphrodite/assets/image.py

@@ -0,0 +1,30 @@
+from dataclasses import dataclass
+from typing import Literal
+
+import torch
+from PIL import Image
+
+from aphrodite.assets.base import get_vllm_public_assets
+
+VLM_IMAGES_DIR = "vision_model_images"
+
+
+@dataclass(frozen=True)
+class ImageAsset:
+    name: Literal["stop_sign", "cherry_blossom"]
+
+    @property
+    def pil_image(self) -> Image.Image:
+
+        image_path = get_vllm_public_assets(filename=f"{self.name}.jpg",
+                                            s3_prefix=VLM_IMAGES_DIR)
+        return Image.open(image_path)
+
+    @property
+    def image_embeds(self) -> torch.Tensor:
+        """
+        Image embeddings, only used for testing purposes with llava 1.5.
+        """
+        image_path = get_vllm_public_assets(filename=f"{self.name}.pt",
+                                            s3_prefix=VLM_IMAGES_DIR)
+        return torch.load(image_path)

+ 2 - 2
aphrodite/common/config.py

@@ -674,7 +674,7 @@ class CacheConfig:
         gpu_memory_utilization: float,
         swap_space: float,
         cache_dtype: str,
-        is_attention_free: bool,
+        is_attention_free: bool = False,
         num_gpu_blocks_override: Optional[int] = None,
         sliding_window: Optional[int] = None,
         enable_prefix_caching: bool = False,
@@ -1038,7 +1038,7 @@ class SchedulerConfig:
                  max_num_batched_tokens: Optional[int],
                  max_num_seqs: int,
                  max_model_len: int,
-                 is_attention_free: bool,
+                 is_attention_free: bool = False,
                  use_v2_block_manager: bool = False,
                  num_lookahead_slots: int = 0,
                  delay_factor: float = 0.0,

+ 167 - 0
aphrodite/connections.py

@@ -0,0 +1,167 @@
+from pathlib import Path
+from typing import Mapping, MutableMapping, Optional
+from urllib.parse import urlparse
+
+import aiohttp
+import requests
+
+from aphrodite.version import __version__ as APHRODITE_VERSION
+
+
+class HTTPConnection:
+    """Helper class to send HTTP requests."""
+
+    def __init__(self, *, reuse_client: bool = True) -> None:
+        super().__init__()
+
+        self.reuse_client = reuse_client
+
+        self._sync_client: Optional[requests.Session] = None
+        self._async_client: Optional[aiohttp.ClientSession] = None
+
+    def get_sync_client(self) -> requests.Session:
+        if self._sync_client is None or not self.reuse_client:
+            self._sync_client = requests.Session()
+
+        return self._sync_client
+
+    # NOTE: We intentionally use an async function even though it is not
+    # required, so that the client is only accessible inside async event loop
+    async def get_async_client(self) -> aiohttp.ClientSession:
+        if self._async_client is None or not self.reuse_client:
+            self._async_client = aiohttp.ClientSession()
+
+        return self._async_client
+
+    def _validate_http_url(self, url: str):
+        parsed_url = urlparse(url)
+
+        if parsed_url.scheme not in ("http", "https"):
+            raise ValueError("Invalid HTTP URL: A valid HTTP URL "
+                             "must have scheme 'http' or 'https'.")
+
+    def _headers(self, **extras: str) -> MutableMapping[str, str]:
+        return {"User-Agent": f"Aphrodite/{APHRODITE_VERSION}", **extras}
+
+    def get_response(
+        self,
+        url: str,
+        *,
+        stream: bool = False,
+        timeout: Optional[float] = None,
+        extra_headers: Optional[Mapping[str, str]] = None,
+    ):
+        self._validate_http_url(url)
+
+        client = self.get_sync_client()
+        extra_headers = extra_headers or {}
+
+        return client.get(url,
+                          headers=self._headers(**extra_headers),
+                          stream=stream,
+                          timeout=timeout)
+
+    async def get_async_response(
+        self,
+        url: str,
+        *,
+        timeout: Optional[float] = None,
+        extra_headers: Optional[Mapping[str, str]] = None,
+    ):
+        self._validate_http_url(url)
+
+        client = await self.get_async_client()
+        extra_headers = extra_headers or {}
+
+        return client.get(url,
+                          headers=self._headers(**extra_headers),
+                          timeout=timeout)
+
+    def get_bytes(self, url: str, *, timeout: Optional[float] = None) -> bytes:
+        with self.get_response(url, timeout=timeout) as r:
+            r.raise_for_status()
+
+            return r.content
+
+    async def async_get_bytes(
+        self,
+        url: str,
+        *,
+        timeout: Optional[float] = None,
+    ) -> bytes:
+        async with await self.get_async_response(url, timeout=timeout) as r:
+            r.raise_for_status()
+
+            return await r.read()
+
+    def get_text(self, url: str, *, timeout: Optional[float] = None) -> str:
+        with self.get_response(url, timeout=timeout) as r:
+            r.raise_for_status()
+
+            return r.text
+
+    async def async_get_text(
+        self,
+        url: str,
+        *,
+        timeout: Optional[float] = None,
+    ) -> str:
+        async with await self.get_async_response(url, timeout=timeout) as r:
+            r.raise_for_status()
+
+            return await r.text()
+
+    def get_json(self, url: str, *, timeout: Optional[float] = None) -> str:
+        with self.get_response(url, timeout=timeout) as r:
+            r.raise_for_status()
+
+            return r.json()
+
+    async def async_get_json(
+        self,
+        url: str,
+        *,
+        timeout: Optional[float] = None,
+    ) -> str:
+        async with await self.get_async_response(url, timeout=timeout) as r:
+            r.raise_for_status()
+
+            return await r.json()
+
+    def download_file(
+        self,
+        url: str,
+        save_path: Path,
+        *,
+        timeout: Optional[float] = None,
+        chunk_size: int = 128,
+    ) -> Path:
+        with self.get_response(url, timeout=timeout) as r:
+            r.raise_for_status()
+
+            with save_path.open("wb") as f:
+                for chunk in r.iter_content(chunk_size):
+                    f.write(chunk)
+
+        return save_path
+
+    async def async_download_file(
+        self,
+        url: str,
+        save_path: Path,
+        *,
+        timeout: Optional[float] = None,
+        chunk_size: int = 128,
+    ) -> Path:
+        async with await self.get_async_response(url, timeout=timeout) as r:
+            r.raise_for_status()
+
+            with save_path.open("wb") as f:
+                async for chunk in r.content.iter_chunked(chunk_size):
+                    f.write(chunk)
+
+        return save_path
+
+
+global_http_connection = HTTPConnection()
+"""The global :class:`HTTPConnection` instance used by Aphrodite."""

+ 154 - 0
aphrodite/endpoints/api_server.py

@@ -0,0 +1,154 @@
+"""
+NOTE: This API server is used only for demonstrating usage of AsyncAphrodite
+and simple performance benchmarks. It is not intended for production use.
+For production use, we recommend using our OpenAI compatible server.
+We are also not going to accept PRs modifying this file, please
+change `aphrodite/endpoints/openai/api_server.py` instead.
+"""
+import asyncio
+import json
+import ssl
+from argparse import Namespace
+from typing import Any, AsyncGenerator, Optional
+
+from fastapi import FastAPI, Request
+from fastapi.responses import JSONResponse, Response, StreamingResponse
+
+from aphrodite.common.sampling_params import SamplingParams
+from aphrodite.common.utils import (FlexibleArgumentParser,
+                                    iterate_with_cancellation, random_uuid)
+from aphrodite.engine.args_tools import AsyncEngineArgs
+from aphrodite.engine.async_aphrodite import AsyncAphrodite
+from aphrodite.server.launch import serve_http
+
+TIMEOUT_KEEP_ALIVE = 5  # seconds.
+app = FastAPI()
+engine = None
+
+
+@app.get("/health")
+async def health() -> Response:
+    """Health check."""
+    return Response(status_code=200)
+
+
+@app.post("/generate")
+async def generate(request: Request) -> Response:
+    """Generate completion for the request.
+
+    The request should be a JSON object with the following fields:
+    - prompt: the prompt to use for the generation.
+    - stream: whether to stream the results or not.
+    - other fields: the sampling parameters (See `SamplingParams` for details).
+    """
+    request_dict = await request.json()
+    prompt = request_dict.pop("prompt")
+    stream = request_dict.pop("stream", False)
+    sampling_params = SamplingParams(**request_dict)
+    request_id = random_uuid()
+
+    assert engine is not None
+    results_generator = engine.generate(prompt, sampling_params, request_id)
+    results_generator = iterate_with_cancellation(
+        results_generator, is_cancelled=request.is_disconnected)
+
+    # Streaming case
+    async def stream_results() -> AsyncGenerator[bytes, None]:
+        async for request_output in results_generator:
+            prompt = request_output.prompt
+            text_outputs = [
+                prompt + output.text for output in request_output.outputs
+            ]
+            ret = {"text": text_outputs}
+            yield (json.dumps(ret) + "\0").encode("utf-8")
+
+    if stream:
+        return StreamingResponse(stream_results())
+
+    # Non-streaming case
+    final_output = None
+    try:
+        async for request_output in results_generator:
+            final_output = request_output
+    except asyncio.CancelledError:
+        return Response(status_code=499)
+
+    assert final_output is not None
+    prompt = final_output.prompt
+    text_outputs = [prompt + output.text for output in final_output.outputs]
+    ret = {"text": text_outputs}
+    return JSONResponse(ret)
+
+
+def build_app(args: Namespace) -> FastAPI:
+    global app
+
+    app.root_path = args.root_path
+    return app
+
+
+async def init_app(
+    args: Namespace,
+    llm_engine: Optional[AsyncAphrodite] = None,
+) -> FastAPI:
+    app = build_app(args)
+
+    global engine
+
+    engine_args = AsyncEngineArgs.from_cli_args(args)
+    engine = (llm_engine
+              if llm_engine is not None else AsyncAphrodite.from_engine_args(
+                  engine_args))
+
+    return app
+
+
+async def run_server(args: Namespace,
+                     llm_engine: Optional[AsyncAphrodite] = None,
+                     **uvicorn_kwargs: Any) -> None:
+
+    app = await init_app(args, llm_engine)
+
+    shutdown_task = await serve_http(
+        app,
+        engine=engine,
+        host=args.host,
+        port=args.port,
+        log_level=args.log_level,
+        timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
+        ssl_keyfile=args.ssl_keyfile,
+        ssl_certfile=args.ssl_certfile,
+        ssl_ca_certs=args.ssl_ca_certs,
+        ssl_cert_reqs=args.ssl_cert_reqs,
+        **uvicorn_kwargs,
+    )
+
+    await shutdown_task
+
+
+if __name__ == "__main__":
+    parser = FlexibleArgumentParser()
+    parser.add_argument("--host", type=str, default=None)
+    parser.add_argument("--port", type=int, default=2242)
+    parser.add_argument("--ssl-keyfile", type=str, default=None)
+    parser.add_argument("--ssl-certfile", type=str, default=None)
+    parser.add_argument("--ssl-ca-certs",
+                        type=str,
+                        default=None,
+                        help="The CA certificates file")
+    parser.add_argument(
+        "--ssl-cert-reqs",
+        type=int,
+        default=int(ssl.CERT_NONE),
+        help="Whether client certificate is required (see stdlib ssl module's)"
+    )
+    parser.add_argument(
+        "--root-path",
+        type=str,
+        default=None,
+        help="FastAPI root_path when app is behind a path based routing proxy")
+    parser.add_argument("--log-level", type=str, default="debug")
+    parser = AsyncEngineArgs.add_cli_args(parser)
+    args = parser.parse_args()
+
+    asyncio.run(run_server(args))

+ 2 - 2
pyproject.toml

@@ -46,8 +46,8 @@ ignore = [
 ]
 
 [tool.codespell]
-ignore-words-list = "dout, te, indicies, ist, subtile, wit, whit, beseige, devlop"
-skip = "./tests/,./aphrodite/endpoints/kobold/klite.embd,./kernels/,./tests/benchmarks/sonnet.txt,./docs/"
+ignore-words-list = "dout, te, indicies, ist, subtile, wit, whit, beseige, devlop, serie, vor, holliday, discus, tennant, carin, parma, mor, slac, revered, chanel, sammon, nast, shepard, insead, bloc, clea"
+skip = "./tests/,./aphrodite/endpoints/kobold/klite.embd,./kernels/,./tests/benchmarks/sonnet.txt,./docs/,./tests/lora/data/long_context_test_data.py"
 
 [tool.isort]
 use_parentheses = true

+ 0 - 0
tests/async_aphrodite/__init__.py


+ 51 - 0
tests/async_aphrodite/api_server_async_aphrodite.py

@@ -0,0 +1,51 @@
+"""aphrodite.endpoints.api_server with some extra logging for testing."""
+from typing import Any, Dict, Iterable
+
+import uvicorn
+from fastapi.responses import JSONResponse, Response
+
+import aphrodite.endpoints.api_server
+from aphrodite.common.utils import FlexibleArgumentParser
+from aphrodite.engine.args_tools import AsyncEngineArgs
+from aphrodite.engine.async_aphrodite import AsyncAphrodite
+
+app = aphrodite.endpoints.api_server.app
+
+
+class AsyncAphroditeWithStats(AsyncAphrodite):
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self._num_aborts = 0
+
+    async def _engine_abort(self, request_ids: Iterable[str]):
+        ids = list(request_ids)
+        self._num_aborts += len(ids)
+        await super()._engine_abort(ids)
+
+    def testing_stats(self) -> Dict[str, Any]:
+        return {"num_aborted_requests": self._num_aborts}
+
+
+@app.get("/stats")
+def stats() -> Response:
+    """Get the statistics of the engine."""
+    return JSONResponse(engine.testing_stats())
+
+
+if __name__ == "__main__":
+    parser = FlexibleArgumentParser()
+    parser.add_argument("--host", type=str, default="localhost")
+    parser.add_argument("--port", type=int, default=8000)
+    parser = AsyncEngineArgs.add_cli_args(parser)
+    args = parser.parse_args()
+
+    engine_args = AsyncEngineArgs.from_cli_args(args)
+    engine = AsyncAphroditeWithStats.from_engine_args(engine_args)
+    aphrodite.endpoints.api_server.engine = engine
+    uvicorn.run(
+        app,
+        host=args.host,
+        port=args.port,
+        log_level="debug",
+        timeout_keep_alive=aphrodite.endpoints.api_server.TIMEOUT_KEEP_ALIVE)

+ 116 - 0
tests/async_aphrodite/test_api_server.py

@@ -0,0 +1,116 @@
+import os
+import subprocess
+import sys
+import time
+from multiprocessing import Pool
+from pathlib import Path
+
+import pytest
+import requests
+
+
+def _query_server(prompt: str, max_tokens: int = 5) -> dict:
+    response = requests.post("http://localhost:2242/generate",
+                             json={
+                                 "prompt": prompt,
+                                 "max_tokens": max_tokens,
+                                 "temperature": 0,
+                                 "ignore_eos": True
+                             })
+    response.raise_for_status()
+    return response.json()
+
+
+def _query_server_long(prompt: str) -> dict:
+    return _query_server(prompt, max_tokens=500)
+
+
+@pytest.fixture
+def api_server(tokenizer_pool_size: int, engine_use_ray: bool,
+               worker_use_ray: bool):
+    script_path = Path(__file__).parent.joinpath(
+        "api_server_async_engine.py").absolute()
+    commands = [
+        sys.executable, "-u",
+        str(script_path), "--model", "facebook/opt-125m", "--host",
+        "127.0.0.1", "--tokenizer-pool-size",
+        str(tokenizer_pool_size)
+    ]
+
+    # Copy the environment variables and append
+    # `APHRODITE_ALLOW_ENGINE_USE_RAY=1` to prevent
+    # `--engine-use-ray` raises an exception due to it deprecation
+    env_vars = os.environ.copy()
+    env_vars["APHRODITE_ALLOW_ENGINE_USE_RAY"] = "1"
+
+    if engine_use_ray:
+        commands.append("--engine-use-ray")
+    if worker_use_ray:
+        commands.append("--worker-use-ray")
+    uvicorn_process = subprocess.Popen(commands, env=env_vars)
+    yield
+    uvicorn_process.terminate()
+
+
+@pytest.mark.parametrize("tokenizer_pool_size", [0, 2])
+@pytest.mark.parametrize("worker_use_ray", [False, True])
+@pytest.mark.parametrize("engine_use_ray", [False, True])
+def test_api_server(api_server, tokenizer_pool_size: int, worker_use_ray: bool,
+                    engine_use_ray: bool):
+    """
+    Run the API server and test it.
+
+    We run both the server and requests in separate processes.
+
+    We test that the server can handle incoming requests, including
+    multiple requests at the same time, and that it can handle requests
+    being cancelled without crashing.
+    """
+    with Pool(32) as pool:
+        # Wait until the server is ready
+        prompts = ["warm up"] * 1
+        result = None
+        while not result:
+            try:
+                for r in pool.map(_query_server, prompts):
+                    result = r
+                    break
+            except requests.exceptions.ConnectionError:
+                time.sleep(1)
+
+        # Actual tests start here
+        # Try with 1 prompt
+        for result in pool.map(_query_server, prompts):
+            assert result
+
+        num_aborted_requests = requests.get(
+            "http://localhost:2242/stats").json()["num_aborted_requests"]
+        assert num_aborted_requests == 0
+
+        # Try with 100 prompts
+        prompts = ["test prompt"] * 100
+        for result in pool.map(_query_server, prompts):
+            assert result
+
+    with Pool(32) as pool:
+        # Cancel requests
+        prompts = ["canceled requests"] * 100
+        pool.map_async(_query_server_long, prompts)
+        time.sleep(0.01)
+        pool.terminate()
+        pool.join()
+
+        # check cancellation stats
+        # give it some times to update the stats
+        time.sleep(1)
+
+        num_aborted_requests = requests.get(
+            "http://localhost:2242/stats").json()["num_aborted_requests"]
+        assert num_aborted_requests > 0
+
+    # check that server still runs after cancellations
+    with Pool(32) as pool:
+        # Try with 100 prompts
+        prompts = ["test prompt after canceled"] * 100
+        for result in pool.map(_query_server, prompts):
+            assert result

+ 150 - 0
tests/async_aphrodite/test_async_aphrodite.py

@@ -0,0 +1,150 @@
+import asyncio
+import os
+from dataclasses import dataclass
+
+import pytest
+import torch
+
+from aphrodite import SamplingParams
+from aphrodite.common.config import ParallelConfig
+from aphrodite.engine.async_aphrodite import AsyncAphrodite, AsyncEngineArgs
+
+from ..utils import wait_for_gpu_memory_to_clear
+
+
+@dataclass
+class RequestOutput:
+    request_id: int
+    finished: bool = False
+
+
+class MockEngine:
+
+    def __init__(self):
+        self.step_calls = 0
+        self.add_request_calls = 0
+        self.abort_request_calls = 0
+        self.request_id = None
+        # Ugly, remove dependency when possible
+        self.parallel_config = ParallelConfig(1, 1, False)
+
+    async def step_async(self, virtual_engine):
+        # PP size is 1, ignore virtual engine
+        self.step_calls += 1
+        return [RequestOutput(
+            request_id=self.request_id)] if self.request_id else []
+
+    async def process_model_inputs_async(self, *args, **kwargs):
+        pass
+
+    async def stop_remote_worker_execution_loop_async(self):
+        pass
+
+    def generate(self, request_id):
+        self.request_id = request_id
+
+    def stop_generating(self):
+        self.request_id = None
+
+    def add_request(self, **kwargs):
+        del kwargs  # Unused
+        self.add_request_calls += 1
+        print(f'Request calls: {self.add_request_calls}')
+
+    async def add_request_async(self, **kwargs):
+        self.add_request_calls += 1
+        return
+
+    def abort_request(self, request_id):
+        del request_id  # Unused
+        self.abort_request_calls += 1
+
+    def has_unfinished_requests(self):
+        return self.request_id is not None
+
+    def has_unfinished_requests_for_virtual_engine(self, virtual_engine):
+        return self.request_id is not None
+
+
+class MockAsyncAphrodite(AsyncAphrodite):
+
+    def _init_engine(self, *args, **kwargs):
+        return MockEngine()
+
+
+@pytest.mark.asyncio
+async def test_new_requests_event():
+    engine = MockAsyncAphrodite(worker_use_ray=False, engine_use_ray=False)
+    engine.start_background_loop()
+    await asyncio.sleep(0.01)
+    assert engine.engine.step_calls == 0
+
+    await engine.add_request("1", "", None)
+    await asyncio.sleep(0.01)
+    assert engine.engine.add_request_calls == 1
+    assert engine.engine.step_calls == 1
+
+    await engine.add_request("2", "", None)
+    engine.engine.generate("2")
+    await asyncio.sleep(0)
+    await asyncio.sleep(0)
+    await asyncio.sleep(0)
+    assert engine.engine.add_request_calls == 2
+    assert engine.engine.step_calls >= 2
+    await asyncio.sleep(0.001)
+    assert engine.engine.step_calls >= 3
+    engine.engine.stop_generating()
+    await asyncio.sleep(0.001)
+    old_step_calls = engine.engine.step_calls
+    await asyncio.sleep(0.001)
+    assert engine.engine.step_calls == old_step_calls
+
+    await engine.add_request("3", "", None)
+    await asyncio.sleep(0.01)
+    assert engine.engine.add_request_calls == 3
+    assert engine.engine.step_calls == old_step_calls + 1
+    await asyncio.sleep(0.01)
+    assert engine.engine.add_request_calls == 3
+    assert engine.engine.step_calls == old_step_calls + 1
+
+    # Allow deprecated engine_use_ray to not raise exception
+    os.environ["APHRODITE_ALLOW_ENGINE_USE_RAY"] = "1"
+
+    engine = MockAsyncAphrodite(worker_use_ray=True, engine_use_ray=True)
+    assert engine.get_model_config() is not None
+    assert engine.get_tokenizer() is not None
+    assert engine.get_decoding_config() is not None
+
+    os.environ.pop("APHRODITE_ALLOW_ENGINE_USE_RAY")
+
+
+def test_asyncio_run():
+    wait_for_gpu_memory_to_clear(
+        devices=list(range(torch.cuda.device_count())),
+        threshold_bytes=2 * 2**30,
+        timeout_s=60,
+    )
+
+    engine = AsyncAphrodite.from_engine_args(
+        AsyncEngineArgs(model="facebook/opt-125m"))
+
+    async def run(prompt: str):
+        sampling_params = SamplingParams(
+            temperature=0,
+            max_tokens=32,
+        )
+
+        async for output in engine.generate(prompt,
+                                            sampling_params,
+                                            request_id=prompt):
+            final_output = output
+        return final_output
+
+    async def generate():
+        return await asyncio.gather(
+            run("test0"),
+            run("test1"),
+        )
+
+    results = asyncio.run(generate())
+    assert len(results) == 2

+ 101 - 0
tests/async_aphrodite/test_chat_template.py

@@ -0,0 +1,101 @@
+import pytest
+
+from aphrodite.endpoints.chat_utils import (apply_chat_template,
+                                            load_chat_template)
+from aphrodite.endpoints.openai.protocol import ChatCompletionRequest
+from aphrodite.transformers_utils.tokenizer import get_tokenizer
+
+from ..utils import APHRODITE_PATH
+
+chatml_jinja_path = APHRODITE_PATH / "examples/chat_templates/chatml.jinja"
+assert chatml_jinja_path.exists()
+
+# Define models, templates, and their corresponding expected outputs
+MODEL_TEMPLATE_GENERATON_OUTPUT = [
+    ("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user
+Hello<|im_end|>
+<|im_start|>assistant
+Hi there!<|im_end|>
+<|im_start|>user
+What is the capital of<|im_end|>
+<|im_start|>assistant
+"""),
+    ("facebook/opt-125m", chatml_jinja_path, False, """<|im_start|>user
+Hello<|im_end|>
+<|im_start|>assistant
+Hi there!<|im_end|>
+<|im_start|>user
+What is the capital of""")
+]
+
+TEST_MESSAGES = [
+    {
+        'role': 'user',
+        'content': 'Hello'
+    },
+    {
+        'role': 'assistant',
+        'content': 'Hi there!'
+    },
+    {
+        'role': 'user',
+        'content': 'What is the capital of'
+    },
+]
+
+
+def test_load_chat_template():
+    # Testing chatml template
+    template_content = load_chat_template(chat_template=chatml_jinja_path)
+
+    # Test assertions
+    assert template_content is not None
+    # Hard coded value for template_chatml.jinja
+    assert template_content == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
+{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}"""  # noqa: E501
+
+
+def test_no_load_chat_template_filelike():
+    # Testing chatml template
+    template = "../../examples/does_not_exist"
+
+    with pytest.raises(ValueError, match="looks like a file path"):
+        load_chat_template(chat_template=template)
+
+
+def test_no_load_chat_template_literallike():
+    # Testing chatml template
+    template = "{{ messages }}"
+
+    template_content = load_chat_template(chat_template=template)
+
+    assert template_content == template
+
+
+@pytest.mark.parametrize(
+    "model,template,add_generation_prompt,expected_output",
+    MODEL_TEMPLATE_GENERATON_OUTPUT)
+def test_get_gen_prompt(model, template, add_generation_prompt,
+                        expected_output):
+    # Initialize the tokenizer
+    tokenizer = get_tokenizer(tokenizer_name=model)
+    template_content = load_chat_template(chat_template=template)
+
+    # Create a mock request object using keyword arguments
+    mock_request = ChatCompletionRequest(
+        model=model,
+        messages=TEST_MESSAGES,
+        add_generation_prompt=add_generation_prompt)
+
+    # Call the function and get the result
+    result = apply_chat_template(
+        tokenizer,
+        conversation=mock_request.messages,
+        chat_template=mock_request.chat_template or template_content,
+        add_generation_prompt=mock_request.add_generation_prompt,
+    )
+
+    # Test assertion
+    assert result == expected_output, (
+        f"The generated prompt does not match the expected output for "
+        f"model {model} and template {template}")

+ 109 - 0
tests/async_aphrodite/test_openapi_server_ray.py

@@ -0,0 +1,109 @@
+import openai  # use the official client for correctness check
+import pytest
+
+from ..utils import APHRODITE_PATH, RemoteOpenAIServer
+
+# any model with a chat template should work here
+MODEL_NAME = "facebook/opt-125m"
+chatml_jinja_path = APHRODITE_PATH / "examples/chat_templates/chatml.jinja"
+assert chatml_jinja_path.exists()
+
+
+@pytest.fixture(scope="module")
+def server():
+    args = [
+        # use half precision for speed and memory savings in CI environment
+        "--dtype",
+        "float16",
+        "--max-model-len",
+        "2048",
+        "--enforce-eager",
+        "--engine-use-ray",
+        "--chat-template",
+        str(chatml_jinja_path),
+    ]
+
+    # Allow `--engine-use-ray`, otherwise the launch of the server throw
+    # an error due to try to use a deprecated feature
+    env_dict = {"APHRODITE_ALLOW_ENGINE_USE_RAY": "1"}
+    with RemoteOpenAIServer(MODEL_NAME, args,
+                            env_dict=env_dict) as remote_server:
+        yield remote_server
+
+
+@pytest.fixture(scope="module")
+def client(server):
+    return server.get_async_client()
+
+
+@pytest.mark.asyncio
+async def test_check_models(client: openai.AsyncOpenAI):
+    models = await client.models.list()
+    models = models.data
+    served_model = models[0]
+    assert served_model.id == MODEL_NAME
+    assert all(model.root == MODEL_NAME for model in models)
+
+
+@pytest.mark.asyncio
+async def test_single_completion(client: openai.AsyncOpenAI):
+    completion = await client.completions.create(model=MODEL_NAME,
+                                                 prompt="Hello, my name is",
+                                                 max_tokens=5,
+                                                 temperature=0.0)
+
+    assert completion.id is not None
+    assert len(completion.choices) == 1
+    assert len(completion.choices[0].text) >= 5
+    assert completion.choices[0].finish_reason == "length"
+    assert completion.usage == openai.types.CompletionUsage(
+        completion_tokens=5, prompt_tokens=6, total_tokens=11)
+
+    # test using token IDs
+    completion = await client.completions.create(
+        model=MODEL_NAME,
+        prompt=[0, 0, 0, 0, 0],
+        max_tokens=5,
+        temperature=0.0,
+    )
+    assert len(completion.choices[0].text) >= 5
+
+
+@pytest.mark.asyncio
+async def test_single_chat_session(client: openai.AsyncOpenAI):
+    messages = [{
+        "role": "system",
+        "content": "you are a helpful assistant"
+    }, {
+        "role": "user",
+        "content": "what is 1+1?"
+    }]
+
+    # test single completion
+    chat_completion = await client.chat.completions.create(model=MODEL_NAME,
+                                                           messages=messages,
+                                                           max_tokens=10,
+                                                           logprobs=True,
+                                                           top_logprobs=5)
+    assert chat_completion.id is not None
+    assert len(chat_completion.choices) == 1
+
+    choice = chat_completion.choices[0]
+    assert choice.finish_reason == "length"
+    assert chat_completion.usage == openai.types.CompletionUsage(
+        completion_tokens=10, prompt_tokens=55, total_tokens=65)
+
+    message = choice.message
+    assert message.content is not None and len(message.content) >= 10
+    assert message.role == "assistant"
+    messages.append({"role": "assistant", "content": message.content})
+
+    # test multi-turn dialogue
+    messages.append({"role": "user", "content": "express your result in json"})
+    chat_completion = await client.chat.completions.create(
+        model=MODEL_NAME,
+        messages=messages,
+        max_tokens=10,
+    )
+    message = chat_completion.choices[0].message
+    assert message.content is not None and len(message.content) >= 0

+ 68 - 0
tests/async_aphrodite/test_request_tracker.py

@@ -0,0 +1,68 @@
+import pytest
+
+from aphrodite.common.outputs import RequestOutput
+from aphrodite.engine.async_aphrodite import RequestTracker
+
+
+@pytest.mark.asyncio
+async def test_request_tracker():
+    tracker = RequestTracker()
+    stream_1 = tracker.add_request("1")
+    assert tracker.new_requests_event.is_set()
+    await tracker.wait_for_new_requests()
+    new, aborted = tracker.get_new_and_aborted_requests()
+    assert not tracker.new_requests_event.is_set()
+    assert len(new) == 1
+    assert new[0]["request_id"] == "1"
+    assert not aborted
+    assert not stream_1.finished
+
+    stream_2 = tracker.add_request("2")
+    stream_3 = tracker.add_request("3")
+    assert tracker.new_requests_event.is_set()
+    await tracker.wait_for_new_requests()
+    new, aborted = tracker.get_new_and_aborted_requests()
+    assert not tracker.new_requests_event.is_set()
+    assert len(new) == 2
+    assert new[0]["request_id"] == "2"
+    assert new[1]["request_id"] == "3"
+    assert not aborted
+    assert not stream_2.finished
+    assert not stream_3.finished
+
+    # request_ids must be unique
+    with pytest.raises(KeyError):
+        tracker.add_request("1")
+    assert not tracker.new_requests_event.is_set()
+
+    tracker.abort_request("1")
+    new, aborted = tracker.get_new_and_aborted_requests()
+    assert len(aborted) == 1
+    assert "1" in aborted
+    assert not new
+    assert stream_1.finished
+
+    stream_4 = tracker.add_request("4")
+    tracker.abort_request("4")
+    assert tracker.new_requests_event.is_set()
+    await tracker.wait_for_new_requests()
+    new, aborted = tracker.get_new_and_aborted_requests()
+    # aborted new requests will cancel each other out -
+    # there's no need for them to propagate into the
+    # engine
+    assert not aborted
+    assert not new
+    assert stream_4.finished
+
+    stream_5 = tracker.add_request("5")
+    assert tracker.new_requests_event.is_set()
+    tracker.process_request_output(
+        RequestOutput("2", "output", [], [], [], finished=True))
+    await tracker.wait_for_new_requests()
+    new, aborted = tracker.get_new_and_aborted_requests()
+    assert not tracker.new_requests_event.is_set()
+    assert not aborted
+    assert len(new) == 1
+    assert new[0]["request_id"] == "5"
+    assert stream_2.finished
+    assert not stream_5.finished

+ 0 - 0
tests/basic_correctness/__init__.py


+ 67 - 0
tests/basic_correctness/test_basic_correctness.py

@@ -0,0 +1,67 @@
+"""Compare the short outputs of HF and Aphrodite when using greedy sampling.
+
+Run `pytest tests/basic_correctness/test_basic_correctness.py`.
+"""
+import os
+import weakref
+
+import pytest
+
+from aphrodite import LLM
+from aphrodite.common.utils import is_hip
+
+from ..models.utils import check_outputs_equal
+
+MODELS = [
+    "facebook/opt-125m",
+    "meta-llama/Llama-2-7b-hf",
+]
+
+
+def test_aphrodite_gc_ed():
+    """Verify aphrodite instance is GC'ed when it is deleted"""
+    llm = LLM("facebook/opt-125m")
+    weak_llm = weakref.ref(llm)
+    del llm
+    # If there's any circular reference to aphrodite, this fails
+    # because llm instance is not GC'ed.
+    assert weak_llm() is None
+
+
+@pytest.mark.parametrize("model", MODELS)
+@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"])
+@pytest.mark.parametrize("dtype", ["half"])
+@pytest.mark.parametrize("max_tokens", [5])
+@pytest.mark.parametrize("enforce_eager", [False, True])
+def test_models(
+    hf_runner,
+    aphrodite_runner,
+    example_prompts,
+    model: str,
+    backend: str,
+    dtype: str,
+    max_tokens: int,
+    enforce_eager: bool,
+) -> None:
+
+    if backend == "FLASHINFER" and is_hip():
+        pytest.skip("Flashinfer does not support ROCm/HIP.")
+
+    os.environ["APHRODITE_ATTENTION_BACKEND"] = backend
+
+    with hf_runner(model, dtype=dtype) as hf_model:
+        hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
+
+    with aphrodite_runner(model,
+                     dtype=dtype,
+                     enforce_eager=enforce_eager,
+                     gpu_memory_utilization=0.7) as aphrodite_model:
+        aphrodite_outputs = aphrodite_model.generate_greedy(example_prompts,
+                                                            max_tokens)
+
+    check_outputs_equal(
+        outputs_0_lst=hf_outputs,
+        outputs_1_lst=aphrodite_outputs,
+        name_0="hf",
+        name_1="aphrodite",
+    )

+ 153 - 0
tests/basic_correctness/test_chunked_prefill.py

@@ -0,0 +1,153 @@
+"""Compare the outputs of HF and Aphrodite when using greedy sampling.
+
+It tests chunked prefill. Chunked prefill can be enabled by
+enable_chunked_prefill=True. If prefill size exceeds max_num_batched_tokens,
+prefill requests are chunked.
+
+Run `pytest tests/models/test_chunked_prefill.py`.
+"""
+
+import pytest
+
+from ..models.utils import check_logprobs_close, check_outputs_equal
+
+MODELS = [
+    "facebook/opt-125m",
+    "meta-llama/Llama-2-7b-hf",
+]
+E5M2_KV_MODELS = [
+    "facebook/opt-125m",
+    "meta-llama/Llama-2-7b-chat-hf",
+]
+E4M3_KV_MODELS = [
+    "meta-llama/Llama-2-7b-chat-hf", "nm-testing/Qwen2-1.5B-Instruct-FP8-K-V",
+    "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
+]
+KV_CACHE_QUANTIZATION_PATHS = {
+    "meta-llama/Llama-2-7b-chat-hf":
+    "./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json"
+}
+
+
+@pytest.mark.parametrize("model", MODELS)
+@pytest.mark.parametrize("dtype", ["half"])
+@pytest.mark.parametrize("max_tokens", [32])
+@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
+@pytest.mark.parametrize("enforce_eager", [False, True])
+# NOTE: Increasing this in this suite will fail CI because we currently cannot
+# reset distributed env properly. Use a value > 1 just when you test.
+@pytest.mark.parametrize("tensor_parallel_size", [1])
+def test_models(
+    hf_runner,
+    aphrodite_runner,
+    example_prompts,
+    model: str,
+    dtype: str,
+    max_tokens: int,
+    chunked_prefill_token_size: int,
+    enforce_eager: bool,
+    tensor_parallel_size: int,
+) -> None:
+    """
+    Checks exact match decode between huggingface model and aphrodite runner
+    with chunked prefill.
+    """
+    max_num_seqs = chunked_prefill_token_size
+    max_num_batched_tokens = chunked_prefill_token_size
+
+    with hf_runner(model, dtype=dtype) as hf_model:
+        hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
+
+    with aphrodite_runner(
+            model,
+            dtype=dtype,
+            max_num_batched_tokens=max_num_batched_tokens,
+            enable_chunked_prefill=True,
+            tensor_parallel_size=tensor_parallel_size,
+            enforce_eager=enforce_eager,
+            max_num_seqs=max_num_seqs,
+    ) as aphrodite_model:
+        aphrodite_outputs = aphrodite_model.generate_greedy(example_prompts,
+                                                            max_tokens)
+
+    check_outputs_equal(
+        outputs_0_lst=hf_outputs,
+        outputs_1_lst=aphrodite_outputs,
+        name_0="hf",
+        name_1="aphrodite",
+    )
+
+
+@pytest.mark.parametrize("kv_cache_dtype,model",
+                         [("fp8_e5m2", m)
+                          for m in E5M2_KV_MODELS] + [("fp8_e4m3", m)
+                                                      for m in E4M3_KV_MODELS])
+# Due to low-precision numerical divergence, we only test logprob of 4 tokens
+@pytest.mark.parametrize("max_tokens", [4])
+@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
+@pytest.mark.parametrize("enforce_eager", [False, True])
+# NOTE: Increasing this in this suite will fail CI because we currently cannot
+# reset distributed env properly. Use a value > 1 just when you test.
+@pytest.mark.parametrize("tensor_parallel_size", [1])
+def test_models_with_fp8_kv_cache(
+    aphrodite_runner,
+    example_prompts,
+    kv_cache_dtype: str,
+    model: str,
+    max_tokens: int,
+    chunked_prefill_token_size: int,
+    enforce_eager: bool,
+    tensor_parallel_size: int,
+) -> None:
+    """
+    Only checks log probs match between chunked-prefill and
+    non-chunked-prefill version of Aphrodite model runner.
+    
+    This test is used when there is discrepancy in kernels
+    / numerics (e.g. when using lower-precision types like FP8).
+    """
+    NUM_LOG_PROBS = 8
+
+    if model == "facebook/opt-125m":
+        pytest.skip(
+            "#7378: CUDA illegal memory access (undiagnosed) facebook/opt-125m"
+        )
+
+    max_num_seqs = chunked_prefill_token_size
+    max_num_batched_tokens = chunked_prefill_token_size
+
+    extra_kwargs = {}
+    if model in KV_CACHE_QUANTIZATION_PATHS:
+        extra_kwargs["quantization_param_path"] = KV_CACHE_QUANTIZATION_PATHS[
+            model]
+
+    with aphrodite_runner(
+            model,
+            tensor_parallel_size=tensor_parallel_size,
+            enforce_eager=enforce_eager,
+            max_num_seqs=max_num_seqs,
+            kv_cache_dtype=kv_cache_dtype,
+            **extra_kwargs,
+    ) as aphrodite_model:
+        no_chunked_prefill_outputs = aphrodite_model.generate_greedy_logprobs(
+            example_prompts, max_tokens, NUM_LOG_PROBS)
+
+    with aphrodite_runner(
+            model,
+            max_num_batched_tokens=max_num_batched_tokens,
+            enable_chunked_prefill=True,
+            tensor_parallel_size=tensor_parallel_size,
+            enforce_eager=enforce_eager,
+            max_num_seqs=max_num_seqs,
+            kv_cache_dtype=kv_cache_dtype,
+            **extra_kwargs,
+    ) as aphrodite_model:
+        chunked_prefill_outputs = aphrodite_model.generate_greedy_logprobs(
+            example_prompts, max_tokens, NUM_LOG_PROBS)
+
+    check_logprobs_close(
+        outputs_0_lst=no_chunked_prefill_outputs,
+        outputs_1_lst=chunked_prefill_outputs,
+        name_0="no_chunked_prefill",
+        name_1="chunked_prefill",
+    )

+ 6 - 0
tests/basic_correctness/test_cpu_offload.py

@@ -0,0 +1,6 @@
+from ..utils import compare_two_settings
+
+
+def test_cpu_offload():
+    compare_two_settings("meta-llama/Llama-2-7b-hf", [],
+                         ["--cpu-offload-gb", "4"])

+ 271 - 0
tests/basic_correctness/test_preemption.py

@@ -0,0 +1,271 @@
+"""Compare the short outputs of HF and Aphrodite when using greedy sampling.
+
+APHRODITE_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 has to be set before running this
+test.
+
+Run `APHRODITE_TEST_ENABLE_ARTIFICIAL_PREEMPT=1
+pytest tests/basic_correctness/test_preemption.py`.
+"""
+import pytest
+from prometheus_client import REGISTRY
+
+from aphrodite import SamplingParams
+from aphrodite.processing.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT,
+                                            ENABLE_ARTIFICIAL_PREEMPT)
+
+from ..models.utils import check_outputs_equal
+
+MODELS = [
+    "facebook/opt-125m",
+]
+
+assert ENABLE_ARTIFICIAL_PREEMPT is True, (
+    "Use an env var APHRODITE_TEST_ENABLE_ARTIFICIAL_PREEMPT=1. "
+    "`APHRODITE_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest "
+    "tests/basic_correctness/test_preemption.py`")
+
+
+@pytest.mark.parametrize("model", MODELS)
+@pytest.mark.parametrize("dtype", ["half"])
+@pytest.mark.parametrize("max_tokens", [96])
+@pytest.mark.parametrize("chunked_prefill_token_size", [16])
+def test_chunked_prefill_recompute(
+    hf_runner,
+    aphrodite_runner,
+    example_prompts,
+    model: str,
+    dtype: str,
+    max_tokens: int,
+    chunked_prefill_token_size: int,
+) -> None:
+    """Ensure that chunked prefill works with preemption."""
+    max_num_seqs = min(chunked_prefill_token_size, 256)
+    enable_chunked_prefill = False
+    max_num_batched_tokens = None
+    if chunked_prefill_token_size != -1:
+        enable_chunked_prefill = True
+        max_num_batched_tokens = chunked_prefill_token_size
+
+    with hf_runner(model, dtype=dtype) as hf_model:
+        hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
+
+    with aphrodite_runner(
+            model,
+            dtype=dtype,
+            max_num_batched_tokens=max_num_batched_tokens,
+            enable_chunked_prefill=enable_chunked_prefill,
+            max_num_seqs=max_num_seqs,
+    ) as aphrodite_model:
+        aphrodite_outputs = aphrodite_model.generate_greedy(example_prompts,
+                                                            max_tokens)
+        assert (
+            aphrodite_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
+                < ARTIFICIAL_PREEMPTION_MAX_CNT)
+
+    for i in range(len(example_prompts)):
+        hf_output_ids, hf_output_str = hf_outputs[i]
+        aphrodite_output_ids, aphrodite_output_str = aphrodite_outputs[i]
+        assert hf_output_str == aphrodite_output_str, (
+            f"Test{i}:\nHF: {hf_output_str!r}\nAphrodite: "
+            f"{aphrodite_output_str!r}")
+        assert hf_output_ids == aphrodite_output_ids, (
+            f"Test{i}:\nHF: {hf_output_ids}\nAphrodite: {aphrodite_output_ids}")
+
+
+@pytest.mark.parametrize("model", MODELS)
+@pytest.mark.parametrize("dtype", ["float"])
+@pytest.mark.parametrize("max_tokens", [96])
+def test_preemption(
+    caplog_aphrodite,
+    hf_runner,
+    aphrodite_runner,
+    example_prompts,
+    model: str,
+    dtype: str,
+    max_tokens: int,
+) -> None:
+    """By default, recompute preemption is enabled"""
+
+    with hf_runner(model, dtype=dtype) as hf_model:
+        hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
+
+    with aphrodite_runner(
+            model,
+            dtype=dtype,
+            disable_log_stats=False,
+    ) as aphrodite_model:
+        aphrodite_outputs = aphrodite_model.generate_greedy(example_prompts,
+                                                            max_tokens)
+        assert (
+            aphrodite_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
+                < ARTIFICIAL_PREEMPTION_MAX_CNT)
+        total_preemption = (
+            aphrodite_model.model.llm_engine.scheduler[0].num_cumulative_preemption)
+
+    check_outputs_equal(
+        outputs_0_lst=hf_outputs,
+        outputs_1_lst=aphrodite_outputs,
+        name_0="hf",
+        name_1="aphrodite",
+    )
+
+    assert ("is preempted by PreemptionMode.RECOMPUTE mode because there "
+            "is not enough KV cache space." in caplog_aphrodite.text)
+    # Ensure the count bucket of request-level histogram metrics matches
+    # the number of requests as a simple sanity check to ensure metrics are
+    # generated
+    preemption_metrics = None
+    for m in REGISTRY.collect():
+        if m.name == "aphrodite:num_preemptions":
+            preemption_metrics = m
+    assert preemption_metrics is not None
+    total_recorded_preemption = 0
+    for sample in preemption_metrics.samples:
+        total_recorded_preemption += sample.value
+    assert total_preemption == total_recorded_preemption
+
+
+@pytest.mark.parametrize("model", MODELS)
+@pytest.mark.parametrize("dtype", ["float"])
+@pytest.mark.parametrize("max_tokens", [96])
+@pytest.mark.parametrize("beam_width", [4])
+def test_swap(
+    caplog_aphrodite,
+    hf_runner,
+    aphrodite_runner,
+    example_prompts,
+    model: str,
+    dtype: str,
+    max_tokens: int,
+    beam_width: int,
+) -> None:
+    """Use beam search enables swapping."""
+    example_prompts = example_prompts[:1]
+    with hf_runner(model, dtype=dtype) as hf_model:
+        hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width,
+                                                   max_tokens)
+
+    with aphrodite_runner(
+            model,
+            dtype=dtype,
+            swap_space=10,
+            disable_log_stats=False,
+    ) as aphrodite_model:
+        aphrodite_outputs = aphrodite_model.generate_beam_search(
+            example_prompts, beam_width, max_tokens)
+        assert (
+            aphrodite_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
+                < ARTIFICIAL_PREEMPTION_MAX_CNT)
+        total_preemption = (
+            aphrodite_model.model.llm_engine.scheduler[0].num_cumulative_preemption)
+
+    for i in range(len(example_prompts)):
+        hf_output_ids, _ = hf_outputs[i]
+        aphrodite_output_ids, _ = aphrodite_outputs[i]
+        assert len(hf_output_ids) == len(aphrodite_output_ids)
+        for j in range(len(hf_output_ids)):
+            assert hf_output_ids[j] == aphrodite_output_ids[j], (
+                f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
+                f"Aphrodite: {aphrodite_output_ids}")
+
+    assert ("is preempted by PreemptionMode.SWAP mode because there "
+            "is not enough KV cache space." in caplog_aphrodite.text)
+    # Ensure the count bucket of request-level histogram metrics matches
+    # the number of requests as a simple sanity check to ensure metrics are
+    # generated
+    preemption_metrics = None
+    for m in REGISTRY.collect():
+        if m.name == "aphrodite:num_preemptions":
+            preemption_metrics = m
+    assert preemption_metrics is not None
+    total_recorded_preemption = 0
+    for sample in preemption_metrics.samples:
+        total_recorded_preemption += sample.value
+    assert total_preemption == total_recorded_preemption
+
+
+@pytest.mark.parametrize("model", MODELS)
+@pytest.mark.parametrize("dtype", ["float"])
+@pytest.mark.parametrize("max_tokens", [96])
+@pytest.mark.parametrize("beam_width", [4])
+def test_swap_infeasible(
+    aphrodite_runner,
+    example_prompts,
+    model: str,
+    dtype: str,
+    max_tokens: int,
+    beam_width: int,
+) -> None:
+    """Verify infeasible swap request will be ignored."""
+    BLOCK_SIZE = 16
+    prefill_blocks = 2
+    decode_blocks = max_tokens // BLOCK_SIZE
+    example_prompts = example_prompts[:1]
+
+    with aphrodite_runner(
+            model,
+            dtype=dtype,
+            swap_space=10,
+            block_size=BLOCK_SIZE,
+            # Since beam search have more than 1 sequence, prefill +
+            # decode blocks are not enough to finish.
+            num_gpu_blocks_override=prefill_blocks + decode_blocks,
+            max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE,
+    ) as aphrodite_model:
+        sampling_params = SamplingParams(n=beam_width,
+                                         use_beam_search=True,
+                                         temperature=0.0,
+                                         max_tokens=max_tokens,
+                                         ignore_eos=True)
+        req_outputs = aphrodite_model.model.generate(
+            example_prompts,
+            sampling_params=sampling_params,
+        )
+        assert (
+            aphrodite_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
+                < ARTIFICIAL_PREEMPTION_MAX_CNT)
+
+    # Verify the request is ignored and not hang.
+    assert req_outputs[0].outputs[0].finish_reason == "length"
+
+
+@pytest.mark.parametrize("model", MODELS)
+@pytest.mark.parametrize("dtype", ["float"])
+@pytest.mark.parametrize("max_tokens", [96])
+def test_preemption_infeasible(
+    aphrodite_runner,
+    example_prompts,
+    model: str,
+    dtype: str,
+    max_tokens: int,
+) -> None:
+    """Verify infeasible preemption request will be ignored."""
+    BLOCK_SIZE = 16
+    prefill_blocks = 2
+    decode_blocks = max_tokens // BLOCK_SIZE
+    with aphrodite_runner(
+            model,
+            dtype=dtype,
+            block_size=BLOCK_SIZE,
+            # Not enough gpu blocks to complete a single sequence.
+            # preemption should happen, and the sequence should be
+            # ignored instead of hanging forever.
+            num_gpu_blocks_override=prefill_blocks + decode_blocks // 2,
+            max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE),
+    ) as aphrodite_model:
+        sampling_params = SamplingParams(max_tokens=max_tokens,
+                                         ignore_eos=True)
+        req_outputs = aphrodite_model.model.generate(
+            example_prompts,
+            sampling_params=sampling_params,
+        )
+
+        assert (
+            aphrodite_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
+                < ARTIFICIAL_PREEMPTION_MAX_CNT)
+
+    # Verify the request is ignored and not hang.
+    for req_output in req_outputs:
+        outputs = req_output.outputs
+        assert len(outputs) == 1
+        assert outputs[0].finish_reason == "length"

+ 20 - 0
tests/compile/test_full_graph.py

@@ -0,0 +1,20 @@
+import os
+
+import pytest
+
+
+@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
+def test_full_graph(model):
+    # make sure these models can be captured in full graph mode
+    os.environ["APHRODITE_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
+
+    from aphrodite import LLM, SamplingParams
+    prompts = [
+        "Hello, my name is",
+        "The president of the United States is",
+        "The capital of France is",
+        "The future of AI is",
+    ]
+    sampling_params = SamplingParams(temperature=0)
+    llm = LLM(model="meta-llama/Meta-Llama-3-8B")
+    llm.generate(prompts, sampling_params)

+ 584 - 60
tests/conftest.py

@@ -1,12 +1,38 @@
+import contextlib
+import gc
+import json
 import os
-from typing import List, Optional, Tuple
+import sys
+import tempfile
+from collections import UserList
+from enum import Enum
+from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict,
+                    TypeVar, Union)
 
 import pytest
 import torch
-from transformers import AutoModelForCausalLM
+import torch.nn as nn
+import torch.nn.functional as F
+from huggingface_hub import snapshot_download
+from loguru import logger
+from PIL import Image
+from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
+                          AutoModelForVision2Seq, AutoTokenizer, BatchEncoding,
+                          BatchFeature)
 
 from aphrodite import LLM, SamplingParams
-from aphrodite.transformers_utils.tokenizer import get_tokenizer
+from aphrodite.assets.image import ImageAsset
+from aphrodite.common.config import TokenizerPoolConfig
+from aphrodite.common.outputs import RequestOutput
+from aphrodite.common.sequence import SampleLogprobs
+from aphrodite.common.utils import (STR_DTYPE_TO_TORCH_DTYPE,
+                                    cuda_device_count_stateless, identity,
+                                    is_cpu)
+from aphrodite.connections import global_http_connection
+from aphrodite.distributed import (destroy_distributed_environment,
+                                   destroy_model_parallel)
+from aphrodite.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
+                              to_enc_dec_tuple_list, zip_enc_dec_prompts)
 
 _TEST_DIR = os.path.dirname(__file__)
 _TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
@@ -19,6 +45,80 @@ def _read_prompts(filename: str) -> List[str]:
         return prompts
 
 
+class _ImageAssetPrompts(TypedDict):
+    stop_sign: str
+    cherry_blossom: str
+
+
+if sys.version_info < (3, 9):
+    # UserList cannot be subscripted
+    class _ImageAssetsBase(UserList):
+        pass
+else:
+
+    class _ImageAssetsBase(UserList[ImageAsset]):
+        pass
+
+
+class _ImageAssets(_ImageAssetsBase):
+
+    def __init__(self) -> None:
+        super().__init__([
+            ImageAsset("stop_sign"),
+            ImageAsset("cherry_blossom"),
+        ])
+
+    def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
+        """
+        Convenience method to define the prompt for each test image.
+
+        The order of the returned prompts matches the order of the
+        assets when iterating through this object.
+        """
+        return [prompts["stop_sign"], prompts["cherry_blossom"]]
+
+
+IMAGE_ASSETS = _ImageAssets()
+"""Singleton instance of :class:`_ImageAssets`."""
+
+
+@pytest.fixture(autouse=True)
+def init_test_http_connection():
+    # pytest_asyncio may use a different event loop per test
+    # so we need to make sure the async client is created anew
+    global_http_connection.reuse_client = False
+
+
+def cleanup():
+    destroy_model_parallel()
+    destroy_distributed_environment()
+    with contextlib.suppress(AssertionError):
+        torch.distributed.destroy_process_group()
+    gc.collect()
+    if not is_cpu():
+        torch.cuda.empty_cache()
+
+
+@pytest.fixture()
+def should_do_global_cleanup_after_test(request) -> bool:
+    """Allow subdirectories to skip global cleanup by overriding this fixture.
+    This can provide a ~10x speedup for non-GPU unit tests since they don't need
+    to initialize torch.
+    """
+
+    if request.node.get_closest_marker("skip_global_cleanup"):
+        return False
+
+    return True
+
+
+@pytest.fixture(autouse=True)
+def cleanup_fixture(should_do_global_cleanup_after_test: bool):
+    yield
+    if should_do_global_cleanup_after_test:
+        cleanup()
+
+
 @pytest.fixture
 def example_prompts() -> List[str]:
     prompts = []
@@ -27,6 +127,46 @@ def example_prompts() -> List[str]:
     return prompts
 
 
+class DecoderPromptType(Enum):
+    """For encoder/decoder models only."""
+    CUSTOM = 1
+    NONE = 2
+    EMPTY_STR = 3
+
+
+@pytest.fixture
+def example_encoder_decoder_prompts(
+) -> Dict[DecoderPromptType, List[ExplicitEncoderDecoderPrompt]]:
+    '''
+    Returns an encoder prompt list and a decoder prompt list, wherein each pair
+    of same-index entries in both lists corresponds to an (encoder prompt,
+    decoder prompt) tuple.
+
+    Returns:
+    
+    * Encoder prompt list
+    * Decoder prompt list (reverse of encoder prompt list)
+    '''
+
+    encoder_prompts = []
+    for filename in _TEST_PROMPTS:
+        encoder_prompts += _read_prompts(filename)
+
+    custom_decoder_prompts = encoder_prompts[::-1]
+    empty_str_decoder_prompts = [""] * len(encoder_prompts)
+    none_decoder_prompts = [None] * len(encoder_prompts)
+
+    # NONE decoder prompt type
+    return {
+        DecoderPromptType.NONE:
+        zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
+        DecoderPromptType.EMPTY_STR:
+        zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
+        DecoderPromptType.CUSTOM:
+        zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
+    }
+
+
 @pytest.fixture
 def example_long_prompts() -> List[str]:
     prompts = []
@@ -35,46 +175,113 @@ def example_long_prompts() -> List[str]:
     return prompts
 
 
-_STR_DTYPE_TO_TORCH_DTYPE = {
-    "half": torch.half,
-    "bfloat16": torch.bfloat16,
-    "float": torch.float,
-}
+@pytest.fixture(scope="session")
+def image_assets() -> _ImageAssets:
+    return IMAGE_ASSETS
+
+
+_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature)
 
 
 class HfRunner:
 
+    def wrap_device(self, input: _T) -> _T:
+        if not is_cpu():
+            return input.to("cuda")
+        else:
+            return input.to("cpu")
+
     def __init__(
         self,
         model_name: str,
-        tokenizer_name: Optional[str] = None,
         dtype: str = "half",
+        *,
+        model_kwargs: Optional[Dict[str, Any]] = None,
+        is_embedding_model: bool = False,
+        is_vision_model: bool = False,
+        is_encoder_decoder_model: bool = False,
+        postprocess_inputs: Callable[[BatchEncoding],
+                                     BatchEncoding] = identity,
     ) -> None:
-        assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
-        torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
-        self.model = AutoModelForCausalLM.from_pretrained(
+        torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
+
+        self.model_name = model_name
+
+        if is_embedding_model:
+            # Lazy init required for AMD CI
+            from sentence_transformers import SentenceTransformer
+            self.model = self.wrap_device(
+                SentenceTransformer(
+                    model_name,
+                    device="cpu",
+                ).to(dtype=torch_dtype))
+        else:
+            if is_vision_model:
+                auto_cls = AutoModelForVision2Seq
+            elif is_encoder_decoder_model:
+                auto_cls = AutoModelForSeq2SeqLM
+            else:
+                auto_cls = AutoModelForCausalLM
+
+            model_kwargs = model_kwargs if model_kwargs is not None else {}
+            self.model = self.wrap_device(
+                auto_cls.from_pretrained(
+                    model_name,
+                    torch_dtype=torch_dtype,
+                    trust_remote_code=True,
+                    **model_kwargs,
+                ))
+
+        self.tokenizer = AutoTokenizer.from_pretrained(
             model_name,
             torch_dtype=torch_dtype,
             trust_remote_code=True,
-        ).cuda()
-        if tokenizer_name is None:
-            tokenizer_name = model_name
-        self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True)
+        )
+
+        try:
+            # don't put this import at the top level
+            # it will call torch.cuda.device_count()
+            from transformers import AutoProcessor  # noqa: F401
+            self.processor = AutoProcessor.from_pretrained(
+                model_name,
+                torch_dtype=torch_dtype,
+                trust_remote_code=True,
+            )
+        except Exception as exc:
+            logger.warning(
+                "Unable to auto-load HuggingFace processor for model (%s). "
+                "Using tokenizer instead. Reason: %s", model_name, exc)
+            self.processor = self.tokenizer
+
+        self.postprocess_inputs = postprocess_inputs
 
     def generate(
         self,
         prompts: List[str],
-        **kwargs,
-    ) -> List[Tuple[List[int], str]]:
-        outputs: List[Tuple[List[int], str]] = []
-        for prompt in prompts:
-            input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
+        images: Optional[List[Image.Image]] = None,
+        **kwargs: Any,
+    ) -> List[Tuple[List[List[int]], List[str]]]:
+        if images:
+            assert len(prompts) == len(images)
+
+        outputs: List[Tuple[List[List[int]], List[str]]] = []
+        for i, prompt in enumerate(prompts):
+            processor_kwargs: Dict[str, Any] = {
+                "text": prompt,
+                "return_tensors": "pt",
+            }
+            if images is not None and images[i] is not None:
+                processor_kwargs["images"] = images[i]
+
+            inputs = self.processor(**processor_kwargs)
+            inputs = self.postprocess_inputs(inputs)
+
             output_ids = self.model.generate(
-                input_ids.cuda(),
+                **self.wrap_device(inputs),
                 use_cache=True,
                 **kwargs,
             )
-            output_str = self.tokenizer.batch_decode(
+            output_str = self.processor.batch_decode(
                 output_ids,
                 skip_special_tokens=True,
                 clean_up_tokenization_spaces=False,
@@ -87,21 +294,24 @@ class HfRunner:
         self,
         prompts: List[str],
         max_tokens: int,
+        images: Optional[List[Image.Image]] = None,
+        **kwargs: Any,
     ) -> List[Tuple[List[int], str]]:
         outputs = self.generate(prompts,
                                 do_sample=False,
-                                max_new_tokens=max_tokens)
-        for i in range(len(outputs)):
-            output_ids, output_str = outputs[i]
-            outputs[i] = (output_ids[0], output_str[0])
-        return outputs
+                                max_new_tokens=max_tokens,
+                                images=images,
+                                **kwargs)
+
+        return [(output_ids[0], output_str[0])
+                for output_ids, output_str in outputs]
 
     def generate_beam_search(
         self,
         prompts: List[str],
         beam_width: int,
         max_tokens: int,
-    ) -> List[Tuple[List[int], str]]:
+    ) -> List[Tuple[List[List[int]], List[str]]]:
         outputs = self.generate(prompts,
                                 do_sample=False,
                                 max_new_tokens=max_tokens,
@@ -121,19 +331,31 @@ class HfRunner:
         self,
         prompts: List[str],
         max_tokens: int,
+        images: Optional[List[Image.Image]] = None,
+        **kwargs: Any,
     ) -> List[List[torch.Tensor]]:
-        all_logprobs = []
-        for prompt in prompts:
-            input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
+        all_logprobs: List[List[torch.Tensor]] = []
+        for i, prompt in enumerate(prompts):
+            processor_kwargs: Dict[str, Any] = {
+                "text": prompt,
+                "return_tensors": "pt",
+            }
+            if images is not None and images[i] is not None:
+                processor_kwargs["images"] = images[i]
+
+            inputs = self.processor(**processor_kwargs)
+            inputs = self.postprocess_inputs(inputs)
+
             output = self.model.generate(
-                input_ids.cuda(),
+                **self.wrap_device(inputs),
                 use_cache=True,
                 do_sample=False,
                 max_new_tokens=max_tokens,
                 output_hidden_states=True,
                 return_dict_in_generate=True,
+                **kwargs,
             )
-            seq_logprobs = []
+            seq_logprobs: List[torch.Tensor] = []
             for hidden_states in output.hidden_states:
                 last_hidden_states = hidden_states[-1][0]
                 logits = torch.matmul(
@@ -143,15 +365,162 @@ class HfRunner:
                 if self.model.get_output_embeddings().bias is not None:
                     logits += self.model.get_output_embeddings(
                     ).bias.unsqueeze(0)
-                logprobs = torch.nn.functional.log_softmax(logits,
-                                                           dim=-1,
-                                                           dtype=torch.float32)
+                logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
                 seq_logprobs.append(logprobs)
             all_logprobs.append(seq_logprobs)
         return all_logprobs
 
+    def _hidden_states_to_logprobs(
+        self,
+        hidden_states,
+        num_logprobs,
+    ) -> Tuple[List[Dict[int, float]], int]:
+        seq_logprobs: List[torch.Tensor] = []
+        output_len = len(hidden_states)
+        for _, hidden_state in enumerate(hidden_states):
+            last_hidden_states = hidden_state[-1][0]
+            logits = torch.matmul(
+                last_hidden_states,
+                self.model.get_output_embeddings().weight.t(),
+            )
+            if getattr(self.model.get_output_embeddings(), "bias",
+                       None) is not None:
+                logits += self.model.get_output_embeddings().bias.unsqueeze(0)
+            logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
+            seq_logprobs.append(logprobs)
+
+        # convert to dict
+        seq_logprobs_lst: List[Dict[int, float]] = []
+        for tok_idx, tok_logprobs in enumerate(seq_logprobs):
+            # drop prompt logprobs
+            if tok_idx == 0:
+                tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
+            topk = tok_logprobs.topk(num_logprobs)
+
+            tok_logprobs_dct = {}
+            for token_id, logprob in zip(topk.indices[0], topk.values[0]):
+                tok_logprobs_dct[token_id.item()] = logprob.item()
+
+            seq_logprobs_lst.append(tok_logprobs_dct)
+
+        return (
+            seq_logprobs_lst,
+            output_len,
+        )
 
-@pytest.fixture
+    def generate_greedy_logprobs_limit(
+        self,
+        prompts: List[str],
+        max_tokens: int,
+        num_logprobs: int,
+        images: Optional[List[Image.Image]] = None,
+        **kwargs: Any,
+    ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
+        all_logprobs: List[List[Dict[int, float]]] = []
+        all_output_ids: List[List[int]] = []
+        all_output_strs: List[str] = []
+
+        for i, prompt in enumerate(prompts):
+            processor_kwargs: Dict[str, Any] = {
+                "text": prompt,
+                "return_tensors": "pt",
+            }
+            if images is not None and images[i] is not None:
+                processor_kwargs["images"] = images[i]
+
+            inputs = self.processor(**processor_kwargs)
+            inputs = self.postprocess_inputs(inputs)
+
+            output = self.model.generate(
+                **self.wrap_device(inputs),
+                use_cache=True,
+                do_sample=False,
+                max_new_tokens=max_tokens,
+                output_hidden_states=True,
+                return_dict_in_generate=True,
+                **kwargs,
+            )
+
+            (
+                seq_logprobs_lst,
+                output_len,
+            ) = self._hidden_states_to_logprobs(output.hidden_states,
+                                                num_logprobs)
+
+            all_logprobs.append(seq_logprobs_lst)
+            seq_ids = output.sequences[0]
+            output_len = len(seq_logprobs_lst)
+            output_ids = seq_ids[-output_len:]
+            all_output_ids.append(output_ids.tolist())
+            all_output_strs.append(self.tokenizer.decode(output_ids))
+
+        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
+        return [(output_ids, output_str, output_logprobs)
+                for output_ids, output_str, output_logprobs in outputs]
+
+    def generate_encoder_decoder_greedy_logprobs_limit(
+        self,
+        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
+        max_tokens: int,
+        num_logprobs: int,
+        **kwargs: Any,
+    ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
+        '''
+        Greedy logprobs generation for Aphrodite encoder/decoder models
+        '''
+
+        all_logprobs: List[List[Dict[int, float]]] = []
+        all_output_ids: List[List[int]] = []
+        all_output_strs: List[str] = []
+
+        for (encoder_prompt,
+             decoder_prompt) in to_enc_dec_tuple_list(encoder_decoder_prompts):
+            encoder_input_ids = self.wrap_device(
+                self.tokenizer(encoder_prompt, return_tensors="pt").input_ids)
+            decoder_input_ids = (
+                None if decoder_prompt is None else self.wrap_device(
+                    self.tokenizer(decoder_prompt,
+                                   return_tensors="pt").input_ids))
+
+            output = self.model.generate(
+                encoder_input_ids,
+                decoder_input_ids=decoder_input_ids,
+                use_cache=True,
+                do_sample=False,
+                max_new_tokens=max_tokens,
+                output_hidden_states=True,
+                return_dict_in_generate=True,
+                **kwargs,
+            )
+
+            (
+                seq_logprobs_lst,
+                output_len,
+            ) = self._hidden_states_to_logprobs(output.decoder_hidden_states,
+                                                num_logprobs)
+
+            all_logprobs.append(seq_logprobs_lst)
+            seq_ids = output.sequences[0]
+            output_ids = seq_ids[-output_len:]
+            all_output_ids.append(output_ids.tolist())
+            all_output_strs.append(self.tokenizer.decode(output_ids))
+
+        outputs = zip(all_output_ids, all_output_strs, all_logprobs)
+        return [(output_ids, output_str, output_logprobs)
+                for output_ids, output_str, output_logprobs in outputs]
+
+    def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
+        return self.model.encode(prompts)
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        del self.model
+        cleanup()
+
+
+@pytest.fixture(scope="session")
 def hf_runner():
     return HfRunner
 
@@ -162,9 +531,16 @@ class AphroditeRunner:
         self,
         model_name: str,
         tokenizer_name: Optional[str] = None,
+        # Use smaller max model length, otherwise bigger model cannot run due
+        # to kv cache size limit.
+        max_model_len: int = 1024,
         dtype: str = "half",
         disable_log_stats: bool = True,
         tensor_parallel_size: int = 1,
+        block_size: int = 16,
+        enable_chunked_prefill: bool = False,
+        swap_space: int = 4,
+        enforce_eager: Optional[bool] = False,
         **kwargs,
     ) -> None:
         self.model = LLM(
@@ -172,9 +548,13 @@ class AphroditeRunner:
             tokenizer=tokenizer_name,
             trust_remote_code=True,
             dtype=dtype,
-            swap_space=0,
+            swap_space=swap_space,
+            enforce_eager=enforce_eager,
             disable_log_stats=disable_log_stats,
             tensor_parallel_size=tensor_parallel_size,
+            max_model_len=max_model_len,
+            block_size=block_size,
+            enable_chunked_prefill=enable_chunked_prefill,
             **kwargs,
         )
 
@@ -182,48 +562,90 @@ class AphroditeRunner:
         self,
         prompts: List[str],
         sampling_params: SamplingParams,
-    ) -> List[Tuple[List[int], str]]:
-        req_outputs = self.model.generate(prompts,
+        images: Optional[Union[List[Image.Image],
+                               List[List[Image.Image]]]] = None,
+    ) -> List[Tuple[List[List[int]], List[str]]]:
+        if images is not None:
+            assert len(prompts) == len(images)
+
+        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
+        if images is not None:
+            for i, image in enumerate(images):
+                inputs[i]["multi_modal_data"] = {"image": image}
+
+        req_outputs = self.model.generate(inputs,
                                           sampling_params=sampling_params)
-        outputs = []
+
+        outputs: List[Tuple[List[List[int]], List[str]]] = []
         for req_output in req_outputs:
             prompt_str = req_output.prompt
             prompt_ids = req_output.prompt_token_ids
-            req_sample_output_ids = []
-            req_sample_output_strs = []
+            req_sample_output_ids: List[List[int]] = []
+            req_sample_output_strs: List[str] = []
             for sample in req_output.outputs:
                 output_str = sample.text
-                output_ids = sample.token_ids
+                output_ids = list(sample.token_ids)
                 req_sample_output_ids.append(prompt_ids + output_ids)
                 req_sample_output_strs.append(prompt_str + output_str)
             outputs.append((req_sample_output_ids, req_sample_output_strs))
         return outputs
 
-    def generate_w_logprobs(
+    def _final_steps_generate_w_logprobs(
         self,
-        prompts: List[str],
-        sampling_params: SamplingParams,
-    ) -> List[Tuple[List[int], str]]:
-        assert sampling_params.logprobs is not None
-
-        req_outputs = self.model.generate(prompts,
-                                          sampling_params=sampling_params)
-        outputs = []
+        req_outputs: List[RequestOutput],
+    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
+        outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
         for req_output in req_outputs:
             for sample in req_output.outputs:
                 output_str = sample.text
-                output_ids = sample.token_ids
+                output_ids = list(sample.token_ids)
                 output_logprobs = sample.logprobs
             outputs.append((output_ids, output_str, output_logprobs))
         return outputs
 
+    def generate_w_logprobs(
+        self,
+        prompts: List[str],
+        sampling_params: SamplingParams,
+        images: Optional[Union[List[Image.Image],
+                               List[List[Image.Image]]]] = None,
+    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
+        assert sampling_params.logprobs is not None
+
+        if images is not None:
+            assert len(prompts) == len(images)
+
+        inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
+        if images is not None:
+            for i, image in enumerate(images):
+                inputs[i]["multi_modal_data"] = {"image": image}
+
+        req_outputs = self.model.generate(inputs,
+                                          sampling_params=sampling_params)
+        return self._final_steps_generate_w_logprobs(req_outputs)
+
+    def generate_encoder_decoder_w_logprobs(
+        self,
+        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
+        sampling_params: SamplingParams,
+    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
+        '''
+        Logprobs generation for Aphrodite encoder/decoder models
+        '''
+
+        assert sampling_params.logprobs is not None
+        req_outputs = self.model.generate(encoder_decoder_prompts,
+                                          sampling_params=sampling_params)
+        return self._final_steps_generate_w_logprobs(req_outputs)
+
     def generate_greedy(
         self,
         prompts: List[str],
         max_tokens: int,
+        images: Optional[List[Image.Image]] = None,
     ) -> List[Tuple[List[int], str]]:
         greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
-        outputs = self.generate(prompts, greedy_params)
+        outputs = self.generate(prompts, greedy_params, images=images)
         return [(output_ids[0], output_str[0])
                 for output_ids, output_str in outputs]
 
@@ -232,11 +654,37 @@ class AphroditeRunner:
         prompts: List[str],
         max_tokens: int,
         num_logprobs: int,
-    ) -> List[Tuple[List[int], str]]:
+        images: Optional[Union[List[Image.Image],
+                               List[List[Image.Image]]]] = None,
+        stop_token_ids: Optional[List[int]] = None,
+    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
+        greedy_logprobs_params = SamplingParams(temperature=0.0,
+                                                max_tokens=max_tokens,
+                                                logprobs=num_logprobs,
+                                                stop_token_ids=stop_token_ids)
+        outputs = self.generate_w_logprobs(prompts,
+                                           greedy_logprobs_params,
+                                           images=images)
+
+        return [(output_ids, output_str, output_logprobs)
+                for output_ids, output_str, output_logprobs in outputs]
+
+    def generate_encoder_decoder_greedy_logprobs(
+        self,
+        encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
+        max_tokens: int,
+        num_logprobs: int,
+    ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
         greedy_logprobs_params = SamplingParams(temperature=0.0,
+                                                use_beam_search=False,
                                                 max_tokens=max_tokens,
                                                 logprobs=num_logprobs)
-        outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params)
+        '''
+        Greedy logprobs generation for Aphrodite encoder/decoder models
+        '''
+
+        outputs = self.generate_encoder_decoder_w_logprobs(
+            encoder_decoder_prompts, greedy_logprobs_params)
 
         return [(output_ids, output_str, output_logprobs)
                 for output_ids, output_str, output_logprobs in outputs]
@@ -246,7 +694,7 @@ class AphroditeRunner:
         prompts: List[str],
         beam_width: int,
         max_tokens: int,
-    ) -> List[Tuple[List[int], str]]:
+    ) -> List[Tuple[List[List[int]], List[str]]]:
         beam_search_params = SamplingParams(n=beam_width,
                                             use_beam_search=True,
                                             temperature=0.0,
@@ -254,7 +702,83 @@ class AphroditeRunner:
         outputs = self.generate(prompts, beam_search_params)
         return outputs
 
+    def encode(self, prompts: List[str]) -> List[List[float]]:
+        req_outputs = self.model.encode(prompts)
+        outputs = []
+        for req_output in req_outputs:
+            embedding = req_output.outputs.embedding
+            outputs.append(embedding)
+        return outputs
 
-@pytest.fixture
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        del self.model
+        cleanup()
+
+
+@pytest.fixture(scope="session")
 def aphrodite_runner():
     return AphroditeRunner
+
+
+def get_tokenizer_pool_config(tokenizer_group_type):
+    if tokenizer_group_type is None:
+        return None
+    if tokenizer_group_type == "ray":
+        return TokenizerPoolConfig(pool_size=1,
+                                   pool_type="ray",
+                                   extra_config={})
+    if isinstance(tokenizer_group_type, type):
+        return TokenizerPoolConfig(pool_size=1,
+                                   pool_type=tokenizer_group_type,
+                                   extra_config={})
+    raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
+
+
+@pytest.fixture()
+def temporary_enable_log_propagate():
+    import logging
+    logger = logging.getLogger("aphrodite")
+    logger.propagate = True
+    yield
+    logger.propagate = False
+
+
+@pytest.fixture()
+def caplog_aphrodite(temporary_enable_log_propagate, caplog):
+    # To capture aphrodite log, we should enable propagate=True temporarily
+    # because caplog depends on logs propagated to the root logger.
+    yield caplog
+
+
+@pytest.fixture(scope="session")
+def num_gpus_available():
+    """Get number of GPUs without initializing the CUDA context
+    in current process."""
+
+    return cuda_device_count_stateless()
+
+
+temp_dir = tempfile.gettempdir()
+_dummy_path = os.path.join(temp_dir, "dummy_opt")
+
+
+@pytest.fixture
+def dummy_opt_path():
+    json_path = os.path.join(_dummy_path, "config.json")
+    if not os.path.exists(_dummy_path):
+        snapshot_download(repo_id="facebook/opt-125m",
+                          local_dir=_dummy_path,
+                          ignore_patterns=[
+                              "*.bin", "*.bin.index.json", "*.pt", "*.h5",
+                              "*.msgpack"
+                          ])
+        assert os.path.exists(json_path)
+        with open(json_path, "r") as f:
+            config = json.load(f)
+        config["architectures"] = ["MyOPTForCausalLM"]
+        with open(json_path, "w") as f:
+            json.dump(config, f)
+    return _dummy_path

+ 0 - 0
tests/core/__init__.py


+ 0 - 0
tests/core/block/__init__.py


+ 12 - 0
tests/core/block/conftest.py

@@ -0,0 +1,12 @@
+import pytest
+
+
+@pytest.fixture()
+def should_do_global_cleanup_after_test() -> bool:
+    """Disable the global cleanup fixture for tests in this directory. This
+    provides a ~10x speedup for unit tests that don't load a model to GPU.
+
+    This requires that tests in this directory clean up after themselves if they
+    use the GPU.
+    """
+    return False

+ 0 - 0
tests/core/block/e2e/__init__.py


+ 68 - 0
tests/core/block/e2e/conftest.py

@@ -0,0 +1,68 @@
+from typing import Callable, Iterable, Optional
+
+import pytest
+
+from aphrodite import LLM
+from aphrodite.modeling.utils import set_random_seed
+
+from ....conftest import cleanup
+
+
+@pytest.fixture
+def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
+                           baseline_llm_kwargs, seed):
+    return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
+                                baseline_llm_kwargs, seed)
+
+
+@pytest.fixture
+def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
+                       test_llm_kwargs, seed):
+    return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
+                                test_llm_kwargs, seed)
+
+
+def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
+                         distinct_llm_kwargs, seed):
+    kwargs = {
+        **common_llm_kwargs,
+        **per_test_common_llm_kwargs,
+        **distinct_llm_kwargs,
+    }
+
+    def generator_inner():
+        llm = LLM(**kwargs)
+
+        set_random_seed(seed)
+
+        yield llm
+        del llm
+        cleanup()
+
+    for llm in generator_inner():
+        yield llm
+        del llm
+
+
+def get_text_from_llm_generator(llm_generator: Iterable[LLM],
+                                prompts,
+                                sampling_params,
+                                llm_cb: Optional[Callable[[LLM],
+                                                          None]] = None):
+    for llm in llm_generator:
+        if llm_cb:
+            llm_cb(llm)
+        outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
+        text = [output.outputs[0].text for output in outputs]
+        del llm
+
+    return text
+
+
+def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
+    for llm in llm_generator:
+        outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
+        token_ids = [output.outputs[0].token_ids for output in outputs]
+        del llm
+
+    return token_ids

+ 558 - 0
tests/core/block/e2e/test_correctness.py

@@ -0,0 +1,558 @@
+from itertools import cycle
+
+import pytest
+
+from aphrodite import SamplingParams
+
+from .conftest import get_token_ids_from_llm_generator
+
+
+@pytest.mark.parametrize(
+    "common_llm_kwargs",
+    [{
+        # Use a small model for a fast test.
+        "model": "facebook/opt-125m",
+
+        # skip cuda graph creation for fast test.
+        "enforce_eager": True,
+
+        # Allow only 5 sequences of ~1024 tokens in worst case.
+        "block_size": 16,
+        "num_gpu_blocks_override": 5 * (64 + 1),
+    }])
+@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
+@pytest.mark.parametrize("baseline_llm_kwargs", [{
+    "use_v2_block_manager": False
+}])
+@pytest.mark.parametrize("test_llm_kwargs", [{
+    "use_v2_block_manager": True,
+    "preemption_mode": "swap"
+}, {
+    "use_v2_block_manager": True,
+    "preemption_mode": "recompute"
+}])
+@pytest.mark.parametrize("batch_size", [10])
+@pytest.mark.parametrize("seed", [1])
+def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator,
+                                               test_llm_generator, batch_size):
+    """Verify block manager v2 produces same outputs as block manager v1, even
+    when there is preemption.
+
+    This constructs two LLM, each with limited number of GPU blocks. The limit
+    is decided such that as the sequences in the batch grow, sequences must be
+    preempted and removed from cache.
+
+    If the output token ids are equivalent, then we have confidence that the KV
+    cache is not corrupted in the v2 block manager.
+
+    NOTE: We want a significant number of generated tokens so that any incorrect
+    KV mapping has time to build up error.
+    """
+    output_len = 1024
+    temperature = 0.0
+
+    # We want to ensure equality even with preemption.
+    # We force the total block size to be 1 + cdiv(output_len, block_size)
+    # so that only one sequence can fit at a time (once the sequences grow).
+
+    prompts = [
+        "Hello, my name is",
+        "The president of the United States is",
+        "The capital of France is",
+        "The future of AI is",
+    ]
+
+    prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
+
+    sampling_params = SamplingParams(
+        max_tokens=output_len,
+        ignore_eos=True,
+        temperature=temperature,
+    )
+
+    print('Getting token ids from block manager v1')
+    baseline_token_ids = get_token_ids_from_llm_generator(
+        baseline_llm_generator, prompts, sampling_params)
+
+    print('Getting token ids from block manager v2')
+    test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
+                                                      prompts, sampling_params)
+
+    for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
+                                                    test_token_ids):
+        assert expected_token_ids == actual_token_ids
+
+    assert baseline_token_ids == test_token_ids
+
+
+@pytest.mark.parametrize(
+    "common_llm_kwargs",
+    [{
+        # Use a small model for a fast test.
+        "model": "facebook/opt-125m",
+
+        # skip cuda graph creation for fast test.
+        "enforce_eager": True,
+
+        # Use a large block size to trigger more copy-on-writes.
+        "block_size": 32,
+    }])
+@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
+@pytest.mark.parametrize("baseline_llm_kwargs", [{
+    "use_v2_block_manager": False
+}])
+@pytest.mark.parametrize("test_llm_kwargs", [{
+    "use_v2_block_manager": True,
+    "preemption_mode": "swap"
+}, {
+    "use_v2_block_manager": True,
+    "preemption_mode": "recompute"
+}])
+@pytest.mark.parametrize("batch_size", [10])
+@pytest.mark.parametrize("seed", [1])
+def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator,
+                                        test_llm_generator, batch_size):
+    """Verify beam search equality with block manager v1 and v2.
+
+    This requires copy-on-writes; if the v1 and v2 output is the same, then
+    we have some confidence cow is working.
+    """
+    output_len = 128
+    temperature = 0.0
+
+    prompts = [
+        "Hello, my name is",
+        "The president of the United States is",
+        "The capital of France is",
+        "The future of AI is",
+    ]
+
+    prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
+
+    sampling_params = SamplingParams(
+        max_tokens=output_len,
+        ignore_eos=True,
+        temperature=temperature,
+        use_beam_search=True,
+        best_of=2,
+    )
+
+    print('Getting token ids from block manager v1')
+    baseline_token_ids = get_token_ids_from_llm_generator(
+        baseline_llm_generator, prompts, sampling_params)
+
+    print('Getting token ids from block manager v2')
+    test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
+                                                      prompts, sampling_params)
+
+    for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
+                                                    test_token_ids):
+        assert expected_token_ids == actual_token_ids
+
+    assert baseline_token_ids == test_token_ids
+
+
+@pytest.mark.parametrize(
+    "common_llm_kwargs",
+    [{
+        # Use a small model for a fast test.
+        "model": "facebook/opt-125m",
+
+        # Our prompts will generate 128 tokens; since the prompts themselves are
+        # small, we don't need much KV space beyond 128.
+        "max_model_len": 160,
+
+        # skip cuda graph creation for fast test.
+        "enforce_eager": True,
+
+        # Lookahead scheduling only supported in v2 block manager.
+        "use_v2_block_manager": True,
+    }])
+@pytest.mark.parametrize(
+    "per_test_common_llm_kwargs",
+    [
+        {
+            "block_size": 16,
+
+            # Allow only 2 sequences of ~128 tokens in worst case.
+            # Note 8 = 128/block_size
+            "num_gpu_blocks_override": 2 * (8 + 1),
+        },
+        {
+            "block_size": 8,
+
+            # Allow only 2 sequences of ~128 tokens in worst case.
+            # Note 16 = 128/block_size
+            "num_gpu_blocks_override": 2 * (16 + 2),
+        }
+    ])
+@pytest.mark.parametrize("baseline_llm_kwargs", [{
+    "num_lookahead_slots": 0,
+}])
+@pytest.mark.parametrize(
+    "test_llm_kwargs",
+    [
+        {
+            # We run one test with block_size < lookahead_slots, one test with
+            # block_size > lookahead_slots
+            "num_lookahead_slots": 10,
+            "preemption_mode": "swap",
+        },
+        {
+            "num_lookahead_slots": 10,
+            "preemption_mode": "recompute",
+        }
+    ])
+@pytest.mark.parametrize("batch_size", [4])
+@pytest.mark.parametrize("seed", [1])
+def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator,
+                                                   test_llm_generator,
+                                                   batch_size):
+    """Verify Aphrodite produces the same output with greedy sampling, when
+    lookahead scheduling is used vs. not.
+
+    Lookahead scheduling is not expected to modify the output, as it simply
+    allocates empty slots ahead of the known token ids in a sliding fashion.
+
+    This test constrains the total number of blocks to force preemption. It also
+    varies the block size so that the lookahead size is less than and greater
+    than the block size.
+    """
+    output_len = 128
+    temperature = 0.0
+
+    prompts = [
+        "Hello, my name is",
+        "The president of the United States is",
+        "The capital of France is",
+        "The future of AI is",
+    ]
+
+    prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
+
+    sampling_params = SamplingParams(
+        max_tokens=output_len,
+        ignore_eos=True,
+        temperature=temperature,
+    )
+
+    print('Getting token ids without lookahead scheduling')
+    baseline_token_ids = get_token_ids_from_llm_generator(
+        baseline_llm_generator, prompts, sampling_params)
+
+    print('Getting token ids with lookahead scheduling')
+    test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
+                                                      prompts, sampling_params)
+
+    for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
+                                                    test_token_ids):
+        assert expected_token_ids == actual_token_ids
+
+    assert baseline_token_ids == test_token_ids
+
+
+@pytest.mark.parametrize(
+    "common_llm_kwargs",
+    [
+        {
+            # Use a small model for a fast test.
+            "model": "facebook/opt-125m",
+
+            # skip cuda graph creation for fast test.
+            "enforce_eager": True,
+            "enable_chunked_prefill": True,
+        },
+    ])
+@pytest.mark.parametrize("per_test_common_llm_kwargs",
+                         [{
+                             "block_size": 8,
+                             "max_num_batched_tokens": 2,
+                             "max_num_seqs": 2,
+                         }, {
+                             "block_size": 8,
+                             "max_num_batched_tokens": 3,
+                             "max_num_seqs": 2,
+                         }, {
+                             "block_size": 8,
+                             "max_num_batched_tokens": 256,
+                             "max_num_seqs": 10,
+                         }])
+@pytest.mark.parametrize("baseline_llm_kwargs", [
+    {
+        "use_v2_block_manager": False,
+    },
+])
+@pytest.mark.parametrize("test_llm_kwargs", [
+    {
+        "use_v2_block_manager": True,
+        "num_lookahead_slots": 0,
+    },
+    {
+        "use_v2_block_manager": True,
+        "num_lookahead_slots": 5,
+    },
+])
+@pytest.mark.parametrize("batch_size", [4])
+@pytest.mark.parametrize("seed", [1])
+def test_chunked_prefill_block_manager_v2(baseline_llm_generator,
+                                          test_llm_generator, batch_size):
+    """Verify that chunked prefill works with BlockManagerV2, with and without
+    lookahead scheduling.
+    """
+    output_len = 32
+    temperature = 0.0
+
+    prompts = [
+        "Hello, my name is",
+        "The president of the United States is",
+        ("1 + " * 50) + " 1 = ",  # Longer prompt.
+        "The capital of France is",
+        "The future of AI is",
+    ]
+
+    prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
+
+    sampling_params = SamplingParams(
+        max_tokens=output_len,
+        ignore_eos=True,
+        temperature=temperature,
+    )
+
+    print('Getting token ids with BlockManagerV1')
+    baseline_token_ids = get_token_ids_from_llm_generator(
+        baseline_llm_generator, prompts, sampling_params)
+
+    print('Getting token ids with BlockManagerV2')
+    test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
+                                                      prompts, sampling_params)
+
+    for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
+                                                    test_token_ids):
+        assert expected_token_ids == actual_token_ids
+
+    assert baseline_token_ids == test_token_ids
+
+
+@pytest.mark.parametrize(
+    "common_llm_kwargs",
+    [{
+        # Use a small model for a fast test.
+        "model": "facebook/opt-125m",
+
+        # skip cuda graph creation for fast test.
+        "enforce_eager": True,
+
+        # Allow only 5 sequences of ~1024 tokens in worst case.
+        "block_size": 16,
+        "num_gpu_blocks_override": 5 * (64 + 1),
+
+        # Enable prefill cache
+        "enable_prefix_caching": True,
+    }])
+@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
+@pytest.mark.parametrize("baseline_llm_kwargs", [{
+    "use_v2_block_manager": False
+}])
+@pytest.mark.parametrize("test_llm_kwargs", [{
+    "use_v2_block_manager": True,
+    "preemption_mode": "swap"
+}, {
+    "use_v2_block_manager": True,
+    "preemption_mode": "recompute"
+}])
+@pytest.mark.parametrize("batch_size", [10])
+@pytest.mark.parametrize("seed", [1])
+def test_v1_v2_greedy_equality_prefix_caching_enabled_with_preemption(
+        baseline_llm_generator, test_llm_generator, batch_size):
+    """Verify block manager v2 produces same outputs as block manager v1, even
+    when there is preemption.
+
+    This constructs two LLM, each with limited number of GPU blocks. The limit
+    is decided such that as the sequences in the batch grow, sequences must be
+    preempted and removed from cache.
+
+    If the output token ids are equivalent, then we have confidence that the KV
+    cache is not corrupted in the v2 block manager.
+
+    NOTE: We want a significant number of generated tokens so that any incorrect
+    KV mapping has time to build up error.
+    """
+    output_len = 1024
+    temperature = 0.0
+
+    # We want to ensure equality even with preemption.
+    # We force the total block size to be 1 + cdiv(output_len, block_size)
+    # so that only one sequence can fit at a time (once the sequences grow).
+
+    prompts = [
+        "Hello, my name is",
+        "The president of the United States is",
+        "The capital of France is",
+        "The future of AI is",
+    ]
+
+    prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
+
+    sampling_params = SamplingParams(
+        max_tokens=output_len,
+        ignore_eos=True,
+        temperature=temperature,
+    )
+
+    print('Getting token ids from block manager v1')
+    baseline_token_ids = get_token_ids_from_llm_generator(
+        baseline_llm_generator, prompts, sampling_params)
+
+    print('Getting token ids from block manager v2')
+    test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
+                                                      prompts, sampling_params)
+
+    for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
+                                                    test_token_ids):
+        assert expected_token_ids == actual_token_ids
+
+    assert baseline_token_ids == test_token_ids
+
+
+@pytest.mark.parametrize(
+    "common_llm_kwargs",
+    [{
+        # Use a small model for a fast test.
+        "model": "facebook/opt-125m",
+
+        # skip cuda graph creation for fast test.
+        "enforce_eager": True,
+
+        # Allow only 5 sequences of ~1024 tokens in worst case.
+        "block_size": 16,
+        "num_gpu_blocks_override": 5 * (64 + 1),
+
+        # Test APC in v2 block
+        "use_v2_block_manager": True,
+    }])
+@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
+@pytest.mark.parametrize("baseline_llm_kwargs", [{
+    "enable_prefix_caching": False
+}])
+@pytest.mark.parametrize("test_llm_kwargs", [{
+    "enable_prefix_caching": True,
+    "preemption_mode": "swap"
+}, {
+    "enable_prefix_caching": True,
+    "preemption_mode": "recompute"
+}])
+@pytest.mark.parametrize("batch_size", [10])
+@pytest.mark.parametrize("seed", [1])
+def test_auto_prefix_caching_with_preemption(baseline_llm_generator,
+                                             test_llm_generator, batch_size):
+    """Verify block manager v2 with auto prefix caching enabled produces same
+    outputs as auto prefix caching disabled, even when there is preemption.
+
+    This constructs two LLM, each with limited number of GPU blocks. The limit
+    is decided such that as the sequences in the batch grow, sequences must be
+    preempted and removed from cache.
+
+    If the output token ids are equivalent, then we have confidence that auto
+    prefix caching itself at least don't cause result error.
+    """
+    output_len = 1024
+    temperature = 0.0
+
+    # We want to ensure equality even with preemption.
+    # We force the total block size to be 1 + cdiv(output_len, block_size)
+    # so that only one sequence can fit at a time (once the sequences grow).
+    prompts = [
+        "Hello, my name is",
+        "The president of the United States is",
+        "The capital of France is",
+        "The future of AI is",
+    ]
+
+    prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
+
+    sampling_params = SamplingParams(
+        max_tokens=output_len,
+        ignore_eos=True,
+        temperature=temperature,
+    )
+
+    print('Getting token ids with APC disabled')
+    baseline_token_ids = get_token_ids_from_llm_generator(
+        baseline_llm_generator, prompts, sampling_params)
+
+    print('Getting token ids with APC enabled')
+    test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
+                                                      prompts, sampling_params)
+
+    for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
+                                                    test_token_ids):
+        assert expected_token_ids == actual_token_ids
+
+    assert baseline_token_ids == test_token_ids
+
+
+@pytest.mark.parametrize(
+    "common_llm_kwargs",
+    [{
+        # Use a small model for a fast test.
+        "model": "facebook/opt-125m",
+
+        # skip cuda graph creation for fast test.
+        "enforce_eager": True,
+
+        # we keep the blocks small, so that hit eviction quickly
+        "max_model_len": 48,
+        "block_size": 16,
+        "num_gpu_blocks_override": 3,
+
+        # Test APC in v2 block
+        "use_v2_block_manager": True,
+    }])
+@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
+@pytest.mark.parametrize("baseline_llm_kwargs", [{
+    "enable_prefix_caching": False
+}])
+@pytest.mark.parametrize("test_llm_kwargs", [{
+    "enable_prefix_caching": True,
+}])
+@pytest.mark.parametrize("seed", [1])
+def test_auto_prefix_caching_after_evition_start(baseline_llm_generator,
+                                                 test_llm_generator):
+    """Verify block manager v2 with auto prefix caching could works normal
+    even when eviction started.
+    With APC enabled, all blocks are held by native block at the beginning.
+    Then blocks are managed by evictor instead. If cache hit at the evitor's
+    block, then it could be reused, or we need to recompute its kv cache.
+    """
+    output_len = 10
+    temperature = 0.0
+
+    prompts = [
+        "You are a helpful assistant. Please answer truthfully and write "
+        "out your thinking step by step to be sure you get the right answer. "
+        "If you make a mistake, attempt to correct it. who are you?",
+        "You are a helpful assistant. Please answer truthfully and write out "
+        "your thinking step by step to be sure you get the right answer. You "
+        "are helpful and harmless and you follow ethical guidelines. "
+        "who are you?"
+    ]
+
+    sampling_params = SamplingParams(
+        max_tokens=output_len,
+        ignore_eos=True,
+        temperature=temperature,
+    )
+
+    print('Getting token ids with APC disabled')
+    baseline_token_ids = get_token_ids_from_llm_generator(
+        baseline_llm_generator, prompts, sampling_params)
+
+    print('Getting token ids with APC enabled')
+    test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
+                                                      prompts, sampling_params)
+
+    for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
+                                                    test_token_ids):
+        assert expected_token_ids == actual_token_ids
+
+    assert baseline_token_ids == test_token_ids

+ 163 - 0
tests/core/block/e2e/test_correctness_sliding_window.py

@@ -0,0 +1,163 @@
+import random
+from typing import List
+
+import pytest
+
+from aphrodite import LLM, SamplingParams
+
+from .conftest import get_text_from_llm_generator
+
+# relatively small model with 4k sliding window
+MODEL = "bigcode/starcoder2-3b"
+BLOCK_SIZE = 16
+
+
+@pytest.mark.parametrize(
+    "common_llm_kwargs",
+    [{
+        "model": MODEL,
+
+        # skip cuda graph creation for fast test.
+        "enforce_eager": True,
+        "block_size": BLOCK_SIZE,
+        "num_gpu_blocks_override": 100000 // BLOCK_SIZE,
+    }])
+@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
+@pytest.mark.parametrize("baseline_llm_kwargs", [{
+    "use_v2_block_manager": False
+}])
+@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}])
+@pytest.mark.parametrize("batch_size", [5])
+@pytest.mark.parametrize("seed", [1])
+def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator,
+                                 batch_size, seed):
+    """
+    The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then
+    asks for value of one of them (which is outside the sliding window).
+    If we tell it upfront which we are going to be looking for, then
+    it answers correctly (mostly).
+
+    Additionally, we compare the results of the v1 and v2 managers.
+    """
+    sampling_params = SamplingParams(
+        max_tokens=1024,
+        ignore_eos=True,
+        temperature=0.0,
+    )
+
+    prompts, answer, indices = prep_prompts(batch_size)
+
+    print('Getting token ids from block manager v1')
+    baseline_texts = get_text_from_llm_generator(baseline_llm_generator,
+                                                 prompts,
+                                                 sampling_params,
+                                                 llm_cb=check_window(prompts))
+
+    check_answers(indices, answer, baseline_texts)
+
+    print('Getting token ids from block manager v2')
+    test_texts = get_text_from_llm_generator(test_llm_generator, prompts,
+                                             sampling_params)
+    check_answers(indices, answer, test_texts)
+
+    cmp = [
+        expected_text == actual_text
+        for expected_text, actual_text in zip(baseline_texts, test_texts)
+    ]
+    print(cmp)
+    assert sum(cmp) > 0.7 * len(cmp)
+
+
+@pytest.mark.parametrize(
+    "common_llm_kwargs",
+    [{
+        "model": MODEL,
+
+        # skip cuda graph creation for fast test.
+        "enforce_eager": True,
+        "block_size": BLOCK_SIZE,
+        "num_gpu_blocks_override": 100000 // BLOCK_SIZE,
+    }])
+@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
+@pytest.mark.parametrize("test_llm_kwargs", [{
+    "use_v2_block_manager": True,
+    "enable_chunked_prefill": True
+}])
+@pytest.mark.parametrize("batch_size", [5])
+@pytest.mark.parametrize("seed", [1])
+def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed):
+    """
+    This is similar to test_sliding_window_retrival, however, it doesn't
+    compare against the v1 block manager since v1 doesn't support
+    chunked prefill with sliding window.
+
+    The results with and without chunked prefill are not the same due to
+    numerical instabilities.
+    """
+    sampling_params = SamplingParams(
+        max_tokens=10,
+        ignore_eos=True,
+        temperature=0.0,
+    )
+
+    prompts, answer, indices = prep_prompts(batch_size)
+
+    # We don't compare with the baseline model here, since the results
+    # slightly different due to different tailing in attention.
+    test_texts = get_text_from_llm_generator(test_llm_generator,
+                                             prompts,
+                                             sampling_params,
+                                             llm_cb=check_window(prompts))
+    check_answers(indices, answer, test_texts)
+
+
+def prep_prompts(batch_size: int):
+    """
+    Generate prompts which a bunch of assignments,
+    then asking for the value of one of them.
+    The prompt is just under 10k tokens; sliding window is 4k
+    so the answer is outside sliding window, but should still be correct.
+    """
+    prompts: List[str] = []
+    answer: List[int] = []
+    indices: List[int] = []
+    random.seed(1)
+    for _ in range(batch_size):
+        idx = random.randint(30, 90)
+        indices.append(idx)
+        prompt = "```python\n# We set a number of variables, " + \
+                 f"x{idx} will be important later\n"
+        ln = random.randint(800, 1100)
+        for k in range(30, ln):
+            v = random.randint(10, 99)
+            if k == idx:
+                answer.append(v)
+            prompt += f"x{k} = {v}\n"
+        prompt += f"# Now, we check the value of x{idx}:\n"
+        prompt += f"assert x{idx} == "
+        prompts.append(prompt)
+    return prompts, answer, indices
+
+
+def check_answers(indices: List[int], answer: List[int], outputs: List[str]):
+    answer2 = [int(text[0:2].strip()) for text in outputs]
+    print(list(zip(indices, zip(answer, answer2))))
+    numok = 0
+    for a1, a2 in zip(answer, answer2):
+        if a1 == a2:
+            numok += 1
+    frac_ok = numok / len(answer)
+    print(f"Num OK: {numok}/{len(answer)} {frac_ok}")
+    assert frac_ok > 0.7
+
+
+def check_window(prompts: List[str]):
+
+    def inner(llm: LLM):
+        sliding_window = llm.llm_engine.model_config.get_sliding_window()
+        assert sliding_window and sliding_window > 0
+        assert any(
+            len(llm.get_tokenizer().tokenize(prompt)) > sliding_window
+            for prompt in prompts)
+
+    return inner

+ 445 - 0
tests/core/block/test_block_manager_v2.py

@@ -0,0 +1,445 @@
+import pytest
+
+from aphrodite.common.sequence import Logprob, SequenceStatus
+from aphrodite.common.utils import chunk_list
+from aphrodite.processing.block.utils import (
+    STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, STR_NOT_IMPL_ENC_DEC_SWA)
+from aphrodite.processing.block_manager_v2 import BlockSpaceManagerV2
+from aphrodite.processing.interfaces import AllocStatus
+
+from ..utils import (create_dummy_prompt, create_seq_group,
+                     create_seq_group_encoder_decoder)
+
+
+@pytest.mark.parametrize("block_size", [16])
+@pytest.mark.parametrize("num_gpu_blocks", [8, 40, 80])
+@pytest.mark.parametrize("num_seqs_per_group", [1, 4])
+@pytest.mark.parametrize("watermark", [0.0, 0.5])
+def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int,
+                                num_gpu_blocks: int, watermark: float):
+    block_manager = BlockSpaceManagerV2(
+        block_size=block_size,
+        num_gpu_blocks=num_gpu_blocks,
+        num_cpu_blocks=1024,
+        watermark=watermark,
+    )
+    num_watermark_blocks = int(watermark * num_gpu_blocks)
+
+    num_output_blocks_per_seq = 1
+
+    # NOTE: This should be num_output_blocks_per_seq * num_seqs_per_group, but
+    # the current implementation assumes all seqs are new prompts / don't have
+    # different output lens.
+    num_output_blocks = num_output_blocks_per_seq
+
+    for num_prompt_blocks in range(1, num_gpu_blocks - num_output_blocks):
+        seq_group = create_seq_group(
+            seq_prompt_len=block_size * num_prompt_blocks,
+            seq_output_lens=[
+                block_size * num_output_blocks_per_seq
+                for _ in range(num_seqs_per_group)
+            ],
+        )
+
+        assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks
+
+        can_allocate_result = block_manager.can_allocate(seq_group)
+
+        num_required_blocks = num_prompt_blocks + num_output_blocks
+
+        if num_gpu_blocks - num_required_blocks < num_watermark_blocks:
+            assert can_allocate_result == AllocStatus.NEVER
+        elif num_gpu_blocks >= num_required_blocks:
+            assert can_allocate_result == AllocStatus.OK
+        else:
+            assert can_allocate_result == AllocStatus.LATER
+
+
+@pytest.mark.parametrize("block_size", [16])
+@pytest.mark.parametrize("num_gpu_blocks", [16, 80, 160])
+@pytest.mark.parametrize("num_seqs_per_group", [1, 4])
+@pytest.mark.parametrize("watermark", [0.0, 0.5])
+def test_can_allocate_seq_group_encoder_decoder(block_size: int,
+                                                num_seqs_per_group: int,
+                                                num_gpu_blocks: int,
+                                                watermark: float):
+    block_manager = BlockSpaceManagerV2(
+        block_size=block_size,
+        num_gpu_blocks=num_gpu_blocks,
+        num_cpu_blocks=1024,
+        watermark=watermark,
+    )
+    num_watermark_blocks = int(watermark * num_gpu_blocks)
+
+    num_output_blocks_per_seq = 1
+
+    # NOTE: This should be num_output_blocks_per_seq * num_seqs_per_group, but
+    # the current implementation assumes all seqs are new prompts / don't have
+    # different output lens.
+    num_output_blocks = num_output_blocks_per_seq
+
+    for bdx, num_prompt_blocks in enumerate(
+            range(1, num_gpu_blocks - num_output_blocks)):
+        num_cross_blocks_per_seq = num_prompt_blocks
+
+        seq_group = create_seq_group_encoder_decoder(
+            seq_prompt_len=block_size * num_prompt_blocks,
+            seq_output_lens=[
+                block_size * num_output_blocks_per_seq
+                for _ in range(num_seqs_per_group)
+            ],
+            request_id=str(bdx))
+
+        assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks
+
+        can_allocate_result = block_manager.can_allocate(seq_group)
+
+        num_required_blocks = num_prompt_blocks + \
+                              num_output_blocks + \
+                              num_cross_blocks_per_seq
+
+        if num_gpu_blocks - num_required_blocks < num_watermark_blocks:
+            assert can_allocate_result == AllocStatus.NEVER
+        elif num_gpu_blocks >= num_required_blocks:
+            assert can_allocate_result == AllocStatus.OK
+        else:
+            assert can_allocate_result == AllocStatus.LATER
+
+
+@pytest.mark.parametrize("block_size", [16])
+@pytest.mark.parametrize("num_gpu_blocks", [16])
+@pytest.mark.parametrize("num_seqs_per_group", [1])
+@pytest.mark.parametrize("watermark", [0.0, 0.5])
+def test_can_allocate_encoder_decoder_fails_with_swa(block_size: int,
+                                                     num_seqs_per_group: int,
+                                                     num_gpu_blocks: int,
+                                                     watermark: float):
+    '''
+    SWA short for Sliding Window Attention.
+
+    At time of writing block manager v2 does not support SWA.
+
+    However even when SWA is implemented for block manager v2,
+    there will still most likely be a separate workstream required
+    to enable SWA for encoder/decoder models.
+
+    Therefore this test enforces that one of the following cases
+    hold true:
+    1. Block manager v2 does not support SWA at all (true at time of writing)
+    2. Block manager v2 fails with NotImplementError when SWA is enabled
+       AND a SequenceGroup with an encoder sequence (i.e. in support of an
+       encoder/decoder model) is passed into can_allocate() as an argument
+
+    The setup for this test is stripped down version of
+    test_can_allocate_seq_group_encoder_decoder()
+    '''
+
+    with pytest.raises((NotImplementedError, AssertionError)) as exc_info:
+        block_manager = BlockSpaceManagerV2(
+            block_size=block_size,
+            num_gpu_blocks=num_gpu_blocks,
+            num_cpu_blocks=1024,
+            watermark=watermark,
+            sliding_window=5  # SWA
+        )
+
+        num_output_blocks_per_seq = 1
+        num_prompt_blocks = 1
+        num_output_blocks = num_output_blocks_per_seq
+        seq_group = create_seq_group_encoder_decoder(
+            seq_prompt_len=block_size * num_prompt_blocks,
+            seq_output_lens=[
+                block_size * num_output_blocks_per_seq
+                for _ in range(num_seqs_per_group)
+            ],
+            request_id="0")
+
+        assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks
+        block_manager.can_allocate(seq_group)
+
+    # Assert that either
+    # 1. Block manager v2 constructor fails with assertion that sliding window
+    #    is not yet supported (most likely near-term outcome at time of
+    #    writing), or
+    # 2. can_allocate() fails with NotImplementedError due to combination of
+    #    encoder/decoder and sliding window attention
+    if isinstance(exc_info.value, NotImplementedError):
+        assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA
+    elif isinstance(exc_info.value, AssertionError):
+        assert str(exc_info.value) == "Sliding window not yet supported"
+
+
+@pytest.mark.parametrize("block_size", [16])
+@pytest.mark.parametrize("num_gpu_blocks", [16])
+@pytest.mark.parametrize("num_seqs_per_group", [1])
+@pytest.mark.parametrize("watermark", [0.0, 0.5])
+def test_can_allocate_encoder_decoder_fails_with_prefix_cache(
+        block_size: int, num_seqs_per_group: int, num_gpu_blocks: int,
+        watermark: float):
+
+    block_manager = BlockSpaceManagerV2(
+        block_size=block_size,
+        num_gpu_blocks=num_gpu_blocks,
+        num_cpu_blocks=1024,
+        watermark=watermark,
+        enable_caching=True  # Prefix cache
+    )
+
+    num_output_blocks_per_seq = 1
+    num_prompt_blocks = 1
+    num_output_blocks = num_output_blocks_per_seq
+    seq_group = create_seq_group_encoder_decoder(
+        seq_prompt_len=block_size * num_prompt_blocks,
+        seq_output_lens=[
+            block_size * num_output_blocks_per_seq
+            for _ in range(num_seqs_per_group)
+        ],
+        request_id="0")
+
+    assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks
+
+    # Assert that either can_allocate() fails with NotImplementedError
+    # due to combination of encoder/decoder and prefix cache
+    with pytest.raises(NotImplementedError) as exc_info:
+        block_manager.can_allocate(seq_group)
+    assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE
+
+
+@pytest.mark.parametrize("block_size", [1, 8])
+@pytest.mark.parametrize("prompt_len", [1, 7, 8])
+@pytest.mark.parametrize("num_slots_to_append", [1, 8, 129])
+@pytest.mark.parametrize("num_lookahead_slots", [0, 10])
+def test_append_slots(block_size, prompt_len, num_slots_to_append,
+                      num_lookahead_slots):
+    """Verify append_slots consumes the correct number of blocks from the block
+    table.
+    """
+
+    num_gpu_blocks = 1024
+    watermark = 0.1
+    block_manager = BlockSpaceManagerV2(
+        block_size=block_size,
+        num_gpu_blocks=num_gpu_blocks,
+        num_cpu_blocks=0,
+        watermark=watermark,
+    )
+
+    seq_group = create_seq_group(
+        seq_prompt_len=prompt_len,
+        seq_output_lens=[0],
+    )
+
+    # Allocate seq
+    assert block_manager.can_allocate(seq_group)
+    block_manager.allocate(seq_group)
+
+    # Seq seq to RUNNING
+    seq = seq_group.get_seqs()[0]
+    seq.status = SequenceStatus.RUNNING
+
+    # Append tokens to the sequeqnce
+    for token_id in range(num_slots_to_append):
+        seq.append_token_id(token_id, {token_id: Logprob(0.0)})
+
+    # Append slots for new tokens and lookahead slots.
+    free_blocks_before_append = block_manager.get_num_free_gpu_blocks()
+    block_manager.append_slots(seq, num_lookahead_slots)
+    num_consumed_blocks = (free_blocks_before_append -
+                           block_manager.get_num_free_gpu_blocks())
+
+    # Expect consumed blocks to be new blocks required to support the new slots.
+    expected_consumed_blocks = len(
+        list(
+            chunk_list(
+                list(
+                    range(prompt_len + num_slots_to_append +
+                          num_lookahead_slots)),
+                block_size))) - len(
+                    list(chunk_list(list(range(prompt_len)), block_size)))
+    assert num_consumed_blocks == expected_consumed_blocks
+
+
+@pytest.mark.parametrize("block_size", [8])
+@pytest.mark.parametrize("num_cpu_blocks", [4])
+@pytest.mark.parametrize("num_gpu_blocks", [4])
+@pytest.mark.parametrize("num_lookahead_slots", [0, 2, 10])
+@pytest.mark.parametrize("enable_caching", [False, True])
+def test_swap(block_size, num_cpu_blocks, num_gpu_blocks, num_lookahead_slots,
+              enable_caching):
+    """Verify blocks number on src/desc device is correct after swapping in/out
+        sequence group (not missing or extra blocks).
+    """
+    block_manager = BlockSpaceManagerV2(block_size,
+                                        num_cpu_blocks,
+                                        num_gpu_blocks,
+                                        watermark=0,
+                                        enable_caching=enable_caching)
+    prompt, seq_group = create_dummy_prompt("1", prompt_length=block_size - 1)
+    prompt.status = SequenceStatus.WAITING
+    block_manager.allocate(seq_group)
+    # Emulate a forward pass by appending a single token.
+    # The block manager then knows how many unprocessed
+    # tokens will be written in the next forward pass.
+    token_id = 0
+    prompt.status = SequenceStatus.RUNNING
+    prompt.append_token_id(token_id, {token_id: Logprob(0.0)})
+
+    # Swap seq group from GPU -> CPU.
+    gpu_blocks = block_manager.get_block_table(prompt)
+    assert block_manager.can_swap_out(seq_group)
+    before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
+    before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
+    mapping = block_manager.swap_out(seq_group)
+    mapping_keys = [key for key, _ in mapping]
+    assert mapping_keys == gpu_blocks
+    after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
+    after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
+    assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks)
+    assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks
+    prompt.status = SequenceStatus.SWAPPED
+
+    # Swap seq group from CPU -> GPU.
+    assert block_manager.can_swap_in(seq_group, num_lookahead_slots)
+    before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
+    before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
+    mapping = block_manager.swap_in(seq_group)
+    cpu_blocks = block_manager.get_block_table(prompt)
+    mapping_keys = [key for key, _ in mapping]
+    assert mapping_keys == [cpu_blocks[0]]
+    after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
+    after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
+    assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks)
+
+
+@pytest.mark.parametrize("block_size", [8])
+@pytest.mark.parametrize("num_gpu_blocks", [4])
+@pytest.mark.parametrize("num_lookahead_slots", [3, 8, 10])
+@pytest.mark.parametrize("enable_caching", [True, False])
+def test_can_swap(block_size, num_gpu_blocks, num_lookahead_slots,
+                  enable_caching):
+    """ Verify the block manager can correctly determine if a sequence group
+        can be swapped in/out.
+    """
+    num_cpu_blocks = num_gpu_blocks
+    block_manager = BlockSpaceManagerV2(block_size,
+                                        num_cpu_blocks,
+                                        num_gpu_blocks,
+                                        watermark=0,
+                                        enable_caching=enable_caching)
+    prompt, seq_group = create_dummy_prompt(
+        "1", prompt_length=(num_gpu_blocks - 1) * block_size - 1)
+    prompt.status = SequenceStatus.WAITING
+    block_manager.allocate(seq_group)
+    prompt.status = SequenceStatus.RUNNING
+
+    # Swap seq group from GPU -> CPU.
+    gpu_blocks = block_manager.get_block_table(prompt)
+    assert block_manager.can_swap_out(seq_group)
+    before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
+    before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
+    mapping = block_manager.swap_out(seq_group)
+    mapping_keys = [key for key, _ in mapping]
+    assert mapping_keys == gpu_blocks
+    after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
+    after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
+    assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks)
+    assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks
+    prompt.status = SequenceStatus.SWAPPED
+
+    # At this moment, we still have enough free blocks to swap in the seq group.
+    if num_lookahead_slots <= block_size:
+        assert block_manager.can_swap_in(seq_group,
+                                         num_lookahead_slots) == AllocStatus.OK
+    else:
+        assert block_manager.can_swap_in(
+            seq_group, num_lookahead_slots) == AllocStatus.NEVER
+
+    # During Swapped out, 2 cached blocks were evicted from the GPU,
+    # so the prompt1 can't be swapped in
+    prompt2_len = 2 * block_size - 1
+    prompt2, seq_group2 = create_dummy_prompt(
+        "2",
+        prompt_length=prompt2_len,
+        prompt_tokens=[10000 + i for i in range(prompt2_len)])
+    prompt2.status = SequenceStatus.WAITING
+    block_manager.allocate(seq_group2)
+
+    # Swap seq group from CPU -> GPU.
+    if num_lookahead_slots <= block_size:
+        assert block_manager.can_swap_in(
+            seq_group, num_lookahead_slots) == AllocStatus.LATER
+    else:
+        assert block_manager.can_swap_in(
+            seq_group, num_lookahead_slots) == AllocStatus.NEVER
+
+
+# TODO: add comprehensive tests for swapping at allocator level.
+
+
+@pytest.mark.parametrize("block_size", [8, 16])
+@pytest.mark.parametrize("prompt_len", [10, 300, 1000])
+@pytest.mark.parametrize("num_slots_to_append", [50])
+@pytest.mark.parametrize("sliding_window", [20, 32, 200, 512])
+def test_sliding_window(block_size, prompt_len, num_slots_to_append,
+                        sliding_window):
+    """Verify append_slots consumes the correct number of blocks from the block
+    table.
+    """
+
+    num_gpu_blocks = 1024
+    watermark = 0.1
+    block_manager = BlockSpaceManagerV2(
+        block_size=block_size,
+        num_gpu_blocks=num_gpu_blocks,
+        num_cpu_blocks=0,
+        watermark=watermark,
+        sliding_window=sliding_window,
+    )
+
+    def check_used(min_n, max_n=None):
+        if max_n is None:
+            max_n = min_n
+        used = num_gpu_blocks - block_manager.get_num_free_gpu_blocks()
+        #print("check", min_n, used, max_n)
+        assert min_n <= used
+        assert used <= max_n
+
+    def num_blocks(num_tokens):
+        return (num_tokens + block_size - 1) // block_size
+
+    check_used(0)
+
+    seq_group = create_seq_group(
+        seq_prompt_len=prompt_len,
+        seq_output_lens=[0],
+    )
+
+    check_used(0)
+
+    # Allocate seq
+    assert block_manager.can_allocate(seq_group)
+    block_manager.allocate(seq_group)
+
+    check_used(num_blocks(prompt_len))
+
+    # Seq seq to RUNNING
+    seq = seq_group.get_seqs()[0]
+    seq.status = SequenceStatus.RUNNING
+
+    seq.data.update_num_computed_tokens(prompt_len)
+    check_used(num_blocks(prompt_len))
+
+    # this is how we compute it in BlockSpaceManagerV2.__init__
+    sliding_blocks = (sliding_window // block_size) + 2
+    # plus one block for null block
+    sliding_blocks += 1
+
+    # Append tokens to the sequeqnce
+    for token_id in range(num_slots_to_append):
+        seq.append_token_id(token_id, {token_id: Logprob(0.0)})
+        seq.data.update_num_computed_tokens(1)
+        block_manager.append_slots(seq, num_lookahead_slots=0)
+        if prompt_len < sliding_window + 10:
+            check_used(0, sliding_blocks + 1)
+        else:
+            check_used(sliding_blocks, sliding_blocks + 1)

+ 577 - 0
tests/core/block/test_block_table.py

@@ -0,0 +1,577 @@
+from typing import List
+
+import pytest
+
+from aphrodite.common.utils import Device, cdiv, chunk_list
+from aphrodite.processing.block.block_table import BlockTable
+from aphrodite.processing.block.cpu_gpu_block_allocator import (
+    CpuGpuBlockAllocator)
+
+
+@pytest.mark.parametrize("block_size", [16])
+@pytest.mark.parametrize("sequence_len", [1, 16, 129])
+def test_allocate_naive(block_size: int, sequence_len: int):
+    """Test the allocation of blocks using the naive allocator.
+
+    This test creates a CpuGpuBlockAllocator with the specified block size and
+    number of blocks. It then allocates multiple BlockTables with varying
+    sequence lengths and verifies that the number of free blocks decreases as
+    expected after each allocation.
+    """
+    assert block_size > 1
+    num_gpu_blocks = 1024
+
+    allocator = CpuGpuBlockAllocator.create(
+        allocator_type="naive",
+        num_gpu_blocks=num_gpu_blocks,
+        num_cpu_blocks=1024,
+        block_size=block_size,
+    )
+
+    token_ids = list(range(sequence_len))
+    num_blocks_per_alloc = len(list(chunk_list(token_ids, block_size)))
+
+    block_tables: List[BlockTable] = []
+    for i in range(5):
+        assert allocator.get_num_free_blocks(
+            device=Device.GPU) == num_gpu_blocks - i * num_blocks_per_alloc
+
+        block_tables.append(
+            BlockTable(
+                block_size=block_size,
+                block_allocator=allocator,
+            ))
+        block_tables[-1].allocate(token_ids=token_ids, device=Device.GPU)
+
+
+@pytest.mark.parametrize("block_size", [16])
+@pytest.mark.parametrize("sequence_len", [1, 16, 129])
+def test_allocate_prefix_caching(block_size: int, sequence_len: int):
+    """Test the allocation of blocks using the prefix caching allocator.
+
+    This test creates a CpuGpuBlockAllocator with the specified block size and
+    number of blocks, using the prefix caching allocator. It then allocates
+    multiple BlockTables with varying sequence lengths and verifies that the
+    number of free blocks decreases as expected after each allocation.
+
+    The test expects all sequences to share allocations, except for their last
+    block, which may be mutable. It calculates the expected number of immutable
+    and mutable blocks per allocation based on the sequence length and block
+    size.
+    """
+    assert block_size > 1
+    num_gpu_blocks = 1024
+
+    allocator = CpuGpuBlockAllocator.create(
+        allocator_type="prefix_caching",
+        num_gpu_blocks=num_gpu_blocks,
+        num_cpu_blocks=1024,
+        block_size=block_size,
+    )
+
+    token_ids = list(range(sequence_len))
+    chunked_tokens = list(chunk_list(token_ids, block_size))
+    num_mutable_blocks_per_alloc = 0 if len(
+        chunked_tokens[-1]) == block_size else 1
+    num_immutable_blocks_per_alloc = len(
+        chunked_tokens) - num_mutable_blocks_per_alloc
+
+    block_tables: List[BlockTable] = []
+    for alloc_i in range(1, 6):
+
+        block_tables.append(
+            BlockTable(
+                block_size=block_size,
+                block_allocator=allocator,
+            ))
+        block_tables[-1].allocate(token_ids=token_ids, device=Device.GPU)
+
+        # Expect all sequences to share allocations, except for their last block
+        # (which may be mutable).
+        assert allocator.get_num_free_blocks(
+            device=Device.GPU) == num_gpu_blocks - (
+                num_immutable_blocks_per_alloc + num_mutable_blocks_per_alloc *
+                (alloc_i))
+
+
+@pytest.mark.parametrize("block_size", [16])
+@pytest.mark.parametrize("sequence_len", [1, 16, 129])
+@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
+@pytest.mark.parametrize("device", ["cpu", "gpu"])
+def test_allocate_free(block_size: int, sequence_len: int, allocator_type: str,
+                       device: str):
+    """Test the allocation and freeing of blocks using different allocators and
+    devices.
+
+    This test creates a CpuGpuBlockAllocator with the specified block size,
+    number of blocks, allocator type, and device. It then allocates a BlockTable
+    multiple times with the same sequence and verifies that the number of free
+    blocks remains consistent after each allocation and freeing.
+    """
+    device = Device[device.upper()]
+
+    num_device_blocks = 1024
+    allocator = CpuGpuBlockAllocator.create(
+        allocator_type=allocator_type,
+        num_gpu_blocks=num_device_blocks,
+        num_cpu_blocks=num_device_blocks,
+        block_size=block_size,
+    )
+
+    token_ids = list(range(sequence_len))
+    num_blocks_per_alloc = len(list(chunk_list(token_ids, block_size)))
+
+    block_table = BlockTable(
+        block_size=block_size,
+        block_allocator=allocator,
+    )
+
+    for i in range(5):
+        block_table.allocate(token_ids=token_ids, device=device)
+        assert allocator.get_num_free_blocks(
+            device) == num_device_blocks - num_blocks_per_alloc
+        assert all(block_id is not None
+                   for block_id in block_table.physical_block_ids)
+
+        block_table.free()
+        assert allocator.get_num_free_blocks(device) == num_device_blocks
+
+
+@pytest.mark.parametrize("block_size", [1, 8])
+@pytest.mark.parametrize("sequence_len", [1, 16, 129])
+@pytest.mark.parametrize("append_len", [1, 16, 129])
+@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
+def test_append_token_ids_allocation(block_size: int, sequence_len: int,
+                                     append_len: int, allocator_type: str):
+    """Test the allocation behavior when appending token IDs to a BlockTable.
+
+    This test creates a CpuGpuBlockAllocator with the specified block size,
+    number of blocks, and allocator type. It then allocates a BlockTable with an
+    initial sequence and appends additional token IDs to it. The test verifies
+    that the number of allocated blocks before and after appending matches the
+    expected values.
+    """
+
+    num_gpu_blocks = 1024
+
+    allocator = CpuGpuBlockAllocator.create(
+        allocator_type=allocator_type,
+        num_gpu_blocks=num_gpu_blocks,
+        num_cpu_blocks=1024,
+        block_size=block_size,
+    )
+
+    token_ids = list(range(sequence_len))
+    token_ids_to_append = list(range(append_len))
+
+    block_table = BlockTable(
+        block_size=block_size,
+        block_allocator=allocator,
+    )
+
+    num_expected_blocks_before_append = len(
+        list(chunk_list(token_ids, block_size)))
+    num_expected_appended_blocks = len(
+        list(chunk_list(token_ids + token_ids_to_append,
+                        block_size))) - num_expected_blocks_before_append
+
+    block_table.allocate(token_ids=token_ids, device=Device.GPU)
+
+    assert len(
+        block_table.physical_block_ids) == num_expected_blocks_before_append
+    block_table.append_token_ids(token_ids_to_append)
+    assert len(
+        block_table.physical_block_ids
+    ) == num_expected_blocks_before_append + num_expected_appended_blocks
+
+
+@pytest.mark.parametrize("block_size", [1, 8])
+@pytest.mark.parametrize("sequence_len", [1, 16, 129])
+@pytest.mark.parametrize("num_empty_slots", [1, 16, 129])
+@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
+def test_ensure_num_empty_slots_allocation(block_size: int, sequence_len: int,
+                                           num_empty_slots: int,
+                                           allocator_type: str):
+    """Test the allocation behavior when ensuring a certain number of empty
+    slots in a BlockTable.
+
+    This test creates a CpuGpuBlockAllocator with the specified block size,
+    number of blocks, and allocator type. It then allocates a BlockTable with an
+    initial sequence and ensures a certain number of empty slots. The test
+    verifies that the number of allocated blocks before and after ensuring empty
+    slots matches the expected values. It also checks that filling up the empty
+    slots does not consume additional blocks.
+    """
+    num_gpu_blocks = 1024
+
+    allocator = CpuGpuBlockAllocator.create(
+        allocator_type=allocator_type,
+        num_gpu_blocks=num_gpu_blocks,
+        num_cpu_blocks=1024,
+        block_size=block_size,
+    )
+
+    token_ids = list(range(sequence_len))
+
+    block_table = BlockTable(
+        block_size=block_size,
+        block_allocator=allocator,
+    )
+
+    num_expected_blocks_before_append = len(
+        list(chunk_list(token_ids, block_size)))
+    num_expected_appended_blocks = len(
+        list(chunk_list(token_ids + [-1] * num_empty_slots,
+                        block_size))) - num_expected_blocks_before_append
+
+    block_table.allocate(token_ids=token_ids, device=Device.GPU)
+
+    # Assert that the empty slots consume the expected number of additional
+    # blocks.
+    assert len(
+        block_table.physical_block_ids) == num_expected_blocks_before_append
+    block_table.ensure_num_empty_slots(num_empty_slots)
+    assert len(
+        block_table.physical_block_ids
+    ) == num_expected_blocks_before_append + num_expected_appended_blocks
+
+    # Now, ensure no additional blocks consumed as we fill up the empty slots.
+    num_free_blocks = allocator.get_num_free_blocks(device=Device.GPU)
+    block_table.append_token_ids(token_ids=list(range(num_empty_slots)))
+    assert num_free_blocks == allocator.get_num_free_blocks(device=Device.GPU)
+
+
+@pytest.mark.parametrize("block_size", [1, 8])
+@pytest.mark.parametrize("sequence_len", [1, 9])
+@pytest.mark.parametrize("append_len", [1, 16, 129])
+@pytest.mark.parametrize("append_size", [1, 4, 129])
+@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
+def test_append_token_ids_correct_content(block_size: int, sequence_len: int,
+                                          append_len: int, allocator_type: str,
+                                          append_size: int):
+    """Verify token ids are correctly appended. Appends various amounts of
+    token ids in various append sizes, and verifies the final sequence is
+    correct.
+    """
+    num_gpu_blocks = 1024
+
+    allocator = CpuGpuBlockAllocator.create(
+        allocator_type=allocator_type,
+        num_gpu_blocks=num_gpu_blocks,
+        num_cpu_blocks=1024,
+        block_size=block_size,
+    )
+
+    token_ids = list(range(sequence_len))
+    token_ids_to_append = list(range(append_len))
+
+    block_table = BlockTable(
+        block_size=block_size,
+        block_allocator=allocator,
+    )
+    block_table.allocate(token_ids=token_ids, device=Device.GPU)
+
+    appended_so_far: List[int] = []
+    for append in chunk_list(token_ids_to_append, append_size):
+        block_table.append_token_ids(append)
+        appended_so_far.extend(append)
+
+        assert block_table._get_all_token_ids() == token_ids + appended_so_far
+
+    assert block_table._get_all_token_ids() == token_ids + token_ids_to_append
+
+
+@pytest.mark.parametrize("seq_len", [1, 9, 129])
+@pytest.mark.parametrize("block_size", [1, 8])
+@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
+def test_fork(seq_len: int, block_size: int, allocator_type: str):
+    """Create a sequence using the specified allocator.
+        1. Assert that after forking the sequence, the free block count is the
+            same.
+        2. Assert that the forked sequence has the same physical mappings.
+        3. Then free the original sequence; verify that the free block count is
+            the same.
+        4. Finally, free the forked sequence and verify that the free block
+            count drops to zero.
+    """
+    num_gpu_blocks = 1024
+
+    allocator = CpuGpuBlockAllocator.create(
+        allocator_type=allocator_type,
+        num_gpu_blocks=num_gpu_blocks,
+        num_cpu_blocks=0,
+        block_size=block_size,
+    )
+
+    token_ids = list(range(seq_len))
+
+    block_table = BlockTable(
+        block_size=block_size,
+        block_allocator=allocator,
+    )
+
+    block_table.allocate(token_ids)
+
+    num_free_blocks_before_fork = allocator.get_num_free_blocks(
+        device=Device.GPU)
+
+    forked_block_table = block_table.fork()
+
+    # Expect physical_block_ids and token_ids to match.
+    assert (block_table.physical_block_ids ==
+            forked_block_table.physical_block_ids)
+    assert block_table._get_all_token_ids(
+    ) == forked_block_table._get_all_token_ids()
+
+    # Do not expect any additional allocations.
+    assert allocator.get_num_free_blocks(
+        device=Device.GPU) == num_free_blocks_before_fork
+
+    # Free the original blocks. Assert num free blocks does not change, since
+    # refcount is nonzero.
+    block_table.free()
+    assert allocator.get_num_free_blocks(
+        device=Device.GPU) == num_free_blocks_before_fork
+
+    # Expect the forked block table to be unaffected by the free.
+    assert all(block_id is not None
+               for block_id in forked_block_table.physical_block_ids)
+
+    # Free the forked blocks. Assert num free blocks does change, since
+    # refcount is now zero.
+    forked_block_table.free()
+    assert allocator.get_num_free_blocks(device=Device.GPU) == num_gpu_blocks
+
+
+@pytest.mark.parametrize("block_size", [8])
+@pytest.mark.parametrize("sequence_len", [1, 16, 129])
+@pytest.mark.parametrize("append_len", [1, 16, 129])
+@pytest.mark.parametrize("appender", ["forked", "original"])
+@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
+def test_cow(block_size: int, sequence_len: int, append_len: int,
+             allocator_type: str, appender: str):
+    """Fork a sequence; append to the forked sequence; verify there's a CoW.
+    """
+    num_gpu_blocks = 1024
+
+    allocator = CpuGpuBlockAllocator.create(
+        allocator_type=allocator_type,
+        num_gpu_blocks=num_gpu_blocks,
+        num_cpu_blocks=0,
+        block_size=block_size,
+    )
+
+    token_ids = list(range(sequence_len))
+    token_ids_to_append = list(range(append_len))
+
+    original_block_table = BlockTable(
+        block_size=block_size,
+        block_allocator=allocator,
+    )
+
+    num_expected_non_cow_blocks = cdiv(sequence_len, block_size)
+    num_expected_cow_blocks = cdiv(sequence_len + append_len,
+                                   block_size) - (sequence_len // block_size)
+
+    original_block_table.allocate(token_ids=token_ids, device=Device.GPU)
+    original_block_ids = original_block_table.physical_block_ids[:]
+
+    print("original_block_ids = {}".format(original_block_ids))
+    forked_block_table = original_block_table.fork()
+
+    # Expect no additional allocation (copy on _write_).
+    assert allocator.get_num_free_blocks(
+        Device.GPU) == (num_gpu_blocks - num_expected_non_cow_blocks)
+
+    if appender == "forked":
+        appender_block_table = forked_block_table
+        static_block_table = original_block_table
+    elif appender == "original":
+        appender_block_table = original_block_table
+        static_block_table = forked_block_table
+    else:
+        raise ValueError(f"unknown test config {appender=}")
+
+    # Write tokens.
+    appender_block_table.append_token_ids(token_ids_to_append)
+
+    # Expect the non-appending block table to have no change.
+    assert static_block_table.physical_block_ids == original_block_ids
+    assert appender_block_table.physical_block_ids != original_block_ids
+
+    # Expect the blocks changed during append to have a CoW.
+    assert allocator.get_num_free_blocks(
+        Device.GPU) == num_gpu_blocks - (num_expected_non_cow_blocks +
+                                         num_expected_cow_blocks)
+
+    cows = allocator.clear_copy_on_writes()
+    if sequence_len % block_size > 0:
+        # If the last block in the sequence is not full, then when appending we
+        # expect a CoW.
+        assert cows
+
+        cow_block_id = sequence_len // block_size
+        expected_src = static_block_table.physical_block_ids[cow_block_id]
+        expected_dst = appender_block_table.physical_block_ids[cow_block_id]
+
+        assert (expected_src, expected_dst) in cows
+    else:
+        # Otherwise, there should be no copy-on-write.
+        assert not cows
+
+    static_block_table.free()
+    appender_block_table.free()
+
+    # After free, expect all blocks to be freed.
+    assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
+
+
+@pytest.mark.parametrize("block_size", [8])
+@pytest.mark.parametrize("sequence_len", [1, 16, 129])
+@pytest.mark.parametrize("append_len", [1, 16, 129])
+@pytest.mark.parametrize("lookahead_slots", [1, 16, 129])
+@pytest.mark.parametrize("appender", ["forked", "original"])
+@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
+def test_cow_lookahead_simple(block_size: int, sequence_len: int,
+                              append_len: int, lookahead_slots: int,
+                              allocator_type: str, appender: str):
+    """Similar to test_cow, except with lookahead allocation. The assertions are
+    less rigorous due to the complexity of the property under test.
+    """
+    num_gpu_blocks = 1024
+
+    allocator = CpuGpuBlockAllocator.create(
+        allocator_type=allocator_type,
+        num_gpu_blocks=num_gpu_blocks,
+        num_cpu_blocks=0,
+        block_size=block_size,
+    )
+
+    token_ids = list(range(sequence_len))
+    token_ids_to_append = list(range(append_len))
+
+    original_block_table = BlockTable(
+        block_size=block_size,
+        block_allocator=allocator,
+    )
+
+    original_block_table.allocate(token_ids=token_ids, device=Device.GPU)
+
+    # Allocate lookahead slots.
+    original_block_table.ensure_num_empty_slots(lookahead_slots)
+    original_block_ids = original_block_table.physical_block_ids[:]
+
+    forked_block_table = original_block_table.fork()
+
+    if appender == "forked":
+        appender_block_table = forked_block_table
+        static_block_table = original_block_table
+    elif appender == "original":
+        appender_block_table = original_block_table
+        static_block_table = forked_block_table
+    else:
+        raise ValueError(f"unknown test config {appender=}")
+
+    # Write tokens.
+    appender_block_table.append_token_ids(token_ids_to_append)
+
+    # Expect the non-appending block table to have no change.
+    assert static_block_table.physical_block_ids == original_block_ids
+    assert appender_block_table.physical_block_ids != original_block_ids
+
+    cows = allocator.clear_copy_on_writes()
+
+    # Always expect copy-on-write
+    assert cows
+
+    if sequence_len % block_size > 0:
+        # If the last block in the sequence is not full, then when appending we
+        # expect a CoW.
+        assert cows
+
+        cow_block_id = sequence_len // block_size
+        expected_src = static_block_table.physical_block_ids[cow_block_id]
+        expected_dst = appender_block_table.physical_block_ids[cow_block_id]
+
+        assert (expected_src, expected_dst) in cows
+
+    static_block_table.free()
+    appender_block_table.free()
+
+    # After free, expect all blocks to be freed.
+    assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
+
+
+@pytest.mark.parametrize("block_size", [1, 8])
+@pytest.mark.parametrize("sequence_len", [1, 16, 129])
+@pytest.mark.parametrize("num_new_tokens", [1, 16, 129])
+@pytest.mark.parametrize("num_lookahead_slots", [1, 7, 8])
+@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
+def test_num_blocks_touched_by_append_slots(block_size: int, sequence_len: int,
+                                            num_new_tokens: int,
+                                            num_lookahead_slots: int,
+                                            allocator_type: str):
+    """Verify correct calculation of get_num_blocks_touched_by_append_slots.
+
+    This is done by using copy-on-write, which requires any modified block to
+    be copied before write if the refcount > 1. We set the refcount>1 by forking
+    a sequence, then measure the free blocks before and after an append. If the
+    number of consumed blocks equals what `get_num_blocks_touched_by_append_
+    slots` returns, then the calculation is correct.
+    """
+
+    num_gpu_blocks = 1024
+
+    allocator = CpuGpuBlockAllocator.create(
+        allocator_type=allocator_type,
+        num_gpu_blocks=num_gpu_blocks,
+        num_cpu_blocks=0,
+        block_size=block_size,
+    )
+
+    token_ids = list(range(sequence_len))
+    token_ids_to_append = list(range(num_new_tokens))
+
+    block_table = BlockTable(
+        block_size=block_size,
+        block_allocator=allocator,
+    )
+
+    block_table.allocate(token_ids=token_ids, device=Device.GPU)
+
+    # Add lookahead before fork so both sequences have the same lookahead
+    # blocks.
+    block_table.ensure_num_empty_slots(num_empty_slots=num_lookahead_slots)
+
+    # Fork sequence so that every block has refcount > 1.
+    _ = block_table.fork()
+
+    # Determine how many blocks should be touched.
+    expected_num_touched_blocks = (
+        block_table.get_num_blocks_touched_by_append_slots(
+            token_ids=token_ids_to_append,
+            num_lookahead_slots=num_lookahead_slots))
+
+    # Measure how many blocks are touched by measuring num_free_blocks before
+    # and after the append.
+    #
+    # We expect append_token_ids to CoW all mutated blocks that have refcount>1.
+    num_free_blocks_before_append = allocator.get_num_free_blocks(Device.GPU)
+    block_table.append_token_ids(token_ids_to_append, num_lookahead_slots)
+    num_consumed_blocks = (num_free_blocks_before_append -
+                           allocator.get_num_free_blocks(Device.GPU))
+
+    # TODO(cade) ensure equality when num_lookahead_slots > 0.
+    # The reason we have < is because lookahead blocks are not copied eagerly;
+    # they are copied on first write. This will cause issues for beam search +
+    # speculative decoding. This is acceptable for now as it is a large effort
+    # to combine the two. To fix this, we can ensure single sequence ownership
+    # of lookahead blocks by appending empty slots to each block, which will
+    # trigger the CoW.
+    #
+    # Until then, we can accept that the consumed tokens are <= the expected
+    # tokens when appending with lookahead.
+    if num_lookahead_slots > 0:
+        assert num_consumed_blocks <= expected_num_touched_blocks
+    else:
+        assert num_consumed_blocks == expected_num_touched_blocks

+ 42 - 0
tests/core/block/test_common.py

@@ -0,0 +1,42 @@
+import random
+
+import pytest
+
+from aphrodite.processing.block.common import RefCounter
+
+
+@pytest.mark.parametrize("seed", list(range(20)))
+@pytest.mark.parametrize("num_incrs", [1, 100])
+@pytest.mark.parametrize("num_blocks", [1024])
+def test_incr(seed: int, num_incrs: int, num_blocks: int):
+    random.seed(seed)
+
+    all_block_indices = list(range(num_blocks))
+    counter = RefCounter(all_block_indices=all_block_indices)
+
+    block_id = random.randint(0, num_blocks - 1)
+    for i in range(num_incrs):
+        value = counter.incr(block_id)
+        assert value == i + 1
+
+
+@pytest.mark.parametrize("seed", list(range(20)))
+@pytest.mark.parametrize("num_incrs", [1, 100])
+@pytest.mark.parametrize("num_blocks", [1024])
+def test_incr_decr(seed: int, num_incrs: int, num_blocks: int):
+    random.seed(seed)
+
+    all_block_indices = list(range(num_blocks))
+    counter = RefCounter(all_block_indices=all_block_indices)
+
+    block_id = random.randint(0, num_blocks - 1)
+    for i in range(num_incrs):
+        value = counter.incr(block_id)
+        assert value == i + 1
+
+    for i in range(num_incrs):
+        value = counter.decr(block_id)
+        assert value == num_incrs - (i + 1)
+
+    with pytest.raises(AssertionError):
+        counter.decr(block_id)

+ 94 - 0
tests/core/block/test_cpu_gpu_block_allocator.py

@@ -0,0 +1,94 @@
+import pytest
+
+from aphrodite.common.utils import Device, chunk_list
+from aphrodite.processing.block.cpu_gpu_block_allocator import (
+    CpuGpuBlockAllocator)
+
+
+@pytest.mark.parametrize("num_cpu_blocks", [0, 512])
+@pytest.mark.parametrize("num_gpu_blocks", [1024])
+@pytest.mark.parametrize("block_size", [16])
+@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
+def test_allocate_mutable_block(num_cpu_blocks: int, num_gpu_blocks: int,
+                                block_size: int, allocator_type: str):
+    allocator = CpuGpuBlockAllocator.create(
+        allocator_type=allocator_type,
+        num_gpu_blocks=num_gpu_blocks,
+        num_cpu_blocks=num_cpu_blocks,
+        block_size=block_size,
+    )
+
+    assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
+    assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
+
+    cpu_blocks = [
+        allocator.allocate_mutable_block(prev_block=None, device=Device.CPU)
+        for _ in range(num_cpu_blocks)
+    ]
+    assert allocator.get_num_free_blocks(Device.CPU) == 0
+    assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
+
+    gpu_blocks = [
+        allocator.allocate_mutable_block(prev_block=None, device=Device.GPU)
+        for _ in range(num_gpu_blocks)
+    ]
+    assert allocator.get_num_free_blocks(Device.CPU) == 0
+    assert allocator.get_num_free_blocks(Device.GPU) == 0
+
+    _ = [allocator.free(block) for block in cpu_blocks]
+    assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
+    assert allocator.get_num_free_blocks(Device.GPU) == 0
+
+    _ = [allocator.free(block) for block in gpu_blocks]
+    assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
+    assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
+
+
+@pytest.mark.parametrize("num_cpu_blocks", [0, 512])
+@pytest.mark.parametrize("num_gpu_blocks", [1024])
+@pytest.mark.parametrize("block_size", [2])
+@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
+def test_allocate_immutable_block(num_cpu_blocks: int, num_gpu_blocks: int,
+                                  block_size: int, allocator_type: str):
+    allocator = CpuGpuBlockAllocator.create(
+        allocator_type=allocator_type,
+        num_gpu_blocks=num_gpu_blocks,
+        num_cpu_blocks=num_cpu_blocks,
+        block_size=block_size,
+    )
+
+    unique_token_ids = list(
+        range((num_cpu_blocks + num_gpu_blocks) * block_size))
+    gpu_token_ids = list(
+        chunk_list(unique_token_ids[:num_gpu_blocks * block_size], block_size))
+    cpu_token_ids = list(
+        chunk_list(unique_token_ids[num_gpu_blocks * block_size:], block_size))
+
+    assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
+    assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
+
+    cpu_blocks = [
+        allocator.allocate_immutable_block(prev_block=None,
+                                           token_ids=token_ids,
+                                           device=Device.CPU)
+        for token_ids in cpu_token_ids
+    ]
+    assert allocator.get_num_free_blocks(Device.CPU) == 0
+    assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
+
+    gpu_blocks = [
+        allocator.allocate_immutable_block(prev_block=None,
+                                           token_ids=token_ids,
+                                           device=Device.GPU)
+        for token_ids in gpu_token_ids
+    ]
+    assert allocator.get_num_free_blocks(Device.CPU) == 0
+    assert allocator.get_num_free_blocks(Device.GPU) == 0
+
+    _ = [allocator.free(block) for block in cpu_blocks]
+    assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
+    assert allocator.get_num_free_blocks(Device.GPU) == 0
+
+    _ = [allocator.free(block) for block in gpu_blocks]
+    assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
+    assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks

+ 145 - 0
tests/core/block/test_naive_block.py

@@ -0,0 +1,145 @@
+from typing import List, Optional
+
+import pytest
+
+from aphrodite.processing.block.interfaces import Block, BlockAllocator
+from aphrodite.processing.block.naive_block import (NaiveBlock,
+                                                    NaiveBlockAllocator)
+
+
+class TestNaiveBlockAllocator:
+
+    @staticmethod
+    def create_allocate_lambda(allocate_type: str,
+                               allocator: NaiveBlockAllocator,
+                               prev_block: Optional[Block],
+                               token_ids: List[int]):
+        if allocate_type == "immutable":
+            allocate_block = lambda: allocator.allocate_immutable_block(
+                prev_block=prev_block, token_ids=token_ids)
+        elif allocate_type == "mutable":
+            allocate_block = lambda: allocator.allocate_mutable_block(
+                prev_block=prev_block)
+        else:
+            raise ValueError()
+
+        return allocate_block
+
+    @staticmethod
+    @pytest.mark.parametrize("allocate_type", ["immutable", "mutable"])
+    @pytest.mark.parametrize("num_blocks", [1, 1024])
+    @pytest.mark.parametrize("block_size", [1, 16])
+    def test_allocate_ooms(allocate_type: str, num_blocks: int,
+                           block_size: int):
+        allocator = NaiveBlockAllocator(create_block=NaiveBlock,
+                                        num_blocks=num_blocks,
+                                        block_size=block_size)
+        allocate_block = TestNaiveBlockAllocator.create_allocate_lambda(
+            allocate_type,
+            allocator,
+            prev_block=None,
+            token_ids=list(range(block_size)))
+
+        [allocate_block() for _ in range(num_blocks)]
+        with pytest.raises(BlockAllocator.NoFreeBlocksError):
+            allocate_block()
+
+    @staticmethod
+    @pytest.mark.parametrize("allocate_type", ["immutable", "mutable"])
+    @pytest.mark.parametrize("num_blocks", [1, 1024])
+    @pytest.mark.parametrize("block_size", [1, 16])
+    def test_free_prevents_oom(allocate_type: str, num_blocks: int,
+                               block_size: int):
+        allocator = NaiveBlockAllocator(create_block=NaiveBlock,
+                                        num_blocks=num_blocks,
+                                        block_size=block_size)
+        allocate_block = TestNaiveBlockAllocator.create_allocate_lambda(
+            allocate_type,
+            allocator,
+            prev_block=None,
+            token_ids=list(range(block_size)))
+
+        blocks = [allocate_block() for _ in range(num_blocks)]
+
+        with pytest.raises(BlockAllocator.NoFreeBlocksError):
+            allocate_block()
+
+        block_to_free = blocks.pop()
+
+        for _ in range(100):
+            block_id = block_to_free.block_id
+            allocator.free(block_to_free)
+            assert block_to_free.block_id is None
+
+            new_block = allocate_block()
+            assert new_block.block_id == block_id
+
+            with pytest.raises(BlockAllocator.NoFreeBlocksError):
+                allocate_block()
+
+            block_to_free = new_block
+
+    @staticmethod
+    @pytest.mark.parametrize("allocate_type", ["immutable", "mutable"])
+    @pytest.mark.parametrize("num_blocks", [1024])
+    @pytest.mark.parametrize("block_size", [16])
+    def test_get_num_free_blocks(allocate_type: str, num_blocks: int,
+                                 block_size: int):
+        allocator = NaiveBlockAllocator(create_block=NaiveBlock,
+                                        num_blocks=num_blocks,
+                                        block_size=block_size)
+        allocate_block = TestNaiveBlockAllocator.create_allocate_lambda(
+            allocate_type,
+            allocator,
+            prev_block=None,
+            token_ids=list(range(block_size)))
+
+        assert allocator.get_num_free_blocks() == num_blocks
+
+        blocks = [allocate_block() for _ in range(num_blocks)]
+
+        for i, block in enumerate(blocks):
+            assert allocator.get_num_free_blocks() == i
+            allocator.free(block)
+
+    @staticmethod
+    @pytest.mark.parametrize("num_blocks", [4])
+    @pytest.mark.parametrize("block_size", [8])
+    def test_naive_block_get_num_blocks_touched(num_blocks, block_size):
+        """ Verify the allocator can correctly return the number of
+        blocks touched, with different lookahead slots.
+        """
+        allocator_src = NaiveBlockAllocator(create_block=NaiveBlock,
+                                            num_blocks=num_blocks,
+                                            block_size=block_size)
+        allocator_dst = NaiveBlockAllocator(create_block=NaiveBlock,
+                                            num_blocks=num_blocks,
+                                            block_size=block_size)
+
+        # Create a chain of cacheable blocks in the dst
+        allocate_block = TestNaiveBlockAllocator.create_allocate_lambda(
+            "immutable",
+            allocator_src,
+            prev_block=None,
+            token_ids=list(range(block_size)))
+        src_blocks = [allocate_block() for _ in range(num_blocks - 1)]
+
+        # All blocks are cached
+        assert allocator_dst.get_num_blocks_touched(
+            src_blocks) == num_blocks - 1
+
+        # Insert one non-full block in the src
+        allocate_non_full_block = \
+            TestNaiveBlockAllocator.create_allocate_lambda(
+                "mutable", allocator_src,
+                prev_block=src_blocks[-1],token_ids=[]
+            )
+        src_blocks.append(allocate_non_full_block())
+        src_blocks[-1].append_token_ids([0])
+
+        assert allocator_dst.get_num_blocks_touched(
+            src_blocks, num_lookahead_slots=1) == num_blocks
+        assert allocator_dst.get_num_blocks_touched(
+            src_blocks, num_lookahead_slots=block_size - 1) == num_blocks
+        assert allocator_dst.get_num_blocks_touched(
+            src_blocks, num_lookahead_slots=block_size) == (num_blocks + 1)

+ 708 - 0
tests/core/block/test_prefix_caching_block.py

@@ -0,0 +1,708 @@
+import math
+import random
+from typing import List, Optional
+from unittest.mock import MagicMock
+
+import pytest
+
+from aphrodite.processing.block.interfaces import Block, BlockAllocator
+from aphrodite.processing.block.prefix_caching_block import (
+    PrefixCachingBlock, PrefixCachingBlockAllocator)
+
+
+class TestPrefixCachingBlock:
+
+    @staticmethod
+    @pytest.mark.parametrize("seed", list(range(10)))
+    @pytest.mark.parametrize("block_size", [1, 16])
+    @pytest.mark.parametrize("is_curr_block_full", [True, False])
+    def test_first_block_has_correct_content_hash(seed: int, block_size: int,
+                                                  is_curr_block_full: bool):
+        """Verify a block which is first in the sequence has the correct hash.
+        """
+        random.seed(seed)
+        num_to_fill = block_size if is_curr_block_full else random.randint(
+            0, block_size - 1)
+        token_ids = list(range(num_to_fill))
+        mock_allocator = MagicMock(spec=PrefixCachingBlockAllocator)
+
+        block_with_prev = PrefixCachingBlock(prev_block=None,
+                                             token_ids=token_ids,
+                                             block_size=block_size,
+                                             allocator=mock_allocator)
+
+        if is_curr_block_full:
+            # Expect hash since block is full.
+            assert block_with_prev.content_hash == (
+                PrefixCachingBlock.hash_block_tokens(
+                    is_first_block=True,
+                    prev_block_hash=None,
+                    cur_block_token_ids=token_ids))
+        else:
+            # Do not expect hash since block is not full.
+            assert block_with_prev.content_hash is None
+
+    @staticmethod
+    @pytest.mark.parametrize("seed", list(range(10)))
+    @pytest.mark.parametrize("block_size", [1, 16])
+    @pytest.mark.parametrize("is_curr_block_full", [True, False])
+    @pytest.mark.parametrize("prev_block_has_hash", [True, False])
+    def test_nth_block_has_correct_content_hash(seed: int, block_size: int,
+                                                is_curr_block_full: bool,
+                                                prev_block_has_hash: bool):
+        """Verify a block which is not first in the sequence has the correct
+        hash.
+        """
+
+        random.seed(seed)
+
+        previous_block = MagicMock(spec=PrefixCachingBlock)
+        prev_block_hash = random.randint(0, 1000)
+        previous_block.content_hash = (prev_block_hash
+                                       if prev_block_has_hash else None)
+
+        num_to_fill = block_size if is_curr_block_full else random.randint(
+            0, block_size - 1)
+        token_ids = list(range(num_to_fill))
+        mock_allocator = MagicMock(spec=PrefixCachingBlockAllocator)
+
+        block_with_prev = PrefixCachingBlock(
+            prev_block=previous_block,
+            token_ids=token_ids,
+            block_size=block_size,
+            allocator=mock_allocator,
+        )
+
+        if is_curr_block_full and prev_block_has_hash:
+            # Expect hash since block is full and previous block has hash.
+            assert (block_with_prev.content_hash ==
+                    PrefixCachingBlock.hash_block_tokens(
+                        is_first_block=False,
+                        prev_block_hash=prev_block_hash,
+                        cur_block_token_ids=token_ids))
+        else:
+            # Do not expect hash since block is not full or the previous block
+            # does not have a hash.
+            assert block_with_prev.content_hash is None
+
+    @staticmethod
+    @pytest.mark.parametrize("block_size", [1, 2, 16])
+    @pytest.mark.parametrize("num_tokens", list(range(3)))
+    @pytest.mark.parametrize("num_empty_trailing_blocks", [0, 1, 10])
+    def test_blocks_have_correct_hash_in_chain(block_size: int,
+                                               num_tokens: int,
+                                               num_empty_trailing_blocks: int):
+        """Create two chains of logical blocks with the same contents.
+        Assert the hashes are equal.
+        """
+        random.seed(0)
+
+        token_ids = [random.randint(0, 50_000) for _ in range(num_tokens)]
+
+        first_chain, second_chain = [
+            TestPrefixCachingBlock.create_chain(
+                block_size=block_size,
+                token_ids=token_ids,
+                num_empty_trailing_blocks=num_empty_trailing_blocks)
+            for _ in range(2)
+        ]
+
+        for first_chain_block, second_chain_block in zip(
+                first_chain, second_chain):
+            assert (first_chain_block.content_hash ==
+                    second_chain_block.content_hash)
+
+        if not first_chain or not second_chain:
+            assert first_chain == second_chain
+            assert num_tokens == 0
+
+    @staticmethod
+    def create_chain(block_size: int,
+                     token_ids: List[int],
+                     num_empty_trailing_blocks=0) -> List[PrefixCachingBlock]:
+        """Helper method which creates a chain of blocks.
+        """
+        blocks: List[PrefixCachingBlock] = []
+        num_blocks = math.ceil(
+            len(token_ids) / block_size) + num_empty_trailing_blocks
+
+        if num_blocks == 0:
+            return []
+
+        allocator = MagicMock(spec=PrefixCachingBlockAllocator)
+
+        prev_block = None
+        for block_number in range(0, num_blocks):
+            prev_block = PrefixCachingBlock(
+                prev_block=prev_block,
+                token_ids=[],
+                block_size=block_size,
+                allocator=allocator,
+            )
+
+            tokens_to_append = token_ids[block_number *
+                                         block_size:(block_number + 1) *
+                                         block_size]
+            if tokens_to_append:
+                prev_block.append_token_ids(tokens_to_append)
+
+            blocks.append(prev_block)
+
+        return blocks
+
+
+class TestPrefixCachingBlockAllocator:
+
+    @staticmethod
+    def create_allocate_lambda(allocate_type: str, allocator: BlockAllocator,
+                               prev_block: Optional[Block],
+                               token_ids: List[int]):
+        if allocate_type == "immutable":
+            allocate_block = lambda: allocator.allocate_immutable_block(
+                prev_block=prev_block, token_ids=token_ids)
+        elif allocate_type == "mutable":
+            allocate_block = lambda: allocator.allocate_mutable_block(
+                prev_block=prev_block)
+        else:
+            raise ValueError()
+
+        return allocate_block
+
+    @staticmethod
+    @pytest.mark.parametrize("num_blocks", [1, 1024])
+    @pytest.mark.parametrize("block_size", [1, 16])
+    def test_allocate_mutable_ooms(num_blocks: int, block_size: int):
+        allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
+                                                block_size=block_size)
+        allocate_block = TestPrefixCachingBlockAllocator.create_allocate_lambda(
+            allocate_type="mutable",
+            allocator=allocator,
+            prev_block=None,
+            token_ids=list(range(block_size)),
+        )
+
+        [allocate_block() for _ in range(num_blocks)]
+        with pytest.raises(BlockAllocator.NoFreeBlocksError):
+            allocate_block()
+
+    @staticmethod
+    @pytest.mark.parametrize("num_blocks", [1, 1024])
+    @pytest.mark.parametrize("block_size", [1, 16])
+    def test_allocate_immutable_does_not_oom_single_hash(
+            num_blocks: int, block_size: int):
+        allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
+                                                block_size=block_size)
+        allocate_block = TestPrefixCachingBlockAllocator.create_allocate_lambda(
+            allocate_type="immutable",
+            allocator=allocator,
+            prev_block=None,
+            token_ids=list(range(block_size)),
+        )
+
+        blocks = [allocate_block() for _ in range(num_blocks)]
+
+        # Expect no OOM. If these were mutable blocks, this would OOM.
+        non_oom_block = allocate_block()
+
+        # Expect all blocks to have same physical block index.
+        for block in blocks:
+            assert (block.block_id == non_oom_block.block_id)
+
+    @staticmethod
+    @pytest.mark.parametrize("num_blocks", [1, 1024])
+    @pytest.mark.parametrize("block_size", [1, 16])
+    def test_allocate_immutable_ooms_many_hash(num_blocks: int,
+                                               block_size: int):
+        """Consume all blocks using many different hashes/block content.
+
+        Do this by creating a sequence that is very long.
+        Expect next block to OOM.
+        """
+        allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
+                                                block_size=block_size)
+
+        # Create token ids that will exhaust all blocks.
+        token_ids = list(range(num_blocks * block_size))
+
+        chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
+            block_size=block_size,
+            token_ids=token_ids,
+            allocator=allocator,
+        )
+
+        # Expect allocation with unseen hash to fail.
+        with pytest.raises(BlockAllocator.NoFreeBlocksError):
+            allocator.allocate_immutable_block(prev_block=chain[-1],
+                                               token_ids=list(
+                                                   range(block_size)))
+
+        # Expect mutable allocation to fail.
+        with pytest.raises(BlockAllocator.NoFreeBlocksError):
+            allocator.allocate_mutable_block(prev_block=chain[-1])
+
+        # Expect allocation of exact same chain to pass.
+        second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
+            block_size=block_size,
+            token_ids=token_ids,
+            allocator=allocator,
+        )
+
+        # Expect physical block indices to be the same in both chains.
+        assert chain and second_chain
+        for first_chain_block, second_chain_block in zip(chain, second_chain):
+            assert (first_chain_block.block_id == second_chain_block.block_id)
+
+    @staticmethod
+    @pytest.mark.parametrize("num_blocks", [1, 1024])
+    @pytest.mark.parametrize("block_size", [1, 16])
+    def test_free_prevents_oom(num_blocks: int, block_size: int):
+        allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
+                                                block_size=block_size)
+
+        # Create token ids that will exhaust all blocks.
+        token_ids = list(range(num_blocks * block_size))
+
+        chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
+            block_size=block_size,
+            token_ids=token_ids,
+            allocator=allocator,
+        )
+
+        # Expect mutable allocation to fail.
+        with pytest.raises(BlockAllocator.NoFreeBlocksError):
+            allocator.allocate_mutable_block(prev_block=None)
+
+        block_to_free = chain[-1]
+
+        # Expect free/allocate loop to succeed many times.
+        for i in range(100):
+            block_id = block_to_free.block_id
+            allocator.free(block_to_free)
+            assert block_to_free.block_id is None, i
+
+            new_block = allocator.allocate_mutable_block(prev_block=None)
+            assert new_block.block_id == block_id, i
+
+            with pytest.raises(BlockAllocator.NoFreeBlocksError):
+                allocator.allocate_mutable_block(prev_block=None)
+
+            block_to_free = new_block
+
+    @staticmethod
+    @pytest.mark.parametrize("num_blocks", [1024])
+    @pytest.mark.parametrize("block_size", [16])
+    @pytest.mark.parametrize("seed", list(range(20)))
+    def test_get_num_free_blocks(num_blocks: int, block_size: int, seed: int):
+        random.seed(seed)
+        allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
+                                                block_size=block_size)
+        num_blocks_to_consume = random.randint(1, num_blocks - 1)
+
+        # Create token ids that will exhaust all blocks.
+        token_ids = list(range(num_blocks_to_consume * block_size))
+
+        chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
+            block_size=block_size,
+            token_ids=token_ids,
+            allocator=allocator,
+        )
+
+        # Free each block in chain, assert num free blocks includes new free
+        # block.
+        for i, block in enumerate(chain):
+            assert allocator.get_num_free_blocks() == (num_blocks -
+                                                       num_blocks_to_consume +
+                                                       i)
+            allocator.free(block)
+
+    @staticmethod
+    @pytest.mark.parametrize("num_blocks", [4])
+    @pytest.mark.parametrize("block_size", [8])
+    def test_prefix_caching_block_get_num_blocks_touched(
+            num_blocks, block_size):
+        """ Verify the allocator can correctly return the number of
+        blocks touched, when there are cached prefixes and different
+        lookahead slots.
+        """
+        allocator_src = PrefixCachingBlockAllocator(num_blocks=num_blocks,
+                                                    block_size=block_size)
+        allocator_dst = PrefixCachingBlockAllocator(num_blocks=num_blocks,
+                                                    block_size=block_size)
+
+        # Create token ids that will exhaust all blocks except the last
+        token_ids = list(range((num_blocks - 1) * block_size))
+
+        # Create a chain of cacheable blocks in the dst
+        cached_blocks = TestPrefixCachingBlockAllocator.create_immutable_chain(
+            block_size=block_size,
+            token_ids=token_ids,
+            allocator=allocator_dst,
+        )
+
+        # Create a chain of the same blocks in the src
+        blocks_to_swap_in = \
+            TestPrefixCachingBlockAllocator.create_immutable_chain(
+                block_size=block_size,
+                token_ids=token_ids,
+                allocator=allocator_src,
+            )
+
+        # All blocks are cached
+        assert allocator_dst.get_num_blocks_touched(blocks_to_swap_in) == 0
+
+        # Free the first block in the dst
+        allocator_dst.free(cached_blocks[0])
+
+        # Now the first block becomes dangling, the swapped blocks need
+        # to reclaim the first block in the dst
+        assert allocator_dst.get_num_blocks_touched(blocks_to_swap_in) == 1
+
+        # Insert one non-full block in the src
+        non_full_block = allocator_src.allocate_mutable_block(
+            blocks_to_swap_in[-1])
+        non_full_block.append_token_ids([0])
+        blocks_to_swap_in.append(non_full_block)
+        assert allocator_dst.get_num_blocks_touched(blocks_to_swap_in,
+                                                    num_lookahead_slots=1) == 2
+        assert allocator_dst.get_num_blocks_touched(
+            blocks_to_swap_in, num_lookahead_slots=block_size - 1) == 2
+        assert allocator_dst.get_num_blocks_touched(
+            blocks_to_swap_in, num_lookahead_slots=block_size) == 3
+
+    @staticmethod
+    @pytest.mark.parametrize("num_blocks", [1024])
+    @pytest.mark.parametrize("block_size", [16])
+    @pytest.mark.parametrize("seed", list(range(20)))
+    def test_get_num_free_blocks_shared(num_blocks: int, block_size: int,
+                                        seed: int):
+        """Verify sharing occurs by allocating two sequences that share prefixes
+        and incrementally freeing blocks.
+        """
+        random.seed(seed)
+        allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
+                                                block_size=block_size)
+        num_blocks_to_consume = random.randint(1, num_blocks - 1)
+
+        # Create token ids that will exhaust all blocks.
+        token_ids = list(range(num_blocks_to_consume * block_size))
+
+        first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
+            block_size=block_size,
+            token_ids=token_ids,
+            allocator=allocator,
+        )
+        second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
+            block_size=block_size,
+            token_ids=token_ids,
+            allocator=allocator,
+        )
+
+        # Free each block in the first chain. Since all blocks are shared, the
+        # free count should stay constant.
+        for i, block in enumerate(first_chain):
+            assert allocator.get_num_free_blocks() == (num_blocks -
+                                                       num_blocks_to_consume)
+            allocator.free(block)
+
+        # Free each block in the second chain. Since the refcount is now zero,
+        # the free count should increment with each free.
+        for i, block in enumerate(second_chain):
+            assert allocator.get_num_free_blocks() == (num_blocks -
+                                                       num_blocks_to_consume +
+                                                       i)
+            allocator.free(block)
+
+    @staticmethod
+    @pytest.mark.parametrize("num_blocks", [1024])
+    @pytest.mark.parametrize("block_size", [16])
+    @pytest.mark.parametrize("seed", list(range(20)))
+    def test_get_common_computed_block_ids(num_blocks: int, block_size: int,
+                                           seed: int):
+        """Verify get_common_computed_block_ids could get correct result
+        by create two immutable chain sharing prefix at specified pos,
+        and compare whether we also could get right result
+        from get_common_computed_block_ids.
+        """
+        random.seed(seed)
+        allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks * 2,
+                                                block_size=block_size)
+        num_blocks_to_consume = random.randint(1, num_blocks - 1)
+
+        # Create token ids that will exhaust all blocks.
+        token_ids = list(range(num_blocks_to_consume * block_size))
+
+        first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
+            block_size=block_size,
+            token_ids=token_ids,
+            allocator=allocator,
+        )
+
+        # After zero_point, second_chain's token_ids would be set -1, which
+        # make it different from here comparing with first_chain
+        zero_point = random.randint(1, len(token_ids) - 1)
+        zero_point_blocks = zero_point // block_size
+        token_ids[zero_point:] = [-1] * (len(token_ids) - zero_point)
+
+        second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
+            block_size=block_size,
+            token_ids=token_ids,
+            allocator=allocator,
+        )
+
+        first_computed_ids = [
+            first_chain[i].block_id for i in range(num_blocks_to_consume)
+        ]
+        second_computed_ids = [
+            second_chain[i].block_id for i in range(num_blocks_to_consume)
+        ]
+        res = allocator.get_common_computed_block_ids(
+            [first_computed_ids, second_computed_ids])
+
+        assert (len(res) == zero_point_blocks)
+
+    # Test case that assume those prompted block after first immutable would
+    # be freed into hashless allocator, while first immutable block get ref
+    # increased.
+    @staticmethod
+    @pytest.mark.parametrize("num_blocks", [3])
+    @pytest.mark.parametrize("block_size", [16])
+    @pytest.mark.parametrize("seed", list(range(10)))
+    def test_alloc_promotion(num_blocks: int, block_size: int, seed: int):
+        random.seed(seed)
+
+        allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
+                                                block_size=block_size)
+        token_ids = list(range(block_size))
+
+        block = allocator.allocate_immutable_block(prev_block=None,
+                                                   token_ids=token_ids)
+
+        assert allocator._refcounter.get(block.block_id) == 1
+        m = allocator.allocate_mutable_block(prev_block=None)
+
+        block_id = m.block_id
+        for i in range(block_size):
+            m.append_token_ids([i])
+
+        # After block get promoted to immutable from mutable, if there is
+        # already same content hash block, then it shall be released into
+        # hashless_allocator
+        # And first immutable block's ref get increased by 1
+        assert m.block_id == block.block_id
+        assert block_id in allocator._hashless_allocator._free_block_indices
+        assert allocator._refcounter.get(block.block_id) == 2
+
+    # Test case when eviction and allocation are mixed,
+    # make sure they work as expected
+    @staticmethod
+    @pytest.mark.parametrize("num_blocks", [3])
+    @pytest.mark.parametrize("block_size", [16])
+    @pytest.mark.parametrize("seed", list(range(10)))
+    def test_eviction_alloc_mixed(num_blocks: int, block_size: int, seed: int):
+        random.seed(seed)
+
+        all_blocks_list = [i for i in range(num_blocks)]
+        zero_ref = {i: 0 for i in range(num_blocks)}
+        one_ref = {i: 1 for i in range(num_blocks)}
+        allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
+                                                block_size=block_size)
+        token_ids = list(range(num_blocks * block_size))
+
+        # Verify initial/pre-alloc state
+
+        # Ensure all blocks are free inside hashless allocator
+        assert list(allocator._hashless_allocator._free_block_indices
+                    ) == all_blocks_list
+        # Ensure no tracked blocks
+        assert len(allocator._block_tracker.keys()) == num_blocks
+        for block_id in range(num_blocks):
+            assert not allocator._block_tracker[block_id].active
+        # Ensure no cached blocks
+        assert len(allocator._cached_blocks.values()) == 0
+        # Ensure no evicted blocks
+        assert len(allocator.evictor.free_table.keys()) == 0
+        # Ensure 0s ref counts for all blocks
+        assert allocator._refcounter._refcounts == zero_ref
+
+        # Allocate immutable chains with only one block residuled in
+        new_block = []
+        for i in range(num_blocks):
+            block = allocator.allocate_immutable_block(
+                prev_block=None,
+                token_ids=token_ids[block_size * i:block_size * (i + 1)])
+            new_block.append(block)
+
+        # Verify post-alloc state
+
+        # Ensure no blocks are free inside hashless allocator
+        assert (len(allocator._hashless_allocator._free_block_indices) == 0)
+        # Ensure all blocks are tracked
+        assert len(allocator._block_tracker.keys()) == num_blocks
+        for block_id in range(num_blocks):
+            assert allocator._block_tracker[block_id].active
+        # Ensure all blocks are cached (all promoted)
+        assert len(allocator._cached_blocks.values()) == num_blocks
+        # Ensure no evicted blocks
+        assert len(allocator.evictor.free_table.keys()) == 0
+        # Ensure 1s ref counts for all blocks
+        assert allocator._refcounter._refcounts == one_ref
+
+        # Free all blocks, and now all blocks shall be in the evictor
+        # there shall be no tracking data left in _block_tracker
+        # all blocks shall be tracked in _cached_blocks
+        # all blocks' ref shall be zero
+        for block in new_block:
+            allocator.free(block)
+
+        # Verify post-free state
+
+        # Ensure no tracked blocks
+        assert len(allocator._block_tracker.keys()) == num_blocks
+        for block_id in range(num_blocks):
+            assert not allocator._block_tracker[block_id].active
+        # Ensure no blocks in hashless allocator (all promoted)
+        assert len(allocator._hashless_allocator._free_block_indices) == 0
+        # Ensure all blocks are cached
+        assert list(allocator._cached_blocks.values()) == all_blocks_list
+        # Ensure all blocks are inside the evictor
+        assert list(allocator.evictor.free_table.keys()) == all_blocks_list
+        # Ensure 0s refcounts
+        assert allocator._refcounter._refcounts == zero_ref
+
+        # Allocate a mutable block, and the first block shall be evicted
+        # and set its content hash into None, ref to 1
+        mutable = allocator.allocate_mutable_block(prev_block=None)
+
+        assert mutable.block_id == 0
+        assert mutable.content_hash is None
+        assert allocator._block_tracker[0].active
+        assert allocator._refcounter.get(0) == 1
+        assert 0 not in allocator._cached_blocks
+        assert 0 not in allocator.evictor
+
+        # Since this mutable block has no hash yet, it shall be released into
+        # hashless allocator
+        allocator.free(mutable)
+
+        assert not allocator._block_tracker[0].active
+        assert allocator._refcounter._refcounts == zero_ref
+        assert 0 not in allocator._cached_blocks
+        assert 0 not in allocator.evictor
+        assert 0 in allocator._hashless_allocator._free_block_indices
+
+        # When allocate immutable with first block_size tokens, we
+        # shall get free block from hashless allocator, thus no block left
+        # in hashless
+        block = allocator.allocate_immutable_block(
+            prev_block=None, token_ids=token_ids[:block_size])
+
+        assert block.block_id == 0
+        assert len(allocator._hashless_allocator._free_block_indices) == 0
+        assert allocator._block_tracker[0].active
+        assert 0 in allocator._cached_blocks.values()
+        assert allocator._refcounter.get(0) == 1
+        assert 0 not in allocator.evictor
+
+        # allocate mutable block again, it shall be popped from evictor
+        mutable = allocator.allocate_mutable_block(prev_block=None)
+        assert len(allocator._hashless_allocator._free_block_indices) == 0
+        assert mutable.block_id not in allocator.evictor.free_table
+        assert allocator._refcounter.get(mutable.block_id) == 1
+
+    # Test case where two last accessed times are equal
+    @staticmethod
+    @pytest.mark.parametrize("num_blocks", [1024])
+    @pytest.mark.parametrize("block_size", [16])
+    @pytest.mark.parametrize("seed", list(range(20)))
+    def test_eviction_order(num_blocks: int, block_size: int, seed: int):
+        """This test case simulate the two chain created and free in order,
+        and together they would exhaust the initial freed blocks.
+
+        So the next block created after those two chain shall use the block
+        from the first chain as that block has long access time.
+        While first chain has two blocks, it shall pick up the last one, as
+        it has larger token number.
+        """
+
+        random.seed(seed)
+        allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
+                                                block_size=block_size)
+        num_blocks_to_consume = num_blocks + 1
+
+        token_ids = list(range(num_blocks_to_consume * block_size))
+
+        num_blocks_in_first_chain = 2
+        num_tokens_in_first_chain = block_size * num_blocks_in_first_chain
+        # First chain takes the first block
+        first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
+            block_size=block_size,
+            token_ids=token_ids[:num_tokens_in_first_chain],
+            allocator=allocator,
+        )
+        # There should only be one block allocated at this point
+        assert allocator.get_num_free_blocks() == (num_blocks -
+                                                   num_blocks_in_first_chain)
+
+        # Set the last accessed time of the first block to 1
+        blocks_ids = [block.block_id for block in first_chain]
+        allocator.mark_blocks_as_accessed(blocks_ids, 1)
+
+        # Second chain takes the rest of the blocks
+        second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
+            block_size=block_size,
+            token_ids=token_ids[num_tokens_in_first_chain:-block_size],
+            allocator=allocator,
+        )
+
+        # There shouldn't be any blocks left at this point
+        assert allocator.get_num_free_blocks() == (0)
+
+        assert len(first_chain) == num_blocks_in_first_chain
+        last_block_id = first_chain[-1].block_id
+        # Free each block in the first chain.
+        for i, block in enumerate(first_chain):
+            allocator.free(block)
+
+        # Set the last accessed time on all of the blocks in the second chain
+        # to 2
+        blocks_ids = [block.block_id for block in second_chain]
+        allocator.mark_blocks_as_accessed(blocks_ids, 2)
+
+        # Free each block in the second chain.
+        for i, block in enumerate(second_chain):
+            allocator.free(block)
+
+        # Allocate a new block and check that it's the least recently used block
+        # from the first chain.
+        new_block = TestPrefixCachingBlockAllocator.create_immutable_chain(
+            block_size=block_size,
+            token_ids=token_ids[-block_size:],
+            allocator=allocator,
+        )
+
+        assert new_block[0].block_id == last_block_id
+
+    @staticmethod
+    def create_immutable_chain(
+        block_size: int,
+        token_ids: List[int],
+        allocator: PrefixCachingBlockAllocator,
+    ) -> List[PrefixCachingBlock]:
+        """Helper method which creates a chain of blocks.
+        """
+        blocks: List[Block] = []
+        num_blocks = math.ceil(len(token_ids) / block_size)
+
+        if num_blocks == 0:
+            return []
+
+        prev_block = None
+        for block_number in range(0, num_blocks):
+            block_token_ids = token_ids[block_number *
+                                        block_size:(block_number + 1) *
+                                        block_size]
+            prev_block = allocator.allocate_immutable_block(
+                prev_block=prev_block, token_ids=block_token_ids)
+            blocks.append(prev_block)
+
+        return blocks

+ 598 - 0
tests/core/test_block_manager.py

@@ -0,0 +1,598 @@
+import time
+from collections import defaultdict
+from typing import List
+
+import pytest
+
+from aphrodite import SamplingParams
+from aphrodite.common.block import PhysicalTokenBlock
+from aphrodite.common.sequence import (Logprob, Sequence, SequenceGroup,
+                                       SequenceStatus)
+from aphrodite.common.utils import Device
+from aphrodite.processing.block.utils import (
+    STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, STR_NOT_IMPL_ENC_DEC_SWA)
+from aphrodite.processing.block_manager_v1 import (BlockSpaceManagerV1,
+                                                   UncachedBlockAllocator)
+from aphrodite.processing.interfaces import AllocStatus
+
+from .utils import create_dummy_prompt, create_dummy_prompt_encoder_decoder
+
+
+def test_block_allocator_allocate():
+    block_size = 4
+    num_cpu_blocks = 4
+    cpu_allocator = UncachedBlockAllocator(Device.CPU, block_size,
+                                           num_cpu_blocks)
+
+    # Allocate all available cpu blocks.
+    num_free = num_cpu_blocks
+    assert cpu_allocator.get_num_free_blocks() == num_free
+    for _ in range(num_cpu_blocks):
+        block = cpu_allocator.allocate()
+        num_free -= 1
+
+        assert block not in cpu_allocator.free_blocks
+        assert cpu_allocator.get_num_free_blocks() == num_free
+
+    with pytest.raises(ValueError):
+        cpu_allocator.allocate()
+
+
+def test_block_allocator_free():
+    block_size = 4
+    num_cpu_blocks = 4
+    cpu_allocator = UncachedBlockAllocator(Device.CPU, block_size,
+                                           num_cpu_blocks)
+
+    # Allocate all available cpu blocks.
+    blocks: List[PhysicalTokenBlock] = []
+    for _ in range(num_cpu_blocks):
+        block = cpu_allocator.allocate()
+        blocks.append(block)
+        assert block not in cpu_allocator.free_blocks
+
+    # Free all allocated cpu blocks.
+    num_free = 0
+    assert cpu_allocator.get_num_free_blocks() == num_free
+    for block in blocks:
+        cpu_allocator.free(block)
+        num_free += 1
+        assert block in cpu_allocator.free_blocks
+        assert cpu_allocator.get_num_free_blocks() == num_free
+
+        with pytest.raises(ValueError):
+            cpu_allocator.free(block)
+
+
+def test_allocate():
+    block_size = 4
+    num_cpu_blocks = 4
+    num_gpu_blocks = 4
+    block_manager = BlockSpaceManagerV1(block_size,
+                                        num_cpu_blocks,
+                                        num_gpu_blocks,
+                                        watermark=0)
+
+    # Allocate same sequence group to all available gpu blocks.
+    for i in range(num_gpu_blocks):
+        _, seq_group = create_dummy_prompt(str(i), block_size)
+        assert block_manager.can_allocate(seq_group) == AllocStatus.OK
+        block_manager.allocate(seq_group)
+    assert block_manager.can_allocate(seq_group) != AllocStatus.OK
+
+    # Allocate same sequence group to all available gpu blocks.
+    # Use watermark to reserve one gpu block.
+    block_manager = BlockSpaceManagerV1(block_size,
+                                        num_cpu_blocks,
+                                        num_gpu_blocks,
+                                        watermark=1 / num_gpu_blocks)
+    for i in range(num_gpu_blocks - 1):
+        _, seq_group = create_dummy_prompt(str(i), block_size)
+        assert block_manager.can_allocate(seq_group) == AllocStatus.OK
+        block_manager.allocate(seq_group)
+    assert block_manager.can_allocate(seq_group) != AllocStatus.OK
+
+
+def test_allocate_encoder_decoder():
+    block_size = 4
+    num_cpu_blocks = 4
+    num_gpu_blocks = 4
+    block_req_per_seq_group = 2
+    block_manager = BlockSpaceManagerV1(block_size,
+                                        num_cpu_blocks,
+                                        num_gpu_blocks,
+                                        watermark=0)
+
+    # Allocate same sequence group to all available gpu blocks.
+    for i in range(num_gpu_blocks // block_req_per_seq_group):
+        _, _, seq_group = create_dummy_prompt_encoder_decoder(
+            str(i),
+            decoder_prompt_length=block_size,
+            encoder_prompt_length=block_size)
+        assert block_manager.can_allocate(seq_group) == AllocStatus.OK
+        block_manager.allocate(seq_group)
+    assert block_manager.can_allocate(seq_group) != AllocStatus.OK
+
+    # Allocate same sequence group to all available gpu blocks.
+    # Use watermark to reserve one gpu block.
+    block_manager = BlockSpaceManagerV1(block_size,
+                                        num_cpu_blocks,
+                                        num_gpu_blocks,
+                                        watermark=1 / num_gpu_blocks)
+    for i in range((num_gpu_blocks - 1) // block_req_per_seq_group):
+        _, _, seq_group = create_dummy_prompt_encoder_decoder(
+            str(i),
+            decoder_prompt_length=block_size,
+            encoder_prompt_length=block_size)
+        assert block_manager.can_allocate(seq_group) == AllocStatus.OK
+        block_manager.allocate(seq_group)
+    assert block_manager.can_allocate(seq_group) != AllocStatus.OK
+
+
+def test_allocate_encoder_decoder_fails_with_swa():
+    # SWA short for sliding window attention
+
+    block_size = 4
+    num_cpu_blocks = 4
+    num_gpu_blocks = 4
+    block_manager = BlockSpaceManagerV1(block_size,
+                                        num_cpu_blocks,
+                                        num_gpu_blocks,
+                                        watermark=0,
+                                        sliding_window=5)  # swa
+
+    # Allocate same sequence group to all available gpu blocks.
+    _, _, seq_group = create_dummy_prompt_encoder_decoder(
+        "0",
+        decoder_prompt_length=block_size,
+        encoder_prompt_length=block_size)
+
+    # Assert that can_allocate() fails due to SWA
+    with pytest.raises(NotImplementedError) as exc_info:
+        block_manager.can_allocate(seq_group)
+
+    assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA
+
+    # Assert that allocate() fails due to SWA
+    with pytest.raises(NotImplementedError) as exc_info:
+        block_manager.allocate(seq_group)
+
+    assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA
+
+
+def test_allocate_encoder_decoder_fails_with_prefix_caching():
+    block_size = 4
+    num_cpu_blocks = 4
+    num_gpu_blocks = 4
+    block_manager = BlockSpaceManagerV1(block_size,
+                                        num_cpu_blocks,
+                                        num_gpu_blocks,
+                                        watermark=0,
+                                        enable_caching=True)  # Prefix cache
+
+    # Allocate same sequence group to all available gpu blocks.
+    _, _, seq_group = create_dummy_prompt_encoder_decoder(
+        "0",
+        decoder_prompt_length=block_size,
+        encoder_prompt_length=block_size)
+
+    # Assert that can_allocate() fails due to prefix caching
+    with pytest.raises(NotImplementedError) as exc_info:
+        block_manager.can_allocate(seq_group)
+
+    assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE
+
+    # Assert that allocate() fails due to prefix caching
+    with pytest.raises(NotImplementedError) as exc_info:
+        block_manager.allocate(seq_group)
+
+    assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE
+
+
+def test_append_slot_single_seq():
+    block_size = 4
+    num_cpu_blocks = 4
+    num_gpu_blocks = 4
+    block_manager = BlockSpaceManagerV1(block_size,
+                                        num_cpu_blocks,
+                                        num_gpu_blocks,
+                                        watermark=0)
+
+    # Allocate single seq to gpu block.
+    prompt, seq_group = create_dummy_prompt("1", block_size)
+    block_manager.allocate(seq_group)
+
+    # Nothing to append. Sequence has no new logical blocks.
+    assert block_manager.can_append_slots(seq_group)
+    before_blocks = block_manager.get_num_free_gpu_blocks()
+    assert not block_manager.append_slots(prompt)
+    after_blocks = block_manager.get_num_free_gpu_blocks()
+    assert before_blocks == after_blocks
+
+    # Add block_size number of new tokens and append slot.
+    for i in range(block_size):
+        token_id = i + 5
+        prompt.append_token_id(token_id, {token_id: Logprob(0.0)})
+
+    assert block_manager.can_append_slots(seq_group)
+    before_blocks = block_manager.get_num_free_gpu_blocks()
+    assert not block_manager.append_slots(prompt)
+    after_blocks = block_manager.get_num_free_gpu_blocks()
+    assert before_blocks - after_blocks == 1
+
+
+def test_append_slot_cow():
+    block_size = 4
+    num_cpu_blocks = 4
+    num_gpu_blocks = 4
+    block_manager = BlockSpaceManagerV1(block_size=block_size,
+                                        num_cpu_blocks=num_cpu_blocks,
+                                        num_gpu_blocks=num_gpu_blocks,
+                                        watermark=0)
+
+    # Allocate prompt to gpu block. There is one slot left in the block.
+    prompt = Sequence(seq_id=1,
+                      inputs={
+                          "prompt": "one two three",
+                          "prompt_token_ids": [1, 2, 3],
+                      },
+                      block_size=block_size)
+
+    # Fork the sequence, such that a COW will be required when we append a new
+    # token id.
+    child = prompt.fork(new_seq_id=2)
+
+    # Allocate space for the sequence group.
+    seq_group = SequenceGroup(request_id="1",
+                              seqs=[prompt, child],
+                              arrival_time=time.time(),
+                              sampling_params=SamplingParams())
+    block_manager.allocate(seq_group)
+
+    # Fork and append a new token id. We expect a COW to be scheduled.
+    token_id = 4
+    child.append_token_id(token_id, {token_id: Logprob(0.0)})
+    block_manager.fork(prompt, child)
+
+    assert block_manager.can_append_slots(seq_group)
+    before_blocks = block_manager.get_num_free_gpu_blocks()
+
+    cows = block_manager.append_slots(child)
+    assert cows
+    dict_cows = defaultdict(list)
+    for src_block, dst_block in cows:
+        dict_cows[src_block].append(dst_block)
+    for src_block, dst_blocks in dict_cows.items():
+        assert src_block not in dst_blocks
+
+    after_blocks = block_manager.get_num_free_gpu_blocks()
+    assert before_blocks - after_blocks == 1
+
+
+def test_fork():
+    block_size = 4
+    num_cpu_blocks = 4
+    num_gpu_blocks = 4
+    block_manager = BlockSpaceManagerV1(block_size,
+                                        num_cpu_blocks,
+                                        num_gpu_blocks,
+                                        watermark=0)
+
+    prompt, seq_group = create_dummy_prompt("1",
+                                            block_size - 1,
+                                            block_size=block_size)
+    block_manager.allocate(seq_group)
+
+    # Fork prompt and copy block tables.
+    child = prompt.fork(2)
+    block_manager.fork(prompt, child)
+    assert block_manager.get_block_table(
+        prompt) == block_manager.get_block_table(child)
+    token_id = 4
+    # Append token to child. Block is shared so copy on write occurs.
+    child.append_token_id(token_id, {token_id: Logprob(0.0)})
+    block_manager.append_slots(child)
+    assert block_manager.get_block_table(
+        prompt) != block_manager.get_block_table(child)
+
+
+def test_swap():
+    block_size = 4
+    num_cpu_blocks = 4
+    num_gpu_blocks = 4
+    block_manager = BlockSpaceManagerV1(block_size,
+                                        num_cpu_blocks,
+                                        num_gpu_blocks,
+                                        watermark=0)
+
+    prompt, seq_group = create_dummy_prompt("1", prompt_length=block_size - 1)
+    prompt.status = SequenceStatus.WAITING
+    block_manager.allocate(seq_group)
+
+    # Emulate a forward pass by appending a single token.
+    # The block manager then knows how many unprocessed
+    # tokens will be written in the next forward pass.
+    token_id = 0
+    prompt.status = SequenceStatus.RUNNING
+    prompt.append_token_id(token_id, {token_id: Logprob(0.0)})
+
+    # Swap seq group from GPU -> CPU.
+    gpu_blocks = block_manager.get_block_table(prompt)
+    assert block_manager.can_swap_out(seq_group)
+    before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
+    before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
+    mapping = block_manager.swap_out(seq_group)
+    assert [x[0] for x in mapping] == gpu_blocks
+    after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
+    after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
+    assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks)
+    assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks
+    prompt.status = SequenceStatus.SWAPPED
+
+    # Swap seq group from CPU -> GPU.
+    cpu_blocks = block_manager.get_block_table(prompt)
+    assert block_manager.can_swap_in(seq_group) == AllocStatus.OK
+    before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
+    before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
+    mapping = block_manager.swap_in(seq_group)
+    assert [x[0] for x in mapping] == cpu_blocks
+    after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
+    after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
+    assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks
+    assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks)
+
+
+def test_swap_encoder_decoder():
+    block_size = 4
+    num_cpu_blocks = 4
+    num_gpu_blocks = 4
+    block_manager = BlockSpaceManagerV1(block_size,
+                                        num_cpu_blocks,
+                                        num_gpu_blocks,
+                                        watermark=0)
+
+    decoder_prompt, encoder_prompt, seq_group = \
+        create_dummy_prompt_encoder_decoder(
+        "1",
+        decoder_prompt_length=block_size,
+        encoder_prompt_length=block_size)
+    decoder_prompt.status = SequenceStatus.WAITING
+    encoder_prompt.status = SequenceStatus.WAITING
+    block_manager.allocate(seq_group)
+
+    # Emulate a forward pass by appending a single token.
+    # The block manager then knows how many unprocessed
+    # tokens will be written in the next forward pass.
+    token_id = 0
+    decoder_prompt.status = SequenceStatus.RUNNING
+    decoder_prompt.append_token_id(token_id, {token_id: Logprob(0.0)})
+
+    # Swap encoder/decoder seq group from GPU -> CPU.
+    decoder_gpu_blocks = block_manager.get_block_table(decoder_prompt)
+    cross_gpu_blocks = block_manager.get_cross_block_table(seq_group)
+    gpu_blocks = decoder_gpu_blocks + cross_gpu_blocks
+    assert block_manager.can_swap_out(seq_group)
+    before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
+    before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
+    mapping = block_manager.swap_out(seq_group)
+    assert [x[0] for x in mapping] == gpu_blocks
+    #assert list(mapping.keys()) == gpu_blocks
+    after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
+    after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
+    assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks)
+    assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks
+    decoder_prompt.status = SequenceStatus.SWAPPED
+
+    # Swap encoder/decoder seq group from CPU -> GPU.
+    decoder_cpu_blocks = block_manager.get_block_table(decoder_prompt)
+    cross_cpu_blocks = block_manager.get_cross_block_table(seq_group)
+    cpu_blocks = decoder_cpu_blocks + cross_cpu_blocks
+    assert block_manager.can_swap_in(seq_group) == AllocStatus.OK
+    before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
+    before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
+    mapping = block_manager.swap_in(seq_group)
+    assert [x[0] for x in mapping] == cpu_blocks
+    after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
+    after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
+    assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks
+    assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks)
+
+
+def test_free():
+    block_size = 4
+    num_cpu_blocks = 4
+    num_gpu_blocks = 4
+    block_manager = BlockSpaceManagerV1(block_size,
+                                        num_cpu_blocks,
+                                        num_gpu_blocks,
+                                        watermark=0)
+
+    prompt, seq_group = create_dummy_prompt("1", block_size)
+    block_manager.allocate(seq_group)
+
+    # Free allocated seq.
+    prompt_blocks = len(block_manager.get_block_table(prompt))
+    before_blocks = block_manager.get_num_free_gpu_blocks()
+    block_manager.free(prompt)
+    after_blocks = block_manager.get_num_free_gpu_blocks()
+    assert after_blocks == before_blocks + prompt_blocks
+
+    # Block table for freed seq is deleted.
+    with pytest.raises(KeyError):
+        block_manager.get_block_table(prompt)
+
+
+def test_free_encoder_decoder():
+    block_size = 4
+    num_cpu_blocks = 4
+    num_gpu_blocks = 4
+    block_manager = BlockSpaceManagerV1(block_size,
+                                        num_cpu_blocks,
+                                        num_gpu_blocks,
+                                        watermark=0)
+
+    decoder_prompt, encoder_prompt, seq_group = \
+        create_dummy_prompt_encoder_decoder(
+        "1",
+        decoder_prompt_length=block_size,
+        encoder_prompt_length=block_size)
+    block_manager.allocate(seq_group)
+
+    # Free allocated seq.
+    decoder_prompt_blocks = len(block_manager.get_block_table(decoder_prompt))
+    encoder_prompt_blocks = len(block_manager.get_cross_block_table(seq_group))
+    prompt_blocks = decoder_prompt_blocks + encoder_prompt_blocks
+    before_blocks = block_manager.get_num_free_gpu_blocks()
+    block_manager.free(decoder_prompt)
+    block_manager.free_cross(seq_group)
+    after_blocks = block_manager.get_num_free_gpu_blocks()
+    assert after_blocks == before_blocks + prompt_blocks
+
+    # Block table for freed encoder & decoder seq's are deleted.
+    with pytest.raises(KeyError):
+        block_manager.get_block_table(decoder_prompt)
+
+    # Block table for freed encoder & decoder seq's are deleted.
+    with pytest.raises(KeyError):
+        block_manager.get_block_table(encoder_prompt)
+
+
+def test_reset():
+    block_size = 4
+    num_cpu_blocks = 4
+    num_gpu_blocks = 4
+    block_manager = BlockSpaceManagerV1(block_size,
+                                        num_cpu_blocks,
+                                        num_gpu_blocks,
+                                        watermark=0)
+
+    # Allocate same seq group on all available gpu blocks.
+    original_blocks = block_manager.get_num_free_gpu_blocks()
+    for i in range(num_gpu_blocks):
+        _, seq_group = create_dummy_prompt(str(i), block_size)
+        block_manager.allocate(seq_group)
+    assert block_manager.get_num_free_gpu_blocks() == 0
+
+    # Resetting block manager frees all allocated blocks.
+    block_manager.reset()
+    assert block_manager.get_num_free_gpu_blocks() == original_blocks
+
+
+def test_reset_encoder_decoder():
+    block_size = 4
+    num_cpu_blocks = 4
+    num_gpu_blocks = 4
+    block_req_per_seq_group = 2
+    block_manager = BlockSpaceManagerV1(block_size,
+                                        num_cpu_blocks,
+                                        num_gpu_blocks,
+                                        watermark=0)
+
+    # Allocate same seq group on all available gpu blocks.
+    original_blocks = block_manager.get_num_free_gpu_blocks()
+    for i in range(num_gpu_blocks // block_req_per_seq_group):
+        _, _, seq_group = create_dummy_prompt_encoder_decoder(
+            f"{i}",
+            decoder_prompt_length=block_size,
+            encoder_prompt_length=block_size)
+        block_manager.allocate(seq_group)
+    assert block_manager.get_num_free_gpu_blocks() == 0
+
+    # Resetting block manager frees all allocated blocks.
+    block_manager.reset()
+    assert block_manager.get_num_free_gpu_blocks() == original_blocks
+
+
+def test_sliding_window_multi_seq():
+    """
+    Tests that memory allocation and deallocation is handled
+    correctly with multiple sequences that exceed the sliding
+    window's capacity.
+    """
+    block_size = 1
+    num_cpu_blocks = 8
+    num_gpu_blocks = 8
+    sliding_window = 2
+    block_manager = BlockSpaceManagerV1(block_size,
+                                        num_cpu_blocks,
+                                        num_gpu_blocks,
+                                        sliding_window=sliding_window,
+                                        watermark=0)
+
+    assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks
+
+    parent = Sequence(seq_id=1,
+                      inputs={
+                          "prompt": "one two three",
+                          "prompt_token_ids": [0, 1, 2],
+                      },
+                      block_size=block_size)
+    seq_group = SequenceGroup(request_id="1",
+                              seqs=[parent],
+                              arrival_time=time.time(),
+                              sampling_params=SamplingParams(),
+                              lora_request=None)
+    block_manager.allocate(seq_group)
+
+    # assert the number of blocks allocated is correct
+    # the parent seq has len 3, but since sliding_window is 2,
+    # we will use at most 2 blocks
+    assert block_manager.get_num_free_gpu_blocks(
+    ) == num_gpu_blocks - sliding_window
+
+    # Fork prompt and copy block tables.
+    child = parent.fork(2)
+    block_manager.fork(parent, child)
+
+    # assert the number of blocks allocated is correct
+    # forking does not increase memory consumption
+    assert block_manager.get_num_free_gpu_blocks(
+    ) == num_gpu_blocks - sliding_window
+
+    # assert both parent and child share all blocks
+    assert block_manager.get_block_table(
+        parent) == block_manager.get_block_table(child)
+
+    token_id = 4
+    # Append token to child. Block is shared so copy on write occurs.
+    child.append_token_id(token_id, {token_id: Logprob(0.0)})
+    block_manager.append_slots(child)
+
+    # assert the number of blocks allocated is correct
+    # we will use now one block more. Each seq will use 2 blocks,
+    # but only one can be shared
+    assert block_manager.get_num_free_gpu_blocks(
+    ) == num_gpu_blocks - sliding_window - 1
+
+    token_id = 5
+    parent.append_token_id(token_id, {token_id: Logprob(0.0)})
+    block_manager.append_slots(parent)
+
+    # assert the number of blocks allocated is correct
+    # no change, because both sequences are still just sharing one block
+    assert block_manager.get_num_free_gpu_blocks(
+    ) == num_gpu_blocks - sliding_window - 1
+
+    block_table_parent = block_manager.get_block_table(parent)
+    block_table_child = block_manager.get_block_table(child)
+
+    assert block_table_parent != block_table_child
+
+    # assert both blocks are sharing the second-last block
+    assert block_table_parent[-2] == block_table_child[-2]
+
+    # now let's clean up...
+    block_manager.free(parent)
+
+    # assert the number of blocks allocated is correct
+    # We have freed one seq, reducing the ref count of two blocks by one.
+    # One of the two was only used by the parent seq, so this is now free.
+    # The child seq still consumes sliding_window blocks
+    assert block_manager.get_num_free_gpu_blocks(
+    ) == num_gpu_blocks - sliding_window
+
+    # free all blocks
+    block_manager.free(child)
+
+    # assert all blocks are free now
+    assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks

+ 583 - 0
tests/core/test_chunked_prefill_scheduler.py

@@ -0,0 +1,583 @@
+from typing import List
+from unittest.mock import MagicMock
+
+import pytest  # noqa
+
+from aphrodite.common.config import CacheConfig, SchedulerConfig
+from aphrodite.common.sequence import Logprob, SequenceGroup
+from aphrodite.processing.interfaces import AllocStatus
+from aphrodite.processing.scheduler import Scheduler
+
+from .utils import create_dummy_prompt
+
+
+def get_sequence_groups(scheduler_output):
+    return [s.seq_group for s in scheduler_output.scheduled_seq_groups]
+
+
+def append_new_token(seq_group, token_id: int):
+    for seq in seq_group.get_seqs():
+        seq.append_token_id(token_id, {token_id: Logprob(token_id)})
+
+
+def schedule_and_update_computed_tokens(scheduler):
+    metas, out = scheduler.schedule()
+    for s, meta in zip(out.scheduled_seq_groups, metas):
+        s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
+    return metas, out
+
+
+def test_simple():
+    """Verify basic scheduling works."""
+    block_size = 4
+    num_seq_group = 4
+    max_model_len = 16
+    max_num_batched_tokens = 64
+    scheduler_config = SchedulerConfig(max_num_batched_tokens,
+                                       num_seq_group,
+                                       max_model_len,
+                                       enable_chunked_prefill=True,
+                                       is_attention_free=False)
+    cache_config = CacheConfig(block_size, 1.0, 1, "auto",
+                               is_attention_free=False)
+    cache_config.num_cpu_blocks = 8
+    cache_config.num_gpu_blocks = 8
+    scheduler = Scheduler(scheduler_config, cache_config, None)
+    running: List[SequenceGroup] = []
+
+    # Add seq groups to scheduler.
+    for i in range(num_seq_group):
+        _, seq_group = create_dummy_prompt(str(i), prompt_length=block_size)
+        scheduler.add_seq_group(seq_group)
+        running.append(seq_group)
+
+    # Schedule seq groups prompts.
+    num_tokens = block_size * num_seq_group
+    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
+    assert set(get_sequence_groups(out)) == set(running)
+    assert out.num_batched_tokens == num_tokens
+    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
+            and not out.blocks_to_swap_out)
+    assert len(seq_group_meta) == num_seq_group
+    for s in running:
+        append_new_token(s, 1)
+
+    # Schedule seq groups generation.
+    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
+    assert set(get_sequence_groups(out)) == set(running)
+    assert out.num_batched_tokens == num_seq_group
+    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
+            and not out.blocks_to_swap_out)
+    assert len(seq_group_meta) == num_seq_group
+
+
+def test_chunk():
+    """Verify prefills are chunked properly."""
+    block_size = 4
+    max_seqs = 60
+    max_model_len = 80
+    max_num_batched_tokens = 64
+    scheduler_config = SchedulerConfig(max_num_batched_tokens,
+                                       max_seqs,
+                                       max_model_len,
+                                       enable_chunked_prefill=True,
+                                       is_attention_free=False)
+    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
+    cache_config.num_cpu_blocks = 8
+    cache_config.num_gpu_blocks = 8
+    scheduler = Scheduler(scheduler_config, cache_config, None)
+    running: List[SequenceGroup] = []
+
+    # Add seq groups to scheduler.
+    for i in range(2):
+        _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
+        scheduler.add_seq_group(seq_group)
+        running.append(seq_group)
+
+    # Verify the second request is chunked.
+    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
+    assert set(get_sequence_groups(out)) == set(running)
+    assert seq_group_meta[0].token_chunk_size == 60
+    # Verify it is chunked.
+    assert seq_group_meta[1].token_chunk_size == 4
+    assert out.num_prefill_groups == 2
+    assert out.num_batched_tokens == 64
+    # Only the first seq group has a new token appended.
+    append_new_token(running[0], 1)
+
+    # One chunked prefill, and one decoding.
+    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
+    assert set(get_sequence_groups(out)) == set(running)
+    # The first one is prefill. Scheduler guarantees ordering.
+    assert seq_group_meta[0].token_chunk_size == 56
+    # The second one is a chunked prefill.
+    assert seq_group_meta[1].token_chunk_size == 1
+    assert out.num_prefill_groups == 1
+    assert out.num_batched_tokens == 57
+
+
+def test_complex():
+    block_size = 4
+    max_seqs = 60
+    max_model_len = 80
+    max_num_batched_tokens = 64
+    scheduler_config = SchedulerConfig(max_num_batched_tokens,
+                                       max_seqs,
+                                       max_model_len,
+                                       enable_chunked_prefill=True,
+                                       is_attention_free=False)
+    cache_config = CacheConfig(block_size, 1.0, 1, "auto",
+                               is_attention_free=False)
+    cache_config.num_cpu_blocks = 8
+    cache_config.num_gpu_blocks = 8
+    scheduler = Scheduler(scheduler_config, cache_config, None)
+    running: List[SequenceGroup] = []
+
+    # Add seq groups to scheduler.
+    for i in range(2):
+        _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
+        scheduler.add_seq_group(seq_group)
+        running.append(seq_group)
+        assert seq_group.is_prefill()
+
+    # Verify the second request is chunked.
+    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
+
+    assert set(get_sequence_groups(out)) == set(running)
+    assert seq_group_meta[0].token_chunk_size == 60
+    # Verify it is chunked.
+    assert seq_group_meta[1].token_chunk_size == 4
+    assert not running[0].is_prefill()
+    assert running[1].is_prefill()
+    assert out.num_prefill_groups == 2
+    assert out.num_batched_tokens == 64
+    # Only the first seq group has a new token appended.
+    append_new_token(running[0], 1)
+
+    # Add 2 more requests.
+    for i in range(2, 4):
+        _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
+        scheduler.add_seq_group(seq_group)
+        running.append(seq_group)
+
+    # Decoding & chunked prefill & first chunk of 3rd request is scheduled.
+    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
+    assert len(get_sequence_groups(out)) == 3
+    # The first one is the first chunked prefill.
+    assert seq_group_meta[0].token_chunk_size == 7
+    # The second one is the second new chunked prefill.
+    assert seq_group_meta[1].token_chunk_size == 56
+    # The last one is decode.
+    assert seq_group_meta[2].token_chunk_size == 1
+    # Two of them are in chunked prefill.
+    assert out.num_prefill_groups == 2
+    assert out.num_batched_tokens == 64
+    # The first 2 requests are now in decodine phase.
+    append_new_token(running[0], 1)
+    assert not running[0].is_prefill()
+    append_new_token(running[1], 1)
+    assert not running[1].is_prefill()
+    # The third request is still in prefill stage.
+    assert running[2].is_prefill()
+
+
+def test_maximal_decoding():
+    """Verify decoding requests are prioritized."""
+    block_size = 4
+    max_seqs = 2
+    max_model_len = 2
+    max_num_batched_tokens = 2
+    scheduler_config = SchedulerConfig(max_num_batched_tokens,
+                                       max_seqs,
+                                       max_model_len,
+                                       enable_chunked_prefill=True,
+                                       is_attention_free=False)
+    cache_config = CacheConfig(block_size, 1.0, 1, "auto",
+                               is_attention_free=False)
+    cache_config.num_cpu_blocks = 8
+    cache_config.num_gpu_blocks = 8
+    scheduler = Scheduler(scheduler_config, cache_config, None)
+    running: List[SequenceGroup] = []
+
+    # Add seq groups to scheduler.
+    for i in range(2):
+        _, seq_group = create_dummy_prompt(str(i), prompt_length=2)
+        scheduler.add_seq_group(seq_group)
+        running.append(seq_group)
+        assert seq_group.is_prefill()
+
+    # The first prefill is scheduled.
+    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
+    assert len(get_sequence_groups(out)) == 1
+    assert seq_group_meta[0].token_chunk_size == 2
+    assert not running[0].is_prefill()
+    assert running[1].is_prefill()
+    assert out.num_prefill_groups == 1
+    assert out.num_batched_tokens == 2
+    # Only the first seq group has a new token appended.
+    append_new_token(running[0], 1)
+
+    # Create one more seq_group.
+    _, seq_group = create_dummy_prompt("3", prompt_length=2)
+    scheduler.add_seq_group(seq_group)
+    running.append(seq_group)
+    assert seq_group.is_prefill()
+    # The first decoding + second chunk is scheduled.
+    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
+    assert len(get_sequence_groups(out)) == 2
+    assert seq_group_meta[0].token_chunk_size == 1
+    assert seq_group_meta[1].token_chunk_size == 1
+    assert not running[0].is_prefill()
+    assert running[1].is_prefill()
+    assert running[2].is_prefill()
+    assert out.num_prefill_groups == 1
+    assert out.num_batched_tokens == 2
+    append_new_token(running[0], 1)
+
+    # Decoding + running prefill is prioritized.
+    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
+    assert len(get_sequence_groups(out)) == 2
+    assert seq_group_meta[0].token_chunk_size == 1
+    assert seq_group_meta[1].token_chunk_size == 1
+    assert not running[0].is_prefill()
+    assert not running[1].is_prefill()
+    assert out.num_prefill_groups == 1
+    assert out.num_batched_tokens == 2
+    append_new_token(running[0], 1)
+    append_new_token(running[1], 1)
+
+    # Only decoding is prioritized.
+    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
+    assert len(get_sequence_groups(out)) == 2
+    assert seq_group_meta[0].token_chunk_size == 1
+    assert seq_group_meta[1].token_chunk_size == 1
+    assert not running[0].is_prefill()
+    assert not running[1].is_prefill()
+    assert out.num_prefill_groups == 0
+    assert out.num_batched_tokens == 2
+    append_new_token(running[0], 1)
+    append_new_token(running[1], 1)
+
+    # After aborting the decoding request, the fcfs new prefill is prioritized.
+    scheduler.abort_seq_group(running[0].request_id)
+    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
+    assert len(get_sequence_groups(out)) == 2
+    assert seq_group_meta[0].token_chunk_size == 1
+    assert seq_group_meta[1].token_chunk_size == 1
+    assert not running[1].is_prefill()
+    assert running[2].is_prefill()
+    assert out.num_prefill_groups == 1
+    assert out.num_batched_tokens == 2
+
+
+def test_prompt_limit():
+    """Verify max_num_batched_tokens < max_model_len is possible."""
+    block_size = 4
+    max_seqs = 32
+    max_model_len = 64
+    max_num_batched_tokens = 32
+    scheduler_config = SchedulerConfig(max_num_batched_tokens,
+                                       max_seqs,
+                                       max_model_len,
+                                       enable_chunked_prefill=True,
+                                       is_attention_free=False)
+    cache_config = CacheConfig(block_size, 1.0, 1, "auto",
+                               is_attention_free=False)
+    cache_config.num_cpu_blocks = 8
+    cache_config.num_gpu_blocks = 8
+    scheduler = Scheduler(scheduler_config, cache_config, None)
+    running: List[SequenceGroup] = []
+
+    _, seq_group = create_dummy_prompt("1", prompt_length=48)
+    scheduler.add_seq_group(seq_group)
+    running.append(seq_group)
+    assert seq_group.is_prefill()
+
+    # The prompt length > max_num_batched_tokens should be still scheduled.
+    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
+    assert len(get_sequence_groups(out)) == 1
+    assert seq_group_meta[0].token_chunk_size == 32
+    assert running[0].is_prefill()
+    assert out.num_prefill_groups == 1
+    assert out.num_batched_tokens == 32
+
+
+def test_prompt_limit_exceed():
+    block_size = 4
+    max_seqs = 64
+    max_model_len = 32
+    max_num_batched_tokens = 64
+    scheduler_config = SchedulerConfig(max_num_batched_tokens,
+                                       max_seqs,
+                                       max_model_len,
+                                       enable_chunked_prefill=True,
+                                       is_attention_free=False)
+    cache_config = CacheConfig(block_size, 1.0, 1, "auto",
+                               is_attention_free=False)
+    cache_config.num_cpu_blocks = 8
+    cache_config.num_gpu_blocks = 8
+    scheduler = Scheduler(scheduler_config, cache_config, None)
+    running: List[SequenceGroup] = []
+
+    _, seq_group = create_dummy_prompt("2", prompt_length=48)
+    scheduler.add_seq_group(seq_group)
+    running.append(seq_group)
+    assert seq_group.is_prefill()
+    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
+    assert len(out.ignored_seq_groups) == 1
+    assert out.ignored_seq_groups[0] == seq_group
+
+
+def test_swap():
+    """Verify swapping works with chunked prefill requests"""
+    block_size = 4
+    max_seqs = 30
+    max_model_len = 200
+    max_num_batched_tokens = 30
+    scheduler_config = SchedulerConfig(max_num_batched_tokens,
+                                       max_seqs,
+                                       max_model_len,
+                                       enable_chunked_prefill=True,
+                                       is_attention_free=False)
+    cache_config = CacheConfig(block_size, 1.0, 1, "auto",
+                               is_attention_free=False)
+    cache_config.num_cpu_blocks = 8
+    cache_config.num_gpu_blocks = 8
+    scheduler = Scheduler(scheduler_config, cache_config, None)
+
+    _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
+    scheduler.add_seq_group(seq_group)
+    _, out = schedule_and_update_computed_tokens(scheduler)
+    # The request is chunked.
+    # prefill scheduled now.
+    assert len(out.scheduled_seq_groups) == 1
+    assert out.num_prefill_groups == 1
+    assert seq_group.is_prefill()
+    assert out.num_batched_tokens == max_num_batched_tokens
+
+    # The last request should be swapped out.
+    scheduler.block_manager.can_append_slots = MagicMock()
+
+    def cannot_append_second_group(seq_group, num_lookahead_slots):
+        return seq_group.request_id != "1"
+
+    scheduler.block_manager.can_append_slots.side_effect = (
+        cannot_append_second_group)
+
+    # The running prefill is now swapped.
+    _, out = schedule_and_update_computed_tokens(scheduler)
+    assert len(out.scheduled_seq_groups) == 0
+    assert out.num_batched_tokens == 0
+    assert out.blocks_to_swap_out != []
+    assert out.blocks_to_swap_in == []
+
+    # Add 1 more task. Swap should be prioritized over new prefill.
+    _, seq_group = create_dummy_prompt("2", prompt_length=60)
+    scheduler.add_seq_group(seq_group)
+    _, out = schedule_and_update_computed_tokens(scheduler)
+    assert len(out.scheduled_seq_groups) == 1
+    # 3 decodes. It is swapped in.
+    assert out.num_batched_tokens == 30
+    assert out.blocks_to_swap_in != []
+    assert out.blocks_to_swap_out == []
+
+
+def test_running_prefill_prioritized_over_swap():
+    block_size = 4
+    max_seqs = 30
+    max_model_len = 200
+    max_num_batched_tokens = 30
+    scheduler_config = SchedulerConfig(max_num_batched_tokens,
+                                       max_seqs,
+                                       max_model_len,
+                                       enable_chunked_prefill=True,
+                                       is_attention_free=False)
+    cache_config = CacheConfig(block_size, 1.0, 1, "auto",
+                               is_attention_free=False)
+    cache_config.num_cpu_blocks = 8
+    cache_config.num_gpu_blocks = 8
+    scheduler = Scheduler(scheduler_config, cache_config, None)
+
+    _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
+    scheduler.add_seq_group(seq_group)
+    _, out = schedule_and_update_computed_tokens(scheduler)
+    # The request is chunked.
+    # prefill scheduled now.
+    assert len(out.scheduled_seq_groups) == 1
+    assert out.num_prefill_groups == 1
+    assert seq_group.is_prefill()
+    assert out.num_batched_tokens == max_num_batched_tokens
+
+    # The request should be swapped out.
+    scheduler.block_manager.can_append_slots = MagicMock()
+
+    def cannot_append_second_group(seq_group, num_lookahead_slots):
+        return seq_group.request_id != "1"
+
+    scheduler.block_manager.can_append_slots.side_effect = (
+        cannot_append_second_group)
+
+    # The running prefill is now swapped.
+    _, out = schedule_and_update_computed_tokens(scheduler)
+    assert len(out.scheduled_seq_groups) == 0
+    assert out.num_batched_tokens == 0
+    assert out.blocks_to_swap_out != []
+    assert out.blocks_to_swap_in == []
+
+    # Add 1 more task. Swap is not possible, so prefill is running.
+    scheduler.block_manager.can_swap_in = MagicMock()
+    scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER
+
+    _, seq_group2 = create_dummy_prompt("2", prompt_length=60)
+    scheduler.add_seq_group(seq_group2)
+    _, out = schedule_and_update_computed_tokens(scheduler)
+    assert len(out.scheduled_seq_groups) == 1
+    # 3 decodes. It is swapped in.
+    assert out.num_batched_tokens == 30
+    assert out.blocks_to_swap_in == []
+    assert out.blocks_to_swap_out == []
+    assert out.scheduled_seq_groups[0].seq_group == seq_group2
+
+    # Now although swap is possible, running prefill is prioritized.
+    scheduler.block_manager.can_swap_in.return_value = AllocStatus.OK
+    _, out = schedule_and_update_computed_tokens(scheduler)
+    assert len(out.scheduled_seq_groups) == 1
+    # 3 decodes. It is swapped in.
+    assert out.num_batched_tokens == 30
+    assert out.blocks_to_swap_in == []
+    assert out.blocks_to_swap_out == []
+    assert not seq_group2.is_prefill()
+    assert out.scheduled_seq_groups[0].seq_group == seq_group2
+    append_new_token(seq_group2, 1)
+
+    # Decoding is prioritized.
+    _, out = schedule_and_update_computed_tokens(scheduler)
+    assert len(out.scheduled_seq_groups) == 1
+    # 3 decodes. It is swapped in.
+    assert out.num_batched_tokens == 1
+    assert out.blocks_to_swap_in == []
+    assert out.blocks_to_swap_out == []
+    assert not seq_group2.is_prefill()
+    assert out.scheduled_seq_groups[0].seq_group == seq_group2
+    append_new_token(seq_group2, 1)
+
+    # Since we abort the sequence group, we can finally swap.
+    scheduler.abort_seq_group(seq_group2.request_id)
+    _, out = schedule_and_update_computed_tokens(scheduler)
+    assert len(out.scheduled_seq_groups) == 1
+    assert out.num_batched_tokens == 30
+    assert out.blocks_to_swap_in != []
+    assert out.blocks_to_swap_out == []
+
+
+def test_chunked_prefill_preempt():
+    """Verify preempt works with chunked prefill requests"""
+    block_size = 4
+    max_seqs = 30
+    max_model_len = 200
+    max_num_batched_tokens = 30
+    scheduler_config = SchedulerConfig(max_num_batched_tokens,
+                                       max_seqs,
+                                       max_model_len,
+                                       enable_chunked_prefill=True,
+                                       is_attention_free=False)
+    cache_config = CacheConfig(block_size, 1.0, 1, "auto",
+                               is_attention_free=False)
+    cache_config.num_cpu_blocks = 8
+    cache_config.num_gpu_blocks = 8
+    scheduler = Scheduler(scheduler_config, cache_config, None)
+
+    _, seq_group = create_dummy_prompt("1", prompt_length=60)
+    scheduler.add_seq_group(seq_group)
+    _, out = schedule_and_update_computed_tokens(scheduler)
+    # The request is chunked.
+    # prefill scheduled now.
+    assert len(out.scheduled_seq_groups) == 1
+    assert out.num_prefill_groups == 1
+    assert seq_group.is_prefill()
+    assert out.num_batched_tokens == max_num_batched_tokens
+
+    # The request should be preempted.
+    scheduler.block_manager.can_append_slots = MagicMock()
+
+    def cannot_append_second_group1(seq_group, num_lookahead_slots):
+        return seq_group.request_id != "1"
+
+    scheduler.block_manager.can_append_slots.side_effect = (
+        cannot_append_second_group1)
+
+    # The running prefill is now preempted.
+    _, out = schedule_and_update_computed_tokens(scheduler)
+    assert len(out.scheduled_seq_groups) == 0
+    assert out.num_batched_tokens == 0
+    assert out.blocks_to_swap_out == []
+    assert out.blocks_to_swap_in == []
+
+    # Make sure we can reschedule preempted request.
+    _, out = schedule_and_update_computed_tokens(scheduler)
+    assert len(out.scheduled_seq_groups) == 1
+    assert out.num_prefill_groups == 1
+    assert seq_group.is_prefill()
+    assert out.num_batched_tokens == max_num_batched_tokens
+    assert seq_group.get_num_uncomputed_tokens() == 30
+
+    # We should be able to run prefill twice as it is chunked.
+    def cannot_append_second_group2(seq_group, num_lookahead_slots):
+        return True
+
+    scheduler.block_manager.can_append_slots.side_effect = (
+        cannot_append_second_group2)
+    _, out = schedule_and_update_computed_tokens(scheduler)
+    assert len(out.scheduled_seq_groups) == 1
+    assert out.num_prefill_groups == 1
+    assert not seq_group.is_prefill()
+    assert out.num_batched_tokens == max_num_batched_tokens
+
+
+def test_chunked_prefill_max_seqs():
+    block_size = 4
+    max_seqs = 2
+    max_model_len = 80
+    max_num_batched_tokens = 64
+    scheduler_config = SchedulerConfig(max_num_batched_tokens,
+                                       max_seqs,
+                                       max_model_len,
+                                       enable_chunked_prefill=True,
+                                       is_attention_free=False)
+    cache_config = CacheConfig(block_size, 1.0, 1, "auto",
+                               is_attention_free=False)
+    cache_config.num_cpu_blocks = 8
+    cache_config.num_gpu_blocks = 8
+    scheduler = Scheduler(scheduler_config, cache_config, None)
+    running: List[SequenceGroup] = []
+
+    _, seq_group = create_dummy_prompt("1", prompt_length=65)
+    scheduler.add_seq_group(seq_group)
+    running.append(seq_group)
+    # The first prefill is chunked.
+    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
+    assert seq_group_meta[0].token_chunk_size == max_num_batched_tokens
+    assert len(get_sequence_groups(out)) == 1
+
+    # Add new requests.
+    for i in range(4):
+        _, seq_group = create_dummy_prompt(str(i), prompt_length=65)
+        scheduler.add_seq_group(seq_group)
+        running.append(seq_group)
+
+    # Make sure only 2 requests are scheduled.
+    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
+    assert out.num_batched_tokens == max_num_batched_tokens
+    assert len(get_sequence_groups(out)) == 2
+    assert not running[0].is_prefill()
+    assert running[1].is_prefill()
+    append_new_token(running[0], 1)
+
+    # Although we have enough token budget, we can only schedule max_seqs.
+    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
+    assert seq_group_meta[0].token_chunk_size == 2
+    assert seq_group_meta[1].token_chunk_size == 1
+    assert out.num_batched_tokens == 3
+    assert len(get_sequence_groups(out)) == max_seqs
+    assert not running[0].is_prefill()
+    assert not running[1].is_prefill()

+ 852 - 0
tests/core/test_scheduler.py

@@ -0,0 +1,852 @@
+import time
+from collections import deque
+from typing import List, Set, Tuple
+from unittest.mock import MagicMock
+
+import pytest  # noqa
+
+from aphrodite.common.config import CacheConfig, LoRAConfig, SchedulerConfig
+from aphrodite.common.sequence import SequenceGroup, SequenceStatus
+from aphrodite.lora.request import LoRARequest
+from aphrodite.processing.interfaces import AllocStatus
+from aphrodite.processing.scheduler import Scheduler, SchedulingBudget
+
+from .utils import (append_new_token, append_new_token_seq_group,
+                    create_dummy_prompt, get_sequence_groups,
+                    schedule_and_update_computed_tokens)
+
+
+def test_scheduler_add_seq_group():
+    block_size = 4
+    scheduler_config = SchedulerConfig(100, 64, 1)
+    cache_config = CacheConfig(block_size, 1.0, 1, cache_dtype="auto",)
+    cache_config.num_cpu_blocks = 4
+    cache_config.num_gpu_blocks = 4
+    scheduler = Scheduler(scheduler_config, cache_config, None)
+
+    # Add seq group to scheduler.
+    num_seq_group = 4
+    for i in range(num_seq_group):
+        _, seq_group = create_dummy_prompt(str(i), block_size)
+        scheduler.add_seq_group(seq_group)
+        assert scheduler.get_num_unfinished_seq_groups() == i + 1
+
+
+def test_scheduler_abort_seq_group():
+    block_size = 4
+    scheduler_config = SchedulerConfig(100, 64, 1)
+    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
+    cache_config.num_cpu_blocks = 4
+    cache_config.num_gpu_blocks = 4
+    scheduler = Scheduler(scheduler_config, cache_config, None)
+
+    # Add multiple seq groups to scheduler.
+    num_seq_group = 4
+    request_ids: Set[str] = set()
+    for i in range(num_seq_group):
+        _, seq_group = create_dummy_prompt(str(i), block_size)
+        scheduler.add_seq_group(seq_group)
+        request_ids.add(str(i))
+
+    # Abort all added seq groups.
+    assert scheduler.get_num_unfinished_seq_groups() == num_seq_group
+    scheduler.abort_seq_group(request_ids)
+    assert scheduler.get_num_unfinished_seq_groups() == 0
+
+
+def test_scheduler_schedule_simple():
+    block_size = 4
+    num_seq_group = 4
+    max_model_len = 16
+    scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len)
+    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
+    cache_config.num_cpu_blocks = 8
+    cache_config.num_gpu_blocks = 8
+    scheduler = Scheduler(scheduler_config, cache_config, None)
+    running: List[SequenceGroup] = []
+
+    # Add seq groups to scheduler.
+    for i in range(num_seq_group):
+        _, seq_group = create_dummy_prompt(str(i), prompt_length=block_size)
+        scheduler.add_seq_group(seq_group)
+        running.append(seq_group)
+
+    # Schedule seq groups prompts.
+    num_tokens = block_size * num_seq_group
+    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
+    assert set(get_sequence_groups(out)) == set(running)
+    assert out.num_batched_tokens == num_tokens
+    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
+            and not out.blocks_to_swap_out)
+    assert len(seq_group_meta) == num_seq_group
+    append_new_token(out, 1)
+
+    # Schedule seq groups generation.
+    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
+    assert set(get_sequence_groups(out)) == set(running)
+    assert out.num_batched_tokens == num_seq_group
+    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
+            and not out.blocks_to_swap_out)
+    assert len(seq_group_meta) == num_seq_group
+    append_new_token(out, 1)
+
+
+def test_scheduler_prefill_prioritized():
+    """Verify running batched tokens are not applied to prefill requests."""
+    block_size = 4
+    max_model_len = 30
+    max_batched_num_tokens = 30
+    scheduler_config = SchedulerConfig(max_batched_num_tokens, 2,
+                                       max_model_len)
+    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
+    cache_config.num_cpu_blocks = 2
+    cache_config.num_gpu_blocks = 2
+    scheduler = Scheduler(scheduler_config, cache_config, None)
+
+    # Add seq groups to scheduler.
+    _, seq_group_a = create_dummy_prompt("1", 1)
+    scheduler.add_seq_group(seq_group_a)
+
+    # Schedule seq groups prompts.
+    _, out = schedule_and_update_computed_tokens(scheduler)
+    assert get_sequence_groups(out) == [seq_group_a]
+
+    # Add a new prefill request B.
+    _, seq_group_b = create_dummy_prompt("2", 30)
+    scheduler.add_seq_group(seq_group_b)
+
+    # Verify prefill requests are prioritized. Since max_batched_num_tokens
+    # is 1, new prefill request has to be scheduled first.
+    _, out = schedule_and_update_computed_tokens(scheduler)
+    assert get_sequence_groups(out) == [seq_group_b]
+
+
+def test_scheduler_schedule_preempt_abort():
+    block_size = 4
+    max_model_len = 16
+    scheduler_config = SchedulerConfig(64, 2, max_model_len)
+    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
+    cache_config.num_cpu_blocks = 2
+    cache_config.num_gpu_blocks = 2
+    scheduler = Scheduler(scheduler_config, cache_config, None)
+
+    # Add seq groups to scheduler.
+    seq_a, seq_group_a = create_dummy_prompt("1", block_size)
+    seq_b, seq_group_b = create_dummy_prompt("2", block_size)
+    scheduler.add_seq_group(seq_group_a)
+    scheduler.add_seq_group(seq_group_b)
+
+    # Schedule seq groups prompts.
+    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
+    assert get_sequence_groups(out) == [seq_group_a, seq_group_b]
+    assert out.num_batched_tokens == block_size * 2  # seq_a and seq_b
+    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
+            and not out.blocks_to_swap_out)
+    assert len(seq_group_meta) == 2
+    assert scheduler.get_num_unfinished_seq_groups() == 2
+
+    # Append "generated" tokens, allowing the sequence to mark prompt tokens as
+    # processed.
+    append_new_token(out, 1)
+
+    # Schedule seq groups generation and preempt seq group b.
+    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
+    assert get_sequence_groups(out) == [seq_group_a]
+    assert out.num_batched_tokens == 1
+    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
+            and not out.blocks_to_swap_out)
+    assert len(seq_group_meta) == 1
+    assert scheduler.get_num_unfinished_seq_groups() == 2
+    assert out.preempted == 1
+
+    # Abort seq group a. Re-schedule seq group b prompt with recomputation.
+    scheduler.abort_seq_group("1")
+    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
+    assert get_sequence_groups(out) == [seq_group_b]
+    assert out.num_batched_tokens == 5  # 4 prompt + 1 generation.
+    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
+            and not out.blocks_to_swap_out)
+    assert len(seq_group_meta) == 1
+    assert scheduler.get_num_unfinished_seq_groups() == 1
+
+
+def test_scheduler_max_seqs():
+    block_size = 4
+    num_seq_group = 4
+    max_seq_group = 2
+    max_model_len = 16
+    scheduler_config = SchedulerConfig(64, max_seq_group, max_model_len)
+    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
+    cache_config.num_cpu_blocks = 8
+    cache_config.num_gpu_blocks = 8
+    scheduler = Scheduler(scheduler_config, cache_config, None)
+
+    all_seq_groups: List[SequenceGroup] = []
+    # Add seq groups to scheduler.
+    for i in range(num_seq_group):
+        _, seq_group = create_dummy_prompt(str(i), prompt_length=block_size)
+        all_seq_groups.append(seq_group)
+
+    # Append 1 seq group
+    scheduler.add_seq_group(all_seq_groups[0])
+
+    # Schedule seq groups prompts.
+    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
+    assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])
+    append_new_token(out, 1)
+
+    # Schedule seq groups generation.
+    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
+    assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])
+    append_new_token(out, 1)
+
+    # Append 2 more seq group
+    scheduler.add_seq_group(all_seq_groups[1])
+    scheduler.add_seq_group(all_seq_groups[2])
+
+    # Schedule seq groups prompts.
+    # Only 1 seq group should be scheduled since max_seq_group is 2
+    # and one is prompting.
+    _, out = schedule_and_update_computed_tokens(scheduler)
+    assert set(get_sequence_groups(out)) == set([all_seq_groups[1]])
+
+
+def test_scheduler_delay_factor():
+    block_size = 4
+    scheduler_config = SchedulerConfig(100, 64, 16, delay_factor=0.5)
+    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
+    cache_config.num_cpu_blocks = 8
+    cache_config.num_gpu_blocks = 8
+    scheduler = Scheduler(scheduler_config, cache_config, None)
+
+    # schedule first prompt
+    seq_group_meta, seq_group = create_dummy_prompt("0",
+                                                    prompt_length=block_size)
+    scheduler.add_seq_group(seq_group)
+    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
+    assert out.num_prefill_groups > 0
+    assert seq_group_meta[0].request_id == '0'
+    append_new_token(out, 1)
+
+    # wait for a second before scheduling next prompt
+    time.sleep(1)
+    seq_group_meta, seq_group = create_dummy_prompt("1",
+                                                    prompt_length=block_size)
+    scheduler.add_seq_group(seq_group)
+
+    # second prompt should *not* be scheduled
+    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
+    assert out.num_prefill_groups == 0
+    assert seq_group_meta[0].request_id == '0'
+    append_new_token(out, 1)
+
+    # wait for more than 0.5 second and try again
+    time.sleep(0.6)
+    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
+    assert out.num_prefill_groups > 0
+    assert seq_group_meta[0].request_id == '1'
+    append_new_token(out, 1)
+
+
+def test_swapped_out_prioritized():
+    scheduler = initialize_scheduler(max_num_seqs=6)
+    # best_of=2 * 3 == 6 sequences.
+    for i in range(3):
+        _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2)
+        scheduler.add_seq_group(seq_group)
+    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
+    # prefill scheduled now.
+    assert len(out.scheduled_seq_groups) == 3
+    append_new_token(out, 1)
+
+    # The last request should be swapped out.
+    scheduler.block_manager.can_append_slots = MagicMock()
+
+    def cannot_append_second_group(seq_group, num_lookahead_slots):
+        return seq_group.request_id != "2"
+
+    scheduler.block_manager.can_append_slots.side_effect = (
+        cannot_append_second_group)
+
+    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
+    assert len(out.scheduled_seq_groups) == 2
+    assert out.num_batched_tokens == 2
+    assert out.blocks_to_swap_out != []
+    assert out.blocks_to_swap_in == []
+    append_new_token(out, 1)
+
+    # Add 1 more task. Swap should be prioritized over prefill.
+    _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2)
+    scheduler.add_seq_group(seq_group)
+    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
+    append_new_token(out, 1)
+    assert len(out.scheduled_seq_groups) == 3
+    # 3 decodes. It is swapped in.
+    assert out.num_batched_tokens == 3
+    assert out.blocks_to_swap_in != []
+    assert out.blocks_to_swap_out == []
+
+
+def initialize_scheduler(*,
+                         max_num_seqs=1000,
+                         max_token_budget=1000,
+                         max_model_len=1000,
+                         lora_config=None):
+    block_size = 4
+    scheduler_config = SchedulerConfig(max_token_budget, max_num_seqs,
+                                       max_model_len)
+    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
+    cache_config.num_cpu_blocks = 8
+    cache_config.num_gpu_blocks = 8
+    scheduler = Scheduler(scheduler_config, cache_config, lora_config)
+    return scheduler
+
+
+def create_token_budget(token_budget: int = 10000,
+                        max_num_seqs: int = 10000) -> SchedulingBudget:
+    return SchedulingBudget(
+        token_budget=token_budget,
+        max_num_seqs=max_num_seqs,
+    )
+
+
+def add_token_budget(budget: SchedulingBudget,
+                     num_batched_tokens: int = 0,
+                     num_curr_seqs: int = 0):
+    mock_seq_group = create_dummy_prompt('10', prompt_length=60)[1]
+    budget.add_num_batched_tokens(mock_seq_group.request_id,
+                                  num_batched_tokens)
+    budget.add_num_seqs(mock_seq_group.request_id, num_curr_seqs)
+
+
+def test_prefill_schedule_max_prompt_len():
+    """
+    Test prompt longer than max_prompt_len is aborted.
+    """
+    scheduler = initialize_scheduler(max_model_len=30)
+    _, seq_group = create_dummy_prompt("0", prompt_length=60)
+    scheduler.add_seq_group(seq_group)
+    budget = create_token_budget()
+    output = scheduler._schedule_prefills(budget, None)
+    remaining_waiting = scheduler.waiting
+    assert len(output.ignored_seq_groups) == 1
+    assert len(output.seq_groups) == 0
+    assert budget.num_batched_tokens == 0
+    assert budget.num_curr_seqs == 0
+    assert len(remaining_waiting) == 0
+
+
+def test_prefill_schedule_token_budget():
+    """
+    Test token budget respected.
+    """
+    scheduler = initialize_scheduler()
+    budget = create_token_budget(token_budget=0)
+    for i in range(2):
+        _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
+        scheduler.add_seq_group(seq_group)
+
+    # 0 token budget == nothing is scheduled.
+    output = scheduler._schedule_prefills(budget, None)
+    remaining_waiting = scheduler.waiting
+    assert len(output.ignored_seq_groups) == 0
+    assert len(output.seq_groups) == 0
+    assert budget.num_batched_tokens == 0
+    assert budget.num_curr_seqs == 0
+    assert len(remaining_waiting) == 2
+
+    # 60 token budget == 1 request scheduled.
+    budget = create_token_budget(token_budget=60)
+    output = scheduler._schedule_prefills(budget, None)
+    remaining_waiting = scheduler.waiting
+    assert len(output.ignored_seq_groups) == 0
+    assert len(output.seq_groups) == 1
+    assert budget.num_batched_tokens == 60
+    assert budget.num_curr_seqs == 1
+    assert len(remaining_waiting) == 1
+
+    # Test when current_batched_tokens respected.
+    scheduler = initialize_scheduler()
+    budget = create_token_budget(token_budget=60)
+    add_token_budget(budget, 30, 0)
+    _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
+    # Cannot schedule a prompt that doesn't fit the budget.
+    scheduler.add_seq_group(seq_group)
+    output = scheduler._schedule_prefills(budget, None)
+    remaining_waiting = scheduler.waiting
+    assert len(output.ignored_seq_groups) == 0
+    assert len(output.seq_groups) == 0
+    assert budget.num_batched_tokens == 30
+    assert budget.num_curr_seqs == 0
+    assert len(remaining_waiting) == 1
+    budget = create_token_budget(token_budget=90)
+    add_token_budget(budget, 30, 0)
+    output = scheduler._schedule_prefills(budget, None)
+    remaining_waiting = scheduler.waiting
+    assert len(output.seq_groups) == 1
+    assert budget.num_batched_tokens == 90
+    assert budget.num_curr_seqs == 1
+    assert len(remaining_waiting) == 0
+
+
+def test_prefill_schedule_max_seqs():
+    """
+    Test max seq respected.
+    """
+    scheduler = initialize_scheduler()
+    budget = create_token_budget(max_num_seqs=2)
+    for i in range(3):
+        _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
+        scheduler.add_seq_group(seq_group)
+    output = scheduler._schedule_prefills(budget, None)
+    remaining_waiting = scheduler.waiting
+    assert len(output.ignored_seq_groups) == 0
+    assert len(output.seq_groups) == 2
+    assert budget.num_batched_tokens == 120
+    assert budget.num_curr_seqs == 2
+    assert len(remaining_waiting) == 1
+
+    # Verify curr_num_seqs respected.
+    scheduler.waiting = deque()
+    budget = create_token_budget(max_num_seqs=2)
+    add_token_budget(budget, 0, 2)
+    _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
+    scheduler.add_seq_group(seq_group)
+    output = scheduler._schedule_prefills(budget, None)
+    remaining_waiting = scheduler.waiting
+    assert len(output.ignored_seq_groups) == 0
+    assert len(output.seq_groups) == 0
+    assert budget.num_batched_tokens == 0
+    assert budget.num_curr_seqs == 2
+    assert len(remaining_waiting) == 1
+
+
+def test_prefill_schedule_max_lora():
+    """
+    Test max lora is respected and prioritized.
+    """
+    lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
+    scheduler = initialize_scheduler(lora_config=lora_config)
+    budget = create_token_budget(token_budget=120)
+    curr_loras: Set[int] = set()
+    for i in range(2):
+        _, seq_group = create_dummy_prompt(str(i),
+                                           prompt_length=60,
+                                           lora_request=LoRARequest(
+                                               lora_name=str(i),
+                                               lora_int_id=i + 1,
+                                               lora_path="abc"))
+        scheduler.add_seq_group(seq_group)
+    # Add two more requests to verify lora is prioritized.
+    # 0: Lora, 1: Lora, 2: regular, 3: regular
+    # In the first iteration, index 0, 2 is scheduled.
+    # If a request is not scheduled because it hits max lora, it is
+    # prioritized. Verify that.
+    for i in range(2, 4):
+        _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
+        scheduler.add_seq_group(seq_group)
+    # Schedule 2 requests (0 and 2)
+    output = scheduler._schedule_prefills(budget, curr_loras)
+    remaining_waiting = scheduler.waiting
+    assert len(output.ignored_seq_groups) == 0
+    assert len(output.seq_groups) == 2
+    assert budget.num_batched_tokens == 120
+    assert budget.num_curr_seqs == 2
+    assert len(remaining_waiting) == 2
+    assert len(curr_loras) == 1
+    # The second lora request is scheduled next as FCFS policy.
+    # Reset curr_loras so that it can be scheduled.
+    curr_loras = set()
+    budget = create_token_budget(token_budget=60)
+    output = scheduler._schedule_prefills(budget, curr_loras)
+    remaining_waiting = scheduler.waiting
+    assert len(output.seq_groups) == 1
+    assert output.seq_groups[0].seq_group.request_id == "1"
+    assert len(remaining_waiting) == 1
+    assert len(curr_loras) == 1
+    assert budget.num_batched_tokens == 60
+
+
+def test_prefill_schedule_no_block_manager_capacity():
+    """
+    Test sequence cannot be scheduled due to block manager has no capacity.
+    """
+    scheduler = initialize_scheduler()
+    budget = create_token_budget()
+    for i in range(3):
+        _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
+        scheduler.add_seq_group(seq_group)
+    scheduler.block_manager.can_allocate = MagicMock()
+    scheduler.block_manager.can_allocate.return_value = AllocStatus.LATER
+    output = scheduler._schedule_prefills(budget, None)
+    remaining_waiting = scheduler.waiting
+    assert len(output.ignored_seq_groups) == 0
+    assert len(output.seq_groups) == 0
+    assert budget.num_batched_tokens == 0
+    assert budget.num_curr_seqs == 0
+    assert len(remaining_waiting) == 3
+
+    scheduler = initialize_scheduler()
+    budget = create_token_budget()
+    for i in range(3):
+        _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
+        scheduler.add_seq_group(seq_group)
+    scheduler.block_manager.can_allocate = MagicMock()
+    scheduler.block_manager.can_allocate.return_value = AllocStatus.NEVER
+    output = scheduler._schedule_prefills(budget, None)
+    remaining_waiting = scheduler.waiting
+    assert len(output.ignored_seq_groups) == 3
+    assert len(output.seq_groups) == 0
+    assert budget.num_batched_tokens == 0
+    assert budget.num_curr_seqs == 0
+    assert len(remaining_waiting) == 0
+
+
+def test_decode_schedule_preempted():
+    """
+    Test decodes cannot be scheduled and preempted.
+    """
+    scheduler = initialize_scheduler()
+    curr_loras = None
+    for i in range(3):
+        _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
+        scheduler._allocate_and_set_running(seq_group)
+        append_new_token_seq_group(60, seq_group, 1)
+        scheduler._add_seq_group_to_running(seq_group)
+    scheduler.block_manager.can_append_slots = MagicMock()
+
+    def cannot_append_second_group(seq_group, num_lookahead_slots):
+        return seq_group.request_id != "1"
+
+    scheduler.block_manager.can_append_slots.side_effect = (
+        cannot_append_second_group)
+
+    # 1 cannot be scheduled, and the lowest priority (request 2)
+    # should be preempted. 1 will also be preempted.
+    budget = create_token_budget()
+    output = scheduler._schedule_running(budget, curr_loras)
+    remainig_running = scheduler.running
+    assert len(remainig_running) == 0
+    assert len(output.decode_seq_groups) == 1
+    assert len(output.prefill_seq_groups) == 0
+    assert output.decode_seq_groups[0].seq_group.request_id == "0"
+    assert len(output.preempted) == 2
+    # Verify budgets are updated.
+    assert budget.num_batched_tokens == 1
+    # NOTE: When enable_chunk is False, num_seqs budget is not updated.
+    # assert budget.num_curr_seqs == 1
+    # Both should be preempted, not swapped.
+    assert output.blocks_to_swap_out == []
+    # Nothing is copied.
+    assert output.blocks_to_copy == []
+
+
+def test_decode_swap_beam_search():
+    """
+    Test best_of > 1 swap out blocks
+    """
+    scheduler = initialize_scheduler()
+    curr_loras = None
+    budget = create_token_budget()
+    for i in range(3):
+        _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2)
+        scheduler._allocate_and_set_running(seq_group)
+        scheduler._add_seq_group_to_running(seq_group)
+        append_new_token_seq_group(60, seq_group, 1)
+        budget.add_num_seqs(seq_group.request_id,
+                            seq_group.get_max_num_running_seqs())
+        budget.add_num_batched_tokens(
+            seq_group.request_id, seq_group.num_seqs(SequenceStatus.RUNNING))
+
+    # The last request should be swapped out.
+    scheduler.block_manager.can_append_slots = MagicMock()
+
+    def cannot_append_second_group(seq_group, num_lookahead_slots):
+        return seq_group.request_id != "2"
+
+    scheduler.block_manager.can_append_slots.side_effect = (
+        cannot_append_second_group)
+    scheduler.block_manager.swap_out = MagicMock()
+    expected_swap_mapping = [("5", "7")]
+    scheduler.block_manager.swap_out.return_value = expected_swap_mapping
+
+    output = scheduler._schedule_running(budget, curr_loras)
+    remainig_running = scheduler.running
+    assert len(remainig_running) == 0
+    assert len(output.decode_seq_groups) == 2
+    assert len(output.prefill_seq_groups) == 0
+    assert output.decode_seq_groups[0].seq_group.request_id == "0"
+    assert output.decode_seq_groups[1].seq_group.request_id == "1"
+    assert len(output.preempted) == 0
+    assert len(output.swapped_out) == 1
+    # Budget should refledct preempted requests.
+    assert budget.num_batched_tokens == 2
+    # since there are 2 sequences, 2 should be subtracted.
+    assert budget.num_curr_seqs == 4
+    # Both should be preempted, not swapped.
+    assert output.blocks_to_swap_out == expected_swap_mapping
+    # Nothing is copied.
+    assert output.blocks_to_copy == []
+
+
+def test_schedule_decode_blocks_to_copy_update():
+    """
+    Verify blocks_to_copy is updated.
+    """
+    scheduler = initialize_scheduler()
+    _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
+    curr_loras = None
+    scheduler._allocate_and_set_running(seq_group)
+    append_new_token_seq_group(60, seq_group, 1)
+    scheduler._add_seq_group_to_running(seq_group)
+
+    # The last request should be swapped out.
+    scheduler.block_manager.append_slots = MagicMock()
+    scheduler.block_manager.append_slots.return_value = [(2, 3)]
+
+    budget = create_token_budget()
+    output = scheduler._schedule_running(budget, curr_loras)
+    remaining_running = scheduler.running
+    assert len(remaining_running) == 0
+    assert len(output.decode_seq_groups) == 1
+    assert len(output.prefill_seq_groups) == 0
+    assert len(output.preempted) == 0
+    assert len(output.swapped_out) == 0
+    # Nothing is preempted.
+    assert output.blocks_to_swap_out == []
+    # Since append_slot returns the source -> dist mapping, it should
+    # applied.
+    assert output.blocks_to_copy == [(2, 3)]
+
+
+def test_schedule_swapped_simple():
+    scheduler = initialize_scheduler()
+    curr_loras = None
+    blocks_to_swap_out: List[Tuple[int, int]] = []
+    _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
+    scheduler._allocate_and_set_running(seq_group)
+    append_new_token_seq_group(60, seq_group, 1)
+    scheduler._swap_out(seq_group, blocks_to_swap_out)
+    scheduler._add_seq_group_to_swapped(seq_group)
+
+    budget = create_token_budget()
+    output = scheduler._schedule_swapped(budget, curr_loras)
+    remaining_swapped = scheduler.swapped
+    assert len(remaining_swapped) == 0
+    assert budget.num_batched_tokens == 1
+    assert budget.num_curr_seqs == 2
+    assert len(output.decode_seq_groups) == 1
+    assert len(output.prefill_seq_groups) == 0
+    # swap in is the reverse of swap out
+    blocks_to_swap_in_reverse = []
+    for swapin, swapout in output.blocks_to_swap_in:
+        blocks_to_swap_in_reverse.append((swapout, swapin))
+    assert blocks_to_swap_out == blocks_to_swap_in_reverse
+
+
+def test_schedule_swapped_max_token_budget():
+    scheduler = initialize_scheduler()
+    curr_loras = None
+    blocks_to_swap_out: List[Tuple[int, int]] = []
+    for _ in range(2):
+        _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
+        scheduler._allocate_and_set_running(seq_group)
+        append_new_token_seq_group(60, seq_group, 1)
+        scheduler._swap_out(seq_group, blocks_to_swap_out)
+        scheduler._add_seq_group_to_swapped(seq_group)
+
+    budget = create_token_budget(token_budget=1)
+    output = scheduler._schedule_swapped(budget, curr_loras)
+    remaining_swapped = scheduler.swapped
+    assert len(remaining_swapped) == 1
+    assert budget.num_batched_tokens == 1
+    assert budget.num_curr_seqs == 2
+    assert len(output.decode_seq_groups) == 1
+    assert len(output.prefill_seq_groups) == 0
+
+    # Verify num_batched_tokens are respected.
+    budget = create_token_budget(token_budget=1)
+    add_token_budget(budget, 1, 0)
+    output = scheduler._schedule_swapped(budget, curr_loras)
+    remaining_swapped = scheduler.swapped
+    assert len(remaining_swapped) == 1
+    assert budget.num_batched_tokens == 1
+    assert budget.num_curr_seqs == 0
+    assert len(output.decode_seq_groups) == 0
+    assert len(output.prefill_seq_groups) == 0
+
+
+def test_schedule_swapped_max_seqs():
+    scheduler = initialize_scheduler()
+    curr_loras = None
+    blocks_to_swap_out: List[Tuple[int, int]] = []
+    for i in range(4):
+        _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
+        scheduler._allocate_and_set_running(seq_group)
+        append_new_token_seq_group(60, seq_group, 1)
+        scheduler._swap_out(seq_group, blocks_to_swap_out)
+        scheduler._add_seq_group_to_swapped(seq_group)
+
+    budget = create_token_budget(max_num_seqs=2)
+    output = scheduler._schedule_swapped(budget, curr_loras)
+    remaining_swapped = scheduler.swapped
+    assert len(remaining_swapped) == 2
+    assert budget.num_batched_tokens == 2
+    assert budget.num_curr_seqs == 2
+    assert len(output.decode_seq_groups) == 2
+    assert len(output.prefill_seq_groups) == 0
+
+    # Verify num_curr_seqs are respected.
+    output = scheduler._schedule_swapped(budget, curr_loras)
+    remaining_swapped = scheduler.swapped
+    assert len(remaining_swapped) == 2
+    assert budget.num_batched_tokens == 2
+    assert budget.num_curr_seqs == 2
+    assert len(output.decode_seq_groups) == 0
+    assert len(output.prefill_seq_groups) == 0
+
+
+def test_schedule_swapped_max_loras():
+    lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
+    scheduler = initialize_scheduler(lora_config=lora_config)
+    curr_loras: Set[int] = set()
+    blocks_to_swap_out: List[Tuple[int, int]] = []
+    for i in range(2):
+        _, seq_group = create_dummy_prompt(str(i),
+                                           prompt_length=60,
+                                           lora_request=LoRARequest(
+                                               lora_name=str(i),
+                                               lora_int_id=i + 1,
+                                               lora_path="abc"))
+        scheduler._allocate_and_set_running(seq_group)
+        append_new_token_seq_group(60, seq_group, 1)
+        scheduler._swap_out(seq_group, blocks_to_swap_out)
+        scheduler._add_seq_group_to_swapped(seq_group)
+
+    budget = create_token_budget()
+    output = scheduler._schedule_swapped(budget, curr_loras)
+    remaining_swapped = scheduler.swapped
+    assert len(remaining_swapped) == 1
+    assert budget.num_batched_tokens == 1
+    assert budget.num_curr_seqs == 1
+    assert len(output.decode_seq_groups) == 1
+    assert len(output.prefill_seq_groups) == 0
+    assert len(curr_loras) == 1
+
+
+def test_schedule_swapped_cannot_swap_in():
+    scheduler = initialize_scheduler()
+    curr_loras = None
+    blocks_to_swap_out: List[Tuple[int, int]] = []
+    for _ in range(2):
+        _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
+        scheduler._allocate_and_set_running(seq_group)
+        append_new_token_seq_group(60, seq_group, 1)
+        scheduler._swap_out(seq_group, blocks_to_swap_out)
+        scheduler._add_seq_group_to_swapped(seq_group)
+
+    # The last request should be swapped out.
+    scheduler.block_manager.can_swap_in = MagicMock()
+    scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER
+    # Since we cannot swap in, none of the requests are swapped in.
+    budget = create_token_budget()
+    output = scheduler._schedule_swapped(budget, curr_loras)
+    remaining_swapped = scheduler.swapped
+    assert len(remaining_swapped) == 2
+    assert budget.num_batched_tokens == 0
+    assert budget.num_curr_seqs == 0
+    assert len(output.decode_seq_groups) == 0
+    assert len(output.prefill_seq_groups) == 0
+
+
+def test_infeasible_swap():
+    scheduler = initialize_scheduler()
+    curr_loras = None
+    blocks_to_swap_out: List[Tuple[int, int]] = []
+    for _ in range(2):
+        _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
+        scheduler._allocate_and_set_running(seq_group)
+        append_new_token_seq_group(60, seq_group, 1)
+        scheduler._swap_out(seq_group, blocks_to_swap_out)
+        scheduler._add_seq_group_to_swapped(seq_group)
+
+    # The last request should be swapped out.
+    scheduler.block_manager.can_swap_in = MagicMock()
+    scheduler.block_manager.can_swap_in.return_value = AllocStatus.NEVER
+    # Since we cannot swap in, none of the requests are swapped in.
+    budget = create_token_budget()
+    output = scheduler._schedule_swapped(budget, curr_loras)
+    remaining_swapped = scheduler.swapped
+    assert len(remaining_swapped) == 0
+    assert len(output.infeasible_seq_groups) == 2
+    assert budget.num_batched_tokens == 0
+    assert budget.num_curr_seqs == 0
+    assert len(output.decode_seq_groups) == 0
+    assert len(output.prefill_seq_groups) == 0
+
+
+def test_schedule_swapped_blocks_to_copy():
+    scheduler = initialize_scheduler()
+    curr_loras = None
+    _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
+    scheduler._allocate_and_set_running(seq_group)
+    append_new_token_seq_group(60, seq_group, 1)
+    blocks_to_swap_out: List[Tuple[int, int]] = []
+    scheduler._swap_out(seq_group, blocks_to_swap_out)
+    scheduler._add_seq_group_to_swapped(seq_group)
+
+    # The last request should be swapped out.
+    scheduler.block_manager.append_slots = MagicMock()
+    scheduler.block_manager.append_slots.return_value = [(2, 3)]
+
+    budget = create_token_budget()
+    output = scheduler._schedule_swapped(budget, curr_loras)
+    remaining_swapped = scheduler.swapped
+    assert len(remaining_swapped) == 0
+    assert len(output.decode_seq_groups) == 1
+    assert len(output.prefill_seq_groups) == 0
+    assert output.blocks_to_copy == [(2, 3)]
+
+
+def test_scheduling_budget():
+    TOKEN_BUDGET = 4
+    MAX_SEQS = 4
+    budget = SchedulingBudget(token_budget=TOKEN_BUDGET, max_num_seqs=MAX_SEQS)
+    assert budget.can_schedule(num_new_tokens=1, num_new_seqs=1)
+    assert budget.can_schedule(num_new_tokens=4, num_new_seqs=4)
+    assert not budget.can_schedule(num_new_tokens=1, num_new_seqs=5)
+    assert not budget.can_schedule(num_new_tokens=5, num_new_seqs=1)
+    assert not budget.can_schedule(num_new_tokens=5, num_new_seqs=5)
+    assert budget.remaining_token_budget() == TOKEN_BUDGET
+
+    # Verify add/subtract num batched tokens.
+    _, seq_group = create_dummy_prompt("1", 3)
+    budget.add_num_batched_tokens(seq_group.request_id, 2)
+    assert budget.remaining_token_budget() == 2
+    assert budget.num_batched_tokens == 2
+    assert budget.can_schedule(num_new_tokens=2, num_new_seqs=1)
+    assert not budget.can_schedule(num_new_tokens=3, num_new_seqs=1)
+    # Verify adding another seq group is no-op.
+    budget.add_num_batched_tokens(seq_group.request_id, 2)
+    assert budget.remaining_token_budget() == 2
+    assert budget.num_batched_tokens == 2
+    budget.subtract_num_batched_tokens(seq_group.request_id, 2)
+    assert budget.remaining_token_budget() == 4
+    assert budget.num_batched_tokens == 0
+    budget.subtract_num_batched_tokens(seq_group.request_id, 2)
+    assert budget.remaining_token_budget() == 4
+    assert budget.num_batched_tokens == 0
+
+    # Verify add/subtract max seqs.
+    _, seq_group = create_dummy_prompt("1", 3)
+    budget.add_num_seqs(seq_group.request_id, 2)
+    assert budget.can_schedule(num_new_tokens=1, num_new_seqs=2)
+    assert not budget.can_schedule(num_new_tokens=1, num_new_seqs=3)
+    assert budget.num_curr_seqs == 2
+    # Verify adding another seq group is no-op.
+    budget.add_num_seqs(seq_group.request_id, 2)
+    assert budget.num_curr_seqs == 2
+    budget.subtract_num_seqs(seq_group.request_id, 2)
+    assert budget.num_curr_seqs == 0
+    budget.subtract_num_seqs(seq_group.request_id, 2)
+    assert budget.num_curr_seqs == 0

+ 99 - 0
tests/core/test_scheduler_encoder_decoder.py

@@ -0,0 +1,99 @@
+from typing import List
+
+import pytest  # noqa
+
+from aphrodite.common.config import CacheConfig, SchedulerConfig
+from aphrodite.common.sequence import SequenceGroup
+from aphrodite.processing.scheduler import Scheduler
+
+from .utils import (append_new_token, create_dummy_prompt_encoder_decoder,
+                    get_sequence_groups, schedule_and_update_computed_tokens)
+
+
+def test_scheduler_schedule_simple_encoder_decoder():
+    '''
+    Test basic scheduler functionality in the context
+    of an encoder/decoder model. Focus on testing
+    enc/dec-specific functionality sense tests already
+    exist for decoder-only functionality
+
+    Test behavior:
+    * Construct Scheduler
+    * Construct dummy encoder/decoder sequence groups
+    * Add dummy seq groups to scheduler backlog
+    * Schedule the next seq group & validate:
+        * Cross-attn block tables
+        * Updated states of seq groups
+        * Number of batched tokens
+        * Number of blocks to copy/swap-in/swap-out
+        * Number of scheduled seq groups
+    * Repeat for both prefill- and decode-phase
+    * Abort scheduled seq groups
+    * Assert that aborted seq groups no longer appear in
+      cross-attention block table
+    '''
+
+    block_size = 4
+    num_seq_group = 4
+    max_model_len = 16
+    scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len)
+    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
+    cache_config.num_cpu_blocks = 16  # enc and dec prompts per seq_group
+    cache_config.num_gpu_blocks = 16  # enc and dec prompts per seq_group
+    scheduler = Scheduler(scheduler_config, cache_config, None)
+    running: List[SequenceGroup] = []
+
+    # Add seq groups to scheduler.
+    req_id_list = []
+    for i in range(num_seq_group):
+        req_id = str(i)
+        req_id_list.append(req_id)
+        _, _, seq_group = create_dummy_prompt_encoder_decoder(
+            req_id, block_size, block_size, block_size)
+        scheduler.add_seq_group(seq_group)
+        running.append(seq_group)
+
+    # Schedule seq groups prefill.
+    num_tokens = block_size * num_seq_group
+    seq_group_meta_list, out = schedule_and_update_computed_tokens(scheduler)
+    # - Verify that sequence group cross-attention block tables are
+    #   registered with the block manager
+    assert all([(req_id in scheduler.block_manager.cross_block_tables)
+                for req_id in req_id_list])
+    # - Validate sequence-group status
+    assert set(get_sequence_groups(out)) == set(running)
+    # - Validate number of batched tokens
+    assert out.num_batched_tokens == num_tokens
+    # - Validate there are no remaining blocks to swap
+    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
+            and not out.blocks_to_swap_out)
+    # - Validate all seq groups were scheduled
+    assert len(seq_group_meta_list) == num_seq_group
+    append_new_token(out, 1)
+
+    # Schedule seq groups decode.
+    seq_group_meta_list, out = schedule_and_update_computed_tokens(scheduler)
+    # - Verify that sequence group metadata includes encoder attention
+    #   and cross-attention metadata
+    assert all([
+        not ((seq_group_meta.encoder_seq_data is None) or
+             (seq_group_meta.cross_block_table is None))
+        for seq_group_meta in seq_group_meta_list
+    ])
+    # - Validate sequence-group status
+    assert set(get_sequence_groups(out)) == set(running)
+    # - Validate there is one batched token per seq group
+    assert out.num_batched_tokens == num_seq_group
+    # - Validate there are no remaining blocks to swap
+    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
+            and not out.blocks_to_swap_out)
+    # - Validate that all seq groups were scheduled
+    assert len(seq_group_meta_list) == num_seq_group
+    append_new_token(out, 1)
+
+    # Abort sequences
+    for req_id in req_id_list:
+        scheduler.abort_seq_group(req_id)
+        # - Verify that sequence group cross-attention block tables are
+        #   NO LONGER registered with the block manager
+        assert req_id not in scheduler.block_manager.cross_block_tables

+ 211 - 0
tests/core/utils.py

@@ -0,0 +1,211 @@
+import time
+from typing import List, Optional
+from typing import Sequence as GenericSequence
+from typing import Tuple
+
+from aphrodite import SamplingParams
+from aphrodite.common.sequence import Logprob, Sequence, SequenceGroup
+from aphrodite.lora.request import LoRARequest
+
+
+def create_dummy_prompt(
+    request_id: str,
+    prompt_length: int,
+    block_size: Optional[int] = None,
+    lora_request: Optional[LoRARequest] = None,
+    use_beam_search: bool = False,
+    best_of: int = 1,
+    prompt_tokens: Optional[List[int]] = None,
+) -> Tuple[Sequence, SequenceGroup]:
+    if not block_size:
+        block_size = prompt_length
+
+    if prompt_tokens is None:
+        # Create dummy prompt sequence with tokens 0...block_size-1
+        # and prompt "0 ... block_size".
+        prompt_tokens = list(range(prompt_length))
+    prompt_str = " ".join([str(t) for t in prompt_tokens])
+    prompt = Sequence(int(request_id),
+                      inputs={
+                          "prompt": prompt_str,
+                          "prompt_token_ids": prompt_tokens,
+                      },
+                      block_size=block_size)
+    seq_group = SequenceGroup(request_id=request_id,
+                              seqs=[prompt],
+                              arrival_time=time.time(),
+                              sampling_params=SamplingParams(
+                                  use_beam_search=use_beam_search,
+                                  best_of=best_of),
+                              lora_request=lora_request)
+
+    return prompt, seq_group
+
+
+def create_dummy_prompt_encoder_decoder(
+    request_id: str,
+    decoder_prompt_length: int,
+    encoder_prompt_length: int,
+    block_size: Optional[int] = None,
+    lora_request: Optional[LoRARequest] = None,
+    use_beam_search: bool = False,
+    best_of: int = 1,
+) -> Tuple[Sequence, Sequence, SequenceGroup]:
+    if not block_size:
+        block_size = decoder_prompt_length
+
+    # Create dummy prompt sequence with tokens 0...block_size-1
+    # and prompt "0 ... block_size". Note that the prompt string
+    # doesn't actually match the tokens
+    decoder_prompt_tokens = list(range(decoder_prompt_length))
+    decoder_prompt_str = " ".join([str(t) for t in decoder_prompt_tokens])
+    encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length))))
+    encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens])
+
+    inputs = {
+        "prompt": decoder_prompt_str,
+        "prompt_token_ids": decoder_prompt_tokens,
+        "encoder_prompt": encoder_prompt_str,
+        "encoder_prompt_token_ids": encoder_prompt_tokens,
+        "multi_modal_data": None,
+    }
+
+    decoder_prompt = Sequence(int(request_id),
+                              inputs=inputs,
+                              block_size=block_size,
+                              from_decoder_prompt=True)
+
+    encoder_prompt = Sequence(int(request_id),
+                              inputs=inputs,
+                              block_size=block_size,
+                              from_decoder_prompt=False)
+    seq_group = SequenceGroup(request_id=request_id,
+                              seqs=[decoder_prompt],
+                              sampling_params=SamplingParams(
+                                  use_beam_search=use_beam_search,
+                                  best_of=best_of),
+                              arrival_time=time.time(),
+                              lora_request=lora_request,
+                              encoder_seq=encoder_prompt)
+
+    return decoder_prompt, encoder_prompt, seq_group
+
+
+def create_seq_group(
+        seq_prompt_len: int = 1024,
+        seq_output_lens: GenericSequence[int] = (128, ),
+        request_id: str = '0',
+        seq_id_start: int = 0,
+        sampling_params: Optional[SamplingParams] = None) -> SequenceGroup:
+
+    assert len(seq_output_lens) > 0
+
+    if sampling_params is None:
+        sampling_params = SamplingParams()
+
+    prompt_token_ids = [0] * seq_prompt_len
+
+    seqs: List[Sequence] = []
+    for seq_id_offset, output_len in enumerate(seq_output_lens):
+        seq = Sequence(
+            seq_id=seq_id_start + seq_id_offset,
+            inputs={"prompt_token_ids": prompt_token_ids},
+            block_size=16,
+        )
+
+        for i in range(output_len):
+            seq.append_token_id(
+                token_id=i,
+                logprobs={i: Logprob(0.0)},
+            )
+        seqs.append(seq)
+
+    seq_group = SequenceGroup(
+        request_id=request_id,
+        seqs=seqs,
+        sampling_params=sampling_params,
+        arrival_time=time.time(),
+    )
+
+    return seq_group
+
+
+def create_seq_group_encoder_decoder(
+        seq_prompt_len: int = 1024,
+        seq_output_lens: GenericSequence[int] = (128, ),
+        request_id: str = '0',
+        seq_id_start: int = 0,
+        sampling_params: Optional[SamplingParams] = None) -> SequenceGroup:
+
+    assert len(seq_output_lens) > 0
+
+    if sampling_params is None:
+        sampling_params = SamplingParams()
+
+    prompt_token_ids = [0] * seq_prompt_len
+
+    inputs = {
+        "prompt": "",
+        "prompt_token_ids": prompt_token_ids,
+        "encoder_prompt": "",
+        "encoder_prompt_token_ids": prompt_token_ids,
+        "multi_modal_data": None,
+    }
+
+    seqs = []
+    for seq_id_offset, output_len in enumerate(seq_output_lens):
+        # Construct decoder input sequences
+        seq = Sequence(seq_id=seq_id_start + seq_id_offset,
+                       inputs=inputs,
+                       block_size=16,
+                       from_decoder_prompt=True)
+
+        for i in range(output_len):
+            seq.append_token_id(
+                token_id=i,
+                logprobs={i: Logprob(0.0)},
+            )
+        seqs.append(seq)
+
+    # Encoder input sequence
+    encoder_seq = Sequence(seq_id=seq_id_start + len(seq_output_lens),
+                           inputs=inputs,
+                           block_size=16,
+                           from_decoder_prompt=False)
+
+    return SequenceGroup(request_id=request_id,
+                         seqs=seqs,
+                         sampling_params=sampling_params,
+                         arrival_time=time.time(),
+                         encoder_seq=encoder_seq)
+
+
+def round_up_to_next_block(seq_len: int, block_size: int) -> int:
+    return (seq_len + block_size - 1) // block_size
+
+
+# Helper functions for scheduler tests
+
+
+def get_sequence_groups(scheduler_output):
+    return [s.seq_group for s in scheduler_output.scheduled_seq_groups]
+
+
+def append_new_token(out, token_id: int):
+    seq_groups = get_sequence_groups(out)
+    for seq_group in seq_groups:
+        for seq in seq_group.get_seqs():
+            seq.append_token_id(token_id, {token_id: Logprob(token_id)})
+
+
+def schedule_and_update_computed_tokens(scheduler):
+    metas, out = scheduler.schedule()
+    for s, meta in zip(out.scheduled_seq_groups, metas):
+        s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
+    return metas, out
+
+
+def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int):
+    seq_group.update_num_computed_tokens(token_chunk_size)
+    for seq in seq_group.get_seqs():
+        seq.append_token_id(token_id, {token_id: Logprob(token_id)})

+ 0 - 0
tests/distributed/__init__.py


+ 81 - 0
tests/distributed/test_basic_distributed_correctness.py

@@ -0,0 +1,81 @@
+"""Compare the outputs of HF and distributed Aphrodite when using
+greedy sampling.
+
+Run:
+```sh
+cd $APHRODITE_PATH/tests
+
+pytest distributed/test_basic_distributed_correctness.py
+```
+"""
+import os
+
+import pytest
+
+from aphrodite.common.utils import cuda_device_count_stateless
+
+from ..models.utils import check_outputs_equal
+from ..utils import fork_new_process_for_each_test
+
+TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4")
+
+
+@pytest.mark.skipif(cuda_device_count_stateless() < 2,
+                    reason="Need at least 2 GPUs to run the test.")
+@pytest.mark.parametrize(
+    "model, distributed_executor_backend, attention_backend, test_suite", [
+        ("facebook/opt-125m", "ray", "", "L4"),
+        ("facebook/opt-125m", "mp", "", "L4"),
+        ("meta-llama/Llama-2-7b-hf", "ray", "", "L4"),
+        ("meta-llama/Llama-2-7b-hf", "mp", "", "L4"),
+        ("facebook/opt-125m", "ray", "", "A100"),
+        ("facebook/opt-125m", "mp", "", "A100"),
+        ("facebook/opt-125m", "mp", "FLASHINFER", "A100"),
+        ("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100"),
+    ])
+@fork_new_process_for_each_test
+def test_models(
+    hf_runner,
+    aphrodite_runner,
+    example_prompts,
+    model: str,
+    distributed_executor_backend: str,
+    attention_backend: str,
+    test_suite: str,
+) -> None:
+
+    if test_suite != TARGET_TEST_SUITE:
+        pytest.skip(f"Skip test for {test_suite}")
+
+    if model == "meta-llama/Llama-2-7b-hf" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4":  # noqa
+        # test ray adag
+        os.environ['APHRODITE_USE_RAY_SPMD_WORKER'] = "1"
+        os.environ['APHRODITE_USE_RAY_COMPILED_DAG'] = "1"
+
+    if attention_backend:
+        os.environ["APHRODITE_ATTENTION_BACKEND"] = attention_backend
+
+    dtype = "half"
+    max_tokens = 5
+
+    # NOTE: take care of the order. run Aphrodite first, and then run HF.
+    # Aphrodite needs a fresh new process without cuda initialization.
+    # if we run HF first, the cuda initialization will be done and it
+    # will hurt multiprocessing backend with fork method (the default method).
+    with aphrodite_runner(model,
+                     dtype=dtype,
+                     tensor_parallel_size=2,
+                     distributed_executor_backend=distributed_executor_backend
+                     ) as aphrodite_model:
+        aphrodite_outputs = aphrodite_model.generate_greedy(
+            example_prompts, max_tokens)
+
+    with hf_runner(model, dtype=dtype) as hf_model:
+        hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
+
+    check_outputs_equal(
+        outputs_0_lst=hf_outputs,
+        outputs_1_lst=aphrodite_outputs,
+        name_0="hf",
+        name_1="aphrodite",
+    )

+ 102 - 0
tests/distributed/test_basic_distributed_correctness_enc_dec.py

@@ -0,0 +1,102 @@
+"""For encoder/decoder models only:
+Compare the outputs of HF and distributed Aphrodite when using greedy sampling.
+
+Run:
+```sh
+cd $APHRODITE_PATH/tests
+
+pytest distributed/test_basic_distributed_correctness_enc_dec.py
+```
+"""
+
+import pytest
+
+from aphrodite.common.utils import cuda_device_count_stateless
+
+from ..conftest import DecoderPromptType
+from ..models.utils import check_logprobs_close
+from ..utils import fork_new_process_for_each_test
+
+
+@pytest.mark.skipif(cuda_device_count_stateless() < 2,
+                    reason="Need at least 2 GPUs to run the test.")
+@pytest.mark.parametrize("model, distributed_executor_backend", [
+    ("facebook/bart-large-cnn", "ray"),
+    ("facebook/bart-large-cnn", "mp"),
+])
+@fork_new_process_for_each_test
+def test_models(
+    model: str,
+    distributed_executor_backend: str,
+    hf_runner,
+    aphrodite_runner,
+    example_encoder_decoder_prompts,
+) -> None:
+    '''
+    Test Aphrodite BART inference on more than one GPU, comparing
+    outputs against HF as a baseline.
+
+    Fork a new process for each test, to prevent CUDA from
+    being re-initialized by successive tests within the same
+    process.
+
+    Arguments:
+
+    * model: the HF ID of the specific BART variant under test
+    * distributed_executor_backend
+    * hf_runner: HuggingFace (HF) test model runner
+    * aphrodite_runner: Aphrodite test model runner
+    * example_encoder_decoder_prompts: test fixture which provides a 
+                                        dictionary of dummy prompts
+    '''
+
+    dtype = "float"
+    max_tokens = 64
+    num_logprobs = 5
+
+    # Example inputs with non-trivial (i.e. not None/empty) encoder &
+    # decoder prompts.
+    test_prompts = example_encoder_decoder_prompts[DecoderPromptType.CUSTOM]
+
+    # NOTE: take care of the order. run Aphrodite first, and then run HF.
+    # Aphrodite needs a fresh new process without cuda initialization.
+    # if we run HF first, the cuda initialization will be done and it
+    # will hurt multiprocessing backend with fork method (the default method).
+    with aphrodite_runner(
+            model,
+            dtype=dtype,
+            tensor_parallel_size=2,
+            distributed_executor_backend=distributed_executor_backend,
+            enforce_eager=True,
+    ) as aphrodite_model:
+        aphrodite_outputs = (
+            aphrodite_model.generate_encoder_decoder_greedy_logprobs(
+                test_prompts, max_tokens, num_logprobs))
+
+    # Configuration settings for HF baseline
+    hf_kwargs = {
+        "top_k": None,
+        "num_beams": 1,
+        "repetition_penalty": 1.0,
+        "top_p": 1.0,
+        "length_penalty": 1.0,
+        "early_stopping": False,
+        "no_repeat_ngram_size": None,
+        "min_length": 0
+    }
+
+    with hf_runner(model, dtype=dtype,
+                   is_encoder_decoder_model=True) as hf_model:
+        hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
+            test_prompts,
+            max_tokens,
+            num_logprobs,
+            **hf_kwargs,
+        ))
+
+    check_logprobs_close(
+        outputs_0_lst=hf_outputs,
+        outputs_1_lst=aphrodite_outputs,
+        name_0="hf",
+        name_1="aphrodite",
+    )

+ 70 - 0
tests/distributed/test_chunked_prefill_distributed.py

@@ -0,0 +1,70 @@
+"""Compare the outputs of HF and distributed Aphrodite when using greedy
+sampling.
+
+Run:
+```sh
+pytest test_chunked_prefill_distributed.py
+```
+"""
+
+import pytest
+
+from aphrodite.common.utils import cuda_device_count_stateless
+
+from ..models.utils import check_outputs_equal
+from ..utils import fork_new_process_for_each_test
+
+
+@pytest.mark.skipif(cuda_device_count_stateless() < 2,
+                    reason="Need at least 2 GPUs to run the test.")
+@pytest.mark.parametrize("model, distributed_executor_backend", [
+    ("facebook/opt-125m", "ray"),
+    ("meta-llama/Llama-2-7b-hf", "ray"),
+    ("facebook/opt-125m", "mp"),
+    ("meta-llama/Llama-2-7b-hf", "mp"),
+])
+@fork_new_process_for_each_test
+def test_models(
+    hf_runner,
+    aphrodite_runner,
+    example_prompts,
+    model: str,
+    distributed_executor_backend: str,
+) -> None:
+
+    dtype = "half"
+    max_tokens = 5
+    chunked_prefill_token_size = 16
+
+    # Add a chunked prefill config.
+    max_num_seqs = min(chunked_prefill_token_size, 256)
+    assert chunked_prefill_token_size != -1
+    enable_chunked_prefill = True
+    max_num_batched_tokens = chunked_prefill_token_size
+
+    # NOTE: take care of the order. run Aphrodite first, and then run HF.
+    # Aphrodite needs a fresh new process without cuda initialization.
+    # if we run HF first, the cuda initialization will be done and it
+    # will hurt multiprocessing backend with fork method (the default method).
+
+    with aphrodite_runner(
+            model,
+            dtype=dtype,
+            tensor_parallel_size=2,
+            max_num_seqs=max_num_seqs,
+            enable_chunked_prefill=enable_chunked_prefill,
+            max_num_batched_tokens=max_num_batched_tokens,
+            distributed_executor_backend=distributed_executor_backend,
+    ) as aphrodite_model:
+        aphrodite_outputs = aphrodite_model.generate_greedy(
+            example_prompts, max_tokens)
+
+    with hf_runner(model, dtype=dtype) as hf_model:
+        hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
+
+    check_outputs_equal(
+        outputs_0_lst=hf_outputs,
+        outputs_1_lst=aphrodite_outputs,
+        name_0="hf",
+        name_1="aphrodite",
+    )

+ 200 - 0
tests/distributed/test_comm_ops.py

@@ -0,0 +1,200 @@
+"""Test the communication operators.
+
+Run `pytest tests/distributed/test_comm_ops.py`.
+"""
+import os
+
+import pytest
+import ray
+import torch
+
+from aphrodite.distributed import (broadcast_tensor_dict, get_pp_group,
+                                   tensor_model_parallel_all_gather,
+                                   tensor_model_parallel_all_reduce)
+
+from ..utils import init_test_distributed_environment, multi_process_parallel
+
+
+@ray.remote(num_gpus=1, max_calls=1)
+def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int,
+                           distributed_init_port: str):
+    # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
+    # so that each worker can see all the GPUs
+    # they will be able to set the device to the correct GPU
+    del os.environ["CUDA_VISIBLE_DEVICES"]
+    device = torch.device(f"cuda:{rank}")
+    torch.cuda.set_device(device)
+    init_test_distributed_environment(tp_size, pp_size, rank,
+                                      distributed_init_port)
+    num_elements = 8
+    all_tensors = [
+        torch.arange(num_elements, dtype=torch.float32, device="cuda") *
+        (r + 1) for r in range(tp_size)
+    ]
+    expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
+    t = all_tensors[rank % tp_size]
+    t = tensor_model_parallel_all_reduce(t)
+    torch.testing.assert_close(t, expected)
+
+
+@ray.remote(num_gpus=1, max_calls=1)
+def all_gather_test_worker(tp_size: int, pp_size: int, rank: int,
+                           distributed_init_port: str):
+    # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
+    # so that each worker can see all the GPUs
+    # they will be able to set the device to the correct GPU
+    del os.environ["CUDA_VISIBLE_DEVICES"]
+    device = torch.device(f"cuda:{rank}")
+    torch.cuda.set_device(device)
+    init_test_distributed_environment(tp_size, pp_size, rank,
+                                      distributed_init_port)
+    num_dimensions = 3
+    tensor_size = list(range(2, num_dimensions + 2))
+    total_size = 1
+    for s in tensor_size:
+        total_size *= s
+    for all_gather_dimension in range(num_dimensions):
+        all_tensors = [
+            torch.arange(total_size, dtype=torch.float32,
+                         device="cuda").reshape(tensor_size) * (r + 1)
+            for r in range(tp_size)
+        ]
+        expected = torch.cat(all_tensors, dim=all_gather_dimension)
+        t = all_tensors[rank % tp_size]
+        t = tensor_model_parallel_all_gather(t, all_gather_dimension)
+        torch.testing.assert_close(t, expected)
+
+
+@ray.remote(num_gpus=1, max_calls=1)
+def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
+                                      distributed_init_port: str):
+    # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
+    # so that each worker can see all the GPUs
+    # they will be able to set the device to the correct GPU
+    del os.environ["CUDA_VISIBLE_DEVICES"]
+    device = torch.device(f"cuda:{rank}")
+    torch.cuda.set_device(device)
+    init_test_distributed_environment(tp_size, pp_size, rank,
+                                      distributed_init_port)
+    test_dict = {
+        # device tensor
+        "a": torch.arange(8, dtype=torch.float32, device="cuda"),
+        # CPU tensor
+        "b": torch.arange(16, dtype=torch.int8, device="cpu"),
+        "c": "test",
+        "d": [1, 2, 3],
+        "e": {
+            "a": 1,
+            "b": 2
+        },
+        # empty tensor
+        "f": torch.tensor([], dtype=torch.float32, device="cuda"),
+    }
+
+    if (rank % tp_size) == 0:
+        broadcast_tensor_dict(test_dict, src=0)
+    else:
+        recv_dict = broadcast_tensor_dict(src=0)
+        assert len(recv_dict) == len(test_dict)
+        torch.testing.assert_close(recv_dict["a"], test_dict["a"])
+        torch.testing.assert_close(recv_dict["b"], test_dict["b"])
+        assert recv_dict["c"] == test_dict["c"]
+        assert recv_dict["d"] == test_dict["d"]
+        assert recv_dict["e"] == test_dict["e"]
+        torch.testing.assert_close(recv_dict["f"], test_dict["f"])
+
+
+@ray.remote(num_gpus=1, max_calls=1)
+def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
+                                      distributed_init_port: str):
+    del os.environ["CUDA_VISIBLE_DEVICES"]
+    device = torch.device(f"cuda:{rank}")
+    torch.cuda.set_device(device)
+    init_test_distributed_environment(tp_size, pp_size, rank,
+                                      distributed_init_port)
+
+    test_dict = {
+        # device tensor
+        "a": torch.arange(8, dtype=torch.float32, device="cuda"),
+        # CPU tensor
+        "b": torch.arange(16, dtype=torch.int8, device="cpu"),
+        "c": "test",
+        "d": [1, 2, 3],
+        "e": {
+            "a": 1,
+            "b": 2
+        },
+        # empty tensor
+        "f": torch.tensor([], dtype=torch.float32, device="cuda"),
+    }
+
+    if not get_pp_group().is_first_rank:
+        recv_dict = get_pp_group().recv_tensor_dict()
+
+    if not get_pp_group().is_last_rank:
+        get_pp_group().send_tensor_dict(test_dict)
+
+    if not get_pp_group().is_first_rank:
+        assert len(recv_dict) == len(test_dict)
+        torch.testing.assert_close(recv_dict["a"], test_dict["a"])
+        torch.testing.assert_close(recv_dict["b"], test_dict["b"])
+        assert recv_dict["c"] == test_dict["c"]
+        assert recv_dict["d"] == test_dict["d"]
+        assert recv_dict["e"] == test_dict["e"]
+        torch.testing.assert_close(recv_dict["f"], test_dict["f"])
+
+
+@ray.remote(num_gpus=1, max_calls=1)
+def send_recv_test_worker(tp_size: int, pp_size: int, rank: int,
+                          distributed_init_port: str):
+    del os.environ["CUDA_VISIBLE_DEVICES"]
+    device = torch.device(f"cuda:{rank}")
+    torch.cuda.set_device(device)
+    init_test_distributed_environment(tp_size, pp_size, rank,
+                                      distributed_init_port)
+
+    size = 64
+    test_tensor = torch.arange(64, dtype=torch.float32, device="cuda")
+
+    if not get_pp_group().is_first_rank:
+        recv_tensor = get_pp_group().recv(size, dtype=torch.float32)
+
+    if not get_pp_group().is_last_rank:
+        get_pp_group().send(test_tensor)
+
+    if not get_pp_group().is_first_rank:
+        torch.testing.assert_close(test_tensor, recv_tensor)
+
+
+@pytest.mark.skipif(torch.cuda.device_count() < 2,
+                    reason="Need at least 2 GPUs to run the test.")
+@pytest.mark.parametrize("tp_size", [2])
+@pytest.mark.parametrize("test_target", [
+    all_reduce_test_worker, all_gather_test_worker,
+    broadcast_tensor_dict_test_worker
+])
+def test_multi_process_tensor_parallel(tp_size, test_target):
+    multi_process_parallel(tp_size, 1, test_target)
+
+
+@pytest.mark.skipif(torch.cuda.device_count() < 2,
+                    reason="Need at least 2 GPUs to run the test.")
+@pytest.mark.parametrize("pp_size", [2])
+@pytest.mark.parametrize(
+    "test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker])
+def test_multi_process_pipeline_parallel(pp_size, test_target):
+    multi_process_parallel(1, pp_size, test_target)
+
+
+@pytest.mark.skipif(torch.cuda.device_count() < 4,
+                    reason="Need at least 4 GPUs to run the test.")
+@pytest.mark.parametrize("tp_size", [2])
+@pytest.mark.parametrize("pp_size", [2])
+@pytest.mark.parametrize("test_target", [
+    send_recv_test_worker, send_recv_tensor_dict_test_worker,
+    all_reduce_test_worker, all_gather_test_worker,
+    broadcast_tensor_dict_test_worker
+])
+def test_multi_process_tensor_parallel_pipeline_parallel(
+        tp_size, pp_size, test_target):
+    multi_process_parallel(tp_size, pp_size, test_target)

+ 115 - 0
tests/distributed/test_custom_all_reduce.py

@@ -0,0 +1,115 @@
+import os
+import random
+
+import pytest
+import ray
+import torch
+import torch.distributed as dist
+
+from aphrodite.distributed.communication_op import (  # noqa
+    tensor_model_parallel_all_reduce)
+from aphrodite.distributed.parallel_state import (
+    get_tensor_model_parallel_group, get_tp_group, graph_capture)
+
+from ..utils import (ensure_model_parallel_initialized,
+                     init_test_distributed_environment, multi_process_parallel)
+
+random.seed(42)
+test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)]
+for i, v in enumerate(test_sizes):
+    test_sizes[i] -= v % 8
+
+
+@ray.remote(num_gpus=1, max_calls=1)
+def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
+    del os.environ["CUDA_VISIBLE_DEVICES"]
+    device = torch.device(f"cuda:{rank}")
+    torch.cuda.set_device(device)
+    init_test_distributed_environment(tp_size, pp_size, rank,
+                                      distributed_init_port)
+    ensure_model_parallel_initialized(tp_size, pp_size)
+    group = get_tensor_model_parallel_group().device_group
+
+    # A small all_reduce for warmup.
+    # this is needed because device communicators might be created lazily
+    # (e.g. NCCL). This will ensure that the communicator is initialized
+    # before any communication happens, so that this group can be used for
+    # graph capture immediately.
+    data = torch.zeros(1)
+    data = data.to(device=device)
+    torch.distributed.all_reduce(data, group=group)
+    torch.cuda.synchronize()
+    del data
+
+    # we use the first group to communicate once
+    # and the second group to communicate twice
+    # and so on
+    # this is used to demonstrate that each group can
+    # communicate independently
+    num_communication = rank // tp_size + 1
+
+    for sz in test_sizes:
+        for dtype in [torch.float32, torch.float16, torch.bfloat16]:
+            with graph_capture() as graph_capture_context:
+                # use integers so result matches NCCL exactly
+                inp1 = torch.randint(1,
+                                     16, (sz, ),
+                                     dtype=dtype,
+                                     device=torch.cuda.current_device())
+                inp2 = torch.randint(1,
+                                     16, (sz, ),
+                                     dtype=dtype,
+                                     device=torch.cuda.current_device())
+                torch.cuda.synchronize()
+                graph = torch.cuda.CUDAGraph()
+                with torch.cuda.graph(graph,
+                                      stream=graph_capture_context.stream):
+                    for i in range(num_communication):
+                        out1 = tensor_model_parallel_all_reduce(inp1)
+                        # the input buffer is immediately modified to test
+                        # synchronization
+                        dist.all_reduce(inp1, group=group)
+                        out2 = tensor_model_parallel_all_reduce(inp2)
+                        dist.all_reduce(inp2, group=group)
+            graph.replay()
+            torch.testing.assert_close(out1, inp1)
+            torch.testing.assert_close(out2, inp2)
+
+
+@ray.remote(num_gpus=1, max_calls=1)
+def eager_allreduce(tp_size, pp_size, rank, distributed_init_port):
+    del os.environ["CUDA_VISIBLE_DEVICES"]
+    device = torch.device(f"cuda:{rank}")
+    torch.cuda.set_device(device)
+    init_test_distributed_environment(tp_size, pp_size, rank,
+                                      distributed_init_port)
+
+    # we use the first group to communicate once
+    # and the second group to communicate twice
+    # and so on
+    # this is used to demonstrate that each group can
+    # communicate independently
+    num_communication = rank // tp_size + 1
+    sz = 1024
+    fa = get_tp_group().ca_comm
+    inp = torch.ones(sz, dtype=torch.float32, device=device)
+    out = inp
+    for _ in range(num_communication):
+        out = fa.all_reduce_unreg(out)
+    torch.testing.assert_close(out, inp * (tp_size**num_communication))
+
+    inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device)
+    out = inp
+    for _ in range(num_communication):
+        out = fa.all_reduce_unreg(out)
+    torch.testing.assert_close(out, inp * (tp_size**num_communication))
+
+
+@pytest.mark.parametrize("tp_size", [2])
+@pytest.mark.parametrize("pipeline_parallel_size", [1, 2])
+@pytest.mark.parametrize("test_target", [eager_allreduce, graph_allreduce])
+def test_custom_allreduce(tp_size, pipeline_parallel_size, test_target):
+    world_size = tp_size * pipeline_parallel_size
+    if world_size > torch.cuda.device_count():
+        pytest.skip("Not enough GPUs to run the test.")
+    multi_process_parallel(tp_size, pipeline_parallel_size, test_target)

+ 6 - 0
tests/distributed/test_distributed_oot.py

@@ -0,0 +1,6 @@
+from ..endpoints.openai.test_oot_registration import (
+    run_and_test_dummy_opt_api_server)
+
+
+def test_distributed_oot(dummy_opt_path: str):
+    run_and_test_dummy_opt_api_server(dummy_opt_path, tp=2)

+ 57 - 0
tests/distributed/test_multimodal_broadcast.py

@@ -0,0 +1,57 @@
+"""Compare the outputs of HF and distributed Aphrodite when using greedy
+sampling.
+
+Run:
+```sh
+pytest -s -v test_multimodal_broadcast.py
+```
+"""
+
+import pytest
+
+from aphrodite.common.utils import cuda_device_count_stateless
+
+from ..utils import fork_new_process_for_each_test
+
+
+@pytest.mark.skipif(cuda_device_count_stateless() < 2,
+                    reason="Need at least 2 GPUs to run the test.")
+@pytest.mark.parametrize("model, distributed_executor_backend", [
+    ("llava-hf/llava-1.5-7b-hf", "ray"),
+    ("llava-hf/llava-v1.6-mistral-7b-hf", "ray"),
+    ("facebook/chameleon-7b", "ray"),
+    ("llava-hf/llava-1.5-7b-hf", "mp"),
+    ("llava-hf/llava-v1.6-mistral-7b-hf", "mp"),
+    ("facebook/chameleon-7b", "mp"),
+])
+@fork_new_process_for_each_test
+def test_models(hf_runner, aphrodite_runner, image_assets, model: str,
+                distributed_executor_backend: str) -> None:
+
+    dtype = "half"
+    max_tokens = 5
+    num_logprobs = 5
+    tensor_parallel_size = 2
+
+    if model.startswith("llava-hf/llava-1.5"):
+        from ..models.test_llava import models, run_test
+    elif model.startswith("llava-hf/llava-v1.6"):
+        from ..models.test_llava_next import models, run_test
+    elif model.startswith("facebook/chameleon"):
+        from ..models.test_chameleon import models, run_test
+    else:
+        raise NotImplementedError(f"Unsupported model: {model}")
+
+    run_test(
+        hf_runner,
+        aphrodite_runner,
+        image_assets,
+        model=models[0],
+        # So that LLaVA-NeXT processor may return nested list
+        size_factors=[0.25, 0.5, 1.0],
+        dtype=dtype,
+        max_tokens=max_tokens,
+        num_logprobs=num_logprobs,
+        tensor_parallel_size=tensor_parallel_size,
+        distributed_executor_backend=distributed_executor_backend,
+    )

+ 92 - 0
tests/distributed/test_pipeline_parallel.py

@@ -0,0 +1,92 @@
+"""
+WARNING: This test runs in both single-node (4 GPUs) and multi-node
+ (2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is
+ important to set the distributed backend to "mp" to avoid Ray scheduling
+ all workers in a node other than the head node, which can cause the test
+ to fail.
+"""
+import os
+
+import pytest
+from loguru import logger
+
+from ..utils import compare_two_settings, fork_new_process_for_each_test
+
+APHRODITE_MULTI_NODE = os.getenv("APHRODITE_MULTI_NODE", "0") == "1"
+
+
+@pytest.mark.parametrize(("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, "
+                          "MODEL_NAME, DIST_BACKEND"),
+                         [
+                             (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"),
+                             (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
+                             (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
+                             (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"),
+                             (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
+                             (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
+                             (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"),
+                             (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
+                             (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
+                             (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"),
+                         ])
+@fork_new_process_for_each_test
+def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
+                    DIST_BACKEND):
+    if APHRODITE_MULTI_NODE and DIST_BACKEND == "mp":
+        pytest.skip("Skipping multi-node pipeline parallel test for "
+                    "multiprocessing distributed backend")
+
+    pp_args = [
+        # use half precision for speed and memory savings in CI environment
+        "--dtype",
+        "float16",
+        "--pipeline-parallel-size",
+        str(PP_SIZE),
+        "--tensor-parallel-size",
+        str(TP_SIZE),
+        "--distributed-executor-backend",
+        DIST_BACKEND,
+    ]
+
+    # compare without pipeline parallelism
+    # NOTE: use mp backend for TP
+    # PP tests might involve multiple nodes, and ray might
+    #  schedule all workers in a node other than the head node,
+    #  which can cause the test to fail.
+    tp_args = [
+        # use half precision for speed and memory savings in CI environment
+        "--dtype",
+        "bfloat16",
+        "--tensor-parallel-size",
+        str(max(TP_SIZE, 2)),  # We only use 2 GPUs in the CI.
+        "--distributed-executor-backend",
+        "mp",
+    ]
+    if CHUNKED_PREFILL:
+        pp_args.append("--enable-chunked-prefill")
+        tp_args.append("--enable-chunked-prefill")
+    if EAGER_MODE:
+        pp_args.append("--enforce-eager")
+        tp_args.append("--enforce-eager")
+    pp_env = None
+    if (DIST_BACKEND == "ray" and TP_SIZE == 2 and PP_SIZE == 2
+            and CHUNKED_PREFILL):
+        # Test Ray ADAG for a subset of the tests
+        pp_env = {
+            "APHRODITE_USE_RAY_COMPILED_DAG": "1",
+            "APHRODITE_USE_RAY_SPMD_WORKER": "1",
+            "APHRODITE_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1",
+        }
+        # Temporary. Currently when zeromq + SPMD is used, it does not properly
+        # terminate because of aDAG issue.
+        pp_args.append("--disable-frontend-multiprocessing")
+        tp_args.append("--disable-frontend-multiprocessing")
+
+    try:
+        compare_two_settings(MODEL_NAME, pp_args, tp_args, pp_env)
+    except Exception:
+        if pp_env is None:
+            raise
+        else:
+            # Ray ADAG tests are flaky, so we don't want to fail the test
+            logger.exception("Ray ADAG tests failed")

+ 34 - 0
tests/distributed/test_pipeline_partition.py

@@ -0,0 +1,34 @@
+import os
+
+import pytest
+
+from aphrodite.distributed.utils import get_pp_indices
+
+
+def test_custom_layer_partition():
+
+    def _verify(partition_str, num_layers, pp_size, goldens):
+        bak = os.environ.get("APHRODITE_PP_LAYER_PARTITION", None)
+        os.environ["APHRODITE_PP_LAYER_PARTITION"] = partition_str
+        for pp_rank, golden in enumerate(goldens):
+            assert get_pp_indices(num_layers, pp_rank, pp_size) == golden
+        if bak is not None:
+            os.environ["APHRODITE_PP_LAYER_PARTITION"] = bak
+
+    # Even partition
+    _verify("5,5,5,5", 20, 4, [(0, 5), (5, 10), (10, 15), (15, 20)])
+    # Balanced partition
+    _verify("4,6,6,4", 20, 4, [(0, 4), (4, 10), (10, 16), (16, 20)])
+    # Put reminder somewhere
+    _verify("5,6,5,6", 22, 4, [(0, 5), (5, 11), (11, 16), (16, 22)])
+    # Invalid partition strings
+    with pytest.raises(ValueError):
+        _verify("5,5,5,5,", 20, 4, [(0, 5), (5, 10), (10, 15), (15, 20)])
+    with pytest.raises(ValueError):
+        _verify("5,5,5,a", 20, 4, [(0, 5), (5, 10), (10, 15), (15, 20)])
+    # Wrong number of partitions
+    with pytest.raises(ValueError):
+        _verify("5,5,5", 20, 4, [(0, 5), (5, 10), (10, 15), (15, 20)])
+    # Wrong number of layers
+    with pytest.raises(ValueError):
+        _verify("5,5,5,5", 21, 4, [(0, 5), (5, 10), (10, 15), (15, 20)])

+ 30 - 0
tests/distributed/test_pp_cudagraph.py

@@ -0,0 +1,30 @@
+import os
+
+import pytest
+
+from ..utils import compare_two_settings, fork_new_process_for_each_test
+
+
+@pytest.mark.parametrize("PP_SIZE, MODEL_NAME", [
+    (2, "JackFram/llama-160m"),
+])
+@pytest.mark.parametrize("ATTN_BACKEND", [
+    "FLASH_ATTN",
+    "FLASHINFER",
+])
+@fork_new_process_for_each_test
+def test_pp_cudagraph(PP_SIZE, MODEL_NAME, ATTN_BACKEND):
+    cudagraph_args = [
+        # use half precision for speed and memory savings in CI environment
+        "--dtype",
+        "float16",
+        "--pipeline-parallel-size",
+        str(PP_SIZE),
+        "--distributed-executor-backend",
+        "mp",
+    ]
+    os.environ["APHRODITE_ATTENTION_BACKEND"] = ATTN_BACKEND
+
+    eager_args = cudagraph_args + ["--enforce-eager"]
+
+    compare_two_settings(MODEL_NAME, eager_args, cudagraph_args)

+ 243 - 0
tests/distributed/test_pynccl.py

@@ -0,0 +1,243 @@
+import multiprocessing
+import os
+from typing import Dict, List
+
+import pytest
+import torch
+import torch.distributed
+
+from aphrodite.common.utils import update_environment_variables
+from aphrodite.distributed.communication_op import (  # noqa
+    tensor_model_parallel_all_reduce)
+from aphrodite.distributed.device_communicators.pynccl import (
+    PyNcclCommunicator)
+from aphrodite.distributed.device_communicators.pynccl_wrapper import (
+    NCCLLibrary)
+from aphrodite.distributed.parallel_state import (
+    ensure_model_parallel_initialized, get_world_group, graph_capture,
+    init_distributed_environment)
+
+
+def distributed_run(fn, world_size):
+    number_of_processes = world_size
+    processes: List[multiprocessing.Process] = []
+    for i in range(number_of_processes):
+        env: Dict[str, str] = {}
+        env['RANK'] = str(i)
+        env['LOCAL_RANK'] = str(i)
+        env['WORLD_SIZE'] = str(number_of_processes)
+        env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
+        env['MASTER_ADDR'] = 'localhost'
+        env['MASTER_PORT'] = '12345'
+        p = multiprocessing.Process(target=fn, args=(env, ))
+        processes.append(p)
+        p.start()
+
+    for p in processes:
+        p.join()
+
+    for p in processes:
+        assert p.exitcode == 0
+
+
+def worker_fn_wrapper(fn):
+    # `multiprocessing.Process` cannot accept environment variables directly
+    # so we need to pass the environment variables as arguments
+    # and update the environment variables in the function
+    def wrapped_fn(env):
+        update_environment_variables(env)
+        local_rank = os.environ['LOCAL_RANK']
+        device = torch.device(f"cuda:{local_rank}")
+        torch.cuda.set_device(device)
+        init_distributed_environment()
+        fn()
+
+    return wrapped_fn
+
+
+@worker_fn_wrapper
+def worker_fn():
+    pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
+                                     device=get_world_group().device)
+    tensor = torch.ones(16, 1024, 1024,
+                        dtype=torch.float32).cuda(pynccl_comm.rank)
+    with pynccl_comm.change_state(enable=True):
+        pynccl_comm.all_reduce(tensor)
+    result = tensor.mean().cpu().item()
+    assert result == pynccl_comm.world_size
+
+
+@pytest.mark.skipif(torch.cuda.device_count() < 2,
+                    reason="Need at least 2 GPUs to run the test.")
+def test_pynccl():
+    distributed_run(worker_fn, 2)
+
+
+@worker_fn_wrapper
+def multiple_allreduce_worker_fn():
+    device = torch.device(f"cuda:{torch.distributed.get_rank()}")
+    groups = [
+        torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
+        torch.distributed.new_group(ranks=[2, 3], backend="gloo")
+    ]
+    group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
+    pynccl_comm = PyNcclCommunicator(group=group, device=device)
+    tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
+    with pynccl_comm.change_state(enable=True):
+        # two groups can communicate independently
+        if torch.distributed.get_rank() in [0, 1]:
+            pynccl_comm.all_reduce(tensor)
+            pynccl_comm.all_reduce(tensor)
+            result = tensor.mean().cpu().item()
+            assert result == 4
+        else:
+            pynccl_comm.all_reduce(tensor)
+            result = tensor.mean().cpu().item()
+            assert result == 2
+
+
+@pytest.mark.skipif(torch.cuda.device_count() < 4,
+                    reason="Need at least 4 GPUs to run the test.")
+def test_pynccl_multiple_allreduce():
+    # this tests pynccl for multiple tp groups, in a standalone way
+    # i.e. call `pynccl_comm.all_reduce` directly
+    distributed_run(multiple_allreduce_worker_fn, 4)
+
+
+@worker_fn_wrapper
+def multiple_allreduce_with_aphrodite_worker_fn():
+    device = torch.device(f"cuda:{torch.distributed.get_rank()}")
+    ensure_model_parallel_initialized(2, 2)
+    tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
+    with graph_capture():
+        # two tp groups can communicate independently
+        if torch.distributed.get_rank() in [0, 1]:
+            tensor = tensor_model_parallel_all_reduce(tensor)
+            tensor = tensor_model_parallel_all_reduce(tensor)
+            result = tensor.mean().cpu().item()
+            assert result == 4
+        else:
+            tensor = tensor_model_parallel_all_reduce(tensor)
+            result = tensor.mean().cpu().item()
+            assert result == 2
+
+
+@pytest.mark.skipif(torch.cuda.device_count() < 4,
+                    reason="Need at least 4 GPUs to run the test.")
+def test_pynccl_multiple_allreduce_with_aphrodite():
+    # this tests pynccl for multiple tp groups, together with aphrodite
+    # i.e. call `tensor_model_parallel_all_reduce`
+    distributed_run(multiple_allreduce_with_aphrodite_worker_fn, 4)
+
+
+@worker_fn_wrapper
+def worker_fn_with_cudagraph():
+    with torch.no_grad():
+        graph = torch.cuda.CUDAGraph()
+        pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
+                                         device=get_world_group().device)
+        # run something in the default stream to initialize torch engine
+        a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
+        torch.cuda.synchronize()
+        with torch.cuda.graph(
+                graph, stream=pynccl_comm.stream), pynccl_comm.change_state(
+                    enable=True):
+            # operation during the graph capture is recorded but not executed
+            # see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa
+            pynccl_comm.all_reduce(a)
+        pynccl_comm.stream.synchronize()
+        assert a.mean().cpu().item() == pynccl_comm.world_size**0
+        graph.replay()
+        pynccl_comm.stream.synchronize()
+        assert a.mean().cpu().item() == pynccl_comm.world_size**1
+
+
+@pytest.mark.skipif(torch.cuda.device_count() < 2,
+                    reason="Need at least 2 GPUs to run the test.")
+def test_pynccl_with_cudagraph():
+    distributed_run(worker_fn_with_cudagraph, 2)
+
+
+@worker_fn_wrapper
+def send_recv_worker_fn():
+    pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
+                                     device=get_world_group().device)
+    if pynccl_comm.rank == 0:
+        tensor = torch.ones(16, 1024, 1024,
+                            dtype=torch.float32).cuda(pynccl_comm.rank)
+    else:
+        tensor = torch.empty(16, 1024, 1024,
+                             dtype=torch.float32).cuda(pynccl_comm.rank)
+    with pynccl_comm.change_state(enable=True):
+        if pynccl_comm.rank == 0:
+            pynccl_comm.send(tensor,
+                             dst=(pynccl_comm.rank + 1) %
+                             pynccl_comm.world_size)
+        else:
+            pynccl_comm.recv(tensor,
+                             src=(pynccl_comm.rank - 1) %
+                             pynccl_comm.world_size)
+    result = tensor.mean().cpu().item()
+    assert result == 1
+
+
+@pytest.mark.skipif(torch.cuda.device_count() < 2,
+                    reason="Need at least 2 GPUs to run the test.")
+def test_pynccl_send_recv():
+    distributed_run(send_recv_worker_fn, 2)
+
+
+@worker_fn_wrapper
+def multiple_send_recv_worker_fn():
+    device = torch.device(f"cuda:{torch.distributed.get_rank()}")
+    groups = [
+        torch.distributed.new_group(ranks=[0, 2], backend="gloo"),
+        torch.distributed.new_group(ranks=[1, 3], backend="gloo")
+    ]
+    group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1]
+    pynccl_comm = PyNcclCommunicator(group=group, device=device)
+    if torch.distributed.get_rank() == 0:
+        tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
+    elif torch.distributed.get_rank() == 1:
+        tensor = 2 * torch.ones(
+            16, 1024, 1024, dtype=torch.float32, device=device)
+    else:
+        tensor = torch.empty(16,
+                             1024,
+                             1024,
+                             dtype=torch.float32,
+                             device=device)
+    with pynccl_comm.change_state(enable=True):
+        if torch.distributed.get_rank() in [0, 1]:
+            pynccl_comm.send(tensor,
+                             dst=(pynccl_comm.rank + 1) %
+                             pynccl_comm.world_size)
+        else:
+            pynccl_comm.recv(tensor,
+                             src=(pynccl_comm.rank - 1) %
+                             pynccl_comm.world_size)
+    result = tensor.mean().cpu().item()
+    if torch.distributed.get_rank() in [0, 2]:
+        assert result == 1
+    else:
+        assert result == 2
+
+
+@pytest.mark.skipif(torch.cuda.device_count() < 4,
+                    reason="Need at least 4 GPUs to run the test.")
+def test_pynccl_multiple_send_recv():
+    distributed_run(multiple_send_recv_worker_fn, 4)
+
+
+def test_ncclGetUniqueId():
+    lib = NCCLLibrary()
+    unique_id = lib.ncclGetUniqueId()
+    # `list(unique_id.internal)` is something like this:
+    # [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0,
+    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+    # as long as the function doesn't raise an exception, we're good
+    assert unique_id is not None

+ 13 - 0
tests/distributed/test_same_node.py

@@ -0,0 +1,13 @@
+import os
+
+import torch
+
+from aphrodite.distributed.parallel_state import in_the_same_node_as
+
+torch.distributed.init_process_group(backend="gloo")
+test_result = all(
+    in_the_same_node_as(torch.distributed.group.WORLD, source_rank=0))
+
+expected = os.environ.get("APHRODITE_TEST_SAME_HOST", "1") == "1"
+assert test_result == expected, f"Expected {expected}, got {test_result}"
+print("Same node test passed!")

+ 35 - 28
tests/distributed/test_shm_broadcast.py

@@ -1,12 +1,22 @@
 import multiprocessing
 import random
 import time
+from typing import List
 
+import numpy as np
 import torch.distributed as dist
 
 from aphrodite.common.utils import update_environment_variables
 from aphrodite.distributed.device_communicators.shm_broadcast import (
-    ShmRingBuffer, ShmRingBufferIO)
+    MessageQueue)
+
+
+def get_arrays(n: int, seed: int = 0) -> List[np.ndarray]:
+    np.random.seed(seed)
+    sizes = np.random.randint(1, 10_000, n)
+    # on average, each array will have 5k elements
+    # with int64, each array will have 40kb
+    return [np.random.randint(1, 100, i) for i in sizes]
 
 
 def distributed_run(fn, world_size):
@@ -46,37 +56,34 @@ def worker_fn_wrapper(fn):
 @worker_fn_wrapper
 def worker_fn():
     writer_rank = 2
-    broadcaster = ShmRingBufferIO.create_from_process_group(
-        dist.group.WORLD, 1024, 2, writer_rank)
+    broadcaster = MessageQueue.create_from_process_group(
+        dist.group.WORLD, 40 * 1024, 2, writer_rank)
     if dist.get_rank() == writer_rank:
-        time.sleep(random.random())
-        broadcaster.broadcast_object(0)
-        time.sleep(random.random())
-        broadcaster.broadcast_object({})
-        time.sleep(random.random())
-        broadcaster.broadcast_object([])
+        seed = random.randint(0, 1000)
+        dist.broadcast_object_list([seed], writer_rank)
     else:
-        time.sleep(random.random())
-        a = broadcaster.broadcast_object(None)
-        time.sleep(random.random())
-        b = broadcaster.broadcast_object(None)
-        time.sleep(random.random())
-        c = broadcaster.broadcast_object(None)
-        assert a == 0
-        assert b == {}
-        assert c == []
+        recv = [None]
+        dist.broadcast_object_list(recv, writer_rank)
+        seed = recv[0]  # type: ignore
+    dist.barrier()
+    # in case we find a race condition
+    # print the seed so that we can reproduce the error
+    print(f"Rank {dist.get_rank()} got seed {seed}")
+    # test broadcasting with about 400MB of data
+    N = 10_000
+    if dist.get_rank() == writer_rank:
+        arrs = get_arrays(N, seed)
+        for x in arrs:
+            broadcaster.broadcast_object(x)
+            time.sleep(random.random() / 1000)
+    else:
+        arrs = get_arrays(N, seed)
+        for x in arrs:
+            y = broadcaster.broadcast_object(None)
+            assert np.array_equal(x, y)
+            time.sleep(random.random() / 1000)
     dist.barrier()
 
 
 def test_shm_broadcast():
     distributed_run(worker_fn, 4)
-
-
-def test_singe_process():
-    buffer = ShmRingBuffer(1, 1024, 4)
-    reader = ShmRingBufferIO(buffer, reader_rank=0)
-    writer = ShmRingBufferIO(buffer, reader_rank=-1)
-    writer.enqueue([0])
-    writer.enqueue([1])
-    assert reader.dequeue() == [0]
-    assert reader.dequeue() == [1]

+ 35 - 0
tests/distributed/test_utils.py

@@ -0,0 +1,35 @@
+import os
+
+import ray
+
+from aphrodite.common.utils import (cuda_device_count_stateless,
+                                    update_environment_variables)
+
+CUDA_VISIBLE_DEVICES = os.environ.get("CUDA_VISIBLE_DEVICES", None)
+
+@ray.remote
+class _CUDADeviceCountStatelessTestActor:
+
+    def get_count(self):
+        return cuda_device_count_stateless()
+
+    def set_cuda_visible_devices(self, cuda_visible_devices: str):
+        update_environment_variables(
+            {"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
+
+    def get_cuda_visible_devices(self):
+        return CUDA_VISIBLE_DEVICES
+
+
+def test_cuda_device_count_stateless():
+    """Test that cuda_device_count_stateless changes return value if
+    CUDA_VISIBLE_DEVICES is changed."""
+    actor = _CUDADeviceCountStatelessTestActor.options(  # type: ignore
+        num_gpus=2).remote()
+    assert sorted(ray.get(
+        actor.get_cuda_visible_devices.remote()).split(",")) == ["0", "1"]
+    assert ray.get(actor.get_count.remote()) == 2
+    ray.get(actor.set_cuda_visible_devices.remote("0"))
+    assert ray.get(actor.get_count.remote()) == 1
+    ray.get(actor.set_cuda_visible_devices.remote(""))
+    assert ray.get(actor.get_count.remote()) == 0

+ 0 - 0
tests/endpoints/__init__.py


+ 89 - 0
tests/endpoints/conftest.py

@@ -0,0 +1,89 @@
+import pytest
+
+
+@pytest.fixture
+def sample_prompts():
+    return [
+        "Hello, my name is",
+        "The president of the United States is",
+        "The capital of France is",
+        "The future of AI is",
+    ]
+
+
+@pytest.fixture
+def sample_token_ids():
+    return [
+        [0],
+        [0, 1],
+        [0, 2, 1],
+        [0, 3, 1, 2],
+    ]
+
+
+@pytest.fixture
+def sample_regex():
+    return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
+            r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
+
+
+@pytest.fixture
+def sample_json_schema():
+    return {
+        "type": "object",
+        "properties": {
+            "name": {
+                "type": "string"
+            },
+            "age": {
+                "type": "integer"
+            },
+            "skills": {
+                "type": "array",
+                "items": {
+                    "type": "string",
+                    "maxLength": 10
+                },
+                "minItems": 3
+            },
+            "work_history": {
+                "type": "array",
+                "items": {
+                    "type": "object",
+                    "properties": {
+                        "company": {
+                            "type": "string"
+                        },
+                        "duration": {
+                            "type": "number"
+                        },
+                        "position": {
+                            "type": "string"
+                        }
+                    },
+                    "required": ["company", "position"]
+                }
+            }
+        },
+        "required": ["name", "age", "skills", "work_history"]
+    }
+
+
+@pytest.fixture
+def sample_guided_choice():
+    return [
+        "Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript",
+        "Ruby", "Swift", "Kotlin"
+    ]
+
+
+@pytest.fixture
+def sample_sql_statements():
+    return ("""
+start: select_statement
+select_statement: "SELECT" column "from" table "where" condition
+column: "col_1" | "col_2"
+table: "table_1" | "table_2"
+condition: column "=" number
+number: "1" | "2"
+""")

+ 0 - 0
tests/endpoints/llm/__init__.py


+ 142 - 0
tests/endpoints/llm/test_encode.py

@@ -0,0 +1,142 @@
+import weakref
+from typing import List
+
+import pytest
+
+from aphrodite import LLM, EmbeddingRequestOutput, PoolingParams
+
+from ...conftest import cleanup
+
+MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
+
+PROMPTS = [
+    "Hello, my name is",
+    "The president of the United States is",
+    "The capital of France is",
+    "The future of AI is",
+]
+
+TOKEN_IDS = [
+    # Using ID={0, 1, 2, 3} results in NaN values,
+    # so we add this offset of 1000
+    [1000],
+    [1000, 1001],
+    [1000, 1002, 1001],
+    [1000, 1003, 1001, 1002],
+]
+
+
+@pytest.fixture(scope="module")
+def llm():
+    # pytest caches the fixture so we use weakref.proxy to
+    # enable garbage collection
+    llm = LLM(model=MODEL_NAME,
+              max_num_batched_tokens=32768,
+              tensor_parallel_size=1,
+              gpu_memory_utilization=0.75,
+              enforce_eager=True)
+
+    with llm.deprecate_legacy_api():
+        yield weakref.proxy(llm)
+
+        del llm
+
+    cleanup()
+
+
+def assert_outputs_equal(o1: List[EmbeddingRequestOutput],
+                         o2: List[EmbeddingRequestOutput]):
+    assert [o.outputs for o in o1] == [o.outputs for o in o2]
+
+
+@pytest.mark.skip_global_cleanup
+@pytest.mark.parametrize('prompt', PROMPTS)
+def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
+    pooling_params = PoolingParams()
+
+    with pytest.warns(DeprecationWarning, match="'prompts'"):
+        v1_output = llm.encode(prompts=prompt, pooling_params=pooling_params)
+
+    v2_output = llm.encode(prompt, pooling_params=pooling_params)
+    assert_outputs_equal(v1_output, v2_output)
+
+    v2_output = llm.encode({"prompt": prompt}, pooling_params=pooling_params)
+    assert_outputs_equal(v1_output, v2_output)
+
+
+@pytest.mark.skip_global_cleanup
+@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
+def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
+                                                    prompt_token_ids):
+    pooling_params = PoolingParams()
+
+    with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
+        v1_output = llm.encode(prompt_token_ids=prompt_token_ids,
+                               pooling_params=pooling_params)
+
+    v2_output = llm.encode({"prompt_token_ids": prompt_token_ids},
+                           pooling_params=pooling_params)
+    assert_outputs_equal(v1_output, v2_output)
+
+
+@pytest.mark.skip_global_cleanup
+def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
+    pooling_params = PoolingParams()
+
+    with pytest.warns(DeprecationWarning, match="'prompts'"):
+        v1_output = llm.encode(prompts=PROMPTS, pooling_params=pooling_params)
+
+    v2_output = llm.encode(PROMPTS, pooling_params=pooling_params)
+    assert_outputs_equal(v1_output, v2_output)
+
+    v2_output = llm.encode(
+        [{
+            "prompt": p
+        } for p in PROMPTS],
+        pooling_params=pooling_params,
+    )
+    assert_outputs_equal(v1_output, v2_output)
+
+
+@pytest.mark.skip_global_cleanup
+def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
+    pooling_params = PoolingParams()
+
+    with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
+        v1_output = llm.encode(prompt_token_ids=TOKEN_IDS,
+                               pooling_params=pooling_params)
+
+    v2_output = llm.encode(
+        [{
+            "prompt_token_ids": p
+        } for p in TOKEN_IDS],
+        pooling_params=pooling_params,
+    )
+    assert_outputs_equal(v1_output, v2_output)
+
+
+@pytest.mark.skip_global_cleanup
+def test_multiple_pooling_params(llm: LLM):
+    pooling_params = [
+        PoolingParams(),
+        PoolingParams(),
+        PoolingParams(),
+        PoolingParams(),
+    ]
+
+    # Multiple PoolingParams should be matched with each prompt
+    outputs = llm.encode(PROMPTS, pooling_params=pooling_params)
+    assert len(PROMPTS) == len(outputs)
+
+    # Exception raised, if the size of params does not match the size of prompts
+    with pytest.raises(ValueError):
+        outputs = llm.encode(PROMPTS, pooling_params=pooling_params[:3])
+
+    # Single PoolingParams should be applied to every prompt
+    single_pooling_params = PoolingParams()
+    outputs = llm.encode(PROMPTS, pooling_params=single_pooling_params)
+    assert len(PROMPTS) == len(outputs)
+
+    # pooling_params is None, default params should be applied
+    outputs = llm.encode(PROMPTS, pooling_params=None)
+    assert len(PROMPTS) == len(outputs)

+ 161 - 0
tests/endpoints/llm/test_generate.py

@@ -0,0 +1,161 @@
+import weakref
+from typing import List
+
+import pytest
+
+from aphrodite import LLM, RequestOutput, SamplingParams
+
+from ...conftest import cleanup
+
+MODEL_NAME = "facebook/opt-125m"
+
+PROMPTS = [
+    "Hello, my name is",
+    "The president of the United States is",
+    "The capital of France is",
+    "The future of AI is",
+]
+
+TOKEN_IDS = [
+    [0],
+    [0, 1],
+    [0, 2, 1],
+    [0, 3, 1, 2],
+]
+
+
+@pytest.fixture(scope="module")
+def llm():
+    # pytest caches the fixture so we use weakref.proxy to
+    # enable garbage collection
+    llm = LLM(model=MODEL_NAME,
+              max_num_batched_tokens=4096,
+              tensor_parallel_size=1,
+              gpu_memory_utilization=0.10,
+              enforce_eager=True)
+
+    with llm.deprecate_legacy_api():
+        yield weakref.proxy(llm)
+
+        del llm
+
+    cleanup()
+
+
+def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]):
+    assert [o.outputs for o in o1] == [o.outputs for o in o2]
+
+
+@pytest.mark.skip_global_cleanup
+@pytest.mark.parametrize('prompt', PROMPTS)
+def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
+    sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
+
+    with pytest.warns(DeprecationWarning, match="'prompts'"):
+        v1_output = llm.generate(prompts=prompt,
+                                 sampling_params=sampling_params)
+
+    v2_output = llm.generate(prompt, sampling_params=sampling_params)
+    assert_outputs_equal(v1_output, v2_output)
+
+    v2_output = llm.generate({"prompt": prompt},
+                             sampling_params=sampling_params)
+    assert_outputs_equal(v1_output, v2_output)
+
+
+@pytest.mark.skip_global_cleanup
+@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
+def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
+                                                    prompt_token_ids):
+    sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
+
+    with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
+        v1_output = llm.generate(prompt_token_ids=prompt_token_ids,
+                                 sampling_params=sampling_params)
+
+    v2_output = llm.generate({"prompt_token_ids": prompt_token_ids},
+                             sampling_params=sampling_params)
+    assert_outputs_equal(v1_output, v2_output)
+
+
+@pytest.mark.skip_global_cleanup
+def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
+    sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
+
+    with pytest.warns(DeprecationWarning, match="'prompts'"):
+        v1_output = llm.generate(prompts=PROMPTS,
+                                 sampling_params=sampling_params)
+
+    v2_output = llm.generate(PROMPTS, sampling_params=sampling_params)
+    assert_outputs_equal(v1_output, v2_output)
+
+    v2_output = llm.generate(
+        [{
+            "prompt": p
+        } for p in PROMPTS],
+        sampling_params=sampling_params,
+    )
+    assert_outputs_equal(v1_output, v2_output)
+
+
+@pytest.mark.skip_global_cleanup
+def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
+    sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
+
+    with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
+        v1_output = llm.generate(prompt_token_ids=TOKEN_IDS,
+                                 sampling_params=sampling_params)
+
+    v2_output = llm.generate(
+        [{
+            "prompt_token_ids": p
+        } for p in TOKEN_IDS],
+        sampling_params=sampling_params,
+    )
+    assert_outputs_equal(v1_output, v2_output)
+
+
+@pytest.mark.skip_global_cleanup
+def test_multiple_sampling_params(llm: LLM):
+    sampling_params = [
+        SamplingParams(temperature=0.01, top_p=0.95),
+        SamplingParams(temperature=0.3, top_p=0.95),
+        SamplingParams(temperature=0.7, top_p=0.95),
+        SamplingParams(temperature=0.99, top_p=0.95),
+    ]
+
+    # Multiple SamplingParams should be matched with each prompt
+    outputs = llm.generate(PROMPTS, sampling_params=sampling_params)
+    assert len(PROMPTS) == len(outputs)
+
+    # Exception raised, if the size of params does not match the size of prompts
+    with pytest.raises(ValueError):
+        outputs = llm.generate(PROMPTS, sampling_params=sampling_params[:3])
+
+    # Single SamplingParams should be applied to every prompt
+    single_sampling_params = SamplingParams(temperature=0.3, top_p=0.95)
+    outputs = llm.generate(PROMPTS, sampling_params=single_sampling_params)
+    assert len(PROMPTS) == len(outputs)
+
+    # sampling_params is None, default params should be applied
+    outputs = llm.generate(PROMPTS, sampling_params=None)
+    assert len(PROMPTS) == len(outputs)
+
+
+def test_chat():
+
+    llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
+
+    prompt1 = "Explain the concept of entropy."
+    messages = [
+        {
+            "role": "system",
+            "content": "You are a helpful assistant"
+        },
+        {
+            "role": "user",
+            "content": prompt1
+        },
+    ]
+    outputs = llm.chat(messages)
+    assert len(outputs) == 1

+ 67 - 0
tests/endpoints/llm/test_generate_multiple_loras.py

@@ -0,0 +1,67 @@
+import weakref
+
+import pytest
+# downloading lora to test lora requests
+from huggingface_hub import snapshot_download
+
+from aphrodite import LLM
+from aphrodite.lora.request import LoRARequest
+
+from ...conftest import cleanup
+
+MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
+
+PROMPTS = [
+    "Hello, my name is",
+    "The president of the United States is",
+    "The capital of France is",
+    "The future of AI is",
+]
+
+LORA_NAME = "typeof/zephyr-7b-beta-lora"
+
+
+@pytest.fixture(scope="module")
+def llm():
+    # pytest caches the fixture so we use weakref.proxy to
+    # enable garbage collection
+    llm = LLM(model=MODEL_NAME,
+              tensor_parallel_size=1,
+              max_model_len=8192,
+              enable_lora=True,
+              max_loras=4,
+              max_lora_rank=64,
+              max_num_seqs=128,
+              enforce_eager=True)
+
+    with llm.deprecate_legacy_api():
+        yield weakref.proxy(llm)
+
+        del llm
+
+    cleanup()
+
+
+@pytest.fixture(scope="module")
+def zephyr_lora_files():
+    return snapshot_download(repo_id=LORA_NAME)
+
+
+@pytest.mark.skip_global_cleanup
+def test_multiple_lora_requests(llm: LLM, zephyr_lora_files):
+    lora_request = [
+        LoRARequest(LORA_NAME, idx + 1, zephyr_lora_files)
+        for idx in range(len(PROMPTS))
+    ]
+    # Multiple SamplingParams should be matched with each prompt
+    outputs = llm.generate(PROMPTS, lora_request=lora_request)
+    assert len(PROMPTS) == len(outputs)
+
+    # Exception raised, if the size of params does not match the size of prompts
+    with pytest.raises(ValueError):
+        outputs = llm.generate(PROMPTS, lora_request=lora_request[:1])
+
+    # Single LoRARequest should be applied to every prompt
+    single_lora_request = lora_request[0]
+    outputs = llm.generate(PROMPTS, lora_request=single_lora_request)
+    assert len(PROMPTS) == len(outputs)

+ 142 - 0
tests/endpoints/llm/test_guided_generate.py

@@ -0,0 +1,142 @@
+import json
+import re
+import weakref
+
+import jsonschema
+import pytest
+
+from aphrodite.common.outputs import RequestOutput
+from aphrodite.common.sampling_params import SamplingParams
+from aphrodite.endpoints.llm import LLM
+
+from ...conftest import cleanup
+
+MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
+
+
+@pytest.fixture(scope="module")
+def llm():
+    # pytest caches the fixture so we use weakref.proxy to
+    # enable garbage collection
+    llm = LLM(model=MODEL_NAME, max_model_len=1024)
+
+    with llm.deprecate_legacy_api():
+        yield weakref.proxy(llm)
+        del llm
+    cleanup()
+
+
+@pytest.mark.skip_global_cleanup
+def test_guided_regex(sample_regex, llm):
+    sampling_params = SamplingParams(
+        temperature=0.8,
+        top_p=0.95,
+    )
+    outputs = llm.generate(
+        prompts=[
+            f"Give an example IPv4 address with this regex: {sample_regex}"
+        ] * 2,
+        sampling_params=sampling_params,
+        use_tqdm=True,
+        guided_options_request=dict(guided_regex=sample_regex))
+
+    assert outputs is not None
+    for output in outputs:
+        assert output is not None
+        assert isinstance(output, RequestOutput)
+        prompt = output.prompt
+        generated_text = output.outputs[0].text
+        print(generated_text)
+        assert generated_text is not None
+        assert re.fullmatch(sample_regex, generated_text) is not None
+        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
+
+
+@pytest.mark.skip_global_cleanup
+def test_guided_json_completion(sample_json_schema, llm):
+    sampling_params = SamplingParams(
+        temperature=1.0,
+        max_tokens=1000,
+    )
+    outputs = llm.generate(
+        prompts=[
+            f"Give an example JSON for an employee profile "
+            f"that fits this schema: {sample_json_schema}"
+        ] * 2,
+        sampling_params=sampling_params,
+        use_tqdm=True,
+        guided_options_request=dict(guided_json=sample_json_schema))
+
+    assert outputs is not None
+
+    for output in outputs:
+        assert output is not None
+        assert isinstance(output, RequestOutput)
+        prompt = output.prompt
+
+        generated_text = output.outputs[0].text
+        assert generated_text is not None
+        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
+        output_json = json.loads(generated_text)
+        jsonschema.validate(instance=output_json, schema=sample_json_schema)
+
+
+@pytest.mark.skip_global_cleanup
+def test_guided_choice_completion(sample_guided_choice, llm):
+    sampling_params = SamplingParams(
+        temperature=0.8,
+        top_p=0.95,
+    )
+    outputs = llm.generate(
+        prompts="The best language for type-safe systems programming is ",
+        sampling_params=sampling_params,
+        use_tqdm=True,
+        guided_options_request=dict(guided_choice=sample_guided_choice))
+
+    assert outputs is not None
+    for output in outputs:
+        assert output is not None
+        assert isinstance(output, RequestOutput)
+        prompt = output.prompt
+        generated_text = output.outputs[0].text
+        print(generated_text)
+        assert generated_text is not None
+        assert generated_text in sample_guided_choice
+        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
+
+
+@pytest.mark.skip_global_cleanup
+def test_guided_grammar(sample_sql_statements, llm):
+
+    sampling_params = SamplingParams(
+        temperature=0.8,
+        top_p=0.95,
+        max_tokens=1000,
+    )
+    outputs = llm.generate(
+        prompts=("Generate a sql state that select col_1 from "
+                 "table_1 where it is equals to 1"),
+        sampling_params=sampling_params,
+        use_tqdm=True,
+        guided_options_request=dict(guided_grammar=sample_sql_statements))
+
+    assert outputs is not None
+    for output in outputs:
+        assert output is not None
+        assert isinstance(output, RequestOutput)
+        prompt = output.prompt
+
+        generated_text = output.outputs[0].text
+        assert generated_text is not None
+        # use Lark to parse the output, and make sure it's a valid parse tree
+        from lark import Lark
+        parser = Lark(sample_sql_statements)
+        parser.parse(generated_text)
+
+        # remove spaces for comparison b/c we removed them in the grammar
+        ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
+            " ", "")
+
+        assert generated_text.strip() == ground_truth
+
+        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

+ 0 - 0
tests/endpoints/openai/__init__.py


+ 355 - 0
tests/endpoints/openai/test_audio.py

@@ -0,0 +1,355 @@
+import math
+import sys
+import time
+from typing import Dict, List, Optional, Tuple, Union, cast
+from unittest.mock import patch
+
+import librosa
+import numpy as np
+import openai
+import pytest
+import requests
+import torch
+
+from aphrodite import ModelRegistry
+from aphrodite.common.config import MultiModalConfig
+from aphrodite.common.utils import get_open_port
+from aphrodite.inputs import INPUT_REGISTRY
+from aphrodite.inputs.data import LLMInputs
+from aphrodite.inputs.registry import InputContext
+from aphrodite.modeling.models.interfaces import SupportsMultiModal
+from aphrodite.modeling.models.opt import OPTForCausalLM
+from aphrodite.multimodal import MULTIMODAL_REGISTRY
+from aphrodite.multimodal.base import MultiModalInputs
+from aphrodite.multimodal.image import (cached_get_tokenizer,
+                                        repeat_and_pad_image_tokens)
+from aphrodite.multimodal.utils import encode_audio_base64, fetch_audio
+
+from ...utils import APHRODITE_PATH
+
+chatml_jinja_path = APHRODITE_PATH / "examples/chat_templates/chatml.jinja"
+assert chatml_jinja_path.exists()
+
+MODEL_NAME = "facebook/opt-125m"
+TEST_AUDIO_URLS = [
+    "https://upload.wikimedia.org/wikipedia/en/b/bf/Dave_Niehaus_Winning_Call_1995_AL_Division_Series.ogg",
+]
+
+
+def server_function(port):
+
+    def fake_input_mapper(ctx: InputContext, data: object):
+        assert isinstance(data, tuple)
+        (audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], data)
+
+        # Resample it to 1 sample per second
+        audio = librosa.resample(audio, orig_sr=sr, target_sr=1)
+        return MultiModalInputs({"processed_audio": torch.from_numpy(audio)})
+
+    def fake_input_processor(ctx: InputContext, llm_inputs: LLMInputs):
+        multi_modal_data = llm_inputs.get("multi_modal_data")
+        if multi_modal_data is None or "audio" not in multi_modal_data:
+            return llm_inputs
+
+        audio, sr = multi_modal_data.get("audio")
+        audio_duration = math.ceil(len(audio) / sr)
+
+        new_prompt, new_token_ids = repeat_and_pad_image_tokens(
+            cached_get_tokenizer(ctx.model_config.tokenizer),
+            llm_inputs.get("prompt"),
+            llm_inputs["prompt_token_ids"],
+            image_token_id=62,  # "_"
+            repeat_count=audio_duration)
+
+        return LLMInputs(prompt_token_ids=new_token_ids,
+                         prompt=new_prompt,
+                         multi_modal_data=multi_modal_data)
+
+    @MULTIMODAL_REGISTRY.register_input_mapper("audio", fake_input_mapper)
+    @MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
+        "audio", lambda *_, **__: 100)
+    @INPUT_REGISTRY.register_input_processor(fake_input_processor)
+    class FakeAudioModel(OPTForCausalLM, SupportsMultiModal):
+
+        def __init__(self, *args, multimodal_config: MultiModalConfig,
+                     **kwargs):
+            assert multimodal_config is not None
+            super().__init__(*args, **kwargs)
+
+        def forward(
+            self,
+            *args,
+            processed_audio: Optional[torch.Tensor] = None,
+            **kwargs,
+        ) -> torch.Tensor:
+            return super().forward(*args, **kwargs)
+
+    ModelRegistry.register_model("OPTForCausalLM", FakeAudioModel)
+
+    with patch(
+            "aphrodite.endpoints.chat_utils._mm_token_str",
+            lambda *_, **__: "_"), patch(
+                "aphrodite.modeling.models.ModelRegistry.is_multimodal_model"
+            ) as mock:
+        mock.return_value = True
+        sys.argv = ["placeholder.py"] + \
+            (f"--model {MODEL_NAME} --gpu-memory-utilization 0.10 "
+            "--dtype bfloat16 --enforce-eager --api-key token-abc123 "
+            f"--port {port} --chat-template {chatml_jinja_path} "
+            "--disable-frontend-multiprocessing").split()
+        import runpy
+        runpy.run_module('aphrodite.endpoints.openai.api_server',
+                         run_name='__main__')
+
+
+@pytest.fixture(scope="module")
+def client():
+    port = get_open_port()
+    ctx = torch.multiprocessing.get_context("spawn")
+    server = ctx.Process(target=server_function, args=(port, ))
+    server.start()
+    MAX_SERVER_START_WAIT_S = 60
+    client = openai.AsyncOpenAI(
+        base_url=f"http://localhost:{port}/v1",
+        api_key="token-abc123",
+    )
+    # run health check
+    health_url = f"http://localhost:{port}/health"
+    start = time.time()
+    while True:
+        try:
+            if requests.get(health_url).status_code == 200:
+                break
+        except Exception as err:
+            result = server.exitcode
+            if result is not None:
+                raise RuntimeError("Server exited unexpectedly.") from err
+
+            time.sleep(0.5)
+            if time.time() - start > MAX_SERVER_START_WAIT_S:
+                raise RuntimeError("Server failed to start in time.") from err
+
+    try:
+        yield client
+    finally:
+        server.kill()
+
+
+@pytest.fixture(scope="session")
+def base64_encoded_audio() -> Dict[str, str]:
+    return {
+        audio_url: encode_audio_base64(*fetch_audio(audio_url))
+        for audio_url in TEST_AUDIO_URLS
+    }
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
+async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
+                                         model_name: str, audio_url: str):
+    messages = [{
+        "role":
+        "user",
+        "content": [
+            {
+                "type": "audio_url",
+                "audio_url": {
+                    "url": audio_url
+                }
+            },
+            {
+                "type": "text",
+                "text": "What's happening in this audio?"
+            },
+        ],
+    }]
+
+    # test single completion
+    chat_completion = await client.chat.completions.create(model=model_name,
+                                                           messages=messages,
+                                                           max_tokens=10,
+                                                           logprobs=True,
+                                                           top_logprobs=5)
+    assert len(chat_completion.choices) == 1
+
+    choice = chat_completion.choices[0]
+    assert choice.finish_reason == "length"
+    assert chat_completion.usage == openai.types.CompletionUsage(
+        completion_tokens=10, prompt_tokens=36, total_tokens=46)
+
+    message = choice.message
+    message = chat_completion.choices[0].message
+    assert message.content is not None and len(message.content) >= 10
+    assert message.role == "assistant"
+    messages.append({"role": "assistant", "content": message.content})
+
+    # test multi-turn dialogue
+    messages.append({"role": "user", "content": "express your result in json"})
+    chat_completion = await client.chat.completions.create(
+        model=model_name,
+        messages=messages,
+        max_tokens=10,
+    )
+    message = chat_completion.choices[0].message
+    assert message.content is not None and len(message.content) >= 0
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
+async def test_single_chat_session_audio_base64encoded(
+        client: openai.AsyncOpenAI, model_name: str, audio_url: str,
+        base64_encoded_audio: Dict[str, str]):
+
+    messages = [{
+        "role":
+        "user",
+        "content": [
+            {
+                "type": "audio_url",
+                "audio_url": {
+                    "url":
+                    f"data:audio/wav;base64,{base64_encoded_audio[audio_url]}"
+                }
+            },
+            {
+                "type": "text",
+                "text": "What's happening in this audio?"
+            },
+        ],
+    }]
+
+    # test single completion
+    chat_completion = await client.chat.completions.create(model=model_name,
+                                                           messages=messages,
+                                                           max_tokens=10,
+                                                           logprobs=True,
+                                                           top_logprobs=5)
+    assert len(chat_completion.choices) == 1
+
+    choice = chat_completion.choices[0]
+    assert choice.finish_reason == "length"
+    assert chat_completion.usage == openai.types.CompletionUsage(
+        completion_tokens=10, prompt_tokens=36, total_tokens=46)
+
+    message = choice.message
+    message = chat_completion.choices[0].message
+    assert message.content is not None and len(message.content) >= 10
+    assert message.role == "assistant"
+    messages.append({"role": "assistant", "content": message.content})
+
+    # test multi-turn dialogue
+    messages.append({"role": "user", "content": "express your result in json"})
+    chat_completion = await client.chat.completions.create(
+        model=model_name,
+        messages=messages,
+        max_tokens=10,
+    )
+    message = chat_completion.choices[0].message
+    assert message.content is not None and len(message.content) >= 0
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
+async def test_chat_streaming_audio(client: openai.AsyncOpenAI,
+                                    model_name: str, audio_url: str):
+    messages = [{
+        "role":
+        "user",
+        "content": [
+            {
+                "type": "audio_url",
+                "audio_url": {
+                    "url": audio_url
+                }
+            },
+            {
+                "type": "text",
+                "text": "What's happening in this audio?"
+            },
+        ],
+    }]
+
+    # test single completion
+    chat_completion = await client.chat.completions.create(
+        model=model_name,
+        messages=messages,
+        max_tokens=10,
+        temperature=0.0,
+    )
+    output = chat_completion.choices[0].message.content
+    stop_reason = chat_completion.choices[0].finish_reason
+
+    # test streaming
+    stream = await client.chat.completions.create(
+        model=model_name,
+        messages=messages,
+        max_tokens=10,
+        temperature=0.0,
+        stream=True,
+    )
+    chunks: List[str] = []
+    finish_reason_count = 0
+    async for chunk in stream:
+        delta = chunk.choices[0].delta
+        if delta.role:
+            assert delta.role == "assistant"
+        if delta.content:
+            chunks.append(delta.content)
+        if chunk.choices[0].finish_reason is not None:
+            finish_reason_count += 1
+    # finish reason should only return in last block
+    assert finish_reason_count == 1
+    assert chunk.choices[0].finish_reason == stop_reason
+    assert delta.content
+    assert "".join(chunks) == output
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
+async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str,
+                                 audio_url: str):
+
+    messages = [{
+        "role":
+        "user",
+        "content": [
+            {
+                "type": "audio_url",
+                "audio_url": {
+                    "url": audio_url
+                }
+            },
+            {
+                "type": "audio_url",
+                "audio_url": {
+                    "url": audio_url
+                }
+            },
+            {
+                "type": "text",
+                "text": "What's happening in this audio?"
+            },
+        ],
+    }]
+
+    with pytest.raises(openai.BadRequestError):  # test multi-audio input
+        await client.chat.completions.create(
+            model=model_name,
+            messages=messages,
+            max_tokens=10,
+            temperature=0.0,
+        )
+
+    # the server should still work afterwards
+    completion = await client.completions.create(
+        model=model_name,
+        prompt=[0, 0, 0, 0, 0],
+        max_tokens=5,
+        temperature=0.0,
+    )
+    completion = completion.choices[0].text
+    assert completion is not None and len(completion) >= 0

+ 52 - 0
tests/endpoints/openai/test_basic.py

@@ -0,0 +1,52 @@
+from http import HTTPStatus
+
+import openai
+import pytest
+import requests
+
+from aphrodite.version import __version__ as APHRODITE_VERSION
+
+from ...utils import RemoteOpenAIServer
+
+MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
+
+
+@pytest.fixture(scope="module")
+def server():
+    args = [
+        # use half precision for speed and memory savings in CI environment
+        "--dtype",
+        "bfloat16",
+        "--max-model-len",
+        "8192",
+        "--enforce-eager",
+        "--max-num-seqs",
+        "128",
+    ]
+
+    with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
+        yield remote_server
+
+
+@pytest.fixture(scope="module")
+def client(server):
+    return server.get_async_client()
+
+
+@pytest.mark.asyncio
+async def test_show_version(client: openai.AsyncOpenAI):
+    base_url = str(client.base_url)[:-3].strip("/")
+
+    response = requests.get(base_url + "/version")
+    response.raise_for_status()
+
+    assert response.json() == {"version": APHRODITE_VERSION}
+
+
+@pytest.mark.asyncio
+async def test_check_health(client: openai.AsyncOpenAI):
+    base_url = str(client.base_url)[:-3].strip("/")
+
+    response = requests.get(base_url + "/health")
+
+    assert response.status_code == HTTPStatus.OK

+ 842 - 0
tests/endpoints/openai/test_chat.py

@@ -0,0 +1,842 @@
+# imports for guided decoding tests
+import json
+import re
+from typing import List
+
+import jsonschema
+import openai  # use the official client for correctness check
+import pytest
+import torch
+from openai import BadRequestError
+
+from ...utils import RemoteOpenAIServer
+from .test_completion import zephyr_lora_added_tokens_files  # noqa: F401
+from .test_completion import zephyr_lora_files  # noqa: F401
+
+# any model with a chat template should work here
+MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
+# technically this needs Mistral-7B-v0.1 as base, but we're not testing
+# generation quality here
+LORA_NAME = "typeof/zephyr-7b-beta-lora"
+
+
+@pytest.fixture(scope="module")
+def server(zephyr_lora_files, zephyr_lora_added_tokens_files):  # noqa: F811
+    args = [
+        # use half precision for speed and memory savings in CI environment
+        "--dtype",
+        "bfloat16",
+        "--max-model-len",
+        "8192",
+        "--enforce-eager",
+        # lora config below
+        "--enable-lora",
+        "--lora-modules",
+        f"zephyr-lora={zephyr_lora_files}",
+        f"zephyr-lora2={zephyr_lora_added_tokens_files}",
+        "--max-lora-rank",
+        "64",
+        "--max-cpu-loras",
+        "2",
+        "--max-num-seqs",
+        "128",
+    ]
+
+    with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
+        yield remote_server
+
+
+@pytest.fixture(scope="module")
+def client(server):
+    return server.get_async_client()
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    # first test base model, then test loras
+    "model_name",
+    [MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
+)
+async def test_no_logprobs_chat(client: openai.AsyncOpenAI, model_name: str):
+    messages = [{
+        "role": "system",
+        "content": "you are a helpful assistant"
+    }, {
+        "role": "user",
+        "content": "what is 1+1?"
+    }]
+
+    chat_completion = await client.chat.completions.create(model=model_name,
+                                                           messages=messages,
+                                                           max_tokens=5,
+                                                           temperature=0.0,
+                                                           logprobs=False)
+
+    choice = chat_completion.choices[0]
+    assert choice.logprobs is None
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    # just test 1 lora hereafter
+    "model_name",
+    [MODEL_NAME, "zephyr-lora"],
+)
+async def test_zero_logprobs_chat(client: openai.AsyncOpenAI, model_name: str):
+    messages = [{
+        "role": "system",
+        "content": "you are a helpful assistant"
+    }, {
+        "role": "user",
+        "content": "what is 1+1?"
+    }]
+
+    chat_completion = await client.chat.completions.create(model=model_name,
+                                                           messages=messages,
+                                                           max_tokens=5,
+                                                           temperature=0.0,
+                                                           logprobs=True,
+                                                           top_logprobs=0)
+
+    choice = chat_completion.choices[0]
+    assert choice.logprobs is not None
+    assert choice.logprobs.content is not None
+    assert len(choice.logprobs.content[0].top_logprobs) == 0
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "model_name",
+    [MODEL_NAME, "zephyr-lora"],
+)
+async def test_some_logprobs_chat(client: openai.AsyncOpenAI, model_name: str):
+    messages = [{
+        "role": "system",
+        "content": "you are a helpful assistant"
+    }, {
+        "role": "user",
+        "content": "what is 1+1?"
+    }]
+
+    chat_completion = await client.chat.completions.create(model=model_name,
+                                                           messages=messages,
+                                                           max_tokens=5,
+                                                           temperature=0.0,
+                                                           logprobs=True,
+                                                           top_logprobs=5)
+
+    choice = chat_completion.choices[0]
+    assert choice.logprobs is not None
+    assert choice.logprobs.content is not None
+    assert len(choice.logprobs.content[0].top_logprobs) == 5
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "model_name",
+    [MODEL_NAME, "zephyr-lora"],
+)
+async def test_too_many_chat_logprobs(client: openai.AsyncOpenAI,
+                                      model_name: str):
+    messages = [{
+        "role": "system",
+        "content": "you are a helpful assistant"
+    }, {
+        "role": "user",
+        "content": "what is 1+1?"
+    }]
+
+    # Default max_logprobs is 20, so this should raise an error
+    with pytest.raises((openai.BadRequestError, openai.APIError)):
+        stream = await client.chat.completions.create(model=model_name,
+                                                      messages=messages,
+                                                      max_tokens=10,
+                                                      logprobs=True,
+                                                      top_logprobs=21,
+                                                      stream=True)
+        async for chunk in stream:
+            ...
+
+    with pytest.raises(openai.BadRequestError):
+        await client.chat.completions.create(model=model_name,
+                                             messages=messages,
+                                             max_tokens=10,
+                                             logprobs=True,
+                                             top_logprobs=30,
+                                             stream=False)
+
+    # the server should still work afterwards
+    chat_completion = await client.chat.completions.create(model=model_name,
+                                                           messages=messages,
+                                                           max_tokens=10,
+                                                           stream=False)
+    message = chat_completion.choices[0].message
+    assert message.content is not None and len(message.content) >= 0
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "model_name",
+    [MODEL_NAME, "zephyr-lora"],
+)
+async def test_single_chat_session(client: openai.AsyncOpenAI,
+                                   model_name: str):
+    messages = [{
+        "role": "system",
+        "content": "you are a helpful assistant"
+    }, {
+        "role": "user",
+        "content": "what is 1+1?"
+    }]
+
+    # test single completion
+    chat_completion = await client.chat.completions.create(model=model_name,
+                                                           messages=messages,
+                                                           max_tokens=10,
+                                                           logprobs=True,
+                                                           top_logprobs=5)
+    assert chat_completion.id is not None
+    assert len(chat_completion.choices) == 1
+
+    choice = chat_completion.choices[0]
+    assert choice.finish_reason == "length"
+    assert chat_completion.usage == openai.types.CompletionUsage(
+        completion_tokens=10, prompt_tokens=37, total_tokens=47)
+
+    message = choice.message
+    assert message.content is not None and len(message.content) >= 10
+    assert message.role == "assistant"
+    messages.append({"role": "assistant", "content": message.content})
+
+    # test multi-turn dialogue
+    messages.append({"role": "user", "content": "express your result in json"})
+    chat_completion = await client.chat.completions.create(
+        model=model_name,
+        messages=messages,
+        max_tokens=10,
+    )
+    message = chat_completion.choices[0].message
+    assert message.content is not None and len(message.content) >= 0
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    # just test 1 lora hereafter
+    "model_name",
+    [MODEL_NAME, "zephyr-lora"],
+)
+async def test_chat_streaming(client: openai.AsyncOpenAI, model_name: str):
+    messages = [{
+        "role": "system",
+        "content": "you are a helpful assistant"
+    }, {
+        "role": "user",
+        "content": "what is 1+1?"
+    }]
+
+    # test single completion
+    chat_completion = await client.chat.completions.create(
+        model=model_name,
+        messages=messages,
+        max_tokens=10,
+        temperature=0.0,
+    )
+    output = chat_completion.choices[0].message.content
+    stop_reason = chat_completion.choices[0].finish_reason
+
+    # test streaming
+    stream = await client.chat.completions.create(
+        model=model_name,
+        messages=messages,
+        max_tokens=10,
+        temperature=0.0,
+        stream=True,
+    )
+    chunks: List[str] = []
+    finish_reason_count = 0
+    async for chunk in stream:
+        delta = chunk.choices[0].delta
+        if delta.role:
+            assert delta.role == "assistant"
+        if delta.content:
+            chunks.append(delta.content)
+        if chunk.choices[0].finish_reason is not None:
+            finish_reason_count += 1
+    # finish reason should only return in last block
+    assert finish_reason_count == 1
+    assert chunk.choices[0].finish_reason == stop_reason
+    assert delta.content
+    assert "".join(chunks) == output
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "model_name",
+    ["HuggingFaceH4/zephyr-7b-beta", "zephyr-lora"],
+)
+async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
+                                              model_name: str):
+    messages = [{
+        "role": "system",
+        "content": "You are a helpful assistant."
+    }, {
+        "role": "user",
+        "content": "What is the capital of France?"
+    }]
+
+    # Test stream=True, stream_options={"include_usage": False}
+    stream = await client.chat.completions.create(
+        model=model_name,
+        messages=messages,
+        max_tokens=10,
+        temperature=0.0,
+        stream=True,
+        stream_options={"include_usage": False})
+    async for chunk in stream:
+        assert chunk.usage is None
+
+    # Test stream=True, stream_options={"include_usage": True,
+    #                                   "continuous_usage_stats": False}}
+    stream = await client.chat.completions.create(model=model_name,
+                                                  messages=messages,
+                                                  max_tokens=10,
+                                                  temperature=0.0,
+                                                  stream=True,
+                                                  stream_options={
+                                                      "include_usage":
+                                                      True,
+                                                      "continuous_usage_stats":
+                                                      False
+                                                  })
+
+    async for chunk in stream:
+        if chunk.choices[0].finish_reason is None:
+            assert chunk.usage is None
+        else:
+            assert chunk.usage is None
+            final_chunk = await stream.__anext__()
+            assert final_chunk.usage is not None
+            assert final_chunk.usage.prompt_tokens > 0
+            assert final_chunk.usage.completion_tokens > 0
+            assert final_chunk.usage.total_tokens == (
+                final_chunk.usage.prompt_tokens +
+                final_chunk.usage.completion_tokens)
+            assert final_chunk.choices == []
+
+    # Test stream=False, stream_options={"include_usage": None}
+    with pytest.raises(BadRequestError):
+        await client.chat.completions.create(
+            model=model_name,
+            messages=messages,
+            max_tokens=10,
+            temperature=0.0,
+            stream=False,
+            stream_options={"include_usage": None})
+
+    # Test stream=False, stream_options={"include_usage": True}
+    with pytest.raises(BadRequestError):
+        await client.chat.completions.create(
+            model=model_name,
+            messages=messages,
+            max_tokens=10,
+            temperature=0.0,
+            stream=False,
+            stream_options={"include_usage": True})
+
+    # Test stream=True, stream_options={"include_usage": True,
+    #                           "continuous_usage_stats": True}
+    stream = await client.chat.completions.create(
+        model=model_name,
+        messages=messages,
+        max_tokens=10,
+        temperature=0.0,
+        stream=True,
+        stream_options={
+            "include_usage": True,
+            "continuous_usage_stats": True
+        },
+    )
+    async for chunk in stream:
+        assert chunk.usage.prompt_tokens >= 0
+        assert chunk.usage.completion_tokens >= 0
+        assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens +
+                                            chunk.usage.completion_tokens)
+
+
+# NOTE: Not sure why, but when I place this after `test_guided_regex_chat`
+# (i.e. using the same ordering as in the Completions API tests), the test
+# will fail on the second `guided_decoding_backend` even when I swap their order
+# (ref: https://github.com/aphrodite-project/aphrodite/pull/5526#issuecomment-2173772256)
+@pytest.mark.asyncio
+@pytest.mark.parametrize("guided_decoding_backend",
+                         ["outlines", "lm-format-enforcer"])
+async def test_guided_choice_chat(client: openai.AsyncOpenAI,
+                                  guided_decoding_backend: str,
+                                  sample_guided_choice):
+    messages = [{
+        "role": "system",
+        "content": "you are a helpful assistant"
+    }, {
+        "role":
+        "user",
+        "content":
+        "The best language for type-safe systems programming is "
+    }]
+    chat_completion = await client.chat.completions.create(
+        model=MODEL_NAME,
+        messages=messages,
+        max_tokens=10,
+        extra_body=dict(guided_choice=sample_guided_choice,
+                        guided_decoding_backend=guided_decoding_backend))
+    choice1 = chat_completion.choices[0].message.content
+    assert choice1 in sample_guided_choice
+
+    messages.append({"role": "assistant", "content": choice1})
+    messages.append({
+        "role": "user",
+        "content": "I disagree, pick another one"
+    })
+    chat_completion = await client.chat.completions.create(
+        model=MODEL_NAME,
+        messages=messages,
+        max_tokens=10,
+        extra_body=dict(guided_choice=sample_guided_choice,
+                        guided_decoding_backend=guided_decoding_backend))
+    choice2 = chat_completion.choices[0].message.content
+    assert choice2 in sample_guided_choice
+    assert choice1 != choice2
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("guided_decoding_backend",
+                         ["outlines", "lm-format-enforcer"])
+async def test_guided_json_chat(client: openai.AsyncOpenAI,
+                                guided_decoding_backend: str,
+                                sample_json_schema):
+    messages = [{
+        "role": "system",
+        "content": "you are a helpful assistant"
+    }, {
+        "role":
+        "user",
+        "content":
+        f"Give an example JSON for an employee profile that "
+        f"fits this schema: {sample_json_schema}"
+    }]
+    chat_completion = await client.chat.completions.create(
+        model=MODEL_NAME,
+        messages=messages,
+        max_tokens=1000,
+        extra_body=dict(guided_json=sample_json_schema,
+                        guided_decoding_backend=guided_decoding_backend))
+    message = chat_completion.choices[0].message
+    assert message.content is not None
+    json1 = json.loads(message.content)
+    jsonschema.validate(instance=json1, schema=sample_json_schema)
+
+    messages.append({"role": "assistant", "content": message.content})
+    messages.append({
+        "role":
+        "user",
+        "content":
+        "Give me another one with a different name and age"
+    })
+    chat_completion = await client.chat.completions.create(
+        model=MODEL_NAME,
+        messages=messages,
+        max_tokens=1000,
+        extra_body=dict(guided_json=sample_json_schema,
+                        guided_decoding_backend=guided_decoding_backend))
+    message = chat_completion.choices[0].message
+    assert message.content is not None
+    json2 = json.loads(message.content)
+    jsonschema.validate(instance=json2, schema=sample_json_schema)
+    assert json1["name"] != json2["name"]
+    assert json1["age"] != json2["age"]
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("guided_decoding_backend",
+                         ["outlines", "lm-format-enforcer"])
+async def test_guided_regex_chat(client: openai.AsyncOpenAI,
+                                 guided_decoding_backend: str, sample_regex):
+    messages = [{
+        "role": "system",
+        "content": "you are a helpful assistant"
+    }, {
+        "role":
+        "user",
+        "content":
+        f"Give an example IP address with this regex: {sample_regex}"
+    }]
+    chat_completion = await client.chat.completions.create(
+        model=MODEL_NAME,
+        messages=messages,
+        max_tokens=20,
+        extra_body=dict(guided_regex=sample_regex,
+                        guided_decoding_backend=guided_decoding_backend))
+    ip1 = chat_completion.choices[0].message.content
+    assert ip1 is not None
+    assert re.fullmatch(sample_regex, ip1) is not None
+
+    messages.append({"role": "assistant", "content": ip1})
+    messages.append({"role": "user", "content": "Give me a different one"})
+    chat_completion = await client.chat.completions.create(
+        model=MODEL_NAME,
+        messages=messages,
+        max_tokens=20,
+        extra_body=dict(guided_regex=sample_regex,
+                        guided_decoding_backend=guided_decoding_backend))
+    ip2 = chat_completion.choices[0].message.content
+    assert ip2 is not None
+    assert re.fullmatch(sample_regex, ip2) is not None
+    assert ip1 != ip2
+
+
+@pytest.mark.asyncio
+async def test_guided_decoding_type_error(client: openai.AsyncOpenAI):
+    messages = [{
+        "role": "system",
+        "content": "you are a helpful assistant"
+    }, {
+        "role":
+        "user",
+        "content":
+        "The best language for type-safe systems programming is "
+    }]
+
+    with pytest.raises(openai.BadRequestError):
+        _ = await client.chat.completions.create(model=MODEL_NAME,
+                                                 messages=messages,
+                                                 extra_body=dict(guided_regex={
+                                                     1: "Python",
+                                                     2: "C++"
+                                                 }))
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("guided_decoding_backend",
+                         ["outlines", "lm-format-enforcer"])
+async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
+                                           guided_decoding_backend: str,
+                                           sample_guided_choice):
+    messages = [{
+        "role": "system",
+        "content": "you are a helpful assistant"
+    }, {
+        "role":
+        "user",
+        "content":
+        "The best language for type-safe systems programming is "
+    }]
+    chat_completion = await client.chat.completions.create(
+        model=MODEL_NAME,
+        messages=messages,
+        max_tokens=10,
+        logprobs=True,
+        top_logprobs=5,
+        extra_body=dict(guided_choice=sample_guided_choice,
+                        guided_decoding_backend=guided_decoding_backend))
+
+    assert chat_completion.choices[0].logprobs is not None
+    assert chat_completion.choices[0].logprobs.content is not None
+    top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs
+
+    # -9999.0 is the minimum logprob returned by OpenAI
+    for item in top_logprobs:
+        assert item.logprob >= -9999.0, f"Failed (top_logprobs={top_logprobs})"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("guided_decoding_backend",
+                         ["outlines", "lm-format-enforcer"])
+async def test_named_tool_use(client: openai.AsyncOpenAI,
+                              guided_decoding_backend: str,
+                              sample_json_schema):
+    messages = [{
+        "role": "system",
+        "content": "you are a helpful assistant"
+    }, {
+        "role":
+        "user",
+        "content":
+        f"Give an example JSON for an employee profile that "
+        f"fits this schema: {sample_json_schema}"
+    }]
+
+    # non-streaming
+
+    chat_completion = await client.chat.completions.create(
+        model=MODEL_NAME,
+        messages=messages,
+        max_tokens=1000,
+        tools=[{
+            "type": "function",
+            "function": {
+                "name": "dummy_function_name",
+                "description": "This is a dummy function",
+                "parameters": sample_json_schema
+            }
+        }],
+        tool_choice={
+            "type": "function",
+            "function": {
+                "name": "dummy_function_name"
+            }
+        })
+    message = chat_completion.choices[0].message
+    assert len(message.content) == 0
+    json_string = message.tool_calls[0].function.arguments
+    json1 = json.loads(json_string)
+    jsonschema.validate(instance=json1, schema=sample_json_schema)
+
+    messages.append({"role": "assistant", "content": json_string})
+    messages.append({
+        "role":
+        "user",
+        "content":
+        "Give me another one with a different name and age"
+    })
+
+    # streaming
+
+    stream = await client.chat.completions.create(
+        model=MODEL_NAME,
+        messages=messages,
+        max_tokens=1000,
+        tools=[{
+            "type": "function",
+            "function": {
+                "name": "dummy_function_name",
+                "description": "This is a dummy function",
+                "parameters": sample_json_schema
+            }
+        }],
+        tool_choice={
+            "type": "function",
+            "function": {
+                "name": "dummy_function_name"
+            }
+        },
+        stream=True)
+
+    output = []
+    finish_reason_count = 0
+    async for chunk in stream:
+        delta = chunk.choices[0].delta
+        if delta.role:
+            assert delta.role == "assistant"
+        assert delta.content is None or len(delta.content) == 0
+        if delta.tool_calls:
+            output.append(delta.tool_calls[0].function.arguments)
+        if chunk.choices[0].finish_reason is not None:
+            finish_reason_count += 1
+    # finish reason should only return in last block
+    assert finish_reason_count == 1
+    json2 = json.loads("".join(output))
+    jsonschema.validate(instance=json2, schema=sample_json_schema)
+    assert json1["name"] != json2["name"]
+    assert json1["age"] != json2["age"]
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("guided_decoding_backend", ["outlines"])
+async def test_required_tool_use_not_yet_supported(
+        client: openai.AsyncOpenAI, guided_decoding_backend: str,
+        sample_json_schema):
+    messages = [{
+        "role": "system",
+        "content": "you are a helpful assistant"
+    }, {
+        "role":
+        "user",
+        "content":
+        f"Give an example JSON for an employee profile that "
+        f"fits this schema: {sample_json_schema}"
+    }]
+
+    with pytest.raises(openai.BadRequestError):
+        await client.chat.completions.create(
+            model=MODEL_NAME,
+            messages=messages,
+            max_tokens=1000,
+            tools=[{
+                "type": "function",
+                "function": {
+                    "name": "dummy_function_name",
+                    "description": "This is a dummy function",
+                    "parameters": sample_json_schema
+                }
+            }],
+            tool_choice="required")
+
+    with pytest.raises(openai.BadRequestError):
+        await client.chat.completions.create(
+            model=MODEL_NAME,
+            messages=messages,
+            max_tokens=1000,
+            tools=[{
+                "type": "function",
+                "function": {
+                    "name": "dummy_function_name",
+                    "description": "This is a dummy function",
+                    "parameters": sample_json_schema
+                }
+            }],
+            tool_choice="auto")
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("guided_decoding_backend", ["outlines"])
+async def test_inconsistent_tool_choice_and_tools(client: openai.AsyncOpenAI,
+                                                  guided_decoding_backend: str,
+                                                  sample_json_schema):
+    messages = [{
+        "role": "system",
+        "content": "you are a helpful assistant"
+    }, {
+        "role":
+        "user",
+        "content":
+        f"Give an example JSON for an employee profile that "
+        f"fits this schema: {sample_json_schema}"
+    }]
+
+    with pytest.raises(openai.BadRequestError):
+        await client.chat.completions.create(model=MODEL_NAME,
+                                             messages=messages,
+                                             max_tokens=1000,
+                                             tool_choice={
+                                                 "type": "function",
+                                                 "function": {
+                                                     "name":
+                                                     "dummy_function_name"
+                                                 }
+                                             })
+
+    with pytest.raises(openai.BadRequestError):
+        await client.chat.completions.create(
+            model=MODEL_NAME,
+            messages=messages,
+            max_tokens=1000,
+            tools=[{
+                "type": "function",
+                "function": {
+                    "name": "dummy_function_name",
+                    "description": "This is a dummy function",
+                    "parameters": sample_json_schema
+                }
+            }],
+            tool_choice={
+                "type": "function",
+                "function": {
+                    "name": "nondefined_function_name"
+                }
+            })
+
+
+@pytest.mark.asyncio
+async def test_response_format_json_object(client: openai.AsyncOpenAI):
+    for _ in range(2):
+        resp = await client.chat.completions.create(
+            model=MODEL_NAME,
+            messages=[{
+                "role":
+                "user",
+                "content": ('what is 1+1? please respond with a JSON object, '
+                            'the format is {"result": 2}')
+            }],
+            response_format={"type": "json_object"})
+
+        content = resp.choices[0].message.content
+        assert content is not None
+
+        loaded = json.loads(content)
+        assert loaded == {"result": 2}, loaded
+
+
+@pytest.mark.asyncio
+async def test_extra_fields(client: openai.AsyncOpenAI):
+    with pytest.raises(BadRequestError) as exc_info:
+        await client.chat.completions.create(
+            model=MODEL_NAME,
+            messages=[{
+                "role": "system",
+                "content": "You are a helpful assistant.",
+                "extra_field": "0",
+            }],  # type: ignore
+            temperature=0,
+            seed=0)
+
+    assert "extra_forbidden" in exc_info.value.message
+
+
+@pytest.mark.asyncio
+async def test_complex_message_content(client: openai.AsyncOpenAI):
+    resp = await client.chat.completions.create(
+        model=MODEL_NAME,
+        messages=[{
+            "role":
+            "user",
+            "content": [{
+                "type":
+                "text",
+                "text":
+                "what is 1+1? please provide the result without any other text."
+            }]
+        }],
+        temperature=0,
+        seed=0)
+    content = resp.choices[0].message.content
+    assert content == "2"
+
+
+@pytest.mark.asyncio
+async def test_custom_role(client: openai.AsyncOpenAI):
+    # Not sure how the model handles custom roles so we just check that
+    # both string and complex message content are handled in the same way
+
+    resp1 = await client.chat.completions.create(
+        model=MODEL_NAME,
+        messages=[{
+            "role": "my-custom-role",
+            "content": "what is 1+1?",
+        }],  # type: ignore
+        temperature=0,
+        seed=0)
+
+    resp2 = await client.chat.completions.create(
+        model=MODEL_NAME,
+        messages=[{
+            "role": "my-custom-role",
+            "content": [{
+                "type": "text",
+                "text": "what is 1+1?"
+            }]
+        }],  # type: ignore
+        temperature=0,
+        seed=0)
+
+    content1 = resp1.choices[0].message.content
+    content2 = resp2.choices[0].message.content
+    assert content1 == content2
+
+
+@pytest.mark.asyncio
+async def test_long_seed(client: openai.AsyncOpenAI):
+    for seed in [
+            torch.iinfo(torch.long).min - 1,
+            torch.iinfo(torch.long).max + 1
+    ]:
+        with pytest.raises(BadRequestError) as exc_info:
+            await client.chat.completions.create(
+                model=MODEL_NAME,
+                messages=[{
+                    "role": "system",
+                    "content": "You are a helpful assistant.",
+                }],
+                temperature=0,
+                seed=seed)
+
+        assert ("greater_than_equal" in exc_info.value.message
+                or "less_than_equal" in exc_info.value.message)

+ 832 - 0
tests/endpoints/openai/test_completion.py

@@ -0,0 +1,832 @@
+# imports for guided decoding tests
+import json
+import re
+import shutil
+from tempfile import TemporaryDirectory
+from typing import Dict, List
+
+import jsonschema
+import openai  # use the official client for correctness check
+import pytest
+# downloading lora to test lora requests
+from huggingface_hub import snapshot_download
+from openai import BadRequestError
+from transformers import AutoTokenizer
+
+from aphrodite.transformers_utils.tokenizer import get_tokenizer
+
+from ...utils import RemoteOpenAIServer
+
+# any model with a chat template should work here
+MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
+# technically these adapters use a different base model,
+# but we're not testing generation quality here
+LORA_NAME = "typeof/zephyr-7b-beta-lora"
+PA_NAME = "swapnilbp/llama_tweet_ptune"
+# if PA_NAME changes, PA_NUM_VIRTUAL_TOKENS might also
+# need to change to match the prompt adapter
+PA_NUM_VIRTUAL_TOKENS = 8
+
+
+@pytest.fixture(scope="module")
+def zephyr_lora_files():
+    return snapshot_download(repo_id=LORA_NAME)
+
+
+@pytest.fixture(scope="module")
+def zephyr_lora_added_tokens_files(zephyr_lora_files):
+    tmp_dir = TemporaryDirectory()
+    tmp_model_dir = f"{tmp_dir.name}/zephyr"
+    shutil.copytree(zephyr_lora_files, tmp_model_dir)
+    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
+    # Copy tokenizer to adapter and add some unique tokens
+    # 32000, 32001, 32002
+    added = tokenizer.add_tokens(["aphrodite1", "aphrodite2", "aphrodite3"],
+                                 special_tokens=True)
+    assert added == 3
+    tokenizer.save_pretrained(tmp_model_dir)
+    yield tmp_model_dir
+    tmp_dir.cleanup()
+
+
+@pytest.fixture(scope="module")
+def zephyr_pa_files():
+    return snapshot_download(repo_id=PA_NAME)
+
+
+@pytest.fixture(scope="module")
+def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
+                        zephyr_pa_files):
+    return [
+        # use half precision for speed and memory savings in CI environment
+        "--dtype",
+        "bfloat16",
+        "--max-model-len",
+        "8192",
+        "--max-num-seqs",
+        "128",
+        "--enforce-eager",
+        # lora config
+        "--enable-lora",
+        "--lora-modules",
+        f"zephyr-lora={zephyr_lora_files}",
+        f"zephyr-lora2={zephyr_lora_added_tokens_files}",
+        "--max-lora-rank",
+        "64",
+        "--max-cpu-loras",
+        "2",
+        # pa config
+        "--enable-prompt-adapter",
+        "--prompt-adapters",
+        f"zephyr-pa={zephyr_pa_files}",
+        f"zephyr-pa2={zephyr_pa_files}",
+        "--max-prompt-adapters",
+        "2",
+        "--max-prompt-adapter-token",
+        "128",
+    ]
+
+
+@pytest.fixture(scope="module",
+                params=["", "--disable-frontend-multiprocessing"])
+def client(default_server_args, request):
+    if request.param:
+        default_server_args.append(request.param)
+    with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
+        yield remote_server.get_async_client()
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    # first test base model, then test loras, then test prompt adapters
+    "model_name,num_virtual_tokens",
+    [(MODEL_NAME, 0), ("zephyr-lora", 0), ("zephyr-lora2", 0),
+     ("zephyr-pa", PA_NUM_VIRTUAL_TOKENS),
+     ("zephyr-pa2", PA_NUM_VIRTUAL_TOKENS)],
+)
+async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
+                                 num_virtual_tokens: int):
+    completion = await client.completions.create(model=model_name,
+                                                 prompt="Hello, my name is",
+                                                 max_tokens=5,
+                                                 temperature=0.0)
+
+    assert completion.id is not None
+    assert completion.choices is not None and len(completion.choices) == 1
+
+    choice = completion.choices[0]
+    assert len(choice.text) >= 5
+    assert choice.finish_reason == "length"
+    assert completion.usage == openai.types.CompletionUsage(
+        completion_tokens=5,
+        prompt_tokens=6 + num_virtual_tokens,
+        total_tokens=11 + num_virtual_tokens)
+
+    # test using token IDs
+    completion = await client.completions.create(
+        model=model_name,
+        prompt=[0, 0, 0, 0, 0],
+        max_tokens=5,
+        temperature=0.0,
+    )
+    assert len(completion.choices[0].text) >= 1
+    assert completion.choices[0].prompt_logprobs is None
+
+
+@pytest.mark.asyncio
+async def test_added_lora_tokens(client: openai.AsyncOpenAI):
+    # test using token IDs
+    completion = await client.completions.create(
+        model="zephyr-lora2",
+        prompt=[0, 0, 32000, 32001, 32002],
+        echo=True,
+        max_tokens=5,
+        temperature=0.0,
+    )
+    # Added tokens should appear in tokenized prompt
+    assert completion.choices[0].text.startswith(
+        "<unk><unk>aphrodite1aphrodite2aphrodite3")
+
+
+@pytest.mark.asyncio
+async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI):
+    # test using token IDs
+    completion = await client.completions.create(
+        model=MODEL_NAME,
+        prompt=[0, 0, 32000, 32001, 32002],
+        echo=True,
+        max_tokens=5,
+        temperature=0.0,
+    )
+    # Added tokens should not appear in tokenized prompt
+    assert "aphrodite" not in completion.choices[0].text
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    # first test base model, then test loras, then test prompt adapters
+    "model_name",
+    [MODEL_NAME, "zephyr-lora", "zephyr-lora2", "zephyr-pa", "zephyr-pa2"],
+)
+async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
+    # test using token IDs
+    completion = await client.completions.create(
+        model=model_name,
+        prompt=[0, 0, 0, 0, 0],
+        max_tokens=5,
+        temperature=0.0,
+        logprobs=None,
+    )
+    choice = completion.choices[0]
+    assert choice.logprobs is None
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    # just test 1 lora and 1 pa hereafter
+    "model_name",
+    [MODEL_NAME, "zephyr-lora", "zephyr-pa"],
+)
+async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
+    # test using token IDs
+    completion = await client.completions.create(
+        model=model_name,
+        prompt=[0, 0, 0, 0, 0],
+        max_tokens=5,
+        temperature=0.0,
+        logprobs=0,
+    )
+    choice = completion.choices[0]
+    assert choice.logprobs is not None
+    assert choice.logprobs.token_logprobs is not None
+    assert choice.logprobs.top_logprobs is not None
+    assert len(choice.logprobs.top_logprobs[0]) == 1
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "model_name",
+    [MODEL_NAME, "zephyr-lora", "zephyr-pa"],
+)
+async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
+    # test using token IDs
+    completion = await client.completions.create(
+        model=model_name,
+        prompt=[0, 0, 0, 0, 0],
+        max_tokens=5,
+        temperature=0.0,
+        logprobs=5,
+    )
+    choice = completion.choices[0]
+    assert choice.logprobs is not None
+    assert choice.logprobs.token_logprobs is not None
+    assert choice.logprobs.top_logprobs is not None
+    assert 5 <= len(choice.logprobs.top_logprobs[0]) <= 6
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "model_name",
+    [MODEL_NAME, "zephyr-lora", "zephyr-pa"],
+)
+async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
+                                            model_name: str):
+
+    with pytest.raises(
+        (openai.BadRequestError, openai.APIError)):  # test using token IDs
+        await client.completions.create(
+            model=model_name,
+            prompt=[0, 0, 0, 0, 0],
+            max_tokens=5,
+            temperature=0.0,
+            # Aphrodite has higher default max_logprobs (20 instead of 5)
+            # to support both Completion API and Chat Completion API
+            logprobs=21,
+        )
+        ...
+    with pytest.raises(
+        (openai.BadRequestError, openai.APIError)):  # test using token IDs
+        stream = await client.completions.create(
+            model=model_name,
+            prompt=[0, 0, 0, 0, 0],
+            max_tokens=5,
+            temperature=0.0,
+            # Aphrodite has higher default max_logprobs (20 instead of 5)
+            # to support both Completion API and Chat Completion API
+            logprobs=30,
+            stream=True,
+        )
+        async for chunk in stream:
+            ...
+
+    # the server should still work afterwards
+    completion = await client.completions.create(
+        model=model_name,
+        prompt=[0, 0, 0, 0, 0],
+        max_tokens=5,
+        temperature=0.0,
+    )
+    assert len(completion.choices[0].text) >= 0
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "model_name, prompt_logprobs",
+    [(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)],
+)
+async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI,
+                                    model_name: str, prompt_logprobs: int):
+    params: Dict = {
+        "messages": [{
+            "role": "system",
+            "content": "You are a helpful assistant."
+        }, {
+            "role": "user",
+            "content": "Who won the world series in 2020?"
+        }, {
+            "role":
+            "assistant",
+            "content":
+            "The Los Angeles Dodgers won the World Series in 2020."
+        }, {
+            "role": "user",
+            "content": "Where was it played?"
+        }],
+        "model":
+        model_name
+    }
+
+    if prompt_logprobs is not None:
+        params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
+
+    if prompt_logprobs and prompt_logprobs < 0:
+        with pytest.raises(BadRequestError) as err_info:
+            await client.chat.completions.create(**params)
+        expected_err_string = (
+            "Error code: 400 - {'object': 'error', 'message': "
+            "'Prompt_logprobs set to invalid negative value: -1',"
+            " 'type': 'BadRequestError', 'param': None, 'code': 400}")
+        assert str(err_info.value) == expected_err_string
+    else:
+        completion = await client.chat.completions.create(**params)
+        if prompt_logprobs and prompt_logprobs > 0:
+            assert completion.prompt_logprobs is not None
+            assert len(completion.prompt_logprobs) > 0
+        else:
+            assert completion.prompt_logprobs is None
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "model_name",
+    [MODEL_NAME],
+)
+async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
+                                                  model_name: str):
+    params: Dict = {
+        "messages": [{
+            "role": "system",
+            "content": "You are a helpful assistant."
+        }, {
+            "role": "user",
+            "content": "Who won the world series in 2020?"
+        }, {
+            "role":
+            "assistant",
+            "content":
+            "The Los Angeles Dodgers won the World Series in 2020."
+        }, {
+            "role": "user",
+            "content": "Where was it played?"
+        }],
+        "model":
+        model_name,
+        "extra_body": {
+            "prompt_logprobs": 1
+        }
+    }
+
+    completion_1 = await client.chat.completions.create(**params)
+
+    params["extra_body"] = {"prompt_logprobs": 2}
+    completion_2 = await client.chat.completions.create(**params)
+
+    assert len(completion_1.prompt_logprobs[3]) == 1
+    assert len(completion_2.prompt_logprobs[3]) == 2
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1),
+                                                         (MODEL_NAME, 0),
+                                                         (MODEL_NAME, 1),
+                                                         (MODEL_NAME, None)])
+async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
+                                          model_name: str,
+                                          prompt_logprobs: int):
+    params: Dict = {
+        "prompt": ["A robot may not injure another robot", "My name is"],
+        "model": model_name,
+    }
+    if prompt_logprobs is not None:
+        params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
+
+    if prompt_logprobs and prompt_logprobs < 0:
+        with pytest.raises(BadRequestError) as err_info:
+            await client.completions.create(**params)
+        expected_err_string = (
+            "Error code: 400 - {'object': 'error', 'message': "
+            "'Prompt_logprobs set to invalid negative value: -1',"
+            " 'type': 'BadRequestError', 'param': None, 'code': 400}")
+        assert str(err_info.value) == expected_err_string
+    else:
+        completion = await client.completions.create(**params)
+        if prompt_logprobs and prompt_logprobs > 0:
+            assert completion.choices[0].prompt_logprobs is not None
+            assert len(completion.choices[0].prompt_logprobs) > 0
+
+            assert completion.choices[1].prompt_logprobs is not None
+            assert len(completion.choices[1].prompt_logprobs) > 0
+
+        else:
+            assert completion.choices[0].prompt_logprobs is None
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "model_name",
+    [MODEL_NAME, "zephyr-lora", "zephyr-pa"],
+)
+async def test_completion_streaming(client: openai.AsyncOpenAI,
+                                    model_name: str):
+    prompt = "What is an LLM?"
+
+    single_completion = await client.completions.create(
+        model=model_name,
+        prompt=prompt,
+        max_tokens=5,
+        temperature=0.0,
+    )
+    single_output = single_completion.choices[0].text
+    stream = await client.completions.create(model=model_name,
+                                             prompt=prompt,
+                                             max_tokens=5,
+                                             temperature=0.0,
+                                             stream=True)
+    chunks: List[str] = []
+    finish_reason_count = 0
+    async for chunk in stream:
+        chunks.append(chunk.choices[0].text)
+        if chunk.choices[0].finish_reason is not None:
+            finish_reason_count += 1
+    # finish reason should only return in last block
+    assert finish_reason_count == 1
+    assert chunk.choices[0].finish_reason == "length"
+    assert chunk.choices[0].text
+    assert "".join(chunks) == single_output
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "model_name",
+    [MODEL_NAME, "zephyr-lora", "zephyr-pa"],
+)
+async def test_completion_stream_options(client: openai.AsyncOpenAI,
+                                         model_name: str):
+    prompt = "What is the capital of France?"
+
+    # Test stream=True, stream_options=
+    #     {"include_usage": False, "continuous_usage_stats": False}
+    stream = await client.completions.create(model=model_name,
+                                             prompt=prompt,
+                                             max_tokens=5,
+                                             temperature=0.0,
+                                             stream=True,
+                                             stream_options={
+                                                 "include_usage": False,
+                                                 "continuous_usage_stats":
+                                                 False,
+                                             })
+
+    async for chunk in stream:
+        assert chunk.usage is None
+
+    # Test stream=True, stream_options=
+    #     {"include_usage": False, "continuous_usage_stats": True}
+    stream = await client.completions.create(model=model_name,
+                                             prompt=prompt,
+                                             max_tokens=5,
+                                             temperature=0.0,
+                                             stream=True,
+                                             stream_options={
+                                                 "include_usage": False,
+                                                 "continuous_usage_stats":
+                                                 True,
+                                             })
+    async for chunk in stream:
+        assert chunk.usage is None
+
+    # Test stream=True, stream_options=
+    #     {"include_usage": True, "continuous_usage_stats": False}
+    stream = await client.completions.create(model=model_name,
+                                             prompt=prompt,
+                                             max_tokens=5,
+                                             temperature=0.0,
+                                             stream=True,
+                                             stream_options={
+                                                 "include_usage": True,
+                                                 "continuous_usage_stats":
+                                                 False,
+                                             })
+    async for chunk in stream:
+        if chunk.choices[0].finish_reason is None:
+            assert chunk.usage is None
+        else:
+            assert chunk.usage is None
+            final_chunk = await stream.__anext__()
+            assert final_chunk.usage is not None
+            assert final_chunk.usage.prompt_tokens > 0
+            assert final_chunk.usage.completion_tokens > 0
+            assert final_chunk.usage.total_tokens == (
+                final_chunk.usage.prompt_tokens +
+                final_chunk.usage.completion_tokens)
+            assert final_chunk.choices == []
+
+    # Test stream=True, stream_options=
+    #     {"include_usage": True, "continuous_usage_stats": True}
+    stream = await client.completions.create(model=model_name,
+                                             prompt=prompt,
+                                             max_tokens=5,
+                                             temperature=0.0,
+                                             stream=True,
+                                             stream_options={
+                                                 "include_usage": True,
+                                                 "continuous_usage_stats":
+                                                 True,
+                                             })
+    async for chunk in stream:
+        assert chunk.usage is not None
+        assert chunk.usage.prompt_tokens > 0
+        assert chunk.usage.completion_tokens > 0
+        assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens +
+                                            chunk.usage.completion_tokens)
+        if chunk.choices[0].finish_reason is not None:
+            final_chunk = await stream.__anext__()
+            assert final_chunk.usage is not None
+            assert final_chunk.usage.prompt_tokens > 0
+            assert final_chunk.usage.completion_tokens > 0
+            assert final_chunk.usage.total_tokens == (
+                final_chunk.usage.prompt_tokens +
+                final_chunk.usage.completion_tokens)
+            assert final_chunk.choices == []
+
+    # Test stream=False, stream_options=
+    #     {"include_usage": None}
+    with pytest.raises(BadRequestError):
+        await client.completions.create(model=model_name,
+                                        prompt=prompt,
+                                        max_tokens=5,
+                                        temperature=0.0,
+                                        stream=False,
+                                        stream_options={"include_usage": None})
+
+    # Test stream=False, stream_options=
+    #    {"include_usage": True}
+    with pytest.raises(BadRequestError):
+        await client.completions.create(model=model_name,
+                                        prompt=prompt,
+                                        max_tokens=5,
+                                        temperature=0.0,
+                                        stream=False,
+                                        stream_options={"include_usage": True})
+
+    # Test stream=False, stream_options=
+    #     {"continuous_usage_stats": None}
+    with pytest.raises(BadRequestError):
+        await client.completions.create(
+            model=model_name,
+            prompt=prompt,
+            max_tokens=5,
+            temperature=0.0,
+            stream=False,
+            stream_options={"continuous_usage_stats": None})
+
+    # Test stream=False, stream_options=
+    #    {"continuous_usage_stats": True}
+    with pytest.raises(BadRequestError):
+        await client.completions.create(
+            model=model_name,
+            prompt=prompt,
+            max_tokens=5,
+            temperature=0.0,
+            stream=False,
+            stream_options={"continuous_usage_stats": True})
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "model_name",
+    [MODEL_NAME, "zephyr-lora", "zephyr-pa"],
+)
+async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
+    # test both text and token IDs
+    for prompts in (["Hello, my name is"] * 2, [[0, 0, 0, 0, 0]] * 2):
+        # test simple list
+        batch = await client.completions.create(
+            model=model_name,
+            prompt=prompts,
+            max_tokens=5,
+            temperature=0.0,
+        )
+        assert len(batch.choices) == 2
+        assert batch.choices[0].text == batch.choices[1].text
+
+        # test n = 2
+        batch = await client.completions.create(
+            model=model_name,
+            prompt=prompts,
+            n=2,
+            max_tokens=5,
+            temperature=0.0,
+            extra_body=dict(
+                # NOTE: this has to be true for n > 1 in Aphrodite, but
+                # not necessary for official client.
+                use_beam_search=True),
+        )
+        assert len(batch.choices) == 4
+        assert batch.choices[0].text != batch.choices[
+            1].text, "beam search should be different"
+        assert batch.choices[0].text == batch.choices[
+            2].text, "two copies of the same prompt should be the same"
+        assert batch.choices[1].text == batch.choices[
+            3].text, "two copies of the same prompt should be the same"
+
+        # test streaming
+        batch = await client.completions.create(
+            model=model_name,
+            prompt=prompts,
+            max_tokens=5,
+            temperature=0.0,
+            stream=True,
+        )
+        texts = [""] * 2
+        async for chunk in batch:
+            assert len(chunk.choices) == 1
+            choice = chunk.choices[0]
+            texts[choice.index] += choice.text
+        assert texts[0] == texts[1]
+
+
+@pytest.mark.asyncio
+async def test_logits_bias(client: openai.AsyncOpenAI):
+    prompt = "Hello, my name is"
+    max_tokens = 5
+    tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
+
+    # Test exclusive selection
+    token_id = 1000
+    completion = await client.completions.create(
+        model=MODEL_NAME,
+        prompt=prompt,
+        max_tokens=max_tokens,
+        temperature=0.0,
+        logit_bias={str(token_id): 100},
+        seed=42,
+    )
+    assert len(completion.choices[0].text) >= 5
+    response_tokens = tokenizer(completion.choices[0].text,
+                                add_special_tokens=False)["input_ids"]
+    expected_tokens = tokenizer(tokenizer.decode([token_id] * 5),
+                                add_special_tokens=False)["input_ids"]
+    assert all([
+        response == expected
+        for response, expected in zip(response_tokens, expected_tokens)
+    ])
+
+    # Test ban
+    completion = await client.completions.create(
+        model=MODEL_NAME,
+        prompt=prompt,
+        max_tokens=max_tokens,
+        temperature=0.0,
+    )
+    response_tokens = tokenizer(completion.choices[0].text,
+                                add_special_tokens=False)["input_ids"]
+    first_response = completion.choices[0].text
+    completion = await client.completions.create(
+        model=MODEL_NAME,
+        prompt=prompt,
+        max_tokens=max_tokens,
+        temperature=0.0,
+        logit_bias={str(token): -100
+                    for token in response_tokens},
+    )
+    assert first_response != completion.choices[0].text
+
+
+@pytest.mark.asyncio
+async def test_allowed_token_ids(client: openai.AsyncOpenAI):
+    prompt = "Hello, my name is"
+    max_tokens = 1
+    tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
+
+    # Test exclusive selection
+    allowed_ids = [21555, 21557, 21558]
+    completion = await client.completions.create(
+        model=MODEL_NAME,
+        prompt=prompt,
+        max_tokens=max_tokens,
+        temperature=0.0,
+        seed=42,
+        extra_body=dict(allowed_token_ids=allowed_ids),
+        logprobs=1,
+    )
+    response_tokens = completion.choices[0].logprobs.tokens
+    assert len(response_tokens) == 1
+    assert tokenizer.convert_tokens_to_ids(response_tokens)[0] in allowed_ids
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("guided_decoding_backend",
+                         ["outlines", "lm-format-enforcer"])
+async def test_guided_json_completion(client: openai.AsyncOpenAI,
+                                      guided_decoding_backend: str,
+                                      sample_json_schema):
+    completion = await client.completions.create(
+        model=MODEL_NAME,
+        prompt=f"Give an example JSON for an employee profile "
+        f"that fits this schema: {sample_json_schema}",
+        n=3,
+        temperature=1.0,
+        max_tokens=500,
+        extra_body=dict(guided_json=sample_json_schema,
+                        guided_decoding_backend=guided_decoding_backend))
+
+    assert completion.id is not None
+    assert len(completion.choices) == 3
+    for i in range(3):
+        output_json = json.loads(completion.choices[i].text)
+        jsonschema.validate(instance=output_json, schema=sample_json_schema)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("guided_decoding_backend",
+                         ["outlines", "lm-format-enforcer"])
+async def test_guided_regex_completion(client: openai.AsyncOpenAI,
+                                       guided_decoding_backend: str,
+                                       sample_regex):
+    completion = await client.completions.create(
+        model=MODEL_NAME,
+        prompt=f"Give an example IPv4 address with this regex: {sample_regex}",
+        n=3,
+        temperature=1.0,
+        max_tokens=20,
+        extra_body=dict(guided_regex=sample_regex,
+                        guided_decoding_backend=guided_decoding_backend))
+
+    assert completion.id is not None
+    assert len(completion.choices) == 3
+    for i in range(3):
+        assert re.fullmatch(sample_regex,
+                            completion.choices[i].text) is not None
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("guided_decoding_backend",
+                         ["outlines", "lm-format-enforcer"])
+async def test_guided_choice_completion(client: openai.AsyncOpenAI,
+                                        guided_decoding_backend: str,
+                                        sample_guided_choice):
+    completion = await client.completions.create(
+        model=MODEL_NAME,
+        prompt="The best language for type-safe systems programming is ",
+        n=2,
+        temperature=1.0,
+        max_tokens=10,
+        extra_body=dict(guided_choice=sample_guided_choice,
+                        guided_decoding_backend=guided_decoding_backend))
+
+    assert completion.id is not None
+    assert len(completion.choices) == 2
+    for i in range(2):
+        assert completion.choices[i].text in sample_guided_choice
+
+
+@pytest.mark.asyncio
+async def test_guided_grammar(client: openai.AsyncOpenAI,
+                              sample_sql_statements):
+
+    completion = await client.completions.create(
+        model=MODEL_NAME,
+        prompt=("Generate a sql state that select col_1 from "
+                "table_1 where it is equals to 1"),
+        temperature=1.0,
+        max_tokens=500,
+        extra_body=dict(guided_grammar=sample_sql_statements))
+
+    content = completion.choices[0].text
+
+    # use Lark to parse the output, and make sure it's a valid parse tree
+    from lark import Lark
+    parser = Lark(sample_sql_statements)
+    parser.parse(content)
+
+    # remove spaces for comparison b/c we removed them in the grammar
+    ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "")
+
+    assert content.strip() == ground_truth
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    # first test base model, then test loras
+    "model_name",
+    [MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
+)
+@pytest.mark.parametrize("logprobs_arg", [1, 0])
+async def test_echo_logprob_completion(client: openai.AsyncOpenAI,
+                                       model_name: str, logprobs_arg: int):
+    tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
+    # test using text and token IDs
+    for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]):
+        completion = await client.completions.create(model=model_name,
+                                                     prompt=prompt,
+                                                     max_tokens=5,
+                                                     temperature=0.0,
+                                                     echo=True,
+                                                     logprobs=logprobs_arg)
+
+        prompt_text = tokenizer.decode(prompt) if isinstance(prompt,
+                                                             list) else prompt
+        assert re.search(r"^" + prompt_text, completion.choices[0].text)
+        logprobs = completion.choices[0].logprobs
+        assert logprobs is not None
+        assert len(logprobs.text_offset) > 5
+        assert (len(logprobs.token_logprobs) > 5
+                and logprobs.token_logprobs[0] is None)
+        assert (len(logprobs.top_logprobs) > 5
+                and logprobs.top_logprobs[0] is None)
+        for top_logprobs in logprobs.top_logprobs[1:]:
+            assert max(logprobs_arg,
+                       1) <= len(top_logprobs) <= logprobs_arg + 1
+        assert len(logprobs.tokens) > 5
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("guided_decoding_backend",
+                         ["outlines", "lm-format-enforcer"])
+async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
+                                          guided_decoding_backend: str,
+                                          sample_json_schema, sample_regex):
+    with pytest.raises(openai.BadRequestError):
+        _ = await client.completions.create(
+            model=MODEL_NAME,
+            prompt="Give an example JSON that fits this schema: 42",
+            extra_body=dict(guided_json=42,
+                            guided_decoding_backend=guided_decoding_backend))
+
+    with pytest.raises(openai.BadRequestError):
+        _ = await client.completions.create(
+            model=MODEL_NAME,
+            prompt="Give an example string that fits this regex",
+            extra_body=dict(guided_regex=sample_regex,
+                            guided_json=sample_json_schema))

+ 136 - 0
tests/endpoints/openai/test_embedding.py

@@ -0,0 +1,136 @@
+import base64
+
+import numpy as np
+import openai
+import pytest
+
+from ...utils import RemoteOpenAIServer
+
+EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
+
+
+@pytest.fixture(scope="module")
+def embedding_server():
+    args = [
+        # use half precision for speed and memory savings in CI environment
+        "--dtype",
+        "bfloat16",
+        "--enforce-eager",
+        "--max-model-len",
+        "8192",
+    ]
+
+    with RemoteOpenAIServer(EMBEDDING_MODEL_NAME, args) as remote_server:
+        yield remote_server
+
+
+@pytest.mark.asyncio
+@pytest.fixture(scope="module")
+def embedding_client(embedding_server):
+    return embedding_server.get_async_client()
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "model_name",
+    [EMBEDDING_MODEL_NAME],
+)
+async def test_single_embedding(embedding_client: openai.AsyncOpenAI,
+                                model_name: str):
+    input_texts = [
+        "The chef prepared a delicious meal.",
+    ]
+
+    # test single embedding
+    embeddings = await embedding_client.embeddings.create(
+        model=model_name,
+        input=input_texts,
+        encoding_format="float",
+    )
+    assert embeddings.id is not None
+    assert len(embeddings.data) == 1
+    assert len(embeddings.data[0].embedding) == 4096
+    assert embeddings.usage.completion_tokens == 0
+    assert embeddings.usage.prompt_tokens == 9
+    assert embeddings.usage.total_tokens == 9
+
+    # test using token IDs
+    input_tokens = [1, 1, 1, 1, 1]
+    embeddings = await embedding_client.embeddings.create(
+        model=model_name,
+        input=input_tokens,
+        encoding_format="float",
+    )
+    assert embeddings.id is not None
+    assert len(embeddings.data) == 1
+    assert len(embeddings.data[0].embedding) == 4096
+    assert embeddings.usage.completion_tokens == 0
+    assert embeddings.usage.prompt_tokens == 5
+    assert embeddings.usage.total_tokens == 5
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "model_name",
+    [EMBEDDING_MODEL_NAME],
+)
+async def test_batch_embedding(embedding_client: openai.AsyncOpenAI,
+                               model_name: str):
+    # test List[str]
+    input_texts = [
+        "The cat sat on the mat.", "A feline was resting on a rug.",
+        "Stars twinkle brightly in the night sky."
+    ]
+    embeddings = await embedding_client.embeddings.create(
+        model=model_name,
+        input=input_texts,
+        encoding_format="float",
+    )
+    assert embeddings.id is not None
+    assert len(embeddings.data) == 3
+    assert len(embeddings.data[0].embedding) == 4096
+
+    # test List[List[int]]
+    input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
+                    [25, 32, 64, 77]]
+    embeddings = await embedding_client.embeddings.create(
+        model=model_name,
+        input=input_tokens,
+        encoding_format="float",
+    )
+    assert embeddings.id is not None
+    assert len(embeddings.data) == 4
+    assert len(embeddings.data[0].embedding) == 4096
+    assert embeddings.usage.completion_tokens == 0
+    assert embeddings.usage.prompt_tokens == 17
+    assert embeddings.usage.total_tokens == 17
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "model_name",
+    [EMBEDDING_MODEL_NAME],
+)
+async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI,
+                                      model_name: str):
+    input_texts = [
+        "Hello my name is",
+        "The best thing about Aphrodite is that it supports many different models"  # noqa: E501
+    ]
+
+    responses_float = await embedding_client.embeddings.create(
+        input=input_texts, model=model_name, encoding_format="float")
+
+    responses_base64 = await embedding_client.embeddings.create(
+        input=input_texts, model=model_name, encoding_format="base64")
+
+    decoded_responses_base64_data = []
+    for data in responses_base64.data:
+        decoded_responses_base64_data.append(
+            np.frombuffer(base64.b64decode(data.embedding),
+                          dtype="float").tolist())
+
+    assert responses_float.data[0].embedding == decoded_responses_base64_data[
+        0]
+    assert responses_float.data[1].embedding == decoded_responses_base64_data[
+        1]

+ 50 - 0
tests/endpoints/openai/test_encoder_decoder.py

@@ -0,0 +1,50 @@
+import openai
+import pytest
+
+from ...utils import RemoteOpenAIServer
+
+MODEL_NAME = "facebook/bart-base"
+
+
+@pytest.fixture(scope="module")
+def server():
+    args = [
+        "--dtype",
+        "bfloat16",
+        "--enforce-eager",
+    ]
+
+    with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
+        yield remote_server
+
+
+@pytest.fixture(scope="module")
+def client(server):
+    return server.get_async_client()
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
+    completion = await client.completions.create(model=model_name,
+                                                 prompt="Hello, my name is",
+                                                 max_tokens=5,
+                                                 temperature=0.0)
+
+    assert completion.id is not None
+    assert completion.choices is not None and len(completion.choices) == 1
+
+    choice = completion.choices[0]
+    assert len(choice.text) >= 5
+    assert choice.finish_reason == "length"
+    assert completion.usage == openai.types.CompletionUsage(
+        completion_tokens=5, prompt_tokens=2, total_tokens=7)
+
+    # test using token IDs
+    completion = await client.completions.create(
+        model=model_name,
+        prompt=[0, 0, 0, 0, 0],
+        max_tokens=5,
+        temperature=0.0,
+    )
+    assert len(completion.choices[0].text) >= 1

+ 72 - 0
tests/endpoints/openai/test_guided_processors.py

@@ -0,0 +1,72 @@
+# This unit test should be moved to a new
+# tests/test_guided_decoding directory.
+import pytest
+import torch
+from transformers import AutoTokenizer
+
+from aphrodite.endpoints.openai.protocol import CompletionRequest
+from aphrodite.modeling.guided_decoding import (
+    get_guided_decoding_logits_processor)
+from aphrodite.modeling.guided_decoding.outlines_logits_processors import (
+    JSONLogitsProcessor, RegexLogitsProcessor)
+
+
+def test_guided_logits_processors(sample_regex, sample_json_schema):
+    """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
+    tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
+    regex_LP = RegexLogitsProcessor(sample_regex, tokenizer)
+    json_LP = JSONLogitsProcessor(sample_json_schema,
+                                  tokenizer,
+                                  whitespace_pattern=None)
+
+    token_ids = tokenizer.encode(
+        f"Give an example IPv4 address with this regex: {sample_regex}")
+    tensor = torch.rand(32000)
+    original_tensor = torch.clone(tensor)
+    regex_LP(token_ids, tensor)
+    assert tensor.shape == original_tensor.shape
+    assert not torch.allclose(tensor, original_tensor)
+
+    token_ids = tokenizer.encode(
+        f"Give an employee profile that fits this schema: {sample_json_schema}"
+    )
+    tensor = torch.rand(32000)
+    original_tensor = torch.clone(tensor)
+    json_LP(token_ids, tensor)
+    assert tensor.shape == original_tensor.shape
+    assert not torch.allclose(tensor, original_tensor)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"])
+async def test_guided_logits_processor_black_box(backend: str, sample_regex,
+                                                 sample_json_schema):
+    tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
+    token_ids = tokenizer.encode(
+        f"Give an example IPv4 address with this regex: {sample_regex}")
+    regex_request = CompletionRequest(model='test',
+                                      prompt=token_ids,
+                                      guided_regex=sample_regex)
+    regex_lp = await get_guided_decoding_logits_processor(
+        backend, regex_request, tokenizer)
+    assert regex_lp is not None
+    tensor = torch.rand(32000)
+    original_tensor = torch.clone(tensor)
+    tensor = regex_lp(token_ids, tensor)
+    assert tensor.shape == original_tensor.shape
+    assert not torch.allclose(tensor, original_tensor)
+
+    token_ids = tokenizer.encode(
+        f"Give an employee profile that fits this schema: {sample_json_schema}"
+    )
+    json_request = CompletionRequest(model='test',
+                                     prompt=token_ids,
+                                     guided_json=sample_json_schema)
+    json_lp = await get_guided_decoding_logits_processor(
+        backend, json_request, tokenizer)
+    assert json_lp is not None
+    tensor = torch.rand(32000)
+    original_tensor = torch.clone(tensor)
+    tensor = json_lp(token_ids, tensor)
+    assert tensor.shape == original_tensor.shape
+    assert not torch.allclose(tensor, original_tensor)

+ 179 - 0
tests/endpoints/openai/test_metrics.py

@@ -0,0 +1,179 @@
+from http import HTTPStatus
+
+import openai
+import pytest
+import requests
+from prometheus_client.parser import text_string_to_metric_families
+from transformers import AutoTokenizer
+
+from ...utils import RemoteOpenAIServer
+
+MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
+
+
+@pytest.fixture(scope="module")
+def default_server_args():
+    return [
+        # use half precision for speed and memory savings in CI environment
+        "--dtype",
+        "bfloat16",
+        "--max-model-len",
+        "1024",
+        "--enforce-eager",
+        "--max-num-seqs",
+        "128",
+    ]
+
+
+@pytest.fixture(scope="module",
+                params=[
+                    "",
+                    "--enable-chunked-prefill",
+                    "--disable-frontend-multiprocessing",
+                ])
+def client(default_server_args, request):
+    if request.param:
+        default_server_args.append(request.param)
+    with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
+        yield remote_server.get_async_client()
+
+
+_PROMPT = "Hello my name is Robert and I love magic"
+tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
+_TOKENIZED_PROMPT = tokenizer(_PROMPT)["input_ids"]
+
+_NUM_REQUESTS = 10
+_NUM_PROMPT_TOKENS_PER_REQUEST = len(_TOKENIZED_PROMPT)
+_NUM_GENERATION_TOKENS_PER_REQUEST = 10
+
+# {metric_family: [(suffix, expected_value)]}
+EXPECTED_VALUES = {
+    "aphrodite:time_to_first_token_seconds": [("_count", _NUM_REQUESTS)],
+    "aphrodite:time_per_output_token_seconds":
+    [("_count", _NUM_REQUESTS * (_NUM_GENERATION_TOKENS_PER_REQUEST - 1))],
+    "aphrodite:e2e_request_latency_seconds": [("_count", _NUM_REQUESTS)],
+    "aphrodite:request_prompt_tokens":
+    [("_sum", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST),
+     ("_count", _NUM_REQUESTS)],
+    "aphrodite:request_generation_tokens":
+    [("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST),
+     ("_count", _NUM_REQUESTS)],
+    "aphrodite:request_params_n": [("_count", _NUM_REQUESTS)],
+    "aphrodite:request_params_best_of": [("_count", _NUM_REQUESTS)],
+    "aphrodite:prompt_tokens": [("_total",
+                            _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)],
+    "aphrodite:generation_tokens":
+    [("_total", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)],
+    "aphrodite:request_success": [("_total", _NUM_REQUESTS)],
+}
+
+
+@pytest.mark.asyncio
+async def test_metrics_counts(client: openai.AsyncOpenAI):
+    base_url = str(client.base_url)[:-3].strip("/")
+
+    for _ in range(_NUM_REQUESTS):
+        # sending a request triggers the metrics to be logged.
+        await client.completions.create(
+            model=MODEL_NAME,
+            prompt=_TOKENIZED_PROMPT,
+            max_tokens=_NUM_GENERATION_TOKENS_PER_REQUEST)
+
+    response = requests.get(base_url + "/metrics")
+    print(response.text)
+    assert response.status_code == HTTPStatus.OK
+
+    # Loop over all expected metric_families
+    for metric_family, suffix_values_list in EXPECTED_VALUES.items():
+        found_metric = False
+
+        # Check to see if the metric_family is found in the prom endpoint.
+        for family in text_string_to_metric_families(response.text):
+            if family.name == metric_family:
+                found_metric = True
+
+                # Check that each suffix is found in the prom endpoint.
+                for suffix, expected_value in suffix_values_list:
+                    metric_name_w_suffix = f"{metric_family}{suffix}"
+                    found_suffix = False
+
+                    for sample in family.samples:
+                        if sample.name == metric_name_w_suffix:
+                            found_suffix = True
+
+                            # For each suffix, value sure the value matches
+                            # what we expect.
+                            assert sample.value == expected_value, (
+                                f"{metric_name_w_suffix} expected value of "
+                                f"{expected_value} did not match found value "
+                                f"{sample.value}")
+                            break
+                    assert found_suffix, (
+                        f"Did not find {metric_name_w_suffix} in prom endpoint"
+                    )
+                break
+
+        assert found_metric, (f"Did not find {metric_family} in prom endpoint")
+
+
+EXPECTED_METRICS = [
+    "aphrodite:num_requests_running",
+    "aphrodite:num_requests_swapped",
+    "aphrodite:num_requests_waiting",
+    "aphrodite:gpu_cache_usage_perc",
+    "aphrodite:cpu_cache_usage_perc",
+    "aphrodite:time_to_first_token_seconds_sum",
+    "aphrodite:time_to_first_token_seconds_bucket",
+    "aphrodite:time_to_first_token_seconds_count",
+    "aphrodite:time_per_output_token_seconds_sum",
+    "aphrodite:time_per_output_token_seconds_bucket",
+    "aphrodite:time_per_output_token_seconds_count",
+    "aphrodite:e2e_request_latency_seconds_sum",
+    "aphrodite:e2e_request_latency_seconds_bucket",
+    "aphrodite:e2e_request_latency_seconds_count",
+    "aphrodite:request_prompt_tokens_sum",
+    "aphrodite:request_prompt_tokens_bucket",
+    "aphrodite:request_prompt_tokens_count",
+    "aphrodite:request_generation_tokens_sum",
+    "aphrodite:request_generation_tokens_bucket",
+    "aphrodite:request_generation_tokens_count",
+    "aphrodite:request_params_n_sum",
+    "aphrodite:request_params_n_bucket",
+    "aphrodite:request_params_n_count",
+    "aphrodite:request_params_best_of_sum",
+    "aphrodite:request_params_best_of_bucket",
+    "aphrodite:request_params_best_of_count",
+    "aphrodite:num_preemptions_total",
+    "aphrodite:prompt_tokens_total",
+    "aphrodite:generation_tokens_total",
+    "aphrodite:request_success_total",
+    "aphrodite:cache_config_info",
+    # labels in cache_config_info
+    "block_size",
+    "cache_dtype",
+    "cpu_offload_gb",
+    "enable_prefix_caching",
+    "gpu_memory_utilization",
+    "num_cpu_blocks",
+    "num_gpu_blocks",
+    "num_gpu_blocks_override",
+    "sliding_window",
+    "swap_space_bytes",
+]
+
+
+@pytest.mark.asyncio
+async def test_metrics_exist(client: openai.AsyncOpenAI):
+    base_url = str(client.base_url)[:-3].strip("/")
+
+    # sending a request triggers the metrics to be logged.
+    await client.completions.create(model=MODEL_NAME,
+                                    prompt="Hello, my name is",
+                                    max_tokens=5,
+                                    temperature=0.0)
+
+    response = requests.get(base_url + "/metrics")
+    assert response.status_code == HTTPStatus.OK
+
+    for metric in EXPECTED_METRICS:
+        assert metric in response.text

+ 60 - 0
tests/endpoints/openai/test_models.py

@@ -0,0 +1,60 @@
+import openai  # use the official client for correctness check
+import pytest
+# downloading lora to test lora requests
+from huggingface_hub import snapshot_download
+
+from ...utils import RemoteOpenAIServer
+
+# any model with a chat template should work here
+MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
+# technically this needs Mistral-7B-v0.1 as base, but we're not testing
+# generation quality here
+LORA_NAME = "typeof/zephyr-7b-beta-lora"
+
+
+@pytest.fixture(scope="module")
+def zephyr_lora_files():
+    return snapshot_download(repo_id=LORA_NAME)
+
+
+@pytest.fixture(scope="module")
+def server(zephyr_lora_files):
+    args = [
+        # use half precision for speed and memory savings in CI environment
+        "--dtype",
+        "bfloat16",
+        "--max-model-len",
+        "8192",
+        "--enforce-eager",
+        # lora config below
+        "--enable-lora",
+        "--lora-modules",
+        f"zephyr-lora={zephyr_lora_files}",
+        f"zephyr-lora2={zephyr_lora_files}",
+        "--max-lora-rank",
+        "64",
+        "--max-cpu-loras",
+        "2",
+        "--max-num-seqs",
+        "128",
+    ]
+
+    with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
+        yield remote_server
+
+
+@pytest.fixture(scope="module")
+def client(server):
+    return server.get_async_client()
+
+
+@pytest.mark.asyncio
+async def test_check_models(client: openai.AsyncOpenAI):
+    models = await client.models.list()
+    models = models.data
+    served_model = models[0]
+    lora_models = models[1:]
+    assert served_model.id == MODEL_NAME
+    assert all(model.root == MODEL_NAME for model in models)
+    assert lora_models[0].id == "zephyr-lora"
+    assert lora_models[1].id == "zephyr-lora2"

+ 38 - 0
tests/endpoints/openai/test_mp_api_server.py

@@ -0,0 +1,38 @@
+import pytest
+
+from aphrodite.common.utils import FlexibleArgumentParser
+from aphrodite.endpoints.openai.api_server import build_async_engine_client
+from aphrodite.endpoints.openai.args import make_arg_parser
+
+
+@pytest.mark.asyncio
+async def test_mp_crash_detection():
+
+    with pytest.raises(RuntimeError) as excinfo:
+        parser = FlexibleArgumentParser(
+            description="Aphrodite's remote OpenAI server.")
+        parser = make_arg_parser(parser)
+        args = parser.parse_args([])
+        # use an invalid tensor_parallel_size to trigger the
+        # error in the server
+        args.tensor_parallel_size = 65536
+
+        async with build_async_engine_client(args):
+            pass
+    assert "The server process died before responding to the readiness probe"\
+          in str(excinfo.value)
+
+
+@pytest.mark.asyncio
+async def test_mp_cuda_init():
+    # it should not crash, when cuda is initialized
+    # in the API server process
+    import torch
+    torch.cuda.init()
+    parser = FlexibleArgumentParser(
+        description="Aphrodite's remote OpenAI server.")
+    parser = make_arg_parser(parser)
+    args = parser.parse_args([])
+
+    async with build_async_engine_client(args):
+        pass

+ 42 - 0
tests/endpoints/openai/test_oot_registration.py

@@ -0,0 +1,42 @@
+from ...utils import APHRODITE_PATH, RemoteOpenAIServer
+
+chatml_jinja_path = APHRODITE_PATH / "examples/chat_templates/chatml.jinja"
+assert chatml_jinja_path.exists()
+
+
+def run_and_test_dummy_opt_api_server(model, tp=1):
+    # the model is registered through the plugin
+    server_args = [
+        "--gpu-memory-utilization",
+        "0.10",
+        "--dtype",
+        "float32",
+        "--chat-template",
+        str(chatml_jinja_path),
+        "--load-format",
+        "dummy",
+        "-tp",
+        f"{tp}",
+    ]
+    with RemoteOpenAIServer(model, server_args) as server:
+        client = server.get_client()
+        completion = client.chat.completions.create(
+            model=model,
+            messages=[{
+                "role": "system",
+                "content": "You are a helpful assistant."
+            }, {
+                "role": "user",
+                "content": "Hello!"
+            }],
+            temperature=0,
+        )
+        generated_text = completion.choices[0].message.content
+        assert generated_text is not None
+        # make sure only the first token is generated
+        rest = generated_text.replace("<s>", "")
+        assert rest == ""
+
+
+def test_oot_registration_for_api_server(dummy_opt_path: str):
+    run_and_test_dummy_opt_api_server(dummy_opt_path)

+ 83 - 0
tests/endpoints/openai/test_return_tokens_as_ids.py

@@ -0,0 +1,83 @@
+# Separate these tests out from test_completion and test_chat, because they
+# require launching a second server with a different flag. Running both servers
+# at the same time on a single node will OOM.
+
+import pytest
+
+from aphrodite.transformers_utils.tokenizer import get_tokenizer
+
+from ...utils import RemoteOpenAIServer
+from .test_completion import default_server_args  # noqa: F401
+from .test_completion import zephyr_lora_added_tokens_files  # noqa: F401
+from .test_completion import zephyr_lora_files  # noqa: F401
+from .test_completion import zephyr_pa_files  # noqa: F401
+from .test_completion import MODEL_NAME
+
+
+@pytest.fixture(scope="module")
+def server_with_return_tokens_as_token_ids_flag(
+        default_server_args):  # noqa: F811
+    args_with_flag = default_server_args + ["--return-tokens-as-token-ids"]
+    with RemoteOpenAIServer(MODEL_NAME, args_with_flag) as remote_server:
+        yield remote_server
+
+
+@pytest.mark.asyncio
+async def test_completion_return_tokens_as_token_ids_completion(
+        server_with_return_tokens_as_token_ids_flag):
+    client = server_with_return_tokens_as_token_ids_flag.get_async_client()
+
+    completion = await client.completions.create(
+        model=MODEL_NAME,
+        # Include Unicode characters to test for dividing a single
+        # character across multiple tokens: 🎉 is [28705, 31862] for the
+        # Zephyr tokenizer
+        prompt="Say 'Hello, world! 🎉'",
+        echo=True,
+        temperature=0,
+        max_tokens=10,
+        logprobs=1)
+
+    text = completion.choices[0].text
+    token_strs = completion.choices[0].logprobs.tokens
+    tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
+    # Check that the token representations are consistent between raw tokens
+    # and top_logprobs
+    # Slice off the first one, because there's no scoring associated with BOS
+    top_logprobs = completion.choices[0].logprobs.top_logprobs[1:]
+    top_logprob_keys = [
+        next(iter(logprob_by_tokens)) for logprob_by_tokens in top_logprobs
+    ]
+    assert token_strs[1:] == top_logprob_keys
+
+    # Check that decoding the tokens gives the expected text
+    tokens = [int(token.removeprefix("token_id:")) for token in token_strs]
+    assert text == tokenizer.decode(tokens, skip_special_tokens=True)
+
+
+@pytest.mark.asyncio
+async def test_chat_return_tokens_as_token_ids_completion(
+        server_with_return_tokens_as_token_ids_flag):
+    client = server_with_return_tokens_as_token_ids_flag.get_async_client()
+    response = await client.chat.completions.create(
+        model=MODEL_NAME,
+        # Include Unicode characters to test for dividing a single
+        # character across multiple tokens: 🎉 is [28705, 31862] for the
+        # Zephyr tokenizer
+        messages=[{
+            "role": "system",
+            "content": "You like to respond in only emojis, like 🎉"
+        }, {
+            "role": "user",
+            "content": "Please write some emojis: 🐱🐶🎉"
+        }],
+        temperature=0,
+        max_tokens=8,
+        logprobs=True)
+
+    text = response.choices[0].message.content
+    tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
+    token_ids = []
+    for logprob_content in response.choices[0].logprobs.content:
+        token_ids.append(int(logprob_content.token.removeprefix("token_id:")))
+    assert tokenizer.decode(token_ids, skip_special_tokens=True) == text

+ 102 - 0
tests/endpoints/openai/test_run_batch.py

@@ -0,0 +1,102 @@
+import subprocess
+import sys
+import tempfile
+
+from aphrodite.endpoints.openai.protocol import BatchRequestOutput
+
+# ruff: noqa: E501
+INPUT_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}
+{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}
+
+{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NonExistModel", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}"""
+
+INVALID_INPUT_BATCH = """{"invalid_field": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}
+{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}"""
+
+INPUT_EMBEDDING_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are a helpful assistant."}}
+{"custom_id": "request-2", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are an unhelpful assistant."}}
+
+{"custom_id": "request-3", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "Hello world!"}}
+{"custom_id": "request-4", "method": "POST", "url": "/v1/embeddings", "body": {"model": "NonExistModel", "input": "Hello world!"}}"""
+
+
+def test_empty_file():
+    with tempfile.NamedTemporaryFile(
+            "w") as input_file, tempfile.NamedTemporaryFile(
+                "r") as output_file:
+        input_file.write("")
+        input_file.flush()
+        proc = subprocess.Popen([
+            sys.executable, "-m", "aphrodite.endpoints.openai.run_batch", "-i",
+            input_file.name, "-o", output_file.name, "--model",
+            "intfloat/e5-mistral-7b-instruct"
+        ], )
+        proc.communicate()
+        proc.wait()
+        assert proc.returncode == 0, f"{proc=}"
+
+        contents = output_file.read()
+        assert contents.strip() == ""
+
+
+def test_completions():
+    with tempfile.NamedTemporaryFile(
+            "w") as input_file, tempfile.NamedTemporaryFile(
+                "r") as output_file:
+        input_file.write(INPUT_BATCH)
+        input_file.flush()
+        proc = subprocess.Popen([
+            sys.executable, "-m", "aphrodite.endpoints.openai.run_batch", "-i",
+            input_file.name, "-o", output_file.name, "--model",
+            "NousResearch/Meta-Llama-3-8B-Instruct"
+        ], )
+        proc.communicate()
+        proc.wait()
+        assert proc.returncode == 0, f"{proc=}"
+
+        contents = output_file.read()
+        for line in contents.strip().split("\n"):
+            # Ensure that the output format conforms to the openai api.
+            # Validation should throw if the schema is wrong.
+            BatchRequestOutput.model_validate_json(line)
+
+
+def test_completions_invalid_input():
+    """
+    Ensure that we fail when the input doesn't conform to the openai api.
+    """
+    with tempfile.NamedTemporaryFile(
+            "w") as input_file, tempfile.NamedTemporaryFile(
+                "r") as output_file:
+        input_file.write(INVALID_INPUT_BATCH)
+        input_file.flush()
+        proc = subprocess.Popen([
+            sys.executable, "-m", "aphrodite.endpoints.openai.run_batch", "-i",
+            input_file.name, "-o", output_file.name, "--model",
+            "NousResearch/Meta-Llama-3-8B-Instruct"
+        ], )
+        proc.communicate()
+        proc.wait()
+        assert proc.returncode != 0, f"{proc=}"
+
+
+def test_embeddings():
+    with tempfile.NamedTemporaryFile(
+            "w") as input_file, tempfile.NamedTemporaryFile(
+                "r") as output_file:
+        input_file.write(INPUT_EMBEDDING_BATCH)
+        input_file.flush()
+        proc = subprocess.Popen([
+            sys.executable, "-m", "aphrodite.endpoints.openai.run_batch", "-i",
+            input_file.name, "-o", output_file.name, "--model",
+            "intfloat/e5-mistral-7b-instruct"
+        ], )
+        proc.communicate()
+        proc.wait()
+        assert proc.returncode == 0, f"{proc=}"
+
+        contents = output_file.read()
+        for line in contents.strip().split("\n"):
+            # Ensure that the output format conforms to the openai api.
+            # Validation should throw if the schema is wrong.
+            BatchRequestOutput.model_validate_json(line)

+ 82 - 0
tests/endpoints/openai/test_serving_chat.py

@@ -0,0 +1,82 @@
+import asyncio
+from contextlib import suppress
+from dataclasses import dataclass
+from unittest.mock import MagicMock
+
+from aphrodite.endpoints.openai.protocol import ChatCompletionRequest
+from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
+from aphrodite.engine.async_aphrodite import AsyncAphrodite
+from aphrodite.transformers_utils.tokenizer import get_tokenizer
+
+MODEL_NAME = "openai-community/gpt2"
+CHAT_TEMPLATE = "Dummy chat template for testing {}"
+
+
+@dataclass
+class MockModelConfig:
+    tokenizer = MODEL_NAME
+    trust_remote_code = False
+    tokenizer_mode = "auto"
+    max_model_len = 100
+    tokenizer_revision = None
+    embedding_mode = False
+
+
+@dataclass
+class MockEngine:
+
+    async def get_model_config(self):
+        return MockModelConfig()
+
+
+async def _async_serving_chat_init():
+    engine = MockEngine()
+    model_config = await engine.get_model_config()
+
+    serving_completion = OpenAIServingChat(engine,
+                                           model_config,
+                                           served_model_names=[MODEL_NAME],
+                                           response_role="assistant",
+                                           chat_template=CHAT_TEMPLATE,
+                                           lora_modules=None,
+                                           prompt_adapters=None,
+                                           request_logger=None)
+    return serving_completion
+
+
+def test_async_serving_chat_init():
+    serving_completion = asyncio.run(_async_serving_chat_init())
+    assert serving_completion.chat_template == CHAT_TEMPLATE
+
+
+def test_serving_chat_should_set_correct_max_tokens():
+    mock_engine = MagicMock(spec=AsyncAphrodite)
+    mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
+
+    serving_chat = OpenAIServingChat(mock_engine,
+                                     MockModelConfig(),
+                                     served_model_names=[MODEL_NAME],
+                                     response_role="assistant",
+                                     chat_template=CHAT_TEMPLATE,
+                                     lora_modules=None,
+                                     prompt_adapters=None,
+                                     request_logger=None)
+    req = ChatCompletionRequest(
+        model=MODEL_NAME,
+        messages=[{
+            "role": "user",
+            "content": "what is 1+1?"
+        }],
+        guided_decoding_backend="outlines",
+    )
+
+    with suppress(Exception):
+        asyncio.run(serving_chat.create_chat_completion(req))
+
+    assert mock_engine.generate.call_args.args[1].max_tokens == 93
+
+    req.max_tokens = 10
+    with suppress(Exception):
+        asyncio.run(serving_chat.create_chat_completion(req))
+
+    assert mock_engine.generate.call_args.args[1].max_tokens == 10

+ 47 - 0
tests/endpoints/openai/test_shutdown.py

@@ -0,0 +1,47 @@
+import json
+import os
+
+import openai
+import pytest
+
+from ...utils import RemoteOpenAIServer
+
+MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
+
+
+@pytest.mark.asyncio
+async def test_shutdown_on_engine_failure(tmp_path):
+    # Use a bad adapter to crash the engine
+    # (This test will fail when that bug is fixed)
+    adapter_path = tmp_path / "bad_adapter"
+    os.mkdir(adapter_path)
+    with open(adapter_path / "adapter_model_config.json", "w") as f:
+        json.dump({"not": "real"}, f)
+    with open(adapter_path / "adapter_model.safetensors", "wb") as f:
+        f.write(b"this is fake")
+
+    # dtype, max-len etc set so that this can run in CI
+    args = [
+        "--dtype",
+        "bfloat16",
+        "--max-model-len",
+        "8192",
+        "--enforce-eager",
+        "--max-num-seqs",
+        "128",
+        "--enable-lora",
+        "--lora-modules",
+        f"bad-adapter={tmp_path / 'bad_adapter'}",
+    ]
+
+    with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
+        client = remote_server.get_async_client()
+
+        with pytest.raises(openai.APIConnectionError):
+            # This crashes the engine
+            await client.completions.create(model="bad-adapter",
+                                            prompt="Hello, my name is")
+
+        # Now the server should shut down
+        return_code = remote_server.proc.wait(timeout=1)
+        assert return_code is not None

+ 152 - 0
tests/endpoints/openai/test_tokenization.py

@@ -0,0 +1,152 @@
+import openai  # use the official client for correctness check
+import pytest
+import requests
+
+from aphrodite.transformers_utils.tokenizer import get_tokenizer
+
+from ...utils import RemoteOpenAIServer
+from .test_completion import zephyr_lora_added_tokens_files  # noqa: F401
+from .test_completion import zephyr_lora_files  # noqa: F401
+
+# any model with a chat template should work here
+MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
+
+
+@pytest.fixture(scope="module")
+def server(zephyr_lora_added_tokens_files: str):  # noqa: F811
+    args = [
+        # use half precision for speed and memory savings in CI environment
+        "--dtype",
+        "bfloat16",
+        "--max-model-len",
+        "8192",
+        "--enforce-eager",
+        "--max-num-seqs",
+        "128",
+        # lora config
+        "--enable-lora",
+        "--lora-modules",
+        f"zephyr-lora2={zephyr_lora_added_tokens_files}",
+        "--max-lora-rank",
+        "64",
+    ]
+
+    with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
+        yield remote_server
+
+
+@pytest.fixture(scope="module")
+def tokenizer_name(model_name: str,
+                   zephyr_lora_added_tokens_files: str):  # noqa: F811
+    return zephyr_lora_added_tokens_files if (
+        model_name == "zephyr-lora2") else model_name
+
+
+@pytest.fixture(scope="module")
+def client(server):
+    return server.get_async_client()
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "model_name,tokenizer_name",
+    [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
+    indirect=["tokenizer_name"],
+)
+async def test_tokenize_completions(client: openai.AsyncOpenAI,
+                                    model_name: str, tokenizer_name: str):
+    base_url = str(client.base_url)[:-3].strip("/")
+    tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
+                              tokenizer_mode="fast")
+
+    for add_special in [False, True]:
+        prompt = "aphrodite1 This is a test prompt."
+        tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
+
+        response = requests.post(base_url + "/tokenize",
+                                 json={
+                                     "add_special_tokens": add_special,
+                                     "model": model_name,
+                                     "prompt": prompt
+                                 })
+        response.raise_for_status()
+
+        assert response.json() == {
+            "tokens": tokens,
+            "count": len(tokens),
+            "max_model_len": 8192
+        }
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "model_name,tokenizer_name",
+    [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
+    indirect=["tokenizer_name"],
+)
+async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str,
+                             tokenizer_name: str):
+    base_url = str(client.base_url)[:-3].strip("/")
+    tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
+                              tokenizer_mode="fast")
+
+    for add_generation in [False, True]:
+        for add_special in [False, True]:
+            conversation = [{
+                "role": "user",
+                "content": "Hi there!"
+            }, {
+                "role": "assistant",
+                "content": "Nice to meet you!"
+            }, {
+                "role": "user",
+                "content": "Can I ask a question? aphrodite1"
+            }]
+
+            prompt = tokenizer.apply_chat_template(
+                add_generation_prompt=add_generation,
+                conversation=conversation,
+                tokenize=False)
+            tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
+
+            response = requests.post(base_url + "/tokenize",
+                                     json={
+                                         "add_generation_prompt":
+                                         add_generation,
+                                         "add_special_tokens": add_special,
+                                         "messages": conversation,
+                                         "model": model_name
+                                     })
+            response.raise_for_status()
+
+            assert response.json() == {
+                "tokens": tokens,
+                "count": len(tokens),
+                "max_model_len": 8192
+            }
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+    "model_name,tokenizer_name",
+    [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
+    indirect=["tokenizer_name"],
+)
+async def test_detokenize(client: openai.AsyncOpenAI, model_name: str,
+                          tokenizer_name: str):
+    base_url = str(client.base_url)[:-3].strip("/")
+    tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
+                              tokenizer_mode="fast")
+
+    prompt = "This is a test prompt. aphrodite1"
+    tokens = tokenizer.encode(prompt, add_special_tokens=False)
+
+    print(f"CALLING {base_url} FOR {model_name}")
+    response = requests.post(base_url + "/detokenize",
+                             json={
+                                 "model": model_name,
+                                 "tokens": tokens
+                             })
+    response.raise_for_status()
+
+    assert response.json() == {"prompt": prompt}

+ 261 - 0
tests/endpoints/openai/test_vision.py

@@ -0,0 +1,261 @@
+from typing import Dict, List
+
+import openai
+import pytest
+
+from aphrodite.multimodal.utils import encode_image_base64, fetch_image
+
+from ...utils import APHRODITE_PATH, RemoteOpenAIServer
+
+MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
+LLAVA_CHAT_TEMPLATE = APHRODITE_PATH / "examples/template_llava.jinja"
+assert LLAVA_CHAT_TEMPLATE.exists()
+
+# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
+TEST_IMAGE_URLS = [
+    "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
+    "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png",
+    "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png",
+    "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
+]
+
+
+@pytest.fixture(scope="module")
+def server():
+    args = [
+        "--dtype",
+        "bfloat16",
+        "--max-model-len",
+        "4096",
+        "--enforce-eager",
+        "--chat-template",
+        str(LLAVA_CHAT_TEMPLATE),
+    ]
+
+    with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
+        yield remote_server
+
+
+@pytest.fixture(scope="module")
+def client(server):
+    return server.get_async_client()
+
+
+@pytest.fixture(scope="session")
+def base64_encoded_image() -> Dict[str, str]:
+    return {
+        image_url: encode_image_base64(fetch_image(image_url))
+        for image_url in TEST_IMAGE_URLS
+    }
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
+async def test_single_chat_session_image(client: openai.AsyncOpenAI,
+                                         model_name: str, image_url: str):
+    messages = [{
+        "role":
+        "user",
+        "content": [
+            {
+                "type": "image_url",
+                "image_url": {
+                    "url": image_url
+                }
+            },
+            {
+                "type": "text",
+                "text": "What's in this image?"
+            },
+        ],
+    }]
+
+    # test single completion
+    chat_completion = await client.chat.completions.create(model=model_name,
+                                                           messages=messages,
+                                                           max_tokens=10,
+                                                           logprobs=True,
+                                                           top_logprobs=5)
+    assert len(chat_completion.choices) == 1
+
+    choice = chat_completion.choices[0]
+    assert choice.finish_reason == "length"
+    assert chat_completion.usage == openai.types.CompletionUsage(
+        completion_tokens=10, prompt_tokens=596, total_tokens=606)
+
+    message = choice.message
+    message = chat_completion.choices[0].message
+    assert message.content is not None and len(message.content) >= 10
+    assert message.role == "assistant"
+    messages.append({"role": "assistant", "content": message.content})
+
+    # test multi-turn dialogue
+    messages.append({"role": "user", "content": "express your result in json"})
+    chat_completion = await client.chat.completions.create(
+        model=model_name,
+        messages=messages,
+        max_tokens=10,
+    )
+    message = chat_completion.choices[0].message
+    assert message.content is not None and len(message.content) >= 0
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
+async def test_single_chat_session_image_base64encoded(
+        client: openai.AsyncOpenAI, model_name: str, image_url: str,
+        base64_encoded_image: Dict[str, str]):
+
+    messages = [{
+        "role":
+        "user",
+        "content": [
+            {
+                "type": "image_url",
+                "image_url": {
+                    "url":
+                    f"data:image/jpeg;base64,{base64_encoded_image[image_url]}"
+                }
+            },
+            {
+                "type": "text",
+                "text": "What's in this image?"
+            },
+        ],
+    }]
+
+    # test single completion
+    chat_completion = await client.chat.completions.create(model=model_name,
+                                                           messages=messages,
+                                                           max_tokens=10,
+                                                           logprobs=True,
+                                                           top_logprobs=5)
+    assert len(chat_completion.choices) == 1
+
+    choice = chat_completion.choices[0]
+    assert choice.finish_reason == "length"
+    assert chat_completion.usage == openai.types.CompletionUsage(
+        completion_tokens=10, prompt_tokens=596, total_tokens=606)
+
+    message = choice.message
+    message = chat_completion.choices[0].message
+    assert message.content is not None and len(message.content) >= 10
+    assert message.role == "assistant"
+    messages.append({"role": "assistant", "content": message.content})
+
+    # test multi-turn dialogue
+    messages.append({"role": "user", "content": "express your result in json"})
+    chat_completion = await client.chat.completions.create(
+        model=model_name,
+        messages=messages,
+        max_tokens=10,
+    )
+    message = chat_completion.choices[0].message
+    assert message.content is not None and len(message.content) >= 0
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
+async def test_chat_streaming_image(client: openai.AsyncOpenAI,
+                                    model_name: str, image_url: str):
+    messages = [{
+        "role":
+        "user",
+        "content": [
+            {
+                "type": "image_url",
+                "image_url": {
+                    "url": image_url
+                }
+            },
+            {
+                "type": "text",
+                "text": "What's in this image?"
+            },
+        ],
+    }]
+
+    # test single completion
+    chat_completion = await client.chat.completions.create(
+        model=model_name,
+        messages=messages,
+        max_tokens=10,
+        temperature=0.0,
+    )
+    output = chat_completion.choices[0].message.content
+    stop_reason = chat_completion.choices[0].finish_reason
+
+    # test streaming
+    stream = await client.chat.completions.create(
+        model=model_name,
+        messages=messages,
+        max_tokens=10,
+        temperature=0.0,
+        stream=True,
+    )
+    chunks: List[str] = []
+    finish_reason_count = 0
+    async for chunk in stream:
+        delta = chunk.choices[0].delta
+        if delta.role:
+            assert delta.role == "assistant"
+        if delta.content:
+            chunks.append(delta.content)
+        if chunk.choices[0].finish_reason is not None:
+            finish_reason_count += 1
+    # finish reason should only return in last block
+    assert finish_reason_count == 1
+    assert chunk.choices[0].finish_reason == stop_reason
+    assert delta.content
+    assert "".join(chunks) == output
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
+async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str,
+                                 image_url: str):
+
+    messages = [{
+        "role":
+        "user",
+        "content": [
+            {
+                "type": "image_url",
+                "image_url": {
+                    "url": image_url
+                }
+            },
+            {
+                "type": "image_url",
+                "image_url": {
+                    "url": image_url
+                }
+            },
+            {
+                "type": "text",
+                "text": "What's in this image?"
+            },
+        ],
+    }]
+
+    with pytest.raises(openai.BadRequestError):  # test multi-image input
+        await client.chat.completions.create(
+            model=model_name,
+            messages=messages,
+            max_tokens=10,
+            temperature=0.0,
+        )
+
+    # the server should still work afterwards
+    completion = await client.completions.create(
+        model=model_name,
+        prompt=[0, 0, 0, 0, 0],
+        max_tokens=5,
+        temperature=0.0,
+    )
+    completion = completion.choices[0].text
+    assert completion is not None and len(completion) >= 0

+ 0 - 30
tests/endpoints/test_llm_generate.py

@@ -1,30 +0,0 @@
-import pytest
-
-from aphrodite import LLM, SamplingParams
-
-
-def test_multiple_sampling_params():
-    llm = LLM(model='gpt2', max_num_batched_tokens=1024)
-    prompts = [
-        "Once upon a time",
-        "In a galaxy far far away",
-        "The quick brown fox jumps over the lazy dog",
-    ]
-    sampling_params = [
-        SamplingParams(temperature=0.7, min_p=0.06),
-        SamplingParams(temperature=0.8, min_p=0.07),
-        SamplingParams(temperature=0.9, min_p=0.08),
-    ]
-
-    outputs = llm.generate(prompts, sampling_params=sampling_params)
-    assert len(prompts) == len(outputs)
-
-    with pytest.raises(ValueError):
-        outputs = llm.generate(prompts, sampling_params=sampling_params[:2])
-
-        single_sampling_params = SamplingParams(temperature=0.7, min_p=0.06)
-        outputs = llm.generate(prompts, sampling_params=single_sampling_params)
-        assert len(prompts) == len(outputs)
-
-        outputs = llm.generate(prompts, sampling_params=None)
-        assert len(prompts) == len(outputs)

+ 0 - 614
tests/endpoints/test_openai_server.py

@@ -1,614 +0,0 @@
-# imports for guided decoding tests
-import json
-import os
-import re
-import subprocess
-import sys
-import time
-
-import jsonschema
-import openai  # use the official client for correctness check
-import pytest
-import ray
-import requests
-from huggingface_hub import snapshot_download
-
-from aphrodite.transformers_utils.tokenizer import get_tokenizer
-
-MAX_SERVER_START_WAIT_S = 600  # wait for server to start for 60 seconds
-MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
-LORA_NAME = "typeof/zephyr-7b-beta-lora"
-
-TEST_SCHEMA = {
-    "type": "object",
-    "properties": {
-        "name": {
-            "type": "string"
-        },
-        "age": {
-            "type": "integer"
-        },
-        "skills": {
-            "type": "array",
-            "items": {
-                "type": "string",
-                "maxLength": 10
-            },
-            "minItems": 3
-        },
-        "work history": {
-            "type": "array",
-            "items": {
-                "type": "object",
-                "properties": {
-                    "company": {
-                        "type": "string"
-                    },
-                    "duration": {
-                        "type": "string"
-                    },
-                    "position": {
-                        "type": "string"
-                    }
-                },
-                "required": ["company", "position"]
-            }
-        }
-    },
-    "required": ["name", "age", "skills", "work history"]
-}
-
-TEST_REGEX = r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + \
-             r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)"
-
-TEST_CHOICE = [
-    "Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", "Ruby",
-    "Swift", "Kotlin"
-]
-
-pytestmark = pytest.mark.asyncio
-
-
-@ray.remote(num_gpus=1)
-class ServerRunner:
-
-    def __init__(self, args):
-        env = os.environ.copy()
-        env["PYTHONUNBUFFERED"] = "1"
-        self.proc = subprocess.Popen(
-            ["python3", "-m", "aphrodite.endpoints.openai.api_server"] + args,
-            env=env,
-            stdout=sys.stdout,
-            stderr=sys.stderr,
-        )
-        self._wait_for_server()
-
-    def ready(self):
-        return True
-
-    def _wait_for_server(self):
-        # run health check
-        start = time.time()
-        while True:
-            try:
-                if requests.get(
-                        "http://localhost:2242/health").status_code == 200:
-                    break
-            except Exception as err:
-                if self.proc.poll() is not None:
-                    raise RuntimeError("Server exited unexpectedly.") from err
-
-                time.sleep(0.5)
-                if time.time() - start > MAX_SERVER_START_WAIT_S:
-                    raise RuntimeError(
-                        "Server failed to start in time.") from err
-
-    def __del__(self):
-        if hasattr(self, "proc"):
-            self.proc.terminate()
-
-
-@pytest.fixture(scope="session")
-def zephyr_lora_files():
-    return snapshot_download(repo_id=LORA_NAME)
-
-
-@pytest.fixture(scope="session")
-def server(zephyr_lora_files):
-    ray.init()
-    server_runner = ServerRunner.remote([
-        "--model",
-        MODEL_NAME,
-        "--dtype",
-        "bfloat16",  # use half precision for speed and memory savings in CI env
-        "--max-model-len",
-        "8192",
-        "--enforce-eager",
-        # lora config below
-        "--enable-lora",
-        "--lora-modules",
-        f"zephyr-lora={zephyr_lora_files}",
-        f"zephyr-lora2={zephyr_lora_files}",
-        "--max-lora-rank",
-        "64",
-        "--max-cpu-loras",
-        "2",
-        "--max-num-seqs",
-        "128"
-    ])
-    ray.get(server_runner.ready.remote())
-    yield server_runner
-    ray.shutdown()
-
-
-@pytest.fixture(scope="session")
-def client():
-    client = openai.AsyncOpenAI(
-        base_url="http://localhost:2242/v1",
-        api_key="",
-    )
-    yield client
-
-
-async def test_check_models(server, client: openai.AsyncOpenAI):
-    models = await client.models.list()
-    models = models.data
-    served_model = models[0]
-    lora_models = models[1:]
-    assert served_model.id == MODEL_NAME
-    assert all(model.root == MODEL_NAME for model in models)
-    assert lora_models[0].id == "zephyr-lora"
-    assert lora_models[1].id == "zephyr-lora2"
-
-
-@pytest.mark.parametrize(
-    # first test base model, then test loras
-    "model_name",
-    [MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
-)
-async def test_single_completion(server, client: openai.AsyncOpenAI,
-                                 model_name: str):
-    completion = await client.completions.create(model=model_name,
-                                                 prompt="Hello, my name is",
-                                                 max_tokens=5,
-                                                 temperature=0.0)
-
-    assert completion.id is not None
-    assert completion.choices is not None and len(completion.choices) == 1
-    assert completion.choices[0].text is not None and len(
-        completion.choices[0].text) >= 5
-    assert completion.choices[0].finish_reason == "length"
-    assert completion.usage == openai.types.CompletionUsage(
-        completion_tokens=5, prompt_tokens=6, total_tokens=11)
-
-    # test using token IDs
-    completion = await client.completions.create(
-        model=MODEL_NAME,
-        prompt=[0, 0, 0, 0, 0],
-        max_tokens=5,
-        temperature=0.0,
-    )
-    assert completion.choices[0].text is not None and len(
-        completion.choices[0].text) >= 5
-
-
-@pytest.mark.parametrize(
-    # just test 1 lora hereafter
-    "model_name",
-    [MODEL_NAME, "zephyr-lora"],
-)
-async def test_single_chat_session(server, client: openai.AsyncOpenAI,
-                                   model_name: str):
-    messages = [{
-        "role": "system",
-        "content": "you are a helpful assistant"
-    }, {
-        "role": "user",
-        "content": "what is 1+1?"
-    }]
-
-    # test single completion
-    chat_completion = await client.chat.completions.create(model=model_name,
-                                                           messages=messages,
-                                                           max_tokens=10,
-                                                           logprobs=True,
-                                                           top_logprobs=10)
-    assert chat_completion.id is not None
-    assert chat_completion.choices is not None and len(
-        chat_completion.choices) == 1
-    assert chat_completion.choices[0].message is not None
-    assert chat_completion.choices[0].logprobs is not None
-    assert chat_completion.choices[0].logprobs.top_logprobs is not None
-    assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 10
-    message = chat_completion.choices[0].message
-    assert message.content is not None and len(message.content) >= 10
-    assert message.role == "assistant"
-    messages.append({"role": "assistant", "content": message.content})
-
-    # test multi-turn dialogue
-    messages.append({"role": "user", "content": "express your result in json"})
-    chat_completion = await client.chat.completions.create(
-        model=MODEL_NAME,
-        messages=messages,
-        max_tokens=10,
-    )
-    message = chat_completion.choices[0].message
-    assert message.content is not None and len(message.content) >= 0
-
-
-@pytest.mark.parametrize(
-    # just test 1 lora hereafter
-    "model_name",
-    [MODEL_NAME, "zephyr-lora"],
-)
-async def test_completion_streaming(server, client: openai.AsyncOpenAI,
-                                    model_name: str):
-    prompt = "What is an LLM?"
-
-    single_completion = await client.completions.create(
-        model=model_name,
-        prompt=prompt,
-        max_tokens=5,
-        temperature=0.0,
-    )
-    single_output = single_completion.choices[0].text
-    single_usage = single_completion.usage
-
-    stream = await client.completions.create(model=model_name,
-                                             prompt=prompt,
-                                             max_tokens=5,
-                                             temperature=0.0,
-                                             stream=True)
-    chunks = []
-    async for chunk in stream:
-        chunks.append(chunk.choices[0].text)
-    assert chunk.choices[0].finish_reason == "length"
-    assert chunk.usage == single_usage
-    assert "".join(chunks) == single_output
-
-
-@pytest.mark.parametrize(
-    # just test 1 lora hereafter
-    "model_name",
-    [MODEL_NAME, "zephyr-lora"],
-)
-async def test_chat_streaming(server, client: openai.AsyncOpenAI,
-                              model_name: str):
-    messages = [{
-        "role": "system",
-        "content": "you are a helpful assistant"
-    }, {
-        "role": "user",
-        "content": "what is 1+1?"
-    }]
-
-    # test single completion
-    chat_completion = await client.chat.completions.create(
-        model=model_name,
-        messages=messages,
-        max_tokens=10,
-        temperature=0.0,
-    )
-    output = chat_completion.choices[0].message.content
-    stop_reason = chat_completion.choices[0].finish_reason
-
-    # test streaming
-    stream = await client.chat.completions.create(
-        model=model_name,
-        messages=messages,
-        max_tokens=10,
-        temperature=0.0,
-        stream=True,
-    )
-    chunks = []
-    async for chunk in stream:
-        delta = chunk.choices[0].delta
-        if delta.role:
-            assert delta.role == "assistant"
-        if delta.content:
-            chunks.append(delta.content)
-    assert chunk.choices[0].finish_reason == stop_reason
-    assert "".join(chunks) == output
-
-
-@pytest.mark.parametrize(
-    # just test 1 lora hereafter
-    "model_name",
-    [MODEL_NAME, "zephyr-lora"],
-)
-async def test_batch_completions(server, client: openai.AsyncOpenAI,
-                                 model_name: str):
-    # test simple list
-    batch = await client.completions.create(
-        model=model_name,
-        prompt=["Hello, my name is", "Hello, my name is"],
-        max_tokens=5,
-        temperature=0.0,
-    )
-    assert len(batch.choices) == 2
-    assert batch.choices[0].text == batch.choices[1].text
-
-    # test n = 2
-    batch = await client.completions.create(
-        model=model_name,
-        prompt=["Hello, my name is", "Hello, my name is"],
-        n=2,
-        max_tokens=5,
-        temperature=0.0,
-        extra_body=dict(
-            # NOTE: this has to be true for n > 1 in Aphrodite, but not
-            # necessary for official client.
-            use_beam_search=True),
-    )
-    assert len(batch.choices) == 4
-    assert batch.choices[0].text != batch.choices[
-        1].text, "beam search should be different"
-    assert batch.choices[0].text == batch.choices[
-        2].text, "two copies of the same prompt should be the same"
-    assert batch.choices[1].text == batch.choices[
-        3].text, "two copies of the same prompt should be the same"
-
-    # test streaming
-    batch = await client.completions.create(
-        model=model_name,
-        prompt=["Hello, my name is", "Hello, my name is"],
-        max_tokens=5,
-        temperature=0.0,
-        stream=True,
-    )
-    texts = [""] * 2
-    async for chunk in batch:
-        assert len(chunk.choices) == 1
-        choice = chunk.choices[0]
-        texts[choice.index] += choice.text
-    assert texts[0] == texts[1]
-
-
-async def test_logits_bias(server, client: openai.AsyncOpenAI):
-    prompt = "Hello, my name is"
-    max_tokens = 5
-    tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
-
-    # Test exclusive selection
-    token_id = 1000
-    completion = await client.completions.create(
-        model=MODEL_NAME,
-        prompt=prompt,
-        max_tokens=max_tokens,
-        temperature=0.0,
-        logit_bias={str(token_id): 100},
-        seed=42,
-    )
-    assert completion.choices[0].text is not None and len(
-        completion.choices[0].text) >= 5
-    response_tokens = tokenizer(completion.choices[0].text,
-                                add_special_tokens=False)["input_ids"]
-    expected_tokens = tokenizer(tokenizer.decode([token_id] * 5),
-                                add_special_tokens=False)["input_ids"]
-    assert all([
-        response == expected
-        for response, expected in zip(response_tokens, expected_tokens)
-    ])
-
-    # Test ban
-    completion = await client.completions.create(
-        model=MODEL_NAME,
-        prompt=prompt,
-        max_tokens=max_tokens,
-        temperature=0.0,
-    )
-    response_tokens = tokenizer(completion.choices[0].text,
-                                add_special_tokens=False)["input_ids"]
-    first_response = completion.choices[0].text
-    completion = await client.completions.create(
-        model=MODEL_NAME,
-        prompt=prompt,
-        max_tokens=max_tokens,
-        temperature=0.0,
-        logit_bias={str(token): -100
-                    for token in response_tokens},
-    )
-    assert first_response != completion.choices[0].text
-
-
-async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
-    completion = await client.completions.create(
-        model=MODEL_NAME,
-        prompt=
-        "Give an example JSON for an employee profile that fits this schema:"
-        f" {TEST_SCHEMA}",
-        n=3,
-        temperature=1.0,
-        max_tokens=500,
-        extra_body=dict(guided_json=TEST_SCHEMA))
-
-    assert completion.id is not None
-    assert completion.choices is not None and len(completion.choices) == 3
-    for i in range(3):
-        assert completion.choices[i].text is not None
-        output_json = json.loads(completion.choices[i].text)
-        jsonschema.validate(instance=output_json, schema=TEST_SCHEMA)
-
-
-async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
-    messages = [{
-        "role": "system",
-        "content": "you are a helpful assistant"
-    }, {
-        "role": "user",
-        "content": "Give an example JSON for an employee profile that " + \
-                    f"fits this schema: {TEST_SCHEMA}"
-    }]
-    chat_completion = await client.chat.completions.create(
-        model=MODEL_NAME,
-        messages=messages,
-        max_tokens=500,
-        extra_body=dict(guided_json=TEST_SCHEMA))
-    message = chat_completion.choices[0].message
-    assert message.content is not None
-    json1 = json.loads(message.content)
-    jsonschema.validate(instance=json1, schema=TEST_SCHEMA)
-
-    messages.append({"role": "assistant", "content": message.content})
-    messages.append({
-        "role":
-        "user",
-        "content":
-        "Give me another one with a different name and age"
-    })
-    chat_completion = await client.chat.completions.create(
-        model=MODEL_NAME,
-        messages=messages,
-        max_tokens=500,
-        extra_body=dict(guided_json=TEST_SCHEMA))
-    message = chat_completion.choices[0].message
-    assert message.content is not None
-    json2 = json.loads(message.content)
-    jsonschema.validate(instance=json2, schema=TEST_SCHEMA)
-    assert json1["name"] != json2["name"]
-    assert json1["age"] != json2["age"]
-
-
-async def test_guided_regex_completion(server, client: openai.AsyncOpenAI):
-    completion = await client.completions.create(
-        model=MODEL_NAME,
-        prompt=f"Give an example IPv4 address with this regex: {TEST_REGEX}",
-        n=3,
-        temperature=1.0,
-        max_tokens=20,
-        extra_body=dict(guided_regex=TEST_REGEX))
-
-    assert completion.id is not None
-    assert completion.choices is not None and len(completion.choices) == 3
-    for i in range(3):
-        assert completion.choices[i].text is not None
-        assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None
-
-
-async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
-    messages = [{
-        "role": "system",
-        "content": "you are a helpful assistant"
-    }, {
-        "role":
-        "user",
-        "content":
-        f"Give an example IP address with this regex: {TEST_REGEX}"
-    }]
-    chat_completion = await client.chat.completions.create(
-        model=MODEL_NAME,
-        messages=messages,
-        max_tokens=20,
-        extra_body=dict(guided_regex=TEST_REGEX))
-    ip1 = chat_completion.choices[0].message.content
-    assert ip1 is not None
-    assert re.fullmatch(TEST_REGEX, ip1) is not None
-
-    messages.append({"role": "assistant", "content": ip1})
-    messages.append({"role": "user", "content": "Give me a different one"})
-    chat_completion = await client.chat.completions.create(
-        model=MODEL_NAME,
-        messages=messages,
-        max_tokens=20,
-        extra_body=dict(guided_regex=TEST_REGEX))
-    ip2 = chat_completion.choices[0].message.content
-    assert ip2 is not None
-    assert re.fullmatch(TEST_REGEX, ip2) is not None
-    assert ip1 != ip2
-
-
-async def test_guided_choice_completion(server, client: openai.AsyncOpenAI):
-    completion = await client.completions.create(
-        model=MODEL_NAME,
-        prompt="The best language for type-safe systems programming is ",
-        n=2,
-        temperature=1.0,
-        max_tokens=10,
-        extra_body=dict(guided_choice=TEST_CHOICE))
-
-    assert completion.id is not None
-    assert completion.choices is not None and len(completion.choices) == 2
-    for i in range(2):
-        assert completion.choices[i].text in TEST_CHOICE
-
-
-async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
-    messages = [{
-        "role": "system",
-        "content": "you are a helpful assistant"
-    }, {
-        "role":
-        "user",
-        "content":
-        "The best language for type-safe systems programming is "
-    }]
-    chat_completion = await client.chat.completions.create(
-        model=MODEL_NAME,
-        messages=messages,
-        max_tokens=10,
-        extra_body=dict(guided_choice=TEST_CHOICE))
-    choice1 = chat_completion.choices[0].message.content
-    assert choice1 in TEST_CHOICE
-
-    messages.append({"role": "assistant", "content": choice1})
-    messages.append({
-        "role": "user",
-        "content": "I disagree, pick another one"
-    })
-    chat_completion = await client.chat.completions.create(
-        model=MODEL_NAME,
-        messages=messages,
-        max_tokens=10,
-        extra_body=dict(guided_choice=TEST_CHOICE))
-    choice2 = chat_completion.choices[0].message.content
-    assert choice2 in TEST_CHOICE
-    assert choice1 != choice2
-
-
-async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI):
-    with pytest.raises(openai.BadRequestError):
-        _ = await client.completions.create(
-            model=MODEL_NAME,
-            prompt="Give an example JSON that fits this schema: 42",
-            extra_body=dict(guided_json=42))
-
-    messages = [{
-        "role": "system",
-        "content": "you are a helpful assistant"
-    }, {
-        "role":
-        "user",
-        "content":
-        "The best language for type-safe systems programming is "
-    }]
-    with pytest.raises(openai.BadRequestError):
-        _ = await client.chat.completions.create(model=MODEL_NAME,
-                                                 messages=messages,
-                                                 extra_body=dict(guided_regex={
-                                                     1: "Python",
-                                                     2: "C++"
-                                                 }))
-
-    with pytest.raises(openai.BadRequestError):
-        _ = await client.completions.create(
-            model=MODEL_NAME,
-            prompt="Give an example string that fits this regex",
-            extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA))
-
-
-async def test_embeddings(server, client: openai.AsyncOpenAI):
-    # model is ignored by the endpoint, but needed by the openai client.
-    model = "all-mpnet-base-v2"
-    text = "I'm the text to extract meaning from"
-    embedding_promise = await client.embeddings.create(input=[text],
-                                                       model=model)
-    response_data = embedding_promise.data[0]
-    embedding = response_data.embedding
-
-    assert isinstance(response_data, openai.types.Embedding)
-    assert isinstance(embedding, list)
-    assert (len(embedding) > 1)
-
-
-if __name__ == "__main__":
-    pytest.main([__file__])

+ 0 - 75
tests/endpoints/test_outlines.py

@@ -1,75 +0,0 @@
-# This unit test should be moved to a new
-# tests/test_guided_decoding directory.
-
-import torch
-from transformers import AutoTokenizer
-
-from aphrodite.modeling.outlines_logits_processors import (
-    JSONLogitsProcessor, RegexLogitsProcessor)
-
-TEST_SCHEMA = {
-    "type": "object",
-    "properties": {
-        "name": {
-            "type": "string"
-        },
-        "age": {
-            "type": "integer"
-        },
-        "skills": {
-            "type": "array",
-            "items": {
-                "type": "string",
-                "maxLength": 10
-            },
-            "minItems": 3
-        },
-        "work history": {
-            "type": "array",
-            "items": {
-                "type": "object",
-                "properties": {
-                    "company": {
-                        "type": "string"
-                    },
-                    "duration": {
-                        "type": "string"
-                    },
-                    "position": {
-                        "type": "string"
-                    }
-                },
-                "required": ["company", "position"]
-            }
-        }
-    },
-    "required": ["name", "age", "skills", "work history"]
-}
-
-TEST_REGEX = r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + \
-             r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)"
-
-
-def test_guided_logits_processors():
-    """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
-    tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
-    regex_LP = RegexLogitsProcessor(TEST_REGEX, tokenizer)
-    json_LP = JSONLogitsProcessor(TEST_SCHEMA, tokenizer)
-
-    regex_LP.init_state()
-    token_ids = tokenizer.encode(
-        f"Give an example IPv4 address with this regex: {TEST_REGEX}")
-    tensor = torch.rand(32000)
-    original_tensor = torch.clone(tensor)
-    regex_LP(token_ids, tensor)
-    assert tensor.shape == original_tensor.shape
-    assert not torch.allclose(tensor, original_tensor)
-
-    json_LP.init_state()
-    token_ids = tokenizer.encode(
-        f"Give an employee profile that fits this schema: {TEST_SCHEMA}")
-    tensor = torch.rand(32000)
-    original_tensor = torch.clone(tensor)
-    json_LP(token_ids, tensor)
-    assert tensor.shape == original_tensor.shape
-    assert not torch.allclose(tensor, original_tensor)

+ 0 - 0
tests/engine/__init__.py


+ 0 - 0
tests/engine/output_processor/__init__.py


+ 272 - 0
tests/engine/output_processor/test_multi_step.py

@@ -0,0 +1,272 @@
+import random
+from unittest.mock import MagicMock
+
+import pytest
+from transformers import PreTrainedTokenizer
+
+from aphrodite.common.sampling_params import SamplingParams
+from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
+                                       SequenceOutput, SequenceStatus)
+from aphrodite.common.utils import Counter
+from aphrodite.engine.output_processor.multi_step import (
+    MultiStepOutputProcessor)
+from aphrodite.engine.output_processor.stop_checker import StopChecker
+from aphrodite.processing.scheduler import Scheduler
+from aphrodite.transformers_utils.detokenizer import Detokenizer
+
+from ...core.utils import create_seq_group
+
+
+@pytest.mark.parametrize("seq_output_len", [128])
+@pytest.mark.parametrize("num_new_tokens", [1, 12])
+@pytest.mark.skip_global_cleanup
+def test_appends_token_ids(num_new_tokens: int, seq_output_len: int):
+    """Verify multi-step decoding appends token ids correctly.
+
+    We append token ids and verify all the token ids were appended correctly.
+    Note that ignore_eos=True.
+    """
+    detokenizer = MagicMock(spec=Detokenizer)
+    scheduler = MagicMock(spec=Scheduler)
+    stop_checker = MagicMock(spec=StopChecker)
+    seq_counter = Counter()
+
+    output_processor = MultiStepOutputProcessor(
+        detokenizer=detokenizer,
+        scheduler=[scheduler],
+        seq_counter=seq_counter,
+        get_tokenizer_for_seq=lambda _: mock_tokenizer(),
+        stop_checker=stop_checker,
+    )
+
+    seq_group = create_seq_group(
+        seq_prompt_len=1024,
+        seq_output_lens=[seq_output_len],
+        sampling_params=SamplingParams(max_tokens=seq_output_len +
+                                       num_new_tokens,
+                                       ignore_eos=True),
+    )
+
+    seq = seq_group.get_seqs()[0]
+    seq.status = SequenceStatus.RUNNING
+
+    new_token_ids = list(range(num_new_tokens))
+
+    outputs = [
+        CompletionSequenceGroupOutput(
+            samples=[
+                SequenceOutput(
+                    parent_seq_id=seq.seq_id,
+                    output_token=output_token,
+                    logprobs={output_token: Logprob(0.0)},
+                )
+            ],
+            prompt_logprobs=None,
+        ) for output_token in new_token_ids
+    ]
+
+    assert seq.get_token_ids()[-len(new_token_ids):] != new_token_ids
+    output_processor.process_outputs(seq_group, outputs)
+    assert seq.get_token_ids()[-len(new_token_ids):] == new_token_ids
+
+
+@pytest.mark.parametrize("seq_prompt_len", [1024])
+@pytest.mark.parametrize("seq_output_len", [128])
+@pytest.mark.parametrize("num_new_tokens", [5, 6, 7, 8])
+@pytest.mark.parametrize("max_tokens", [128 + 3])
+@pytest.mark.skip_global_cleanup
+def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int,
+                             seq_output_len: int, max_tokens: int):
+    """Verify tokens after max_tokens are dropped and not appended to the
+    sequence.
+    """
+    detokenizer = MagicMock(spec=Detokenizer)
+    scheduler = MagicMock(spec=Scheduler)
+    stop_checker = MagicMock(spec=StopChecker)
+    seq_counter = Counter()
+
+    output_processor = MultiStepOutputProcessor(
+        detokenizer=detokenizer,
+        scheduler=[scheduler],
+        seq_counter=seq_counter,
+        get_tokenizer_for_seq=lambda _: mock_tokenizer(),
+        stop_checker=stop_checker,
+    )
+
+    seq_group = create_seq_group(
+        seq_prompt_len=seq_prompt_len,
+        seq_output_lens=[seq_output_len],
+        sampling_params=SamplingParams(max_tokens=max_tokens, ),
+    )
+
+    seq = seq_group.get_seqs()[0]
+    seq.status = SequenceStatus.RUNNING
+
+    new_token_ids = list(range(num_new_tokens))
+
+    outputs = [
+        CompletionSequenceGroupOutput(
+            samples=[
+                SequenceOutput(
+                    parent_seq_id=seq.seq_id,
+                    output_token=output_token,
+                    logprobs={output_token: Logprob(0.0)},
+                )
+            ],
+            prompt_logprobs=None,
+        ) for output_token in new_token_ids
+    ]
+
+    assert seq.get_len() == seq_prompt_len + seq_output_len
+    output_processor.process_outputs(seq_group, outputs)
+
+    # Expect the processed sequence to not go over max tokens in len.
+    assert seq.get_len() == seq_prompt_len + max_tokens
+
+    # Expect the correct tokens were appended.
+    expected_appended_tokens = new_token_ids[:max_tokens - seq_output_len]
+    assert seq.get_token_ids(
+    )[-len(expected_appended_tokens):] == expected_appended_tokens
+
+
+@pytest.mark.parametrize("seq_prompt_len", [1024])
+@pytest.mark.parametrize("seq_output_len", [128])
+@pytest.mark.parametrize("num_new_tokens", [12])
+@pytest.mark.parametrize("seed", list(range(6)))
+@pytest.mark.skip_global_cleanup
+def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
+                               seq_output_len: int, seed: int):
+    """Verify the eos token id is included in the sequence, but subsequent
+    tokens are dropped (not appended to sequence).
+    """
+    random.seed(seed)
+    detokenizer = MagicMock(spec=Detokenizer)
+    scheduler = MagicMock(spec=Scheduler)
+    stop_checker = MagicMock(spec=StopChecker)
+    seq_counter = Counter()
+
+    eos_token_id = 100
+
+    output_processor = MultiStepOutputProcessor(
+        detokenizer=detokenizer,
+        scheduler=[scheduler],
+        seq_counter=seq_counter,
+        get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
+        stop_checker=stop_checker,
+    )
+
+    seq_group = create_seq_group(
+        seq_prompt_len=seq_prompt_len,
+        seq_output_lens=[seq_output_len],
+        sampling_params=SamplingParams(
+            # Ensure enough space.
+            max_tokens=seq_output_len + num_new_tokens, ),
+    )
+
+    seq = seq_group.get_seqs()[0]
+    seq.status = SequenceStatus.RUNNING
+
+    new_token_ids = list(range(num_new_tokens))
+    assert eos_token_id not in new_token_ids
+    eos_index = random.randint(0, len(new_token_ids) - 1)
+    new_token_ids[eos_index] = eos_token_id
+
+    outputs = [
+        CompletionSequenceGroupOutput(
+            samples=[
+                SequenceOutput(
+                    parent_seq_id=seq.seq_id,
+                    output_token=output_token,
+                    logprobs={output_token: Logprob(0.0)},
+                )
+            ],
+            prompt_logprobs=None,
+        ) for output_token in new_token_ids
+    ]
+
+    assert seq.get_len() == seq_prompt_len + seq_output_len
+    output_processor.process_outputs(seq_group, outputs)
+
+    # Expect the processed sequence to not go beyond provided eos.
+    assert seq.get_len() == seq_prompt_len + seq_output_len + (eos_index + 1)
+
+    # Expect the correct tokens were appended.
+    expected_appended_tokens = new_token_ids[:eos_index + 1]
+    assert seq.get_token_ids(
+    )[-len(expected_appended_tokens):] == expected_appended_tokens
+
+
+@pytest.mark.parametrize("seq_prompt_len", [1024])
+@pytest.mark.parametrize("seq_output_len", [128])
+@pytest.mark.parametrize("num_new_tokens", [12])
+@pytest.mark.parametrize("seed", list(range(6)))
+@pytest.mark.skip_global_cleanup
+def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
+                              seq_output_len: int, seed: int):
+    """When sampling parameters dictate that we should ignore the eos token id,
+    ensure all token ids are appended even if the eos token id is emitted.
+    """
+    random.seed(seed)
+    detokenizer = MagicMock(spec=Detokenizer)
+    scheduler = MagicMock(spec=Scheduler)
+    stop_checker = MagicMock(spec=StopChecker)
+    seq_counter = Counter()
+
+    eos_token_id = 100
+
+    output_processor = MultiStepOutputProcessor(
+        detokenizer=detokenizer,
+        scheduler=[scheduler],
+        seq_counter=seq_counter,
+        get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
+        stop_checker=stop_checker,
+    )
+
+    seq_group = create_seq_group(
+        seq_prompt_len=seq_prompt_len,
+        seq_output_lens=[seq_output_len],
+        sampling_params=SamplingParams(
+            # Ensure enough space.
+            max_tokens=seq_output_len + num_new_tokens,
+            ignore_eos=True,
+        ),
+    )
+
+    seq = seq_group.get_seqs()[0]
+    seq.status = SequenceStatus.RUNNING
+
+    new_token_ids = list(range(num_new_tokens))
+    assert eos_token_id not in new_token_ids
+    eos_index = random.randint(0, len(new_token_ids) - 1)
+    new_token_ids[eos_index] = eos_token_id
+
+    outputs = [
+        CompletionSequenceGroupOutput(
+            samples=[
+                SequenceOutput(
+                    parent_seq_id=seq.seq_id,
+                    output_token=output_token,
+                    logprobs={output_token: Logprob(0.0)},
+                )
+            ],
+            prompt_logprobs=None,
+        ) for output_token in new_token_ids
+    ]
+
+    assert seq.get_len() == seq_prompt_len + seq_output_len
+    output_processor.process_outputs(seq_group, outputs)
+
+    # Expect the processed sequence to go beyond eos.
+    assert seq.get_len() == seq_prompt_len + seq_output_len + num_new_tokens
+
+    # Expect the correct tokens were appended.
+    expected_appended_tokens = new_token_ids[:seq_output_len + num_new_tokens -
+                                             seq_output_len]
+    assert seq.get_token_ids(
+    )[-len(expected_appended_tokens):] == expected_appended_tokens
+
+
+def mock_tokenizer(eos_token_id=1000):
+    tokenizer = MagicMock(spec=PreTrainedTokenizer)
+    tokenizer.eos_token_id = eos_token_id
+    return tokenizer

+ 85 - 0
tests/engine/output_processor/test_stop_checker.py

@@ -0,0 +1,85 @@
+from unittest.mock import MagicMock
+
+import pytest
+from transformers import PreTrainedTokenizer
+
+from aphrodite.common.sampling_params import SamplingParams
+from aphrodite.common.sequence import Logprob, Sequence, SequenceStatus
+from aphrodite.engine.output_processor.stop_checker import StopChecker
+
+
+def sequence_with_eos(text: str, eos_token: str,
+                      eos_token_id: int) -> Sequence:
+    """
+    Create a Sequence that ends with an EOS token.
+    """
+    seq = Sequence(
+        seq_id=0,
+        inputs={"prompt_token_ids": []},
+        block_size=16,
+        eos_token_id=eos_token_id,
+    )
+    seq.output_text = text + eos_token
+
+    offset = eos_token_id + 1
+    for i in range(offset, len(text) + offset):
+        seq.append_token_id(token_id=i, logprobs={i: Logprob(0.0)})
+    seq.append_token_id(token_id=eos_token_id,
+                        logprobs={eos_token_id: Logprob(0.0)})
+
+    seq.status = SequenceStatus.RUNNING
+
+    return seq
+
+
+@pytest.mark.parametrize(["text_wo_eos", "eos_token", "eos_token_id"], [
+    ("This text ends with EOS token", "</s>", 2),
+])
+@pytest.mark.parametrize("ignore_eos", [True, False])
+@pytest.mark.parametrize("include_stop_str_in_output", [True, False])
+@pytest.mark.skip_global_cleanup
+def test_stop_on_eos_token(text_wo_eos: str, eos_token: str, eos_token_id: int,
+                           ignore_eos: bool, include_stop_str_in_output: bool):
+    """
+    Test the behavior of the StopChecker's maybe_stop_sequence method
+    when an EOS token is encountered.
+
+    This test covers:
+    - When the EOS token should stop the sequence and be removed from the output
+    - When the EOS token should stop the sequence and be included in the output
+    - When the EOS token should be ignored, and the sequence continues
+    """
+
+    tokenizer = MagicMock(spec=PreTrainedTokenizer)
+    get_tokenizer_for_seq = MagicMock(return_value=tokenizer)
+    stop_checker = StopChecker(max_model_len=1024,
+                               get_tokenizer_for_seq=get_tokenizer_for_seq)
+
+    seq = sequence_with_eos(
+        text=text_wo_eos,
+        eos_token=eos_token,
+        eos_token_id=eos_token_id,
+    )
+    new_char_count = len(eos_token)
+
+    # Note that `stop` and `stop_token_ids` are not specified
+    sampling_params = SamplingParams(
+        min_tokens=1,
+        ignore_eos=ignore_eos,
+        include_stop_str_in_output=include_stop_str_in_output)
+
+    stop_checker.maybe_stop_sequence(
+        seq=seq,
+        new_char_count=new_char_count,
+        sampling_params=sampling_params,
+    )
+
+    if ignore_eos:
+        assert seq.status == SequenceStatus.RUNNING
+        assert seq.output_text == text_wo_eos + eos_token
+    elif include_stop_str_in_output:
+        assert seq.status == SequenceStatus.FINISHED_STOPPED
+        assert seq.output_text == text_wo_eos + eos_token
+    else:
+        assert seq.status == SequenceStatus.FINISHED_STOPPED
+        assert seq.output_text == text_wo_eos

+ 24 - 0
tests/engine/test_args.py

@@ -0,0 +1,24 @@
+import pytest
+
+from aphrodite.common.utils import FlexibleArgumentParser
+from aphrodite.engine.args_tools import EngineArgs
+
+
+@pytest.mark.parametrize(("arg", "expected"), [
+    (None, None),
+    ("image=16", {
+        "image": 16
+    }),
+    ("image=16,video=2", {
+        "image": 16,
+        "video": 2
+    }),
+])
+def test_limit_mm_per_prompt_parser(arg, expected):
+    parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
+    if arg is None:
+        args = parser.parse_args([])
+    else:
+        args = parser.parse_args(["--limit-mm-per-prompt", arg])
+
+    assert args.limit_mm_per_prompt == expected

+ 34 - 0
tests/engine/test_computed_prefix_block.py

@@ -0,0 +1,34 @@
+import pytest
+
+from aphrodite.common.sampling_params import SamplingParams
+from aphrodite.engine.aphrodite_engine import AphroditeEngine
+from aphrodite.engine.args_tools import EngineArgs
+
+
+@pytest.mark.parametrize("model", ["facebook/opt-125m"])
+@pytest.mark.parametrize("block_size", [16])
+def test_computed_prefix_blocks(model: str, block_size: int):
+    # This test checks if we are able to run the engine to completion
+    # without triggering asserts.
+    # We are in a scenario where all blocks from the second request's prompt
+    # are full and already computed when the second request arrives.
+    prompt = (
+        "You are a helpful assistant. How do I build a car from cardboard and "
+        "paper clips? Is there an easy to follow video tutorial available "
+        "online for free?")
+    prompt2 = (
+        " Please recommend to me some resources where I can learn not only to "
+        "handle technical difficulties of building a car, but also "
+        "decoration.")
+
+    engine_args = EngineArgs(model=model,
+                             block_size=block_size,
+                             enable_prefix_caching=True)
+
+    engine = AphroditeEngine.from_engine_args(engine_args)
+    sampling_params = SamplingParams()
+
+    engine.add_request("0", prompt + prompt2, sampling_params)
+    engine.step()
+    engine.add_request("1", prompt, sampling_params)
+    engine.step()

+ 34 - 0
tests/engine/test_computed_prefix_blocks.py

@@ -0,0 +1,34 @@
+import pytest
+
+from aphrodite.common.sampling_params import SamplingParams
+from aphrodite.engine.aphrodite_engine import AphroditeEngine
+from aphrodite.engine.args_tools import EngineArgs
+
+
+@pytest.mark.parametrize("model", ["facebook/opt-125m"])
+@pytest.mark.parametrize("block_size", [16])
+def test_computed_prefix_blocks(model: str, block_size: int):
+    # This test checks if we are able to run the engine to completion
+    # without triggering asserts.
+    # We are in a scenario where all blocks from the second request's prompt
+    # are full and already computed when the second request arrives.
+    prompt = (
+        "You are a helpful assistant. How do I build a car from cardboard and "
+        "paper clips? Is there an easy to follow video tutorial available "
+        "online for free?")
+    prompt2 = (
+        " Please recommend to me some resources where I can learn not only to "
+        "handle technical difficulties of building a car, but also "
+        "decoration.")
+
+    engine_args = EngineArgs(model=model,
+                             block_size=block_size,
+                             enable_prefix_caching=True)
+
+    engine = AphroditeEngine.from_engine_args(engine_args)
+    sampling_params = SamplingParams()
+
+    engine.add_request("0", prompt + prompt2, sampling_params)
+    engine.step()
+    engine.add_request("1", prompt, sampling_params)
+    engine.step()

+ 90 - 0
tests/engine/test_custom_executor.py

@@ -0,0 +1,90 @@
+import asyncio
+import os
+
+import pytest
+
+from aphrodite.common.sampling_params import SamplingParams
+from aphrodite.engine.args_tools import AsyncEngineArgs, EngineArgs
+from aphrodite.engine.async_aphrodite import AphroditeEngine, AsyncAphrodite
+from aphrodite.executor.gpu_executor import GPUExecutor, GPUExecutorAsync
+
+
+class Mock:
+    ...
+
+
+class CustomGPUExecutor(GPUExecutor):
+
+    def execute_model(self, *args, **kwargs):
+        # Drop marker to show that this was ran
+        with open(".marker", "w"):
+            ...
+        return super().execute_model(*args, **kwargs)
+
+
+class CustomGPUExecutorAsync(GPUExecutorAsync):
+
+    async def execute_model_async(self, *args, **kwargs):
+        with open(".marker", "w"):
+            ...
+        return await super().execute_model_async(*args, **kwargs)
+
+
+@pytest.mark.parametrize("model", ["facebook/opt-125m"])
+def test_custom_executor_type_checking(model):
+    with pytest.raises(ValueError):
+        engine_args = EngineArgs(model=model,
+                                 distributed_executor_backend=Mock)
+        AphroditeEngine.from_engine_args(engine_args)
+    with pytest.raises(ValueError):
+        engine_args = AsyncEngineArgs(model=model,
+                                      distributed_executor_backend=Mock)
+        AsyncAphrodite.from_engine_args(engine_args)
+    with pytest.raises(TypeError):
+        engine_args = AsyncEngineArgs(
+            model=model, distributed_executor_backend=CustomGPUExecutor)
+        AsyncAphrodite.from_engine_args(engine_args)
+
+
+@pytest.mark.parametrize("model", ["facebook/opt-125m"])
+def test_custom_executor(model, tmpdir):
+    cwd = os.path.abspath(".")
+    os.chdir(tmpdir)
+    try:
+        assert not os.path.exists(".marker")
+
+        engine_args = EngineArgs(
+            model=model, distributed_executor_backend=CustomGPUExecutor)
+        engine = AphroditeEngine.from_engine_args(engine_args)
+        sampling_params = SamplingParams(max_tokens=1)
+
+        engine.add_request("0", "foo", sampling_params)
+        engine.step()
+
+        assert os.path.exists(".marker")
+    finally:
+        os.chdir(cwd)
+
+
+@pytest.mark.parametrize("model", ["facebook/opt-125m"])
+def test_custom_executor_async(model, tmpdir):
+    cwd = os.path.abspath(".")
+    os.chdir(tmpdir)
+    try:
+        assert not os.path.exists(".marker")
+
+        engine_args = AsyncEngineArgs(
+            model=model, distributed_executor_backend=CustomGPUExecutorAsync)
+        engine = AsyncAphrodite.from_engine_args(engine_args)
+        sampling_params = SamplingParams(max_tokens=1)
+
+        async def t():
+            stream = await engine.add_request("0", "foo", sampling_params)
+            async for x in stream:
+                ...
+
+        asyncio.run(t())
+
+        assert os.path.exists(".marker")
+    finally:
+        os.chdir(cwd)

+ 32 - 0
tests/engine/test_detokenization.py

@@ -0,0 +1,32 @@
+import pytest
+
+from aphrodite.common.sampling_params import SamplingParams
+from aphrodite.endpoints.llm import LLM
+
+
+@pytest.mark.parametrize("model", ["facebook/opt-125m"])
+def test_computed_prefix_blocks(model: str):
+    # This test checks if the engine generates completions both with and
+    # without optional detokenization, that detokenization includes text
+    # and no-detokenization doesn't, and that both completions have the same
+    # token_ids.
+    prompt = (
+        "You are a helpful assistant. How do I build a car from cardboard and "
+        "paper clips? Is there an easy to follow video tutorial available "
+        "online for free?")
+
+    llm = LLM(model=model)
+    sampling_params = SamplingParams(max_tokens=10,
+                                     temperature=0.0,
+                                     detokenize=False)
+
+    outputs_no_detokenization = llm.generate(prompt,
+                                             sampling_params)[0].outputs[0]
+    sampling_params.detokenize = True
+    outputs_with_detokenization = llm.generate(prompt,
+                                               sampling_params)[0].outputs[0]
+
+    assert outputs_no_detokenization.text == ''
+    assert outputs_with_detokenization.text != ''
+    assert outputs_no_detokenization.token_ids == \
+        outputs_with_detokenization.token_ids

+ 0 - 55
tests/engine/test_detokenize.py

@@ -1,55 +0,0 @@
-import pytest
-from transformers import AutoTokenizer
-
-from aphrodite.transformers_utils.tokenizer import detokenize_incrementally
-
-TRUTH = [
-    "Tell me your favorite story.",
-    "Transformers have revolutionized almost all natural language processing (NLP) tasks but suffer from memory and computational complexity that scales quadratically with sequence length. In contrast, recurrent neural networks (RNNs) exhibit linear scaling in memory and computational requirements but struggle to match the same performance as Transformers due to limitations in parallelization and scalability. We propose a novel model architecture, Receptance Weighted Key Value (RWKV), that combines the efficient parallelizable training of Transformers with the efficient inference of RNNs. Our approach leverages a linear attention mechanism and allows us to formulate the model as either a Transformer or an RNN, which parallelizes computations during training and maintains constant computational and memory complexity during inference, leading to the first non-transformer architecture to be scaled to tens of billions of parameters. Our experiments reveal that RWKV performs on par with similarly sized Transformers, suggesting that future work can leverage this architecture to create more efficient models. This work presents a significant step towards reconciling the trade-offs between computational efficiency and model performance in sequence processing tasks."  # noqa: E501
-    "トランスフォーマーは、ほぼすべての自然言語処理に革命をもたらしました",
-]
-
-TOKENIZERS = [
-    "EleutherAI/gpt-j-6b",
-    "EleutherAI/pythia-70m-deduped",
-    "meta-llama/llama-2-7b-hf",
-    "/mistralai/Mistral-7B-v0.1",
-]
-
-
-def _run_incremental_decode(tokenizer, all_input_ids,
-                            skip_special_tokens: bool):
-    decoded_text = ""
-    offset = 0
-    token_offset = 0
-    prev_tokens = None
-    for i in range(len(all_input_ids)):
-        new_tokens, text, offset, token_offset = detokenize_incrementally(
-            tokenizer,
-            all_input_ids[:i + 1],
-            prev_tokens,
-            offset,
-            token_offset,
-            skip_special_tokens=skip_special_tokens)
-        decoded_text += text
-        if prev_tokens is None:
-            prev_tokens = new_tokens
-        else:
-            prev_tokens += new_tokens
-    return decoded_text
-
-
-@pytest.mark.parametrize("truth", TRUTH)
-@pytest.mark.parametrize("tokenizer_id", TOKENIZERS)
-@pytest.mark.parametrize("skip_special_tokens", (True, False))
-def test_decode_streaming(tokenizer_id, truth, skip_special_tokens):
-    tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
-    all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"]
-    if skip_special_tokens:
-        all_input_ids = (
-            [tokenizer_id.bos_token_id] if tokenizer.bos_token_id is not None
-            else []) + all_input_ids + [tokenizer.eos_token_id]  # type: ignore
-    decoded_text = _run_incremental_decode(
-        tokenizer, all_input_ids, skip_special_tokens=skip_special_tokens)
-
-    assert decoded_text == truth

+ 177 - 0
tests/engine/test_multiproc_workers.py

@@ -0,0 +1,177 @@
+import asyncio
+from concurrent.futures import ThreadPoolExecutor
+from functools import partial
+from time import sleep
+from typing import Any, List, Tuple
+
+import pytest
+
+from aphrodite.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
+                                                       ResultHandler,
+                                                       WorkerMonitor)
+
+
+class DummyWorker:
+    """Dummy version of aphrodite.task_handler.worker.Worker"""
+
+    def __init__(self, rank: int):
+        self.rank = rank
+
+    def worker_method(self, worker_input: Any) -> Tuple[int, Any]:
+        sleep(0.05)
+
+        if isinstance(worker_input, Exception):
+            # simulate error case
+            raise worker_input
+
+        return self.rank, input
+
+
+def _start_workers() -> Tuple[List[ProcessWorkerWrapper], WorkerMonitor]:
+    result_handler = ResultHandler()
+    workers = [
+        ProcessWorkerWrapper(result_handler, partial(DummyWorker, rank=rank))
+        for rank in range(8)
+    ]
+
+    worker_monitor = WorkerMonitor(workers, result_handler)
+    assert not worker_monitor.is_alive()
+
+    result_handler.start()
+    worker_monitor.start()
+    assert worker_monitor.is_alive()
+
+    return workers, worker_monitor
+
+
+def test_local_workers() -> None:
+    """Test workers with sync task submission"""
+
+    workers, worker_monitor = _start_workers()
+
+    def execute_workers(worker_input: str) -> None:
+        worker_outputs = [
+            worker.execute_method("worker_method", worker_input)
+            for worker in workers
+        ]
+
+        for rank, output in enumerate(worker_outputs):
+            assert output.get() == (rank, input)
+
+    executor = ThreadPoolExecutor(max_workers=4)
+
+    # Test concurrent submission from different threads
+    futures = [
+        executor.submit(partial(execute_workers, f"thread {thread_num}"))
+        for thread_num in range(4)
+    ]
+
+    for future in futures:
+        future.result()
+
+    # Test error case
+    exception = ValueError("fake error")
+    result = workers[0].execute_method("worker_method", exception)
+    try:
+        result.get()
+        pytest.fail("task should have failed")
+    except Exception as e:
+        assert isinstance(e, ValueError)
+        assert str(e) == "fake error"
+
+    # Test cleanup when a worker fails
+    assert worker_monitor.is_alive()
+    workers[3].process.kill()
+
+    # Other workers should get shut down here
+    worker_monitor.join(2)
+
+    # Ensure everything is stopped
+    assert not worker_monitor.is_alive()
+    assert all(not worker.process.is_alive() for worker in workers)
+
+    # Further attempts to submit tasks should fail
+    try:
+        _result = workers[0].execute_method("worker_method", "test")
+        pytest.fail("task should fail once workers have been shut down")
+    except Exception as e:
+        assert isinstance(e, ChildProcessError)
+
+
+def test_local_workers_clean_shutdown() -> None:
+    """Test clean shutdown"""
+
+    workers, worker_monitor = _start_workers()
+
+    assert worker_monitor.is_alive()
+    assert all(worker.process.is_alive() for worker in workers)
+
+    # Clean shutdown
+    worker_monitor.close()
+
+    worker_monitor.join(5)
+
+    # Ensure everything is stopped
+    assert not worker_monitor.is_alive()
+    assert all(not worker.process.is_alive() for worker in workers)
+
+    # Further attempts to submit tasks should fail
+    try:
+        _result = workers[0].execute_method("worker_method", "test")
+        pytest.fail("task should fail once workers have been shut down")
+    except Exception as e:
+        assert isinstance(e, ChildProcessError)
+
+
+@pytest.mark.asyncio
+async def test_local_workers_async() -> None:
+    """Test local workers with async task submission"""
+
+    workers, worker_monitor = _start_workers()
+
+    async def execute_workers(worker_input: str) -> None:
+        worker_coros = [
+            worker.execute_method_async("worker_method", worker_input)
+            for worker in workers
+        ]
+
+        results = await asyncio.gather(*worker_coros)
+        for rank, result in enumerate(results):
+            assert result == (rank, input)
+
+    tasks = [
+        asyncio.create_task(execute_workers(f"task {task_num}"))
+        for task_num in range(4)
+    ]
+
+    for task in tasks:
+        await task
+
+    # Test error case
+    exception = ValueError("fake error")
+    try:
+        _result = await workers[0].execute_method_async(
+            "worker_method", exception)
+        pytest.fail("task should have failed")
+    except Exception as e:
+        assert isinstance(e, ValueError)
+        assert str(e) == "fake error"
+
+    # Test cleanup when a worker fails
+    assert worker_monitor.is_alive()
+    workers[3].process.kill()
+
+    # Other workers should get shut down here
+    worker_monitor.join(2)
+
+    # Ensure everything is stopped
+    assert not worker_monitor.is_alive()
+    assert all(not worker.process.is_alive() for worker in workers)
+
+    # Further attempts to submit tasks should fail
+    try:
+        _result = await workers[0].execute_method_async(
+            "worker_method", "test")
+        pytest.fail("task should fail once workers have been shut down")
+    except Exception as e:
+        assert isinstance(e, ChildProcessError)

+ 23 - 0
tests/engine/test_skip_tokenizer_init.py

@@ -0,0 +1,23 @@
+import pytest
+
+from aphrodite.common.sampling_params import SamplingParams
+from aphrodite.endpoints.llm import LLM
+
+
+@pytest.mark.parametrize("model", ["facebook/opt-125m"])
+def test_skip_tokenizer_initialization(model: str):
+    # This test checks if the flag skip_tokenizer_init skips the initialization
+    # of tokenizer and detokenizer. The generated output is expected to contain
+    # token ids.
+    llm = LLM(model=model, skip_tokenizer_init=True)
+    sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True)
+    with pytest.raises(ValueError) as err:
+        llm.generate("abc", sampling_params)
+    assert "prompts must be None if" in str(err.value)
+    outputs = llm.generate({"prompt_token_ids": [1, 2, 3]},
+                           sampling_params=sampling_params)
+    assert len(outputs) > 0
+    completions = outputs[0].outputs
+    assert len(completions) > 0
+    assert completions[0].text == ""
+    assert completions[0].token_ids

+ 62 - 0
tests/engine/test_stop_reason.py

@@ -0,0 +1,62 @@
+"""Test the different finish_reason="stop" situations during generation:
+    1. One of the provided stop strings
+    2. One of the provided stop tokens
+    3. The EOS token
+
+Run `pytest tests/engine/test_stop_reason.py`.
+"""
+
+import pytest
+import transformers
+
+from aphrodite import SamplingParams
+
+MODEL = "facebook/opt-350m"
+STOP_STR = "."
+SEED = 42
+MAX_TOKENS = 1024
+
+
+@pytest.fixture
+def aphrodite_model(aphrodite_runner):
+    with aphrodite_runner(MODEL) as aphrodite_model:
+        yield aphrodite_model
+
+
+def test_stop_reason(aphrodite_model, example_prompts):
+    tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL)
+    stop_token_id = tokenizer.convert_tokens_to_ids(STOP_STR)
+    llm = aphrodite_model.model
+
+    # test stop token
+    outputs = llm.generate(example_prompts,
+                           sampling_params=SamplingParams(
+                               ignore_eos=True,
+                               seed=SEED,
+                               max_tokens=MAX_TOKENS,
+                               stop_token_ids=[stop_token_id]))
+    for output in outputs:
+        output = output.outputs[0]
+        assert output.finish_reason == "stop"
+        assert output.stop_reason == stop_token_id
+
+    # test stop string
+    outputs = llm.generate(example_prompts,
+                           sampling_params=SamplingParams(
+                               ignore_eos=True,
+                               seed=SEED,
+                               max_tokens=MAX_TOKENS,
+                               stop="."))
+    for output in outputs:
+        output = output.outputs[0]
+        assert output.finish_reason == "stop"
+        assert output.stop_reason == STOP_STR
+
+    # test EOS token
+    outputs = llm.generate(example_prompts,
+                           sampling_params=SamplingParams(
+                               seed=SEED, max_tokens=MAX_TOKENS))
+    for output in outputs:
+        output = output.outputs[0]
+        assert output.finish_reason == "length" or (
+            output.finish_reason == "stop" and output.stop_reason is None)

+ 112 - 0
tests/engine/test_stop_string.py

@@ -0,0 +1,112 @@
+from typing import Any, List, Optional
+
+import pytest
+
+from aphrodite import AphroditeEngine, CompletionOutput, SamplingParams
+
+MODEL = "meta-llama/llama-2-7b-hf"
+MAX_TOKENS = 200
+
+
+@pytest.fixture(scope="session")
+def aphrodite_model(aphrodite_runner):
+    with aphrodite_runner(MODEL) as aphrodite_model:
+        yield aphrodite_model
+
+
+@pytest.mark.skip_global_cleanup
+def test_stop_basic(aphrodite_model):
+    _test_stopping(aphrodite_model.model.llm_engine,
+                   stop=["."],
+                   include_in_output=False,
+                   expected_output="VLLM is a 100% volunteer organization",
+                   expected_reason=".")
+
+    _test_stopping(aphrodite_model.model.llm_engine,
+                   stop=["."],
+                   include_in_output=True,
+                   expected_output="VLLM is a 100% volunteer organization.",
+                   expected_reason=".")
+
+
+@pytest.mark.skip_global_cleanup
+def test_stop_multi_tokens(aphrodite_model):
+    _test_stopping(
+        aphrodite_model.model.llm_engine,
+        stop=["group of peo", "short"],
+        include_in_output=False,
+        expected_output="VLLM is a 100% volunteer organization. We are a ",
+        expected_reason="group of peo")
+
+    _test_stopping(
+        aphrodite_model.model.llm_engine,
+        stop=["group of peo", "short"],
+        include_in_output=True,
+        expected_output=
+        "VLLM is a 100% volunteer organization. We are a group of peo",
+        expected_reason="group of peo")
+
+
+@pytest.mark.skip_global_cleanup
+def test_stop_partial_token(aphrodite_model):
+    _test_stopping(aphrodite_model.model.llm_engine,
+                   stop=["gani"],
+                   include_in_output=False,
+                   expected_output="VLLM is a 100% volunteer or",
+                   expected_reason="gani")
+
+    _test_stopping(aphrodite_model.model.llm_engine,
+                   stop=["gani"],
+                   include_in_output=True,
+                   expected_output="VLLM is a 100% volunteer organi",
+                   expected_reason="gani")
+
+
+@pytest.mark.skip_global_cleanup
+def test_stop_token_id(aphrodite_model):
+    # token id 13013 => " organization"
+
+    _test_stopping(aphrodite_model.model.llm_engine,
+                   stop_token_ids=[13013],
+                   include_in_output=False,
+                   expected_output="VLLM is a 100% volunteer",
+                   expected_reason=13013)
+
+    _test_stopping(aphrodite_model.model.llm_engine,
+                   stop_token_ids=[13013],
+                   include_in_output=True,
+                   expected_output="VLLM is a 100% volunteer organization",
+                   expected_reason=13013)
+
+
+def _test_stopping(llm_engine: AphroditeEngine,
+                   expected_output: str,
+                   expected_reason: Any,
+                   stop: Optional[List[str]] = None,
+                   stop_token_ids: Optional[List[int]] = None,
+                   include_in_output: bool = False) -> None:
+    llm_engine.add_request(
+        "id", "A story about vLLM:\n",
+        SamplingParams(
+            temperature=0.0,
+            max_tokens=MAX_TOKENS,
+            stop=stop,
+            stop_token_ids=stop_token_ids,
+            include_stop_str_in_output=include_in_output,
+        ), None)
+
+    output: Optional[CompletionOutput] = None
+    output_text = ""
+    stop_reason = None
+    while llm_engine.has_unfinished_requests():
+        (request_output, ) = llm_engine.step()
+        (output, ) = request_output.outputs
+
+        # Ensure we don't backtrack
+        assert output.text.startswith(output_text)
+        output_text = output.text
+        stop_reason = output.stop_reason
+
+    assert output is not None
+    assert output_text == expected_output
+    assert stop_reason == expected_reason

+ 0 - 0
tests/kernels/__init__.py


+ 18 - 0
tests/kernels/allclose_default.py

@@ -0,0 +1,18 @@
+import torch
+
+# Reference default values of atol and rtol are from
+# https://github.com/pytorch/pytorch/blob/6d96beb6bec24d73ee3f080bac54d2104068f675/test/test_transformers.py#L67
+default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float: 1e-5}
+default_rtol = {
+    torch.float16: 1e-3,
+    torch.bfloat16: 1.6e-2,
+    torch.float: 1.3e-6
+}
+
+
+def get_default_atol(output) -> float:
+    return default_atol[output.dtype]
+
+
+def get_default_rtol(output) -> float:
+    return default_rtol[output.dtype]

+ 7 - 37
tests/kernels/conftest.py

@@ -1,44 +1,14 @@
-from typing import List, Tuple
-
 import pytest
-import torch
-
 
-def create_kv_caches(
-    num_blocks: int,
-    block_size: int,
-    num_layers: int,
-    num_heads: int,
-    head_size: int,
-    dtype: torch.dtype,
-    seed: int,
-    device: str,
-) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
-    torch.random.manual_seed(seed)
-    torch.cuda.manual_seed(seed)
+from aphrodite.common.utils import (create_kv_caches_with_random,
+                                    create_kv_caches_with_random_flash)
 
-    scale = head_size**-0.5
-    x = 16 // torch.tensor([], dtype=dtype).element_size()
-    key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
-    key_caches = []
-    for _ in range(num_layers):
-        key_cache = torch.empty(size=key_cache_shape,
-                                dtype=dtype,
-                                device=device)
-        key_cache.uniform_(-scale, scale)
-        key_caches.append(key_cache)
 
-    value_cache_shape = (num_blocks, num_heads, head_size, block_size)
-    value_caches = []
-    for _ in range(num_layers):
-        value_cache = torch.empty(size=value_cache_shape,
-                                  dtype=dtype,
-                                  device=device)
-        value_cache.uniform_(-scale, scale)
-        value_caches.append(value_cache)
-    return key_caches, value_caches
+@pytest.fixture()
+def kv_cache_factory():
+    return create_kv_caches_with_random
 
 
 @pytest.fixture()
-def kv_cache_factory():
-    return create_kv_caches
+def kv_cache_factory_flashinfer():
+    return create_kv_caches_with_random_flash

+ 83 - 0
tests/kernels/quant_utils.py

@@ -0,0 +1,83 @@
+from typing import Optional, Tuple, Union
+
+import torch
+
+from aphrodite.common.utils import is_hip
+
+# Using the default value (240.0) from pytorch will cause accuracy
+# issue on dynamic quantization models. Here use 224.0 for rocm.
+ROCM_FP8_MAX = 224.0
+FP8_DTYPE = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
+
+
+def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
+    return torch.as_tensor(x, dtype=torch.float32, device='cuda')
+
+def ref_dynamic_per_token_quant(x: torch.tensor,
+                                quant_dtype: torch.dtype,
+                                scale_ub: Optional[torch.tensor] = None) \
+        -> Tuple[torch.tensor, torch.tensor]:
+
+    assert quant_dtype in [torch.int8, FP8_DTYPE]
+    if scale_ub is not None:
+        assert quant_dtype == FP8_DTYPE
+
+    qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
+            else torch.finfo(quant_dtype)
+    qtype_traits_max = ROCM_FP8_MAX if is_hip() else qtype_traits.max
+    qtype_traits_min = -ROCM_FP8_MAX if is_hip() else qtype_traits.min
+    qtype_max = as_float32_tensor(qtype_traits_max)
+    s_1 = as_float32_tensor(1.0)
+    s_512 = as_float32_tensor(512.0)
+
+    # For fp8, in order to match the cuda kernel output, we have to do exactly
+    # the same operations as in the corresponding fp8 kernel to prevent
+    # rounding errors.
+
+    # Compute scales
+    x_token_max, _ = x.abs().max(dim=-1)
+    x_token_max = as_float32_tensor(x_token_max)
+    if scale_ub is not None:
+        x_token_max = x_token_max.clamp(max=scale_ub)
+    scales = (x_token_max / qtype_max)[:, None]
+
+    # Quant
+    if quant_dtype == torch.int8:
+        iscales = as_float32_tensor(s_1 / scales)
+        torch_out = as_float32_tensor(x) * iscales
+        torch_out = torch_out.round()
+        torch_out = torch_out.clamp(qtype_traits_min,
+                                    qtype_traits_max).to(quant_dtype)
+    else:
+        assert quant_dtype == FP8_DTYPE
+        min_scaling_factor = s_1 / (qtype_max * s_512)
+        scales = scales.clamp(min=min_scaling_factor)
+        torch_out = as_float32_tensor(x) / scales
+        torch_out = torch_out.clamp(qtype_traits_min,
+                                    qtype_traits_max).to(quant_dtype)
+
+    return torch_out, scales
+
+
+# The int8 version is very similar. Incorporate the int8 version, like in
+# ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant
+# kernel
+def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
+                    -> Tuple[torch.tensor, torch.tensor]:
+
+    fp8_traits = torch.finfo(FP8_DTYPE)
+    fp8_traits_max = ROCM_FP8_MAX if is_hip() else fp8_traits.max
+    fp8_traits_min = -ROCM_FP8_MAX if is_hip() else fp8_traits.min
+    fp8_max = as_float32_tensor(fp8_traits_max)
+    one = as_float32_tensor(1.0)
+
+    # For fp8, in order to match the cuda kernel output, we have to do exactly
+    # the same operations as in the corresponding fp8 kernel to prevent
+    # rounding errors.
+
+    x_max = as_float32_tensor(x.abs().max())
+    ref_scale = x_max / fp8_max
+    ref_iscale = one / ref_scale
+    ref_out = (as_float32_tensor(x) * ref_iscale).clamp(
+        fp8_traits_min, fp8_traits_max).to(FP8_DTYPE)
+    return ref_out, ref_scale.view((1, ))

Энэ ялгаанд хэт олон файл өөрчлөгдсөн тул зарим файлыг харуулаагүй болно