Browse Source

revert previous commit

AlpinDale 1 year ago
parent
commit
fefbf029c9

+ 0 - 20
.github/workflows/thumbnail.yml

@@ -1,20 +0,0 @@
-name: Embed Thumbnail
-
-on:
-  push:
-    branches:
-      - main
-
-jobs:
-  update-readme:
-    runs-on: ubuntu-latest
-    steps:
-      - name: Checkout repository
-        uses: actions/checkout@v2
-
-      - name: Embed Thumbnail
-        uses: JamesIves/github-pages-deploy-action@4.1.1
-        with:
-          branch: main
-          folder: "./assets"
-          file: "aphrodite.png"

+ 201 - 0
LICENSE

@@ -0,0 +1,201 @@
+                                 Apache License
+                           Version 2.0, January 2004
+                        http://www.apache.org/licenses/
+
+   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+   1. Definitions.
+
+      "License" shall mean the terms and conditions for use, reproduction,
+      and distribution as defined by Sections 1 through 9 of this document.
+
+      "Licensor" shall mean the copyright owner or entity authorized by
+      the copyright owner that is granting the License.
+
+      "Legal Entity" shall mean the union of the acting entity and all
+      other entities that control, are controlled by, or are under common
+      control with that entity. For the purposes of this definition,
+      "control" means (i) the power, direct or indirect, to cause the
+      direction or management of such entity, whether by contract or
+      otherwise, or (ii) ownership of fifty percent (50%) or more of the
+      outstanding shares, or (iii) beneficial ownership of such entity.
+
+      "You" (or "Your") shall mean an individual or Legal Entity
+      exercising permissions granted by this License.
+
+      "Source" form shall mean the preferred form for making modifications,
+      including but not limited to software source code, documentation
+      source, and configuration files.
+
+      "Object" form shall mean any form resulting from mechanical
+      transformation or translation of a Source form, including but
+      not limited to compiled object code, generated documentation,
+      and conversions to other media types.
+
+      "Work" shall mean the work of authorship, whether in Source or
+      Object form, made available under the License, as indicated by a
+      copyright notice that is included in or attached to the work
+      (an example is provided in the Appendix below).
+
+      "Derivative Works" shall mean any work, whether in Source or Object
+      form, that is based on (or derived from) the Work and for which the
+      editorial revisions, annotations, elaborations, or other modifications
+      represent, as a whole, an original work of authorship. For the purposes
+      of this License, Derivative Works shall not include works that remain
+      separable from, or merely link (or bind by name) to the interfaces of,
+      the Work and Derivative Works thereof.
+
+      "Contribution" shall mean any work of authorship, including
+      the original version of the Work and any modifications or additions
+      to that Work or Derivative Works thereof, that is intentionally
+      submitted to Licensor for inclusion in the Work by the copyright owner
+      or by an individual or Legal Entity authorized to submit on behalf of
+      the copyright owner. For the purposes of this definition, "submitted"
+      means any form of electronic, verbal, or written communication sent
+      to the Licensor or its representatives, including but not limited to
+      communication on electronic mailing lists, source code control systems,
+      and issue tracking systems that are managed by, or on behalf of, the
+      Licensor for the purpose of discussing and improving the Work, but
+      excluding communication that is conspicuously marked or otherwise
+      designated in writing by the copyright owner as "Not a Contribution."
+
+      "Contributor" shall mean Licensor and any individual or Legal Entity
+      on behalf of whom a Contribution has been received by Licensor and
+      subsequently incorporated within the Work.
+
+   2. Grant of Copyright License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      copyright license to reproduce, prepare Derivative Works of,
+      publicly display, publicly perform, sublicense, and distribute the
+      Work and such Derivative Works in Source or Object form.
+
+   3. Grant of Patent License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      (except as stated in this section) patent license to make, have made,
+      use, offer to sell, sell, import, and otherwise transfer the Work,
+      where such license applies only to those patent claims licensable
+      by such Contributor that are necessarily infringed by their
+      Contribution(s) alone or by combination of their Contribution(s)
+      with the Work to which such Contribution(s) was submitted. If You
+      institute patent litigation against any entity (including a
+      cross-claim or counterclaim in a lawsuit) alleging that the Work
+      or a Contribution incorporated within the Work constitutes direct
+      or contributory patent infringement, then any patent licenses
+      granted to You under this License for that Work shall terminate
+      as of the date such litigation is filed.
+
+   4. Redistribution. You may reproduce and distribute copies of the
+      Work or Derivative Works thereof in any medium, with or without
+      modifications, and in Source or Object form, provided that You
+      meet the following conditions:
+
+      (a) You must give any other recipients of the Work or
+          Derivative Works a copy of this License; and
+
+      (b) You must cause any modified files to carry prominent notices
+          stating that You changed the files; and
+
+      (c) You must retain, in the Source form of any Derivative Works
+          that You distribute, all copyright, patent, trademark, and
+          attribution notices from the Source form of the Work,
+          excluding those notices that do not pertain to any part of
+          the Derivative Works; and
+
+      (d) If the Work includes a "NOTICE" text file as part of its
+          distribution, then any Derivative Works that You distribute must
+          include a readable copy of the attribution notices contained
+          within such NOTICE file, excluding those notices that do not
+          pertain to any part of the Derivative Works, in at least one
+          of the following places: within a NOTICE text file distributed
+          as part of the Derivative Works; within the Source form or
+          documentation, if provided along with the Derivative Works; or,
+          within a display generated by the Derivative Works, if and
+          wherever such third-party notices normally appear. The contents
+          of the NOTICE file are for informational purposes only and
+          do not modify the License. You may add Your own attribution
+          notices within Derivative Works that You distribute, alongside
+          or as an addendum to the NOTICE text from the Work, provided
+          that such additional attribution notices cannot be construed
+          as modifying the License.
+
+      You may add Your own copyright statement to Your modifications and
+      may provide additional or different license terms and conditions
+      for use, reproduction, or distribution of Your modifications, or
+      for any such Derivative Works as a whole, provided Your use,
+      reproduction, and distribution of the Work otherwise complies with
+      the conditions stated in this License.
+
+   5. Submission of Contributions. Unless You explicitly state otherwise,
+      any Contribution intentionally submitted for inclusion in the Work
+      by You to the Licensor shall be under the terms and conditions of
+      this License, without any additional terms or conditions.
+      Notwithstanding the above, nothing herein shall supersede or modify
+      the terms of any separate license agreement you may have executed
+      with Licensor regarding such Contributions.
+
+   6. Trademarks. This License does not grant permission to use the trade
+      names, trademarks, service marks, or product names of the Licensor,
+      except as required for reasonable and customary use in describing the
+      origin of the Work and reproducing the content of the NOTICE file.
+
+   7. Disclaimer of Warranty. Unless required by applicable law or
+      agreed to in writing, Licensor provides the Work (and each
+      Contributor provides its Contributions) on an "AS IS" BASIS,
+      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+      implied, including, without limitation, any warranties or conditions
+      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+      PARTICULAR PURPOSE. You are solely responsible for determining the
+      appropriateness of using or redistributing the Work and assume any
+      risks associated with Your exercise of permissions under this License.
+
+   8. Limitation of Liability. In no event and under no legal theory,
+      whether in tort (including negligence), contract, or otherwise,
+      unless required by applicable law (such as deliberate and grossly
+      negligent acts) or agreed to in writing, shall any Contributor be
+      liable to You for damages, including any direct, indirect, special,
+      incidental, or consequential damages of any character arising as a
+      result of this License or out of the use or inability to use the
+      Work (including but not limited to damages for loss of goodwill,
+      work stoppage, computer failure or malfunction, or any and all
+      other commercial damages or losses), even if such Contributor
+      has been advised of the possibility of such damages.
+
+   9. Accepting Warranty or Additional Liability. While redistributing
+      the Work or Derivative Works thereof, You may choose to offer,
+      and charge a fee for, acceptance of support, warranty, indemnity,
+      or other liability obligations and/or rights consistent with this
+      License. However, in accepting such obligations, You may act only
+      on Your own behalf and on Your sole responsibility, not on behalf
+      of any other Contributor, and only if You agree to indemnify,
+      defend, and hold each Contributor harmless for any liability
+      incurred by, or claims asserted against, such Contributor by reason
+      of your accepting any such warranty or additional liability.
+
+   END OF TERMS AND CONDITIONS
+
+   APPENDIX: How to apply the Apache License to your work.
+
+      To apply the Apache License to your work, attach the following
+      boilerplate notice, with the fields enclosed by brackets "[]"
+      replaced with your own identifying information. (Don't include
+      the brackets!)  The text should be enclosed in the appropriate
+      comment syntax for the file format. We also recommend that a
+      file or class name and description of purpose be included on the
+      same "printed page" as the copyright notice for easier
+      identification within third-party archives.
+
+   Copyright [yyyy] [name of copyright owner]
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.

+ 12 - 0
README.md

@@ -0,0 +1,12 @@
+<h1 align="center">
+Aphrodite - The Pygmalion Backend
+</h1>
+<h3 align="center">
+Work in Progress
+</h3>
+
+![aphrodite](./assets/aphrodite.png)
+
+Aphrodite is the backend service for PygmalionAI, built on top of [FastChat](https://github.com/lm-sys/FastChat), [vLLM](https://github.com/vllm-project/vllm), [SkyPilot](https://github.com/skypilot-org/skypilot), and more.
+
+Currently a work in progress, not remotely functional.

+ 10 - 0
aphrodite/__init__.py

@@ -0,0 +1,10 @@
+from aphrodite.common.outputs import ChatCompletionOutput, RequestOutput
+from aphrodite.common.sampling_params import SamplingParams
+
+__version__ = "0.0"
+
+__all__ = [
+    "SamplingParams",
+    "RequestOutput",
+    "ChatCompletionOutput",
+]

+ 67 - 0
aphrodite/common/block.py

@@ -0,0 +1,67 @@
+"""Token blocks."""
+from typing import List
+
+from aphrodite.common.utils import Device
+
+_BLANK_TOKEN_ID = -1
+
+
+class LogicalTokenBlock:
+    """A block that stores a contiguous chunk of tokens from left to right.
+    
+    Logical blocks are used to represent the states of the corresponding physical
+    blocks in the KV cache.
+    """
+    def __init__(
+        self,
+        block_number: int,
+        block_size: int,
+    ) -> None:
+        self.block_number = block_number
+        self.block_size = block_size
+        
+        self.token_ids = [_BLANK_TOKEN_ID] * block_size
+        self.num_tokens = 0
+
+    def is_empty(self) -> bool:
+        return self.num_tokens == 0
+
+    def get_num_empty_slots(self) -> int:
+        return self.block_size - self.num_tokens
+
+    def is_full(self) -> bool:
+        return self.num_tokens == self.block_size
+
+    def append_tokens(self, token_ids: List[int]) -> None:
+        assert len(token_ids) <= self.get_num_empty_slots()
+        self.token_ids[self.num_tokens:self.num_tokens + len(token_ids)] = token_ids
+        self.num_tokens += len(token_ids)
+
+    def get_token_ids(self) -> List[int]:
+        return self.token_ids[:self.num_tokens]
+
+    def get_last_token_id(self) -> int:
+        assert self.num_tokens > 0
+        return self.token_ids[self.num_tokens - 1]
+
+
+class PhysicalTokenBlock:
+    """Represents the state of a block in the KV cache.
+    Needs to be double checked.
+    """
+    def __init__(
+        self,
+        device: Device,
+        block_number: int,
+        block_size: int,
+    ) -> None:
+        self.device = device
+        self.block_number = block_number
+        self.block_size = block_size
+
+        self.ref_count = 0
+
+    def __repr__(self) -> str:
+        return (f'PhysicalTokenBlock(device={self.device}, '
+        f'block_number={self.block_number}, '
+        f'ref_count={self.ref_count})')

+ 209 - 0
aphrodite/common/config.py

@@ -0,0 +1,209 @@
+"""Configuration"""
+from typing import Optional
+
+import torch
+from transformers import AutoConfig, PretrainedConfig
+
+from aphrodite.common.logger import init_logger
+from aphrodite.common.utils import get_cpu_memory
+
+logger = init_logger(__name__)
+
+_GiB = 1 << 30
+
+class ModelConfig:
+    """Configuration for the model.
+
+    Args:
+        model: Name or path of the HF model to use.
+        download_dir: Directory to download and load the weights, defaults to
+            default HF cache directory.
+        use_np_weights: Save a numpy copy of model weights for faster loading.
+            This can increase the disk usage by up to 2x, and the model will be
+            loaded into CPU memory first.
+        use_dummy_weights: Use dummy values for model weights (for profiling).
+        dtype: Datatype for model weights and activations. The "auto" option will
+            use FP16 precision for FP32/FP16 models, and BF16 precision for BF16.
+        seed: Random seed for consistent reproducibility.
+    """
+
+    def __init__(
+        self,
+        model: str,
+        download_dir: Optional[str],
+        use_np_weights: bool,
+        use_dummy_weights: bool,
+        dtype: str,
+        seed: int,
+    ) -> None:
+        self.model = model
+        self.download_dir = download_dir
+        self.use_np_weights = use_np_weights
+        self.use_dummy_weights = use_dummy_weights
+        self.seed = seed
+
+        self.hf_config: PretrainedConfig = AutoConfig.from_pretrained(model)
+        self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
+
+    def _verify_with_parallel_config(
+        self,
+        parallel_config: "ParallelConfig",
+    ) -> None:
+        total_num_attention_heads = self.hf_config.num_attention_heads
+        tensor_parallel_size = parallel_config.tensor_parallel_size
+        if total_num_attention_heads % tensor_parallel_size != 0:
+            raise ValueError(
+                f"Total number of attention heads ({total_num_attention_heads})"
+                " must be divisible by tensor parallel size "
+                f"({tensor_parallel_size}).")
+
+        total_num_hidden_layers = self.hf_config.num_hidden_layers
+        pipeline_parallel_size = parallel_config.pipeline_parallel_size
+        if total_num_hidden_layers % pipeline_parallel_size != 0:
+            raise ValueError(
+                f"Total number of hidden layers ({total_num_hidden_layers}) "
+                "must be divisible by pipeline parallel size "
+                f"({pipeline_parallel_size}).")
+
+    def get_hidden_size(self) -> int:
+        return self.hf_config.hidden_size
+
+    def get_head_size(self) -> int:
+        return self.hf_config.hidden_size // self.hf_config.num_attention_heads
+
+    def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
+        total_num_attention_heads = self.hf_config.num_attention_heads
+        return total_num_attention_heads // parallel_config.tensor_parallel_size
+
+    def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
+        total_num_hidden_layers = self.hf_config.num_hidden_layers
+        return total_num_hidden_layers // parallel_config.pipeline_parallel_size
+
+class CacheConfig:
+    """Configuration for the KV cache.
+    Args:
+        block_size: Size of a cache block in number of tokens.
+        gpu_memory_utilization: Fraction of GPU memory to use for the Aphrodite execution.
+        swap_space: Size of the CPU swap space per GPU (in GiB).
+    """
+
+    def __init__(
+        self,
+        block_size: int,
+        gpu_memory_utilization: float,
+        swap_space: int,
+    ) -> None:
+        self.block_size = block_size
+        self.gpu_memory_utilization = gpu_memory_utilization
+        self.swap_space_bytes = swap_space * _GiB
+        self._verify_args()
+
+        self.num_gpu_blocks = None
+        self.num_cpu_blocks = None
+
+    def _verify_args(self) -> None:
+        if self.gpu_memory_utilization > 1.0:
+            raise ValueError(
+                "GPU memory utilization must be less than 1.0. You passed "
+                f"{self.gpu_memory_utilization} instead.")
+
+    def _verify_with_parallel_config(
+        self,
+        parallel_config: "ParallelConfig",
+    ) -> None:
+        total_cpu_memory = get_cpu_memory()
+        num_gpu_per_node = parallel_config.tensor_parallel_size
+        cpu_memory_usage = self.swap_space_bytes * num_gpu_per_node
+
+        msg = (
+            f"{cpu_memory_usage / _GiB:.2f} GiB out of "
+            f"the {total_cpu_memory / _GiB:.2f} GiB total CPU memory is "
+            "allocated for the swap space.")
+        if cpu_memory_usage > 0.7 * total_cpu_memory:
+            raise ValueError("Too large swap space. " + msg)
+        elif cpu_memory_usage > 0.4 * total_cpu_memory:
+            logger.warn("Possibly too large swap space. " + msg)
+
+
+class ParallelConfig:
+    """Configuration for the distributed inference.
+    Args:
+        pipeline_parallel_size: Number of pipeline parallel groups.
+        tensor_parallel_size: Number of tensor parallel groups.
+        worker_use_ray: Whether to use Ray for model workers. Will be
+            set to `True` if either pipeline_parallel_size or
+            tensor_parallel_size is greater than 1.
+    """
+
+    def __init__(
+        self,
+        pipeline_parallel_size: int,
+        tensor_parallel_size: int,
+        worker_use_ray: bool,
+    ) -> None:
+        self.pipeline_parallel_size = pipeline_parallel_size
+        self.tensor_parallel_size = tensor_parallel_size
+        self.worker_use_ray = worker_use_ray
+
+        self.world_size = pipeline_parallel_size * tensor_parallel_size
+        if self.world_size > 1:
+            self.worker_use_ray = True
+        self._verify_args()
+
+    """TODO(alpin): Implement pipeline parallelism."""
+    def _verify_args(self) -> None:
+        if self.pipeline_parallel_size > 1:
+            raise NotImplementedError(
+                "Pipeline parallelism is not supported yet.")
+
+class SchedulerConfig:
+    """Scheduler Configuration:
+    Args:
+        max_num_batched_tokens: Maximum number of tokens to be processed in
+            a single iteration.
+        max_num_seqs: Maximum number of sequences to be processed in a single
+            iteration.
+    """
+    def __init__(
+        self,
+        max_num_batched_tokens: int,
+        max_num_seqs: int,
+    ) -> None:
+        self.max_num_batched_tokens = max_num_batched_tokens
+        self.max_num_seqs = max_num_seqs
+
+_STR_DTYPE_TO_TORCH_DTYPE = {
+    "half": torch.float16,
+    "float16": torch.float16,
+    "float": torch.float32,
+    "float32": torch.float32,
+    "bfloat16": torch.bfloat16,
+}
+
+def _get_and_verify_dtype(
+    config: PretrainedConfig,
+    dtype: str,
+) -> torch.dtype:
+    """Note: getattr(config, "torch_dtype", torch.float32) is incorrect
+    because config.torch_dtype can be None"""
+    config_dtype = getattr(config, "torch_dtype", None)
+    if config_dtype is None:
+        config_dtype = torch.float32
+
+    dtype = dtype.lower()
+    # Check to see if dtype is a valid dtype *or* if it's auto
+    if dtype not in [*_STR_DTYPE_TO_TORCH_DTYPE.values(), "auto"]:
+        raise ValueError(f"Unknown dtype: {dtype}")
+    
+    # Obtain torch_dtype
+    if dtype == "auto":
+        if config_dtype == torch.float32:
+            # Cast to 16-bit precision, BF16 if available
+            torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+            logger.warning(f"Casting {config_dtype} to {torch_dtype}. Not recommended.")
+        else:
+            torch_dtype = config_dtype
+    else:
+        torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
+
+    return torch_dtype

+ 49 - 0
aphrodite/common/logger.py

@@ -0,0 +1,49 @@
+"""
+Logging utility. Adapted from https://github.com/skypilot-org/skypilot/blob/master/sky/sky_logging.py
+"""
+
+import logging
+import sys
+
+_FORMAT = "%(levelname).1s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
+_DATE_FORMAT = '%m-%d %H:%M:%S'
+
+class NewLineFormatter(logging.Formatter):
+    """Adds logging prefix to newlines to align multi-line messages."""
+
+    def __init__(self, fmt, datefmt=None):
+        logging.Formatter.__init__(self, fmt, datefmt)
+
+    def format(self, record):
+        msg = logging.Formatter.format(self, record)
+        if record.message != '':
+            parts = msg.split(record.message)
+            msg = msg.replace('\n', '\r\n' + parts[0])
+        return msg
+
+
+_root_logger = logging.getLogger('aphrodite')
+_default_handler = None
+
+def _setup_logger():
+    _root_logger.setLevel(logging.DEBUG)
+    global _default_handler
+    if _default_handler is None:
+        _default_handler = logging.StreamHandler(sys.stdout)
+        _default_handler.flush = sys.stdout.flush  # type: ignore
+        _default_handler.setLevel(logging.INFO)
+        _root_logger.addHandler(_default_handler)
+    fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT)
+    _default_handler.setFormatter(fmt)
+    """Setting this will avoid the message
+    being propagated to the parent logger"""
+    _root_logger.propagate = False
+
+
+# The logger is initialized when the module is imported.
+# This is thread-safe as the module is only imported once,
+# guaranteed by the Python GIL.
+_setup_logger()
+
+def init_logger(name: str):
+    return logging.getLogger(name)

+ 99 - 0
aphrodite/common/outputs.py

@@ -0,0 +1,99 @@
+"""Generation outputs by the model."""
+from typing import Dict, List, Optional
+
+from aphrodite.common.sequence import SequenceGroup, SequenceStatus
+
+
+class ChatCompletionOutput:
+    """The output data of one chat completion output of a request.
+    Args:
+        index: The index of the output in the request.
+        text: The generated output text.
+        token_ids: The token IDs of the generated output text.
+        cumulative_logprob: The cumulative log probability of the generated output text.
+        logprobs: The log probabilities of the top probability words at each position
+            if the logprobs are requested.
+        finish_reason: The reason why the sequence is finished.
+    """
+
+    def __init__(
+        self,
+        index: int,
+        text: str,
+        token_ids: List[int],
+        cumulative_logprob: float,
+        longprobs: Optional[List[Dict[int, float]]],
+        finish_reason: Optional[str] = None,
+    ) -> None:
+        self.index = index
+        self.text = text
+        self.token_ids = token_ids
+        self.cumulative_logprob = cumulative_logprob
+        self.logprobs = longprobs
+        self.finish_reason = finish_reason
+
+    def finished(self) -> bool:
+        return self.finish_reason is not None
+
+    def __repr__(self) -> str:
+        return (f"ChatCompletionOutput(index={self.index}, "
+                f"text={self.text!r}, "
+                f"token_ids={self.token_ids}, "
+                f"cumulative_logprob={self.cumulative_logprob}, "
+                f"logprobs={self.logprobs}, "
+                f"finish_reason={self.finish_reason})")
+
+
+class RequestOutput:
+    """The output data of a request to the LLM.
+    Args:
+        request_id: The unique ID of the request.
+        prompt: The prompt string of the request.
+        prompt_token_ids: The token IDs of the prompt.
+        outputs: The output sequences of the request.
+        finished: Whether the whole request is finished.
+    """
+
+    def __init__(
+        self,
+        request_id: str,
+        prompt: str,
+        prompt_token_ids: List[int],
+        outputs: List[ChatCompletionOutput],
+        finished: bool,
+    ) -> None:
+        self.request_id = request_id
+        self.prompt = prompt
+        self.prompt_token_ids = prompt_token_ids
+        self.outputs = outputs
+        self.finished = finished
+
+    @classmethod
+    def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
+        n = seq_group.sampling_params.name
+        seqs = seq_group.get_seqs()
+        assert n <= len(seqs)
+        sorted_seqs = sorted(
+            seqs, key=lambda seq: seq.get_cumulative_logprob(), reverse=True)
+        top_n_seqs = sorted_seqs[:n]
+
+        outputs: List[ChatCompletionOutput] = []
+        for seq in top_n_seqs:
+            logprob = seq.output_logprobs
+            if seq_group.sampling_params.logprobs is None:
+                logprobs = {}
+            finished_reason = SequenceStatus.get_finished_reason(seq.status)
+            output = ChatCompletionOutput(seqs.index(seq), seq.output_text, seq.get_output_token_ids(), seq.get_cumulative_logprob(), logprobs, finished_reason)
+            outputs.append(output)
+
+        prompt = top_n_seqs[0].prompt
+        prompt_token_ids = top_n_seqs[0].data.prompt_token_ids
+        finished = seq_group.is_finished()
+        return cls(seq_group.request_id, prompt, prompt_token_ids, outputs, finished)
+
+    def __repr__(self) -> str:
+        return (f"RequestOutput(request_id={self.request_id}, "
+                f"prompt={self.prompt!r}, "
+                f"prompt_token_ids={self.prompt_token_ids}, "
+                f"outputs={self.outputs}, "
+                f"finished={self.finished})")

+ 96 - 0
aphrodite/common/sampling_params.py

@@ -0,0 +1,96 @@
+"""Sampling parameters used for text generation."""
+from typing import List, Optional, Union
+
+class SamplingParams:
+
+    def __init__(
+        self,
+        n: int = 1,
+        best_of: Optional[int] = None,
+        presence_penalty: float = 0.0,
+        frequency_penalty: float = 0.0,
+        temperature: float = 1.0,
+        top_p: float = 1.0,
+        top_k: int = -1,
+        use_beam_search: bool = False,
+        stop: Union[str, List[str]] = [],
+        ignore_eos: bool = False,
+        max_tokens: int = 16,
+        logprobs: Optional[int] = None,
+    ) -> None:
+        self.n = n
+        self.best_of = best_of if best_of is not None else n
+        self.presence_penalty = presence_penalty
+        self.frequency_penalty = frequency_penalty
+        self.temperature = temperature
+        self.top_p = top_p
+        self.top_k = top_k
+        self.use_beam_search = use_beam_search
+        self.stop = [stop] if isinstance(stop, str) else list(stop)
+        self.ignore_eos = ignore_eos
+        self.max_tokens = max_tokens
+        self.logprobs = logprobs
+
+        self._verify_args()
+        if self.use_beam_search:
+            self._verify_beam_search()
+        elif self.temperature == 0.0:
+            self._verify_greedy_sampling()
+
+    def _verify_args(self) -> None:
+        if self.n < 1:
+            raise ValueError(*f"n must be at least 1, got {self.n}.")
+        if self.best_of < self.n:
+            raise ValueError(f"best_of must be greater than or equal to n, "
+                            f"got n={self.n} and best_of={self.best_of}.")
+        if not -2.0 <= self.presence_penalty <= 2.0: 
+            raise ValueError(f"presence_penalty must be in [-2, 2] "
+                            f"got {self.presence_penalty}.")
+        if self.temperature < 0.0:
+            raise ValueError(
+                f"temperature must be non-negative, got {self.temperature}.")
+        if not 0.0 < self.top_p <= 1.0:
+            raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
+        if self.top_k < -1 or self.top_k == 0:
+            raise ValueError(f"top_k must be -1 (disable), or at least 1, "
+                            f"got {self.top_k}.")
+        if self.max_tokens < 1:
+            raise ValueError(
+                f"max_tokens must be at least 1, got {self.max_tokens}.")
+        if self.logprobs is not None and self.logprobs < 0:
+            raise ValueError(
+                f"logprobs must be non-negative, got {self.logprobs}.")
+
+    def _verify_beam_search(self) -> None:
+        if self.best_of == 1:
+            raise ValueError("best_of must be greater than 1 when using "
+                            f"beam search. Got {self.best_of}.")
+        if self.temperature > 0.0:
+            raise ValueError("temperature must be 0 when using beam search.")
+        if self.top_p < 1.0:
+            raise ValueError("top_p must be 1 when using beam search.")
+        if self.top_k != -1:
+            raise ValueError("top_k must be -1 when using beam search.")
+
+    def _verify_greedy_sampling(self) -> None:
+        if self.best_of > 1:
+            raise ValueError("best_of must be 1 when using greedy sampling. Got {self.best_of}.")
+        if self.top_p < 1.0:
+            raise ValueError("top_p must be 1 when using greedy sampling.")
+        if self.top_k != -1:
+            raise ValueError("top_k must be -1 when using greedy sampling.")
+
+
+    def __repr__(self) -> str:
+        return (f"SamplingParams(n={self.n}, "
+                f"best_of={self.best_of}, "
+                f"presence_penalty={self.presence_penalty}, "
+                f"frequency_penalty={self.frequency_penalty}, "
+                f"temperature={self.temperature}, "
+                f"top_p={self.top_p}, "
+                f"top_k={self.top_k}, "
+                f"use_beam_search={self.use_beam_search}, "
+                f"stop={self.stop}, "
+                f"ignore_eos={self.ignore_eos}, "
+                f"max_tokens={self.max_tokens}, "
+                f"logprobs={self.logprobs})")

+ 231 - 0
aphrodite/common/sequence.py

@@ -0,0 +1,231 @@
+"""Sequence."""
+import copy
+import enum
+from typing import Dict, List, Optional
+
+from aphrodite.common.block import LogicalTokenBlock
+from aphrodite.common.sampling_params import SamplingParams
+
+class SequenceStatus(enum.Enum):
+    WAITING = enum.auto()
+    RUNNING = enum.auto()
+    SWAPPED = enum.auto()
+    FINISHED_STOPPED = enum.auto()
+    FINISHED_LENGTH_CAPPED = enum.auto()
+    FINISHED_ABORTED = enum.auto()
+
+    @staticmethod
+    def is_finished(status: "SequenceStatus") -> bool:
+        return status in [
+            SequenceStatus.FINISHED_STOPPED,
+            SequenceStatus.FINISHED_LENGTH_CAPPED,
+            SequenceStatus.FINISHED_ABORTED,
+        ]
+
+    @staticmethod
+    def get_finished_reason(status: "SequenceStatus") -> Optional[str]:
+        if status == SequenceStatus.FINISHED_STOPPED:
+            finish_reason = "stop"
+        elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
+            finish_reason = "length"
+        elif status == SequenceStatus.FINISHED_ABORTED:
+            finish_reason = "abort"
+        else:
+            finish_reason = None
+        return finish_reason
+
+class SequenceData:
+
+    def __init__(
+        self,
+        prompt_token_ids: List[int],
+    ) -> None:
+        self.prompt_token_ids = prompt_token_ids
+        self.output_token_ids: List[int] = []
+        self.cumulative_logprob = 0.0
+
+    def append_token_id(self, token_id: int, logprob: float) -> None:
+        self.output_token_ids.append(token_id)
+        self.cumulative_logprob += logprob
+
+    def get_len(self) -> int:
+        return len(self.output_token_ids) + len(self.prompt_token_ids)
+
+    def get_output_len(self) -> int:
+        return len(self.output_token_ids)
+
+    def get_token_ids(self) -> int:
+        if not self.output_token_ids:
+            return self.prompt_token_ids[-1]
+        return self.output_token_ids[-1]
+
+    def __repr__(self) -> str:
+        return (f"SequenceData("
+                f"prompt_token_ids={self.prompt_token_ids}, "
+                f"output_token_ids={self.output_token_ids}, "
+                f"cumulative_logprob={self.cumulative_logprob})")
+
+
+class Sequence:
+
+    def __init__(
+        self,
+        seq_id: int,
+        prompt: str,
+        prompt_token_ids: List[int],
+        block_size: int,
+    ) -> None:
+        self.seq_id = seq_id
+        self.prompt = prompt
+        self.block_size = block_size
+
+        self.data = SequenceData(prompt_token_ids)
+        self.output_logprobs: List[Dict[int, float]] = []
+        self.output_tokens: List[str] = []
+        self.output_text = ""
+
+        self.logical_token_blocks: List[LogicalTokenBlock] = []
+        self._append_tokens_to_blocks(prompt_token_ids)
+        self.status = SequenceStatus.WAITING
+
+    def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
+        while token_ids:
+            if not self.logical_token_blocks:
+                self._append_logical_block()
+
+            last_block = self.logical_token_blocks[-1]
+            if last_block.is_full():
+                self._append_logical_block()
+                last_block = self.logical_token_blocks[-1]
+
+            num_empty_slots = last_block.get_num_empty_slots()
+            last_block.append_tokens(token_ids[:num_empty_slots])
+            token_ids = token_ids[num_empty_slots:]
+
+    def append_token_id(
+        self,
+        token_id: int,
+        logprobs: Dict[int, float],
+    ) -> None:
+        assert token_id in logprobs
+        self._append_tokens_to_blocks([token_id])
+        self.output_logprobs.append(logprobs)
+        self.data.append_token_id(token_id, logprobs[token_id])
+
+    def get_len(self) -> int:
+        return self.data.get_len()
+
+    def get_output_len(self) -> int:
+        return self.data.get_output_len()
+
+    def get_token_ids(self) -> List[int]:
+        return self.data.get_token_ids()
+
+    def get_last_token_id(self) -> float:
+        return self.data.cumulative_logprob
+
+    def is_finished(self) -> bool:
+        return SequenceStatus.is_finished(self.status)
+
+    def fork(self, child_seq: 'Sequence') -> None:
+        child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
+        child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
+        child_seq.data = copy.deepcopy(self.data)
+        return None
+
+    def __repr__(self) -> str:
+        return (f"Sequence(seq_id={self.seq_id}, "
+                f"Status={self.status.name}, "
+                f"num_blocks={len(self.logical_token_blocks)})")
+
+
+class SequenceGroup:
+
+    def __init__(
+        self,
+        request_id: str,
+        seqs: List[Sequence],
+        sampling_params: float,
+    ) -> None:
+        self.request_id = request_id
+        self.seqs = seqs
+        self.sampling_params = sampling_params
+        # self.arrival_time = arrival_time
+
+    def get_seqs(
+        self,
+        status: Optional[SequenceStatus] = None,
+    ) -> List[Sequence]:
+        if status is None:
+            return self.seqs
+        else:
+            return [seq for seq in self.reqs if seq.status == status]
+
+    def num_seqs(self, status: Optional[SequenceData] = None) -> int:
+        return len(self.get_seqs(status))
+
+    def find(self, seq_id: int) -> Sequence:
+        for seq in self.seqs:
+            if seq.seq_id == seq_id:
+                return seq
+        raise ValueError(f'Sequence {seq_id} not found.')
+
+    def find(self, seq_id: int) -> Sequence:
+        for seq in self.seqs:
+            if seq.seq_id == seq_id:
+                return seq
+        raise ValueError(f'Sequence {seq_id} not found.')
+
+    def is_finished(self) -> bool:
+        return all(seq.is_finished() for seq in self.seqs)
+
+    def __repr__(self) -> str:
+        return (f"SequenceGroup(request_id={self.request_id}, " 
+                f"sampling_params={self.sampling_params}, "
+                f"num_seqs={len(self.seqs)})")
+
+
+class SequenceGroupMetadata:
+
+    def __init__(
+        self,
+        request_id: str,
+        is_prompt: bool,
+        seq_data: Dict[int, SequenceData],
+        sampling_params: SamplingParams,
+        block_tables: Dict[int, List[int]],
+    ) -> None:
+        self.request_id = request_id
+        self.is_prompt = is_prompt
+        self.seq_data = seq_data
+        self.sampling_params = sampling_params
+        self.block_tables = block_tables
+
+
+class SequenceOutputs:
+
+    def __init__(
+        self,
+        seq_id: int,
+        parent_seq_id: int,
+        output_token: int,
+        logprobs: Dict[int, float],
+    ) -> None:
+        self.seq_id = seq_id
+        self.parent_seq_id = parent_seq_id
+        self.output_token = output_token
+        self.logprobs = logprobs
+
+    def __repr__(self) -> str:
+        return (f'SequenceOutputs(seq_id={self.seq_id}, '
+                f'parent_seq_id={self.parent_seq_id}, '
+                f'output_token={self.output_token}, '
+                f'logprobs={self.logprobs})')
+
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, SequenceOutputs):
+            return NotImplemented
+        return (self.seq_id == other.seq_id and
+                self.parent_seq_id == other.parent_seq_id and
+                self.output_token == other.output_token and
+                self.logprobs == other.logprobs)

+ 35 - 0
aphrodite/common/utils.py

@@ -0,0 +1,35 @@
+"""Utils."""
+import enum
+import uuid
+
+import psutil
+import torch
+
+class Device(enum.Enum):
+    GPU = enum.auto()
+    CPU = enum.auto()
+
+class Counter:
+    '''A basic counter.'''
+    def __init__(self, start: int = 0) -> None:
+        self.counter = start
+
+    def __next__(self) -> int:
+        id = self.counter
+        self.counter += 1
+        return id
+
+    def reset(self) -> None:
+        self.counter = 0
+
+
+def get_gpu_memory(gpu: int = 0) -> int:
+    """Returns the total memory of the GPU in bytes."""
+    return torch.cuda.get_device_properties(gpu).total_memory
+
+def get_cpu_memory() -> int:
+    """Returns the total CPU memory of the node in bytes."""
+    return psutil.virtual_memory().total
+
+def random_uuid() -> str:
+    return str(uuid.uuid4().hex)

+ 229 - 0
aphrodite/processing/block_manager.py

@@ -0,0 +1,229 @@
+"""A block manager that manages token blocks."""
+from typing import Dict, List, Optional, Set, Tuple
+from aphrodite.common.block import PhysicalTokenBlock
+from aphrodite.common.sequence import Sequence, SequenceGroup, SequenceStatus
+from aphrodite.common.utils import Device
+
+class BlockAllocator:
+    """Manages free physical token blocks for a device.
+
+    The allocator maintains a list of free blocks and allocates a block when
+    requested. When a block is freed, its reference count is decremented. If
+    the reference count becomes zero, the block is added back to the free list.
+    """
+
+    def __init__(
+        self,
+        device: Device,
+        block_size: int,
+        num_blocks: int,
+    ) -> None:
+        self.device = device
+        self.block_size = block_size
+        self.num_blocks = num_blocks
+
+        self.free_blocks: List[PhysicalTokenBlock] = []
+        for i in range(num_blocks):
+            block = PhysicalTokenBlock(
+                device=device, block_number=1, block_size=block_size)
+            self.free_blocks.append(block)
+
+    def allocate(self) -> PhysicalTokenBlock:
+        if not self.free_blocks:
+            raise ValueError("Out Of Memory! No free blocks are available.")
+        block = self.free_blocks.pop()
+        block.ref_count = 1
+        return block
+
+    def free(self, block: PhysicalTokenBlock) -> None:
+        if block.ref_count == 0:
+            raise ValueError(f"Double free! {block} is already freed.")
+        block.ref_count -= 1
+        if block.ref_count == 0:
+            self.free_blocks.append(block)
+
+    def get_num_free_blocks(self) -> int:
+        return len(self.free_blocks)
+
+BlockTable = List[PhysicalTokenBlock]
+
+class BlockSpaceManager:
+    """Manages the mapping between logical and physical blocks."""
+
+    def __init__(
+        self,
+        block_size: int,
+        num_gpu_blocks: int,
+        num_cpu_blocks: int,
+        watermark: float = 0.01,
+    ) -> None:
+        self.block_size = block_size
+        self.num_total_gpu_blocks = num_gpu_blocks
+        self.num_total_cpu_blocks = num_cpu_blocks
+        self.watermark = watermark
+        assert watermark >= 0.0
+
+        self.watermark_blocks = int(watermark * num_gpu_blocks)
+        self.gpu_allocator = BlockAllocator(DEVICE.GPU, block_size, num_gpu_blocks)
+        self.cpu_allocator = BlockAllocator(DEVICE.CPU, block_size, num_cpu_blocks)
+
+        self.block_tables: Dict[int, BlockTable] = {}
+
+    def can_allocate(self, seq_group: SequenceGroup) -> bool:
+        """
+        NOTE: we assume that all sequences in the group share the same prompt.
+        This might not be true for preempted sequences. Needs fixing.
+        """
+        seq = seq_group.get_seqs()[0]
+        num_required_blocks = len(seq.logical_token_blocks)
+        num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
+        return num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks
+
+    def allocate(self, seq_group: SequenceGroup) -> None:
+        seq = seq_group.get_seqs()[0]
+
+        block_table: BlockTable = []
+        for _ in range(len(seq.logical_token_blocks)):
+            block = self.gpu_allocator.allocate()
+            block.ref_count = seq_group.num_seqs()
+            block_table.append(block)
+
+        for seq in seq_group.get_seqs():
+            self.block_tables[seq.seq_id] = block_table.copy()
+
+    def can_append_slot(self, seq_group: SequenceGroup) -> bool:
+        """
+        Simple heuristic: If there's at least one free block
+        for each sequence, we can append.
+        """
+        num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
+        num_seq = seq_group.num_seqs(status=SequenceStatus.RUNNING)
+        return num_seqs <= num_free_gpu_blocks
+
+    def append_slot(self, req: Sequence) -> Optional[Tuple[int, int]]:
+        """Allocate a physical slot for a new token"""
+        logical_blocks = seq.logical_token_blocks
+        block_table = self.block_tables[seq.seq_id]
+
+        if len(block_table) < len(logical_blocks):
+            block = self.gpu_allocator.allocate()
+            block_table.append(block)
+            return None
+
+        last_block = block_table[-1]
+        assert last_block.device == Device.GPU
+        if last_block.ref_count == 1:
+            return None
+        else:
+            new_block = self.gpu_allocator.allocate()
+            block_table[-1] = new_block
+            self.gpu_allocator.free(last_block)
+            return last_block.block_number, new_block.block_number
+    
+    def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
+        src_block_table = self.block_size[parent_seq.seq_id]
+        self.block_tables[child_seq.seq_id] = src_block_table.copy()
+        for block in src_block_table:
+            block.ref_count += 1
+
+    def _get_physical_blocks(self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
+        blocks: Set[PhysicalTokenBlock] = set()
+        for seq in seq_group.get_seqs():
+            if seq.is_finished():
+                continue
+            block_table = self.block_size[seq.seq_id]
+            for block in block_table:
+                blocks.add(block)
+        return list(blocks)
+
+
+    def can_swap_in(self, seq_group: SequenceGroup) -> bool:
+        blocks = self._get_physical_blocks(seq_group)
+        num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
+        num_free_blocks = self.gpu_allocator.get_num_free_blocks()
+        num_required_blocks = len(blocks) + num_swapped_seqs
+        return num_free_blocks - num_free_blocks >= self.watermark_blocks
+
+    def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
+        mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
+        for seq in seq_group.get_seqs():
+            if seq.is_finished():
+                continue
+            new_block_table: BlockTable = []
+            block_table = self.block_tables[seq.seq_id]
+
+            for cpu_block in block_table:
+                if cpu_block in mapping:
+                    gpu_block = mapping[cpu_block]
+                    gpu_block.ref_count += 1
+                else:
+                    gpu_block = self.gpu_allocator.allocate()
+                    mapping[cpu_block] = gpu_block
+                new_block_table.append(gpu_block)
+                self.cpu_allocator.free(cpu_block)
+            self.block_tables[seq.seq_id] = new_block_table
+
+        block_number_mapping = {
+            cpu_block.block_number: gpu_block.block_number
+            for cpu_block, gpu_block in mapping.items()
+        }
+        return block_number_mapping
+    
+    def can_swap_out(self, seq_group: SequenceGroup) -> bool:
+        blocks = self._get_physical_blocks(seq_group)
+        return len(blocks) <= self.cpu_allocator.get_num_free_blocks()
+
+    def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
+        # GPU block -> CPU block
+        mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
+        for seq in seq_group.get_seqs():
+            if seq.is_finished():
+                continue
+            new_block_table: BlockTable = []
+            block_table = self.block_tables[seq.seq_id]
+
+            for gpu_block in block_table:
+                if gpu_block in mapping:
+                    cpu_block = mapping[gpu_block]
+                    cpu_block.ref_count += 1
+                else:
+                    cpu_block = self.cpu_allocator.allocate
+                    maping[gpu_block] = cpu_block
+                new_block_table.append(cpu_block)
+                self.gpu_allocator.free(gpu_block)
+            self.block_tables[seq.seq_id] = new_block_table
+
+        block_table_mapping = {
+            gpu_block.block_number: cpu_block.block_number
+            for gpu_block, cpu_block in mapping.items()
+        }
+        return block_number_mapping
+
+    def _free_block_table(self, block_table: BlockTable) -> None:
+        for block in block_table:
+            if block.device == DEVICE.GPU:
+                self.gpu_allocator.free(block)
+            else:
+                self.cpu_allocator.free(block)
+
+    def free(self, seq: Sequence) -> None:
+        if seq.seq_id not in self.block_tables:
+            return
+        block_table = self.block_tables[seq.seq_id]
+        self._free_block_table[block_table]
+        del self.block_tables[seq.seq_id]
+
+    def reset(self) -> None:
+        for block_table in self.block_tables.values():
+            self._free_block_table(block_table)
+        self.block_tables.clear()
+
+    def get_block_table(self, seq: Sequence) -> List[int]:
+        block_table = self.block_tables[seq.seq_id]
+        return [block.block_number for block in block_table]
+
+    def get_num_free_gpu_blocks(self) -> int:
+        return self.gpu_allocator.get_num_free_blocks()
+
+    def get_num_free_cpu_blocks(self) -> int:
+        return self.cpu_allocator.get_num_free_blocks()

+ 42 - 0
aphrodite/processing/policy.py

@@ -0,0 +1,42 @@
+from typing import List
+
+from aphrodite.common.sequence import SequenceGroup
+
+class Policy:
+
+    def get_priority(
+        self,
+        now: float,
+        seq_group: SequenceGroup,
+    ) -> float:
+        raise NotImplementedError
+
+    def sort_by_priority(
+        self,
+        now: float,
+        seq_groups: List[SequenceGroup],
+    ) -> List[SequenceGroup]:
+        return sorted(
+            seq_groups,
+            key=lambda seq_group: self.get_priority(now, seq_group),
+            reverse=True,
+        )
+
+class FCFS(Policy):
+
+    def get_priority(
+        self,
+        now: float,
+        seq_group: SequenceGroup,
+    ) -> float:
+        return now - seq_group.arrival_time
+
+class PolicyFactory:
+
+    _POLICY_REGISTRY = {
+        'fcfs': FCFS,
+    }
+
+    @classmethod
+    def get_policy(cls, policy_name: str, **kwargs) -> Policy:
+        return cls._POLICY_REGISTRY[policy_name](**kwargs)

+ 187 - 0
aphrodite/processing/scheduler.py

@@ -0,0 +1,187 @@
+import enum
+import time
+from typing import Dict, List, Optional, Tuple
+
+from aphrodite.common.config import CacheConfig, SchedulerConfig
+from aphrodite.processing.block_manager import BlockSpaceManager
+from aphrodite.processing.policy import PolicyFactory
+from aphrodite.common.logger import init_logger
+from aphrodite.common.sequence import (Sequence, SequenceData, SequenceGroup,
+                                SequenceGroupMetadata, SequenceOutputs, SequenceStatus)
+
+logger = init_logger(__name__)
+
+__LOGGING_INTERVAL_SEC = 5
+
+class PreemptionMode(enum.Enum):
+    """Preemtion modes.
+
+    1. Swapping: Swap out the blocks of the preempted sequences to CPU memory and
+    swap them back in when the sequences are resumed.
+    2. Recomputation: Discard the blocks of the preempted sequences and recompute
+    them when the sequences are resumed, treating the sequences as new prompts.
+    """
+    SWAP = enum.auto()
+    RECOMPUTE = enum.auto()
+
+class SchedulerOutputs:
+
+    def __init__(
+        self,
+        blocks_to_swap_in: Dict[int, int],
+        blocks_to_swap_out: Dict[int, int],
+        blocks_to_copy: Dict[int, List[int]],
+    ) -> None:
+        self.blocks_to_swap_in = blocks_to_swap_in
+        self.blocks_to_swap_out = blocks_to_swap_out
+        self.blocks_to_copy = blocks_to_copy
+        assert not (blocks_to_swap_in and blocks_to_swap_out)
+
+    def is_empty(self) -> bool:
+        return (not self.blocks_to_swap_in and not self.blocks_to_swap_out and not self.blocks_to_copy)
+    
+
+class Scheduler:
+
+    def __init__(
+        self,
+        scheduler_config: SchedulerConfig,
+        cache_config: CacheConfig,
+        log_stats: bool,
+    ) -> None:
+        self.scheduler_config = scheduler_config
+        self.cache_config - cache_config
+        self.log_stats = log_stats
+
+        self.policy = PolicyFactory.get_policy(policy_name='fcfs')
+        self.block_manager = BlockingSpaceManager(
+            block_size=self.cache_config.block_size,
+            num_gpu_blocks=self.cache_config.num_gpu_blocks,
+            num_cpu_blocks=self.cache_config.num_cpu_blocks,
+        )
+
+        self.waiting: List[SequenceGroup] = []
+        self.running: List[SequenceGroup] = []
+        self.swapped: List[SequenceGroup] = []
+
+        self.last_logging_time: float = 0.0
+        self.num_input_tokens: List[Tuple[float, int]] = []
+
+    def add_seq_group(self, seq_group: SequenceGroup) -> None:
+        self.waiting.append(seq_group)
+
+    def abort_seq_group(self, request_id: str) -> None:
+        for state_queue in [self.waiting, self.running, self.swapped]:
+            for seq_group in state_queue:
+                if seq_group in state_queue:
+                    state_queue.remove(seq_group)
+                    for seq in seq_group.seqs:
+                        if seq.is_finished():
+                            continue
+                        self.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
+                    return
+
+    def has_unfinished_seqs(self) -> bool:
+        return self.waiting or self.running or self.swapped
+
+    def get_num_unfinished_seq_groups(self) -> int:
+        return len(self.waiting) + len(self.running) + len(self.swapped)
+    
+    def _schedule(self) -> Tuple[SchedulerOutputs, List[str]]:
+        blocks_to_swap_in: Dict[int, int] = {}
+        blocks_to_swap_out: Dict[int, int] = {}
+        blocks_to_copy: Dict[int, List[int]] = {}
+
+        now time.time()
+
+        """
+        NOTE: We prioritize the sequence groups in the RUNNING state in order to
+        minimize the preemption overheads.
+        Preemption happens only when there's no available slot to keep all the
+        sequence groups in the RUNNING state.
+        In this case the policy is responsible for deciding which sequence groups to
+        preempt.
+        """
+        self.running = self.policy.sort_by_priority(now, self.running)
+
+        running: List[SequenceGroup] = []
+        preempted: List[SequenceGroup] = []
+        while self.running:
+            seq_group = self.running.pop(0)
+            while not self.block_manager.can_append_slot(seq_group):
+                if self.running:
+                    victim_seq_group = self.running.pop(-1)
+                    self._preempt(victim_seq_group, blocks_to_swap_out)
+                    preempted.append(victim_seq_group)
+                else:
+                    self._preempt(seq_group, blocks_to_swap_out)
+                    preempted.append(seq_group)
+                    break
+            else:
+                self.append_slot(seq_group, blocks_to_copy)
+                running.append(seq_group)
+        self.running = running
+        self.swapped = self.policy.sort_by_priority(now, self.swapped)
+        while self.swapped and not blocks_to_swap_out:
+            seq_group = self.swapped[0]
+            if seq_group in preempted:
+                break
+            if not self.block_manager.can_swap_in(seq_group):
+                break
+
+            num_new_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
+            num_curr_seqs = len(self.running)
+            if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs:
+                break
+
+            seq_group = self.swapped.pop(0)
+            self._swap_in(seq_group, blocks_to_swap_in)
+            self._append_slot(seq_group, blocks_to_copy)
+            self.running.append(seq_group)
+
+        num_batched_tokens = sum(
+            seq_group.num_seqs(status=SequenceStatus.RUNNING)
+            for seq_group in self.running
+        )
+        
+        prompt_group_ids: List[str] = []
+        """
+        NOTE: The sequence groups in the SWAPPED state are strictly prioritized
+        over the sequence groups in the WAITING state.
+        This is because we want to bound the amount of CPU memory taken by the
+        swapped sequence groups.
+        """
+        if not self.swapped:
+            """
+            NOTE(optimization): We don't sort the waiting queue since the preempted sequence
+            groups are added to the front and the new sequence groups are added to the back.
+            """
+            while self.waiting:
+                seq_group = self.waiting[0]
+                if seq_group in preempted:
+                    break
+                if not self.block_manager.can_allocate(seq_group):
+                    break
+                num_prompt_tokens = seq_group.get_seqs()[0].get_len()
+                if (num_batched_tokens + num_prompt_tokens > self.scheduler_config.max_num_batched_tokens):
+                    break
+
+                num_new_seqs = seq_group.num_seqs(status=SequenceStatus.WAITING)
+                num_curr_seqs = len(self.running)
+                if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs:
+                    break
+
+                seq_group = self.waiting.pop(0)
+                self._allocate(seq_group)
+                self.running.append(seq_group)
+                num_batched_tokens += num_prompt_tokens
+                prompt_group_ids.append(seq_group.request_id)
+        
+        scheduler_outputs = SchedulerOutputs(
+            blocks_to_swap_in=blocks_to_swap_in,
+            blocks_to_swap_out=blocks_to_swap_out,
+            blocks_to_copy=blocks_to_copy,
+        )
+        if not self.log_stats:
+            return scheduler_outputs, prompt_group_ids
+            """WORK IN PROGRESS"""

+ 0 - 0
aphrodite.png → assets/aphrodite.png


+ 6 - 0
requirements.txt

@@ -0,0 +1,6 @@
+psutil
+numpy
+torch >= 2.0.0
+transformers >= 4.28.0
+mypy
+pytest