Преглед на файлове

ignore infeasible swap requests

AlpinDale преди 8 месеца
родител
ревизия
25c2b6feca

+ 10 - 6
aphrodite/processing/block/block_table.py

@@ -41,7 +41,9 @@ class BlockTable:
     ):
     ):
         self._block_size = block_size
         self._block_size = block_size
         self._allocator = block_allocator
         self._allocator = block_allocator
-        self._blocks: Optional[List[Block]] = _blocks
+        if _blocks is None:
+            _blocks = []
+        self._blocks: List[Block] = _blocks
 
 
         # Use helper method instead of directly calculating, as blocks
         # Use helper method instead of directly calculating, as blocks
         # may not be allocated.
         # may not be allocated.
@@ -105,7 +107,7 @@ class BlockTable:
             token_ids (List[int]): The sequence of token IDs to be appended.
             token_ids (List[int]): The sequence of token IDs to be appended.
         """
         """
         assert self._is_allocated
         assert self._is_allocated
-        assert self._blocks is not None
+        assert len(self._blocks) > 0
 
 
         self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
         self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
                                     num_lookahead_slots)
                                     num_lookahead_slots)
@@ -142,6 +144,7 @@ class BlockTable:
         blocks_to_allocate = cdiv(slots_to_allocate, self._block_size)
         blocks_to_allocate = cdiv(slots_to_allocate, self._block_size)
 
 
         for _ in range(blocks_to_allocate):
         for _ in range(blocks_to_allocate):
+            assert len(self._blocks) > 0
             self._blocks.append(
             self._blocks.append(
                 self._allocator.allocate_mutable(prev_block=self._blocks[-1],
                 self._allocator.allocate_mutable(prev_block=self._blocks[-1],
                                                  device=device))
                                                  device=device))
@@ -160,6 +163,7 @@ class BlockTable:
                 the current instance.
                 the current instance.
         """
         """
         assert self._is_allocated
         assert self._is_allocated
+        assert len(self._blocks) > 0
         forked_blocks = self._allocator.fork(self._blocks[-1])
         forked_blocks = self._allocator.fork(self._blocks[-1])
         return BlockTable(
         return BlockTable(
             block_size=self._block_size,
             block_size=self._block_size,
@@ -178,10 +182,10 @@ class BlockTable:
         assert self._is_allocated
         assert self._is_allocated
         for block in self._blocks:
         for block in self._blocks:
             self._allocator.free(block)
             self._allocator.free(block)
-        self._blocks = None
+        self._blocks = []
 
 
     @property
     @property
-    def physical_block_ids(self) -> List[int]:
+    def physical_block_ids(self) -> List[Optional[int]]:
         """Returns a list of physical block indices for the blocks in the
         """Returns a list of physical block indices for the blocks in the
         BlockTable.
         BlockTable.
 
 
@@ -236,7 +240,7 @@ class BlockTable:
 
 
     def _get_all_token_ids(self) -> List[int]:
     def _get_all_token_ids(self) -> List[int]:
         # NOTE: This function is O(seq_len); use sparingly.
         # NOTE: This function is O(seq_len); use sparingly.
-        token_ids = []
+        token_ids: List[int] = []
 
 
         if not self._is_allocated:
         if not self._is_allocated:
             return token_ids
             return token_ids
@@ -248,7 +252,7 @@ class BlockTable:
 
 
     @property
     @property
     def _is_allocated(self) -> bool:
     def _is_allocated(self) -> bool:
-        return self._blocks is not None
+        return len(self._blocks) > 0
 
 
     @property
     @property
     def _num_empty_slots(self) -> int:
     def _num_empty_slots(self) -> int:

+ 16 - 4
aphrodite/processing/block/common.py

@@ -1,5 +1,5 @@
 from collections import defaultdict
 from collections import defaultdict
-from typing import Dict, Iterable, List, Optional
+from typing import Dict, Iterable, List, Optional, Protocol
 
 
 from aphrodite.processing.block.interfaces import Block, BlockAllocator
 from aphrodite.processing.block.interfaces import Block, BlockAllocator
 
 
@@ -7,7 +7,19 @@ BlockId = int
 RefCount = int
 RefCount = int
 
 
 
 
-class RefCounter:
+class RefCounterProtocol(Protocol):
+
+    def incr(self, block_id: BlockId) -> RefCount:
+        raise NotImplementedError
+
+    def decr(self, block_id: BlockId) -> RefCount:
+        raise NotImplementedError
+
+    def get(self, block_id: BlockId) -> RefCount:
+        raise NotImplementedError
+
+
+class RefCounter(RefCounterProtocol):
     """A class for managing reference counts for a set of block indices.
     """A class for managing reference counts for a set of block indices.
 
 
     The RefCounter class maintains a dictionary that maps block indices to their
     The RefCounter class maintains a dictionary that maps block indices to their
@@ -54,7 +66,7 @@ class RefCounter:
         return ReadOnlyRefCounter(self)
         return ReadOnlyRefCounter(self)
 
 
 
 
-class ReadOnlyRefCounter:
+class ReadOnlyRefCounter(RefCounterProtocol):
     """A read-only view of the RefCounter class.
     """A read-only view of the RefCounter class.
 
 
     The ReadOnlyRefCounter class provides a read-only interface to access the
     The ReadOnlyRefCounter class provides a read-only interface to access the
@@ -96,7 +108,7 @@ class CopyOnWriteTracker:
 
 
     def __init__(
     def __init__(
         self,
         self,
-        refcounter: RefCounter,
+        refcounter: RefCounterProtocol,
         allocator: BlockAllocator,
         allocator: BlockAllocator,
     ):
     ):
         self._copy_on_writes: Dict[BlockId, List[BlockId]] = defaultdict(list)
         self._copy_on_writes: Dict[BlockId, List[BlockId]] = defaultdict(list)

+ 25 - 10
aphrodite/processing/block/cpu_gpu_block_allocator.py

@@ -1,7 +1,8 @@
-from typing import Dict, List, Optional
+from typing import Dict, FrozenSet, List, Optional
 
 
 from aphrodite.common.utils import Device
 from aphrodite.common.utils import Device
 from aphrodite.processing.block.interfaces import (Block, BlockAllocator,
 from aphrodite.processing.block.interfaces import (Block, BlockAllocator,
+                                                   BlockId,
                                                    DeviceAwareBlockAllocator)
                                                    DeviceAwareBlockAllocator)
 from aphrodite.processing.block.naive_block import (NaiveBlock,
 from aphrodite.processing.block.naive_block import (NaiveBlock,
                                                     NaiveBlockAllocator)
                                                     NaiveBlockAllocator)
@@ -59,15 +60,15 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
         cpu_block_ids = block_ids[num_gpu_blocks:]
         cpu_block_ids = block_ids[num_gpu_blocks:]
 
 
         if allocator_type == "naive":
         if allocator_type == "naive":
-            gpu_allocator = NaiveBlockAllocator(
-                create_block=NaiveBlock,
+            gpu_allocator: BlockAllocator = NaiveBlockAllocator(
+                create_block=NaiveBlock,  # type: ignore
                 num_blocks=num_gpu_blocks,
                 num_blocks=num_gpu_blocks,
                 block_size=block_size,
                 block_size=block_size,
                 block_ids=gpu_block_ids,
                 block_ids=gpu_block_ids,
             )
             )
 
 
-            cpu_allocator = NaiveBlockAllocator(
-                create_block=NaiveBlock,
+            cpu_allocator: BlockAllocator = NaiveBlockAllocator(
+                create_block=NaiveBlock,  # type: ignore
                 num_blocks=num_cpu_blocks,
                 num_blocks=num_cpu_blocks,
                 block_size=block_size,
                 block_size=block_size,
                 block_ids=cpu_block_ids,
                 block_ids=cpu_block_ids,
@@ -107,7 +108,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
             Device.GPU: gpu_block_allocator,
             Device.GPU: gpu_block_allocator,
         }
         }
 
 
-        self._block_ids_to_allocator = {}
+        self._block_ids_to_allocator: Dict[int, BlockAllocator] = {}
         for _, allocator in self._allocators.items():
         for _, allocator in self._allocators.items():
             for block_id in allocator.all_block_ids:
             for block_id in allocator.all_block_ids:
                 self._block_ids_to_allocator[block_id] = allocator
                 self._block_ids_to_allocator[block_id] = allocator
@@ -151,7 +152,9 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
         Args:
         Args:
             block (Block): The block to be freed.
             block (Block): The block to be freed.
         """
         """
-        allocator = self._block_ids_to_allocator[block.block_id]
+        block_id = block.block_id
+        assert block_id is not None
+        allocator = self._block_ids_to_allocator[block_id]
         return allocator.free(block)
         return allocator.free(block)
 
 
     def fork(self, last_block: Block) -> List[Block]:
     def fork(self, last_block: Block) -> List[Block]:
@@ -165,7 +168,9 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
             List[Block]: A new list of blocks that shares the same memory as the
             List[Block]: A new list of blocks that shares the same memory as the
                 original sequence.
                 original sequence.
         """
         """
-        allocator = self._block_ids_to_allocator[last_block.block_id]
+        block_id = last_block.block_id
+        assert block_id is not None
+        allocator = self._block_ids_to_allocator[block_id]
         return allocator.fork(last_block)
         return allocator.fork(last_block)
 
 
     def get_num_free_blocks(self, device: Device) -> int:
     def get_num_free_blocks(self, device: Device) -> int:
@@ -173,13 +178,16 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
 
 
         Args:
         Args:
             device (Device): The device for which to query the number of free
             device (Device): The device for which to query the number of free
-                blocks.
+                blocks. AssertionError is raised if None is passed.
 
 
         Returns:
         Returns:
             int: The number of free blocks available on the specified device.
             int: The number of free blocks available on the specified device.
         """
         """
         return self._allocators[device].get_num_free_blocks()
         return self._allocators[device].get_num_free_blocks()
 
 
+    def get_num_total_blocks(self, device: Device) -> int:
+        return self._allocators[device].get_num_total_blocks()
+
     def clear_copy_on_writes(self) -> Dict[int, List[int]]:
     def clear_copy_on_writes(self) -> Dict[int, List[int]]:
         """Clears the copy-on-write (CoW) state and returns the mapping of
         """Clears the copy-on-write (CoW) state and returns the mapping of
             source to destination block IDs.
             source to destination block IDs.
@@ -212,5 +220,12 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
         return self._allocators[device].get_common_computed_block_ids(
         return self._allocators[device].get_common_computed_block_ids(
             seq_block_ids)
             seq_block_ids)
 
 
-    def all_block_ids(self) -> frozenset[int]:
+    @property
+    def all_block_ids(self) -> FrozenSet[int]:
         return frozenset(self._block_ids_to_allocator.keys())
         return frozenset(self._block_ids_to_allocator.keys())
+
+    def promote_to_immutable_block(self, block: Block) -> BlockId:
+        raise NotImplementedError
+
+    def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
+        raise NotImplementedError

+ 108 - 12
aphrodite/processing/block/interfaces.py

@@ -1,8 +1,10 @@
-from abc import ABC, abstractmethod, abstractproperty
-from typing import Dict, List, Optional, Protocol
+from abc import ABC, abstractmethod
+from typing import Dict, FrozenSet, List, Optional, Protocol
 
 
 from aphrodite.common.utils import Device
 from aphrodite.common.utils import Device
 
 
+BlockId = int
+
 
 
 class Block(ABC):
 class Block(ABC):
 
 
@@ -10,26 +12,58 @@ class Block(ABC):
     def append_token_ids(self, token_ids: List[int]) -> None:
     def append_token_ids(self, token_ids: List[int]) -> None:
         pass
         pass
 
 
-    @abstractproperty
+    @property
+    @abstractmethod
     def block_id(self) -> Optional[int]:
     def block_id(self) -> Optional[int]:
         pass
         pass
 
 
-    @abstractproperty
+    @block_id.setter
+    @abstractmethod
+    def block_id(self, value: Optional[int]) -> None:
+        """NOTE: Do not use this API outside Block."""
+        self._block_id = value
+
+    @property
+    @abstractmethod
     def token_ids(self) -> List[int]:
     def token_ids(self) -> List[int]:
         pass
         pass
 
 
-    @abstractproperty
+    @property
+    @abstractmethod
     def num_empty_slots(self) -> int:
     def num_empty_slots(self) -> int:
         pass
         pass
 
 
-    @abstractproperty
+    @property
+    @abstractmethod
     def is_full(self) -> bool:
     def is_full(self) -> bool:
         pass
         pass
 
 
-    @abstractproperty
+    @property
+    @abstractmethod
     def prev_block(self) -> Optional["Block"]:
     def prev_block(self) -> Optional["Block"]:
         pass
         pass
 
 
+    @property
+    @abstractmethod
+    def computed(self) -> bool:
+        raise NotImplementedError
+
+    @computed.setter
+    @abstractmethod
+    def computed(self, value) -> bool:
+        """Should be only used by PrefixCacingAllocator"""
+        raise NotImplementedError
+
+    @property
+    @abstractmethod
+    def last_accessed(self) -> float:
+        raise NotImplementedError
+
+    @last_accessed.setter
+    @abstractmethod
+    def last_accessed(self, last_accessed_ts: float):
+        raise NotImplementedError
+
     class Factory(Protocol):
     class Factory(Protocol):
 
 
         @abstractmethod
         @abstractmethod
@@ -43,6 +77,17 @@ class Block(ABC):
         ) -> "Block":
         ) -> "Block":
             pass
             pass
 
 
+    @property
+    @abstractmethod
+    def content_hash(self) -> Optional[int]:
+        """Return the content-based hash of the current block, or None if it is
+        not yet defined or not supported.
+
+        For the content-based hash to be defined, the current block must be
+        full.
+        """
+        return None
+
 
 
 class BlockAllocator(ABC):
 class BlockAllocator(ABC):
 
 
@@ -63,12 +108,17 @@ class BlockAllocator(ABC):
     def fork(self, last_block: Block) -> List[Block]:
     def fork(self, last_block: Block) -> List[Block]:
         pass
         pass
 
 
+    @abstractmethod
+    def get_num_total_blocks(self) -> int:
+        pass
+
     @abstractmethod
     @abstractmethod
     def get_num_free_blocks(self) -> int:
     def get_num_free_blocks(self) -> int:
         pass
         pass
 
 
-    @abstractproperty
-    def all_block_ids(self) -> frozenset[int]:
+    @property
+    @abstractmethod
+    def all_block_ids(self) -> FrozenSet[int]:
         pass
         pass
 
 
     @abstractmethod
     @abstractmethod
@@ -76,11 +126,12 @@ class BlockAllocator(ABC):
         pass
         pass
 
 
     @abstractmethod
     @abstractmethod
-    def mark_blocks_as_accessed(self) -> None:
+    def mark_blocks_as_accessed(self, block_ids: List[int],
+                                now: float) -> None:
         pass
         pass
 
 
     @abstractmethod
     @abstractmethod
-    def mark_blocks_as_computed(self) -> None:
+    def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
         pass
         pass
 
 
     @abstractmethod
     @abstractmethod
@@ -88,11 +139,21 @@ class BlockAllocator(ABC):
             self, seq_block_ids: List[List[int]]) -> List[int]:
             self, seq_block_ids: List[List[int]]) -> List[int]:
         pass
         pass
 
 
+    @abstractmethod
+    def cow_block_if_not_appendable(self, block: Block) -> Optional["BlockId"]:
+        """NOTE: This should not be used besides Block"""
+        pass
+
+    @abstractmethod
+    def promote_to_immutable_block(self, block: Block) -> BlockId:
+        """NOTE: This should not be used besides Block"""
+        pass
+
     class NoFreeBlocksError(ValueError):
     class NoFreeBlocksError(ValueError):
         pass
         pass
 
 
 
 
-class DeviceAwareBlockAllocator(BlockAllocator):
+class DeviceAwareBlockAllocator(ABC):
 
 
     @abstractmethod
     @abstractmethod
     def allocate_mutable(self, prev_block: Optional[Block],
     def allocate_mutable(self, prev_block: Optional[Block],
@@ -107,3 +168,38 @@ class DeviceAwareBlockAllocator(BlockAllocator):
     @abstractmethod
     @abstractmethod
     def get_num_free_blocks(self, device: Device) -> int:
     def get_num_free_blocks(self, device: Device) -> int:
         pass
         pass
+
+    @abstractmethod
+    def get_num_total_blocks(self, device: Device) -> int:
+        pass
+
+    @abstractmethod
+    def free(self, block: Block) -> None:
+        pass
+
+    @abstractmethod
+    def fork(self, last_block: Block) -> List[Block]:
+        pass
+
+    @property
+    @abstractmethod
+    def all_block_ids(self) -> FrozenSet[int]:
+        pass
+
+    @abstractmethod
+    def clear_copy_on_writes(self) -> Dict[int, List[int]]:
+        pass
+
+    @abstractmethod
+    def mark_blocks_as_accessed(self, block_ids: List[int],
+                                now: float) -> None:
+        pass
+
+    @abstractmethod
+    def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
+        pass
+
+    @abstractmethod
+    def get_common_computed_block_ids(
+            self, seq_block_ids: List[List[int]]) -> List[int]:
+        pass

+ 45 - 9
aphrodite/processing/block/naive_block.py

@@ -1,10 +1,10 @@
-from typing import Dict, Iterable, List, Optional, Set
+from typing import Dict, FrozenSet, Iterable, List, Optional, Set
 
 
 from aphrodite.processing.block.common import (CopyOnWriteTracker, RefCounter,
 from aphrodite.processing.block.common import (CopyOnWriteTracker, RefCounter,
                                                get_all_blocks_recursively)
                                                get_all_blocks_recursively)
-from aphrodite.processing.block.interfaces import Block, BlockAllocator
+from aphrodite.processing.block.interfaces import (Block, BlockAllocator,
+                                                   BlockId, Device)
 
 
-BlockId = int
 Refcount = int
 Refcount = int
 
 
 
 
@@ -49,8 +49,10 @@ class NaiveBlockAllocator(BlockAllocator):
             allocator=self,
             allocator=self,
         )
         )
 
 
-    def allocate_immutable(self, prev_block: Optional[Block],
-                           token_ids: List[int]) -> Block:
+    def allocate_immutable(self,
+                           prev_block: Optional[Block],
+                           token_ids: List[int],
+                           device: Optional[Device] = None) -> Block:
         """Allocates a new immutable block with the given token IDs, linked to
         """Allocates a new immutable block with the given token IDs, linked to
         the previous block.
         the previous block.
 
 
@@ -63,11 +65,14 @@ class NaiveBlockAllocator(BlockAllocator):
         Returns:
         Returns:
             Block: The newly allocated immutable block.
             Block: The newly allocated immutable block.
         """
         """
+        assert device is None
         block = self.allocate_mutable(prev_block=prev_block)
         block = self.allocate_mutable(prev_block=prev_block)
         block.append_token_ids(token_ids)
         block.append_token_ids(token_ids)
         return block
         return block
 
 
-    def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
+    def allocate_mutable(self,
+                         prev_block: Optional[Block],
+                         device: Optional[Device] = None) -> Block:
         """Allocates a new mutable block, linked to the previous block.
         """Allocates a new mutable block, linked to the previous block.
 
 
         Args:
         Args:
@@ -78,6 +83,7 @@ class NaiveBlockAllocator(BlockAllocator):
         Returns:
         Returns:
             Block: The newly allocated mutable block.
             Block: The newly allocated mutable block.
         """
         """
+        assert device is None
         block_id = self._allocate_new_block_id()
         block_id = self._allocate_new_block_id()
         return self._create_block(
         return self._create_block(
             prev_block=prev_block,
             prev_block=prev_block,
@@ -88,6 +94,7 @@ class NaiveBlockAllocator(BlockAllocator):
         )
         )
 
 
     def free(self, block: Block) -> None:
     def free(self, block: Block) -> None:
+        assert block.block_id is not None
         self._free_block_id(block.block_id)
         self._free_block_id(block.block_id)
 
 
         # Mark the block as having no allocation.
         # Mark the block as having no allocation.
@@ -111,6 +118,7 @@ class NaiveBlockAllocator(BlockAllocator):
         for block in source_blocks:
         for block in source_blocks:
 
 
             # Increment refcount for each block.
             # Increment refcount for each block.
+            assert block.block_id is not None
             refcount = self._refcounter.incr(block.block_id)
             refcount = self._refcounter.incr(block.block_id)
             assert refcount != 1, "can't fork free'd block"
             assert refcount != 1, "can't fork free'd block"
 
 
@@ -129,6 +137,9 @@ class NaiveBlockAllocator(BlockAllocator):
     def get_num_free_blocks(self) -> int:
     def get_num_free_blocks(self) -> int:
         return len(self._free_block_indices)
         return len(self._free_block_indices)
 
 
+    def get_num_total_blocks(self) -> int:
+        return len(self._all_block_indices)
+
     def _allocate_new_block_id(self) -> BlockId:
     def _allocate_new_block_id(self) -> BlockId:
         if not self._free_block_indices:
         if not self._free_block_indices:
             raise BlockAllocator.NoFreeBlocksError()
             raise BlockAllocator.NoFreeBlocksError()
@@ -148,7 +159,7 @@ class NaiveBlockAllocator(BlockAllocator):
         return self._refcounter
         return self._refcounter
 
 
     @property
     @property
-    def all_block_ids(self):
+    def all_block_ids(self) -> FrozenSet[int]:
         return self._all_block_indices
         return self._all_block_indices
 
 
     def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
     def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
@@ -177,6 +188,7 @@ class NaiveBlockAllocator(BlockAllocator):
     def mark_blocks_as_accessed(self, block_ids: List[int],
     def mark_blocks_as_accessed(self, block_ids: List[int],
                                 now: float) -> None:
                                 now: float) -> None:
         """Mark blocks as accessed, used in prefix caching.
         """Mark blocks as accessed, used in prefix caching.
+
         Since the naive allocator does not implement prefix caching, we do
         Since the naive allocator does not implement prefix caching, we do
         nothing.
         nothing.
         """
         """
@@ -199,6 +211,9 @@ class NaiveBlockAllocator(BlockAllocator):
         """
         """
         return []
         return []
 
 
+    def promote_to_immutable_block(self, block: Block) -> BlockId:
+        raise NotImplementedError
+
 
 
 class NaiveBlock(Block):
 class NaiveBlock(Block):
     """An implementation of the Block class that does not support prefix
     """An implementation of the Block class that does not support prefix
@@ -223,13 +238,13 @@ class NaiveBlock(Block):
     """
     """
 
 
     def __init__(self,
     def __init__(self,
-                 prev_block: Block,
+                 prev_block: Optional[Block],
                  token_ids: List[int],
                  token_ids: List[int],
                  block_size: int,
                  block_size: int,
                  allocator: BlockAllocator,
                  allocator: BlockAllocator,
                  block_id: Optional[int] = None,
                  block_id: Optional[int] = None,
                  _cow_target: Optional[Block] = None):
                  _cow_target: Optional[Block] = None):
-        self._token_ids = []
+        self._token_ids: List[int] = []
         self._block_size = block_size
         self._block_size = block_size
         self._prev_block = prev_block
         self._prev_block = prev_block
         self._block_id = block_id
         self._block_id = block_id
@@ -255,6 +270,22 @@ class NaiveBlock(Block):
         assert self.num_empty_slots >= len(token_ids)
         assert self.num_empty_slots >= len(token_ids)
         self._token_ids.extend(token_ids)
         self._token_ids.extend(token_ids)
 
 
+    @property
+    def computed(self) -> bool:
+        raise NotImplementedError
+
+    @computed.setter
+    def computed(self, value) -> None:
+        raise NotImplementedError
+
+    @property
+    def last_accessed(self) -> float:
+        raise NotImplementedError
+
+    @last_accessed.setter
+    def last_accessed(self, last_accessed_ts: float):
+        raise NotImplementedError
+
     @property
     @property
     def block_id(self) -> Optional[int]:
     def block_id(self) -> Optional[int]:
         return self._block_id
         return self._block_id
@@ -275,9 +306,14 @@ class NaiveBlock(Block):
     def token_ids(self) -> List[int]:
     def token_ids(self) -> List[int]:
         return self._token_ids
         return self._token_ids
 
 
+    @property
     def block_size(self) -> int:
     def block_size(self) -> int:
         return self._block_size
         return self._block_size
 
 
     @property
     @property
     def prev_block(self) -> Optional["Block"]:
     def prev_block(self) -> Optional["Block"]:
         return self._prev_block
         return self._prev_block
+
+    @property
+    def content_hash(self) -> Optional[int]:
+        return None

+ 76 - 31
aphrodite/processing/block/prefix_caching_block.py

@@ -1,18 +1,18 @@
 """Token blocks."""
 """Token blocks."""
 from itertools import takewhile
 from itertools import takewhile
 from os.path import commonprefix
 from os.path import commonprefix
-from typing import Dict, Iterable, List, Optional
+from typing import Dict, FrozenSet, Iterable, List, Optional
 
 
 from aphrodite.processing.block.common import (CopyOnWriteTracker,
 from aphrodite.processing.block.common import (CopyOnWriteTracker,
                                                get_all_blocks_recursively)
                                                get_all_blocks_recursively)
-from aphrodite.processing.block.interfaces import Block, BlockAllocator
+from aphrodite.processing.block.interfaces import (Block, BlockAllocator,
+                                                   BlockId, Device)
 from aphrodite.processing.block.naive_block import (NaiveBlock,
 from aphrodite.processing.block.naive_block import (NaiveBlock,
                                                     NaiveBlockAllocator)
                                                     NaiveBlockAllocator)
 from aphrodite.processing.evictor_v2 import (EvictionPolicy, Evictor,
 from aphrodite.processing.evictor_v2 import (EvictionPolicy, Evictor,
                                              make_evictor)
                                              make_evictor)
 
 
 PrefixHash = int
 PrefixHash = int
-BlockId = int
 
 
 # By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME
 # By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME
 # so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME,
 # so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME,
@@ -40,7 +40,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
         num_blocks: int,
         num_blocks: int,
         block_size: int,
         block_size: int,
         block_ids: Optional[Iterable[int]] = None,
         block_ids: Optional[Iterable[int]] = None,
-        eviction_policy: Optional[EvictionPolicy] = EvictionPolicy.LRU,
+        eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
     ):
     ):
         # A mapping of prefix hash to block index. All blocks which have a
         # A mapping of prefix hash to block index. All blocks which have a
         # prefix hash will be in this dict, even if they have refcount 0.
         # prefix hash will be in this dict, even if they have refcount 0.
@@ -51,7 +51,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
 
 
         # An allocator for blocks that do not have prefix hashes.
         # An allocator for blocks that do not have prefix hashes.
         self._hashless_allocator = NaiveBlockAllocator(
         self._hashless_allocator = NaiveBlockAllocator(
-            create_block=self._create_block,
+            create_block=self._create_block,  # type: ignore
             num_blocks=num_blocks,
             num_blocks=num_blocks,
             block_size=block_size,
             block_size=block_size,
             block_ids=block_ids,
             block_ids=block_ids,
@@ -81,7 +81,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
         block_size: int,
         block_size: int,
         allocator: BlockAllocator,
         allocator: BlockAllocator,
         block_id: Optional[int] = None,
         block_id: Optional[int] = None,
-        computed: Optional[bool] = False,
+        computed: bool = False,
     ) -> Block:
     ) -> Block:
         # Bind block to self.
         # Bind block to self.
         allocator = self
         allocator = self
@@ -95,8 +95,10 @@ class PrefixCachingBlockAllocator(BlockAllocator):
             computed=computed,
             computed=computed,
         )
         )
 
 
-    def allocate_immutable(self, prev_block: Optional[Block],
-                           token_ids: List[int]) -> Block:
+    def allocate_immutable(self,
+                           prev_block: Optional[Block],
+                           token_ids: List[int],
+                           device: Optional[Device] = None) -> Block:
         """Allocates an immutable block with the given token IDs, reusing cached
         """Allocates an immutable block with the given token IDs, reusing cached
         blocks if possible.
         blocks if possible.
 
 
@@ -107,6 +109,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
         Returns:
         Returns:
             Block: The allocated immutable block.
             Block: The allocated immutable block.
         """
         """
+        assert device is None
         assert_prefix_caching_block_or_none(prev_block)
         assert_prefix_caching_block_or_none(prev_block)
 
 
         block = self._create_block(
         block = self._create_block(
@@ -129,16 +132,20 @@ class PrefixCachingBlockAllocator(BlockAllocator):
 
 
         return block
         return block
 
 
-    def allocate_mutable(self, prev_block: Block) -> Block:
+    def allocate_mutable(self,
+                         prev_block: Optional[Block],
+                         device: Optional[Device] = None) -> Block:
         """Allocates a mutable block. If there are no free blocks, this will
         """Allocates a mutable block. If there are no free blocks, this will
         evict unused cached blocks.
         evict unused cached blocks.
 
 
         Args:
         Args:
             prev_block (Block): The previous block in the sequence.
             prev_block (Block): The previous block in the sequence.
+                None is not allowed unlike it is super class.
 
 
         Returns:
         Returns:
             Block: The allocated mutable block.
             Block: The allocated mutable block.
         """
         """
+        assert device is None
         assert_prefix_caching_block_or_none(prev_block)
         assert_prefix_caching_block_or_none(prev_block)
 
 
         try:
         try:
@@ -146,6 +153,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
                 prev_block=prev_block)
                 prev_block=prev_block)
 
 
             assert block.block_id not in self._blocks
             assert block.block_id not in self._blocks
+            assert block.block_id is not None
             self._blocks[block.block_id] = block
             self._blocks[block.block_id] = block
             return block
             return block
         except BlockAllocator.NoFreeBlocksError:
         except BlockAllocator.NoFreeBlocksError:
@@ -185,6 +193,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
             assert block.content_hash is None
             assert block.content_hash is None
 
 
             assert block.block_id not in self._blocks
             assert block.block_id not in self._blocks
+            assert block.block_id is not None
             self._blocks[block.block_id] = block
             self._blocks[block.block_id] = block
             return block
             return block
 
 
@@ -215,6 +224,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
                 is not None), "freeing unallocated block is undefined"
                 is not None), "freeing unallocated block is undefined"
 
 
         self._free_block_id_for_block(block.block_id, block)
         self._free_block_id_for_block(block.block_id, block)
+
         block.block_id = None
         block.block_id = None
 
 
     def _free_block_id_for_block(self, block_id: BlockId,
     def _free_block_id_for_block(self, block_id: BlockId,
@@ -226,6 +236,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
             # We have fork case where block would get more than one ref,
             # We have fork case where block would get more than one ref,
             # so we cannot free it from tracking if ref cnt large than 1
             # so we cannot free it from tracking if ref cnt large than 1
             if refcount <= 1:
             if refcount <= 1:
+                assert block.block_id is not None
                 del self._blocks[block.block_id]
                 del self._blocks[block.block_id]
             return self._hashless_allocator.free(block)
             return self._hashless_allocator.free(block)
 
 
@@ -234,6 +245,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
         # If no longer used, add the block to the evictor.
         # If no longer used, add the block to the evictor.
         if refcount == 0:
         if refcount == 0:
             assert block.content_hash in self._cached_blocks
             assert block.content_hash in self._cached_blocks
+            assert block.block_id is not None
             del self._blocks[block.block_id]
             del self._blocks[block.block_id]
             self.evictor.add(block.block_id, block.content_hash,
             self.evictor.add(block.block_id, block.content_hash,
                              block.num_tokens_total, block.last_accessed)
                              block.num_tokens_total, block.last_accessed)
@@ -269,18 +281,21 @@ class PrefixCachingBlockAllocator(BlockAllocator):
 
 
         return forked_blocks
         return forked_blocks
 
 
-    def get_num_free_blocks(self) -> int:
+    def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
+        assert device is None
         # The number of free blocks is the number of hashless free blocks
         # The number of free blocks is the number of hashless free blocks
         # plus the number of blocks evictor could free from its list.
         # plus the number of blocks evictor could free from its list.
         return self._hashless_allocator.get_num_free_blocks(
         return self._hashless_allocator.get_num_free_blocks(
         ) + self.evictor.num_blocks
         ) + self.evictor.num_blocks
 
 
+    def get_num_total_blocks(self) -> int:
+        return self._hashless_allocator.get_num_total_blocks()
+
     @property
     @property
-    def all_block_ids(self) -> frozenset[int]:
+    def all_block_ids(self) -> FrozenSet[int]:
         return self._hashless_allocator.all_block_ids
         return self._hashless_allocator.all_block_ids
 
 
-    def promote_to_immutable_block(self,
-                                   block: "PrefixCachingBlock") -> BlockId:
+    def promote_to_immutable_block(self, block: Block) -> BlockId:
         """Once a mutable block is full, it can be promoted to an immutable
         """Once a mutable block is full, it can be promoted to an immutable
         block. This means that its content can be referenced by future blocks
         block. This means that its content can be referenced by future blocks
         having the same prefix.
         having the same prefix.
@@ -290,7 +305,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
         block.
         block.
 
 
         Args:
         Args:
-            block (PrefixCachingBlock): The mutable block to be promoted.
+            block: The mutable block to be promoted.
 
 
         Returns:
         Returns:
             BlockId: Either the original block index, or the block index of
             BlockId: Either the original block index, or the block index of
@@ -337,6 +352,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
     def mark_blocks_as_accessed(self, block_ids: List[int],
     def mark_blocks_as_accessed(self, block_ids: List[int],
                                 now: float) -> None:
                                 now: float) -> None:
         """Mark blocks as accessed, used in prefix caching.
         """Mark blocks as accessed, used in prefix caching.
+
         If the block is added into evictor, we need to update corresponding
         If the block is added into evictor, we need to update corresponding
         info in evictor's metadata.
         info in evictor's metadata.
         """
         """
@@ -371,6 +387,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
     def get_common_computed_block_ids(
     def get_common_computed_block_ids(
             self, seq_block_ids: List[List[int]]) -> List[int]:
             self, seq_block_ids: List[List[int]]) -> List[int]:
         """Return the block ids that are common for a given sequence group.
         """Return the block ids that are common for a given sequence group.
+
         Only those blocks that are immutable and already be marked
         Only those blocks that are immutable and already be marked
         compyted would be taken consideration.
         compyted would be taken consideration.
         """
         """
@@ -384,8 +401,11 @@ class PrefixCachingBlockAllocator(BlockAllocator):
                 takewhile(lambda block_id: self.block_is_computed(block_id),
                 takewhile(lambda block_id: self.block_is_computed(block_id),
                           seq[:-1])) for seq in seq_block_ids
                           seq[:-1])) for seq in seq_block_ids
         ]
         ]
-        res = commonprefix([ids for ids in ids_list if ids != []])
-        return res
+        # It returns a list of int although type annotation says list of string.
+        return commonprefix([
+            ids for ids in ids_list  # type: ignore
+            if ids != []
+        ])
 
 
 
 
 class PrefixCachingBlock(Block):
 class PrefixCachingBlock(Block):
@@ -402,27 +422,33 @@ class PrefixCachingBlock(Block):
         token_ids (List[int]): The initial token IDs to be stored in the block.
         token_ids (List[int]): The initial token IDs to be stored in the block.
         block_size (int): The maximum number of token IDs that can be stored in
         block_size (int): The maximum number of token IDs that can be stored in
             the block.
             the block.
-        prefix_caching_allocator (PrefixCachingBlockAllocator): The prefix
+        prefix_caching_allocator (BlockAllocator): The prefix
             caching block allocator associated with this block.
             caching block allocator associated with this block.
         block_id (Optional[int], optional): The physical block index
         block_id (Optional[int], optional): The physical block index
             of this block. Defaults to None.
             of this block. Defaults to None.
     """
     """
 
 
-    def __init__(self,
-                 prev_block: Optional["PrefixCachingBlock"],
-                 token_ids: List[int],
-                 block_size: int,
-                 prefix_caching_allocator: PrefixCachingBlockAllocator,
-                 block_id: Optional[int] = None,
-                 computed: Optional[bool] = False):
+    def __init__(
+        self,
+        prev_block: Optional[Block],
+        token_ids: List[int],
+        block_size: int,
+        prefix_caching_allocator: BlockAllocator,
+        block_id: Optional[int] = None,
+        computed: bool = False,
+    ):
+        assert isinstance(prefix_caching_allocator,
+                          PrefixCachingBlockAllocator), (
+                              "Currently this class is only tested with "
+                              "PrefixCachingBlockAllocator.")
         assert_prefix_caching_block_or_none(prev_block)
         assert_prefix_caching_block_or_none(prev_block)
 
 
         self._prev_block = prev_block
         self._prev_block = prev_block
         self._cached_content_hash: Optional[int] = None
         self._cached_content_hash: Optional[int] = None
         self._cached_num_tokens_total: Optional[int] = None
         self._cached_num_tokens_total: Optional[int] = None
         self._prefix_caching_allocator = prefix_caching_allocator
         self._prefix_caching_allocator = prefix_caching_allocator
-        self.last_accessed = _DEFAULT_LAST_ACCESSED_TIME
-        self.computed = computed
+        self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME
+        self._computed = computed
 
 
         self._block = NaiveBlock(
         self._block = NaiveBlock(
             prev_block=prev_block,
             prev_block=prev_block,
@@ -433,6 +459,22 @@ class PrefixCachingBlock(Block):
             _cow_target=self,
             _cow_target=self,
         )
         )
 
 
+    @property
+    def computed(self) -> bool:
+        return self._computed
+
+    @computed.setter
+    def computed(self, value) -> None:
+        self._computed = value
+
+    @property
+    def last_accessed(self) -> float:
+        return self._last_accessed
+
+    @last_accessed.setter
+    def last_accessed(self, last_accessed_ts: float):
+        self._last_accessed = last_accessed_ts
+
     def append_token_ids(self, token_ids: List[int]) -> None:
     def append_token_ids(self, token_ids: List[int]) -> None:
         """Appends the given token IDs to the block and registers the block as
         """Appends the given token IDs to the block and registers the block as
         immutable if the block becomes full.
         immutable if the block becomes full.
@@ -473,17 +515,18 @@ class PrefixCachingBlock(Block):
     @property
     @property
     def num_tokens_total(self) -> int:
     def num_tokens_total(self) -> int:
         """return the total tokens so far.
         """return the total tokens so far.
+
         Here we iterate the block chain till to the first block, while
         Here we iterate the block chain till to the first block, while
         cache the result in local to prevent repeated computations.
         cache the result in local to prevent repeated computations.
         """
         """
         if self._cached_num_tokens_total is not None:
         if self._cached_num_tokens_total is not None:
             return self._cached_num_tokens_total
             return self._cached_num_tokens_total
 
 
-        _block = self
+        _block: Optional[Block] = self
         self._cached_num_tokens_total = 0
         self._cached_num_tokens_total = 0
 
 
-        # TODO: current implementation here is O(N^2), we expect future
-        # ones to be O(1)
+        # TODO: current implement here take O(N^2), we expect future
+        # we have O(1) here
         while _block is not None:
         while _block is not None:
             self._cached_num_tokens_total += len(_block.token_ids)
             self._cached_num_tokens_total += len(_block.token_ids)
             _block = _block.prev_block
             _block = _block.prev_block
@@ -520,8 +563,10 @@ class PrefixCachingBlock(Block):
             return None
             return None
 
 
         is_first_block = self._prev_block is None
         is_first_block = self._prev_block is None
-        prev_block_hash = (None if is_first_block else
-                           self._prev_block.content_hash)
+        prev_block_hash = (
+            None if is_first_block else
+            self._prev_block.content_hash  # type: ignore
+        )
 
 
         # Previous block exists but does not yet have a hash.
         # Previous block exists but does not yet have a hash.
         # Return no hash in this case.
         # Return no hash in this case.

+ 23 - 7
aphrodite/processing/block_manager_v1.py

@@ -1,4 +1,5 @@
 """A block manager that manages token blocks."""
 """A block manager that manages token blocks."""
+import math
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from itertools import count, takewhile
 from itertools import count, takewhile
 from os.path import commonprefix
 from os.path import commonprefix
@@ -46,6 +47,10 @@ class BlockAllocatorBase(ABC):
     def get_num_free_blocks(self) -> int:
     def get_num_free_blocks(self) -> int:
         pass
         pass
 
 
+    @abstractmethod
+    def get_num_total_blocks(self) -> int:
+        pass
+
     @abstractmethod
     @abstractmethod
     def contains_block(self, block_hash: int) -> bool:
     def contains_block(self, block_hash: int) -> bool:
         pass
         pass
@@ -130,6 +135,9 @@ class CachedBlockAllocator(BlockAllocatorBase):
         return (self.num_blocks - self.current_num_blocks +
         return (self.num_blocks - self.current_num_blocks +
                 self.evictor.num_blocks)
                 self.evictor.num_blocks)
 
 
+    def get_num_total_blocks(self) -> int:
+        return self.num_blocks
+
     def contains_block(self, block_hash: int) -> bool:
     def contains_block(self, block_hash: int) -> bool:
         return block_hash in self.cached_blocks or block_hash in self.evictor
         return block_hash in self.cached_blocks or block_hash in self.evictor
 
 
@@ -189,6 +197,9 @@ class UncachedBlockAllocator(BlockAllocatorBase):
     def get_num_free_blocks(self) -> int:
     def get_num_free_blocks(self) -> int:
         return len(self.free_blocks)
         return len(self.free_blocks)
 
 
+    def get_num_total_blocks(self) -> int:
+        return self.num_blocks
+
     def contains_block(self, block_hash: int) -> bool:
     def contains_block(self, block_hash: int) -> bool:
         raise NotImplementedError(
         raise NotImplementedError(
             "Invalid codepath for uncached block allocator.")
             "Invalid codepath for uncached block allocator.")
@@ -220,9 +231,9 @@ class BlockSpaceManagerV1(BlockSpaceManager):
 
 
         self.block_sliding_window = None
         self.block_sliding_window = None
         if sliding_window is not None:
         if sliding_window is not None:
-            assert sliding_window % block_size == 0, (sliding_window,
-                                                      block_size)
-            self.block_sliding_window = sliding_window // block_size
+            # Round up to nearest block size to regularize sliding window
+            # allocation sizes.
+            self.block_sliding_window = math.ceil(sliding_window / block_size)
 
 
         self.watermark = watermark
         self.watermark = watermark
         assert watermark >= 0.0
         assert watermark >= 0.0
@@ -246,7 +257,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
         self.block_tables: Dict[int, BlockTable] = {}
         self.block_tables: Dict[int, BlockTable] = {}
 
 
     def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
     def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
-        # FIXME: Here we assume that all sequences in the group share
+        # FIXME(woosuk): Here we assume that all sequences in the group share
         # the same prompt. This may not be true for preempted sequences.
         # the same prompt. This may not be true for preempted sequences.
         seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
         seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
         num_required_blocks = len(seq.logical_token_blocks)
         num_required_blocks = len(seq.logical_token_blocks)
@@ -390,7 +401,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
                 block_table.append(block_table[len(block_table) %
                 block_table.append(block_table[len(block_table) %
                                                self.block_sliding_window])
                                                self.block_sliding_window])
             else:
             else:
-                # The sequence has a new logical block.
+                # The sequence hash a new logical block.
                 # Allocate a new physical block.
                 # Allocate a new physical block.
                 new_block = self._allocate_last_physical_block(seq)
                 new_block = self._allocate_last_physical_block(seq)
                 block_table.append(new_block)
                 block_table.append(new_block)
@@ -443,7 +454,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
 
 
     def can_swap_in(self,
     def can_swap_in(self,
                     seq_group: SequenceGroup,
                     seq_group: SequenceGroup,
-                    num_lookahead_slots: int = 0) -> bool:
+                    num_lookahead_slots: int = 0) -> AllocStatus:
         assert (num_lookahead_slots == 0
         assert (num_lookahead_slots == 0
                 ), "BlockSpaceManagerV1 does not support lookahead allocation"
                 ), "BlockSpaceManagerV1 does not support lookahead allocation"
         blocks = self._get_physical_blocks(seq_group)
         blocks = self._get_physical_blocks(seq_group)
@@ -453,7 +464,12 @@ class BlockSpaceManagerV1(BlockSpaceManager):
         # at least one free block right after the swap-in.
         # at least one free block right after the swap-in.
         # NOTE: This should match the logic in can_append_slot().
         # NOTE: This should match the logic in can_append_slot().
         num_required_blocks = len(blocks) + num_swapped_seqs
         num_required_blocks = len(blocks) + num_swapped_seqs
-        return num_free_blocks - num_required_blocks >= self.watermark_blocks
+        if self.gpu_allocator.get_num_total_blocks() < num_required_blocks:
+            return AllocStatus.NEVER
+        elif num_free_blocks - num_required_blocks >= self.watermark_blocks:
+            return AllocStatus.OK
+        else:
+            return AllocStatus.LATER
 
 
     def swap_in(self,
     def swap_in(self,
                 seq_group: SequenceGroup,
                 seq_group: SequenceGroup,

+ 8 - 5
aphrodite/processing/block_manager_v2.py

@@ -187,7 +187,7 @@ class BlockSpaceManagerV2(BlockSpaceManager):
         assert seq.seq_id in self.block_tables
         assert seq.seq_id in self.block_tables
         block_ids = self.block_tables[seq.seq_id].physical_block_ids
         block_ids = self.block_tables[seq.seq_id].physical_block_ids
         assert all(b is not None for b in block_ids)
         assert all(b is not None for b in block_ids)
-        return block_ids
+        return block_ids  # type: ignore
 
 
     def access_all_blocks_in_seq(self, seq: Sequence, now: float):
     def access_all_blocks_in_seq(self, seq: Sequence, now: float):
         # Update the last accessed time of all the blocks accessed
         # Update the last accessed time of all the blocks accessed
@@ -201,7 +201,9 @@ class BlockSpaceManagerV2(BlockSpaceManager):
             block_ids = []
             block_ids = []
             for block_id in block_table.physical_block_ids:
             for block_id in block_table.physical_block_ids:
                 block_ids.append(block_id)
                 block_ids.append(block_id)
-            self.block_allocator.mark_blocks_as_accessed(block_ids, now)
+            self.block_allocator.mark_blocks_as_accessed(
+                block_ids,  # type: ignore
+                now)
 
 
     def mark_blocks_as_computed(self, seq_group: SequenceGroup):
     def mark_blocks_as_computed(self, seq_group: SequenceGroup):
         # The only need for mark block as computed is for prefix caching,
         # The only need for mark block as computed is for prefix caching,
@@ -224,16 +226,17 @@ class BlockSpaceManagerV2(BlockSpaceManager):
         seq_block_ids = [
         seq_block_ids = [
             self.block_tables[seq.seq_id].physical_block_ids for seq in seqs
             self.block_tables[seq.seq_id].physical_block_ids for seq in seqs
         ]
         ]
+        # NOTE: This assumes seq_block_ids doesn't contain any None.
         return self.block_allocator.get_common_computed_block_ids(
         return self.block_allocator.get_common_computed_block_ids(
-            seq_block_ids)
+            seq_block_ids)  # type: ignore
 
 
     def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
     def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
         src_block_table = self.block_tables[parent_seq.seq_id]
         src_block_table = self.block_tables[parent_seq.seq_id]
         self.block_tables[child_seq.seq_id] = src_block_table.fork()
         self.block_tables[child_seq.seq_id] = src_block_table.fork()
 
 
     def can_swap_in(self, seq_group: SequenceGroup,
     def can_swap_in(self, seq_group: SequenceGroup,
-                    num_lookahead_slots: int) -> bool:
-        return False
+                    num_lookahead_slots: int) -> AllocStatus:
+        return AllocStatus.LATER
 
 
     def swap_in(self, seq_group: SequenceGroup,
     def swap_in(self, seq_group: SequenceGroup,
                 num_lookahead_slots: int) -> Dict[int, int]:
                 num_lookahead_slots: int) -> Dict[int, int]:

+ 11 - 5
aphrodite/processing/evictor_v2.py

@@ -32,15 +32,20 @@ class Evictor(ABC):
 
 
     @abstractmethod
     @abstractmethod
     def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
     def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
-            last_accessed: int):
+            last_accessed: float):
         """Adds block to the evictor, making it a candidate for eviction"""
         """Adds block to the evictor, making it a candidate for eviction"""
         pass
         pass
 
 
     @abstractmethod
     @abstractmethod
-    def update(self, block_id: int, last_accessed: int):
+    def update(self, block_id: int, last_accessed: float):
         """Update corresponding block's access time in metadata"""
         """Update corresponding block's access time in metadata"""
         pass
         pass
 
 
+    @abstractmethod
+    def remove(self, block_id: int):
+        """Remove a given block id from the cache."""
+        pass
+
     @abstractproperty
     @abstractproperty
     def num_blocks(self) -> int:
     def num_blocks(self) -> int:
         pass
         pass
@@ -49,12 +54,13 @@ class Evictor(ABC):
 class BlockMetaData():
 class BlockMetaData():
     """Data structure for storing key data describe cached block, so that
     """Data structure for storing key data describe cached block, so that
     evitor could use to make its decision which one to choose for eviction
     evitor could use to make its decision which one to choose for eviction
+
     Here we use physical block id as the dict key, as there maybe several
     Here we use physical block id as the dict key, as there maybe several
     blocks with the same content hash, but their physical id is unique.
     blocks with the same content hash, but their physical id is unique.
     """
     """
 
 
     def __init__(self, content_hash: int, num_hashed_tokens: int,
     def __init__(self, content_hash: int, num_hashed_tokens: int,
-                 last_accessed: int):
+                 last_accessed: float):
         self.content_hash = content_hash
         self.content_hash = content_hash
         self.num_hashed_tokens = num_hashed_tokens
         self.num_hashed_tokens = num_hashed_tokens
         self.last_accessed = last_accessed
         self.last_accessed = last_accessed
@@ -95,12 +101,12 @@ class LRUEvictor(Evictor):
         return evicted_block_id, evicted_block.content_hash
         return evicted_block_id, evicted_block.content_hash
 
 
     def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
     def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
-            last_accessed: int):
+            last_accessed: float):
         self.free_table[block_id] = BlockMetaData(content_hash,
         self.free_table[block_id] = BlockMetaData(content_hash,
                                                   num_hashed_tokens,
                                                   num_hashed_tokens,
                                                   last_accessed)
                                                   last_accessed)
 
 
-    def update(self, block_id: int, last_accessed: int):
+    def update(self, block_id: int, last_accessed: float):
         self.free_table[block_id].last_accessed = last_accessed
         self.free_table[block_id].last_accessed = last_accessed
 
 
     def remove(self, block_id: int):
     def remove(self, block_id: int):

+ 8 - 4
aphrodite/processing/interfaces.py

@@ -1,6 +1,7 @@
 import enum
 import enum
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from typing import Dict, List
 from typing import Dict, List
+from typing import Sequence as GenericSequence
 
 
 from aphrodite.common.sequence import Sequence, SequenceGroup
 from aphrodite.common.sequence import Sequence, SequenceGroup
 
 
@@ -26,11 +27,13 @@ class BlockSpaceManager(ABC):
         version = version.lower()
         version = version.lower()
 
 
         if version == "v1":
         if version == "v1":
-            from aphrodite.processing.block_manager_v1 import BlockSpaceManagerV1  # noqa: E501
+            from aphrodite.processing.block_manager_v1 import \
+                BlockSpaceManagerV1
             return BlockSpaceManagerV1
             return BlockSpaceManagerV1
 
 
         if version == "v2":
         if version == "v2":
-            from aphrodite.processing.block_manager_v2 import BlockSpaceManagerV2  # noqa: E501
+            from aphrodite.processing.block_manager_v2 import \
+                BlockSpaceManagerV2
             return BlockSpaceManagerV2
             return BlockSpaceManagerV2
 
 
         raise ValueError(f"Unknown version {version=}")
         raise ValueError(f"Unknown version {version=}")
@@ -62,7 +65,7 @@ class BlockSpaceManager(ABC):
 
 
     @abstractmethod
     @abstractmethod
     def can_swap_in(self, seq_group: SequenceGroup,
     def can_swap_in(self, seq_group: SequenceGroup,
-                    num_lookahead_slots: int) -> bool:
+                    num_lookahead_slots: int) -> AllocStatus:
         pass
         pass
 
 
     @abstractmethod
     @abstractmethod
@@ -103,7 +106,8 @@ class BlockSpaceManager(ABC):
         pass
         pass
 
 
     @abstractmethod
     @abstractmethod
-    def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
+    def get_common_computed_block_ids(
+            self, seqs: List[Sequence]) -> GenericSequence[int]:
         pass
         pass
 
 
     @abstractmethod
     @abstractmethod

+ 30 - 20
aphrodite/processing/scheduler.py

@@ -209,6 +209,8 @@ class SchedulerSwappedInOutputs:
     blocks_to_copy: Dict[int, List[int]]
     blocks_to_copy: Dict[int, List[int]]
     # The number of slots for lookahead decoding.
     # The number of slots for lookahead decoding.
     num_lookahead_slots: int
     num_lookahead_slots: int
+    # Infeasible sequence groups.
+    infeasible_seq_groups: List[SequenceGroup]
 
 
     @classmethod
     @classmethod
     def create_empty(cls) -> "SchedulerSwappedInOutputs":
     def create_empty(cls) -> "SchedulerSwappedInOutputs":
@@ -218,6 +220,7 @@ class SchedulerSwappedInOutputs:
             blocks_to_swap_in={},
             blocks_to_swap_in={},
             blocks_to_copy={},
             blocks_to_copy={},
             num_lookahead_slots=0,
             num_lookahead_slots=0,
+            infeasible_seq_groups=[],
         )
         )
 
 
 
 
@@ -335,7 +338,7 @@ class Scheduler:
             for seq_group in state_queue:
             for seq_group in state_queue:
                 if not request_ids:
                 if not request_ids:
                     # Using 'break' here may add two extra iterations,
                     # Using 'break' here may add two extra iterations,
-                    # but is acceptable to reduce complexity .
+                    # but is acceptable to reduce complexity.
                     break
                     break
                 if seq_group.request_id in request_ids:
                 if seq_group.request_id in request_ids:
                     # Appending aborted group into pending list.
                     # Appending aborted group into pending list.
@@ -395,13 +398,12 @@ class Scheduler:
         preempted: List[SequenceGroup] = []
         preempted: List[SequenceGroup] = []
         swapped_out: List[SequenceGroup] = []
         swapped_out: List[SequenceGroup] = []
 
 
-        # NOTE: Preemption happens only when there is no available slot
+        # NOTE(woosuk): Preemption happens only when there is no available slot
         # to keep all the sequence groups in the RUNNING state.
         # to keep all the sequence groups in the RUNNING state.
         # In this case, the policy is responsible for deciding which sequence
         # In this case, the policy is responsible for deciding which sequence
         # groups to preempt.
         # groups to preempt.
         now = time.time()
         now = time.time()
         running_queue = policy.sort_by_priority(now, running_queue)
         running_queue = policy.sort_by_priority(now, running_queue)
-
         while running_queue:
         while running_queue:
             seq_group = running_queue[0]
             seq_group = running_queue[0]
             num_running_tokens = self._get_num_new_tokens(
             num_running_tokens = self._get_num_new_tokens(
@@ -511,14 +513,26 @@ class Scheduler:
         prefill_seq_groups: List[ScheduledSequenceGroup] = []
         prefill_seq_groups: List[ScheduledSequenceGroup] = []
         now = time.time()
         now = time.time()
         swapped_queue = policy.sort_by_priority(now, swapped_queue)
         swapped_queue = policy.sort_by_priority(now, swapped_queue)
+        infeasible_seq_groups: List[SequenceGroup] = []
 
 
         leftover_swapped: Deque[SequenceGroup] = deque()
         leftover_swapped: Deque[SequenceGroup] = deque()
         while swapped_queue:
         while swapped_queue:
             seq_group = swapped_queue[0]
             seq_group = swapped_queue[0]
 
 
             # If the sequence group cannot be swapped in, stop.
             # If the sequence group cannot be swapped in, stop.
-            if not self.block_manager.can_swap_in(seq_group):
+            alloc_status = self.block_manager.can_swap_in(seq_group)
+            if alloc_status == AllocStatus.LATER:
                 break
                 break
+            elif alloc_status == AllocStatus.NEVER:
+                logger.warning(
+                    "Failing the request %s because there's not enough kv "
+                    "cache blocks to run the entire sequence.",
+                    seq_group.request_id)
+                for seq in seq_group.get_seqs():
+                    seq.status = SequenceStatus.FINISHED_IGNORED
+                infeasible_seq_groups.append(seq_group)
+                swapped_queue.popleft()
+                continue
 
 
             lora_int_id = 0
             lora_int_id = 0
             if self.lora_enabled:
             if self.lora_enabled:
@@ -569,7 +583,9 @@ class Scheduler:
             blocks_to_swap_in=blocks_to_swap_in,
             blocks_to_swap_in=blocks_to_swap_in,
             blocks_to_copy=blocks_to_copy,
             blocks_to_copy=blocks_to_copy,
             num_lookahead_slots=self._get_num_lookahead_slots(
             num_lookahead_slots=self._get_num_lookahead_slots(
-                is_prefill=False))
+                is_prefill=False),
+            infeasible_seq_groups=infeasible_seq_groups,
+        )
 
 
     def _schedule_prefills(
     def _schedule_prefills(
         self,
         self,
@@ -777,7 +793,8 @@ class Scheduler:
             blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
             blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
             blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy,
             blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy,
                                        swapped_in.blocks_to_copy),
                                        swapped_in.blocks_to_copy),
-            ignored_seq_groups=prefills.ignored_seq_groups,
+            ignored_seq_groups=prefills.ignored_seq_groups +
+            swapped_in.infeasible_seq_groups,
             num_lookahead_slots=running_scheduled.num_lookahead_slots,
             num_lookahead_slots=running_scheduled.num_lookahead_slots,
         )
         )
 
 
@@ -877,12 +894,14 @@ class Scheduler:
     def _can_append_slots(self, seq_group: SequenceGroup) -> bool:
     def _can_append_slots(self, seq_group: SequenceGroup) -> bool:
         """Determine whether or not we have enough space in the KV cache to
         """Determine whether or not we have enough space in the KV cache to
         continue generation of the sequence group.
         continue generation of the sequence group.
-        """# It is True only for testing case to trigger artificial preemption.
+        """
+        # It is True only for testing case to trigger artificial preemption.
         if (self.enable_artificial_preemption
         if (self.enable_artificial_preemption
                 and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB
                 and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB
                 and self.artificial_preempt_cnt > 0):
                 and self.artificial_preempt_cnt > 0):
             self.artificial_preempt_cnt -= 1
             self.artificial_preempt_cnt -= 1
             return False
             return False
+
         # Appending slots only occurs in decoding.
         # Appending slots only occurs in decoding.
         is_prefill = False
         is_prefill = False
 
 
@@ -891,15 +910,6 @@ class Scheduler:
             num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
             num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
         )
         )
 
 
-    def _can_swap_in(self, seq_group: SequenceGroup) -> bool:
-        # Swapping in is considered decode.
-        is_prefill = False
-
-        return self.block_manager.can_swap_in(
-            seq_group=seq_group,
-            num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
-        )
-
     def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
     def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
         # Schedule sequence groups.
         # Schedule sequence groups.
         # This function call changes the internal states of the scheduler
         # This function call changes the internal states of the scheduler
@@ -970,7 +980,7 @@ class Scheduler:
         # Now that the batch has been created, we can assume all blocks in the
         # Now that the batch has been created, we can assume all blocks in the
         # batch will have been computed before the next scheduling invocation.
         # batch will have been computed before the next scheduling invocation.
         # This is because the engine assumes that a failure in model execution
         # This is because the engine assumes that a failure in model execution
-        # will crash the Aphrodite instance / will not retry.
+        # will crash the vLLM instance / will not retry.
         for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
         for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
             self.block_manager.mark_blocks_as_computed(
             self.block_manager.mark_blocks_as_computed(
                 scheduled_seq_group.seq_group)
                 scheduled_seq_group.seq_group)
@@ -1029,11 +1039,11 @@ class Scheduler:
         # swapping. However, when the sequence group has multiple sequences
         # swapping. However, when the sequence group has multiple sequences
         # (e.g., beam search), recomputation is not currently supported. In
         # (e.g., beam search), recomputation is not currently supported. In
         # such a case, we use swapping instead.
         # such a case, we use swapping instead.
-        # FIXME: This makes our scheduling policy a bit bizarre.
+        # FIXME(woosuk): This makes our scheduling policy a bit bizarre.
         # As swapped sequences are prioritized over waiting sequences,
         # As swapped sequences are prioritized over waiting sequences,
         # sequence groups with multiple sequences are implicitly prioritized
         # sequence groups with multiple sequences are implicitly prioritized
         # over sequence groups with a single sequence.
         # over sequence groups with a single sequence.
-        # TODO: Support recomputation for sequence groups with multiple
+        # TODO(woosuk): Support recomputation for sequence groups with multiple
         # sequences. This may require a more sophisticated CUDA kernel.
         # sequences. This may require a more sophisticated CUDA kernel.
         if preemption_mode is None:
         if preemption_mode is None:
             if seq_group.get_max_num_running_seqs() == 1:
             if seq_group.get_max_num_running_seqs() == 1:
@@ -1082,7 +1092,7 @@ class Scheduler:
         blocks_to_swap_out: Dict[int, int],
         blocks_to_swap_out: Dict[int, int],
     ) -> None:
     ) -> None:
         if not self.block_manager.can_swap_out(seq_group):
         if not self.block_manager.can_swap_out(seq_group):
-            # FIXME: Abort the sequence group instead of aborting the
+            # FIXME(woosuk): Abort the sequence group instead of aborting the
             # entire engine.
             # entire engine.
             raise RuntimeError(
             raise RuntimeError(
                 "Aborted due to the lack of CPU swap space. Please increase "
                 "Aborted due to the lack of CPU swap space. Please increase "