123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 |
- """Token blocks."""
- from typing import List
- from aphrodite.common.utils import Device
- _BLANK_TOKEN_ID = -1
- DEFAULT_LAST_ACCESSED_TIME = -1
- class LogicalTokenBlock:
- """A block that stores a contiguous chunk of tokens from left to right.
- Logical blocks are used to represent the states of the corresponding
- physical blocks in the KV cache.
- """
- def __init__(
- self,
- block_number: int,
- block_size: int,
- ) -> None:
- self.block_number = block_number
- self.block_size = block_size
- self.token_ids = [_BLANK_TOKEN_ID] * block_size
- self.num_tokens = 0
- def is_empty(self) -> bool:
- return self.num_tokens == 0
- def get_num_empty_slots(self) -> int:
- return self.block_size - self.num_tokens
- def is_full(self) -> bool:
- return self.num_tokens == self.block_size
- def append_tokens(self, token_ids: List[int]) -> None:
- assert len(token_ids) <= self.get_num_empty_slots()
- curr_idx = self.num_tokens
- self.token_ids[curr_idx:curr_idx + len(token_ids)] = token_ids
- self.num_tokens += len(token_ids)
- def get_token_ids(self) -> List[int]:
- return self.token_ids[:self.num_tokens]
- def get_last_token_id(self) -> int:
- assert self.num_tokens > 0
- return self.token_ids[self.num_tokens - 1]
- class PhysicalTokenBlock:
- """Represents the state of a block in the KV cache."""
- def __init__(
- self,
- device: Device,
- block_number: int,
- block_size: int,
- block_hash: int,
- num_hashed_tokens: int,
- ) -> None:
- self.device = device
- self.block_number = block_number
- self.block_size = block_size
- self.block_hash = block_hash
- self.num_hashed_tokens = num_hashed_tokens
- self.ref_count = 0
- self.last_accessed = DEFAULT_LAST_ACCESSED_TIME
- self.computed = False
- def __repr__(self) -> str:
- return (f'PhysicalTokenBlock(device={self.device}, '
- f'block_number={self.block_number}, '
- f'num_hashed_tokens={self.num_hashed_tokens}, '
- f'ref_count={self.ref_count}, '
- f'last_accessed={self.last_accessed}, '
- f'computed={self.computed})')
- # Mapping: logical block number -> physical block.
- BlockTable = List[PhysicalTokenBlock]
|