123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118 |
- """Token blocks."""
- import weakref
- from collections import defaultdict
- from typing import Dict, List
- from aphrodite.common.utils import Device
- _BLANK_TOKEN_ID = -1
- DEFAULT_LAST_ACCESSED_TIME = -1
- TokensBlock = List[int]
- class BlockPool:
- """A pool of logical blocks.
- When requests come, we create a lot of logical blocks;
- when requests are done, we destroy a lot of logical blocks.
- It turns out that creating and destroying logical blocks can be expensive,
- especially for the `token_ids` field, which is a list of integers.
- To avoid this overhead, we use a pool to manage the logical blocks.
- When an old request is done and a new request comes, we can reuse the
- logical blocks from the old request to feed the new request.
- """
- def __init__(self) -> None:
- # block size to list of token blocks
- self.pool: Dict[int, List[TokensBlock]] = defaultdict(list)
- def alloc_block(self, block_size: int) -> TokensBlock:
- if block_size in self.pool and self.pool[block_size]:
- return self.pool[block_size].pop()
- return [_BLANK_TOKEN_ID] * block_size
- def del_block(self, block: TokensBlock) -> None:
- self.pool[len(block)].append(block)
- _BLOCK_POOL = BlockPool()
- class LogicalTokenBlock:
- """A block that stores a contiguous chunk of tokens from left to right.
- Logical blocks are used to represent the states of the corresponding
- physical blocks in the KV cache.
- """
- def __init__(
- self,
- block_number: int,
- block_size: int,
- ) -> None:
- self.block_number = block_number
- self.block_size = block_size
- self.token_ids = _BLOCK_POOL.alloc_block(block_size)
- # this finalizer is used to return the block to the pool when the object is deleted # noqa
- # NOTE: don't use __del__ because it cannot guarantee the order of finalization, # noqa
- # i.e. `self.token_ids` may be deleted before `self`, and we lose
- # the opportunity to return the block to the pool
- self._finalizer = weakref.finalize(self, _BLOCK_POOL.del_block,
- self.token_ids)
- self.num_tokens = 0
- def is_empty(self) -> bool:
- return self.num_tokens == 0
- def get_num_empty_slots(self) -> int:
- return self.block_size - self.num_tokens
- def is_full(self) -> bool:
- return self.num_tokens == self.block_size
- def append_tokens(self, token_ids: List[int]) -> None:
- assert len(token_ids) <= self.get_num_empty_slots()
- curr_idx = self.num_tokens
- self.token_ids[curr_idx:curr_idx + len(token_ids)] = token_ids
- self.num_tokens += len(token_ids)
- def get_token_ids(self) -> List[int]:
- return self.token_ids[:self.num_tokens]
- def get_last_token_id(self) -> int:
- assert self.num_tokens > 0
- return self.token_ids[self.num_tokens - 1]
- class PhysicalTokenBlock:
- """Represents the state of a block in the KV cache."""
- def __init__(
- self,
- device: Device,
- block_number: int,
- block_size: int,
- block_hash: int,
- num_hashed_tokens: int,
- ) -> None:
- self.device = device
- self.block_number = block_number
- self.block_size = block_size
- self.block_hash = block_hash
- self.num_hashed_tokens = num_hashed_tokens
- self.ref_count = 0
- self.last_accessed = DEFAULT_LAST_ACCESSED_TIME
- self.computed = False
- def __repr__(self) -> str:
- return (f'PhysicalTokenBlock(device={self.device}, '
- f'block_number={self.block_number}, '
- f'num_hashed_tokens={self.num_hashed_tokens}, '
- f'ref_count={self.ref_count}, '
- f'last_accessed={self.last_accessed}, '
- f'computed={self.computed})')
- # Mapping: logical block number -> physical block.
- BlockTable = List[PhysicalTokenBlock]
|