"""Token blocks.""" from typing import List from aphrodite.common.utils import Device _BLANK_TOKEN_ID = -1 DEFAULT_LAST_ACCESSED_TIME = -1 class LogicalTokenBlock: """A block that stores a contiguous chunk of tokens from left to right. Logical blocks are used to represent the states of the corresponding physical blocks in the KV cache. """ def __init__( self, block_number: int, block_size: int, ) -> None: self.block_number = block_number self.block_size = block_size self.token_ids = [_BLANK_TOKEN_ID] * block_size self.num_tokens = 0 def is_empty(self) -> bool: return self.num_tokens == 0 def get_num_empty_slots(self) -> int: return self.block_size - self.num_tokens def is_full(self) -> bool: return self.num_tokens == self.block_size def append_tokens(self, token_ids: List[int]) -> None: assert len(token_ids) <= self.get_num_empty_slots() curr_idx = self.num_tokens self.token_ids[curr_idx:curr_idx + len(token_ids)] = token_ids self.num_tokens += len(token_ids) def get_token_ids(self) -> List[int]: return self.token_ids[:self.num_tokens] def get_last_token_id(self) -> int: assert self.num_tokens > 0 return self.token_ids[self.num_tokens - 1] class PhysicalTokenBlock: """Represents the state of a block in the KV cache.""" def __init__( self, device: Device, block_number: int, block_size: int, block_hash: int, num_hashed_tokens: int, ) -> None: self.device = device self.block_number = block_number self.block_size = block_size self.block_hash = block_hash self.num_hashed_tokens = num_hashed_tokens self.ref_count = 0 self.last_accessed = DEFAULT_LAST_ACCESSED_TIME self.computed = False def __repr__(self) -> str: return (f'PhysicalTokenBlock(device={self.device}, ' f'block_number={self.block_number}, ' f'num_hashed_tokens={self.num_hashed_tokens}, ' f'ref_count={self.ref_count}, ' f'last_accessed={self.last_accessed}, ' f'computed={self.computed})') # Mapping: logical block number -> physical block. BlockTable = List[PhysicalTokenBlock]