123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338 |
- """A CPU worker class."""
- from typing import Dict, List, Optional, Tuple
- import torch
- import torch.distributed
- from loguru import logger
- import aphrodite.common.envs as envs
- from aphrodite.attention import get_attn_backend
- from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
- LoRAConfig, ModelConfig, ParallelConfig,
- PromptAdapterConfig, SchedulerConfig)
- from aphrodite.common.sequence import ExecuteModelRequest
- from aphrodite.common.utils import STR_DTYPE_TO_TORCH_DTYPE
- from aphrodite.distributed import (ensure_model_parallel_initialized,
- init_distributed_environment)
- from aphrodite.modeling import set_random_seed
- from aphrodite.worker.cpu_model_runner import CPUModelRunner
- from aphrodite.worker.worker_base import (LocalOrDistributedWorkerBase,
- LoraNotSupportedWorkerBase,
- WorkerInput)
- APHRODITE_CPU_OMP_THREADS_BIND = envs.APHRODITE_CPU_OMP_THREADS_BIND
- class CPUCacheEngine:
- """Manages the KV cache for CPU backend.
- This class is responsible for initializing and managing CPU KV
- caches. It also provides methods for performing KV cache operations, such
- as copying.
- """
- def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
- parallel_config: ParallelConfig,
- device_config: DeviceConfig) -> None:
- assert device_config.device_type == "cpu"
- self.cache_config = cache_config
- self.model_config = model_config
- self.parallel_config = parallel_config
- self.head_size = model_config.get_head_size()
- self.num_layers = model_config.get_num_layers(parallel_config)
- self.num_heads = model_config.get_num_kv_heads(parallel_config)
- self.block_size = cache_config.block_size
- # Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks
- # for CPU backend, because we want to reuse KV cache management
- # in the scheduler.
- self.num_cpu_blocks = cache_config.num_gpu_blocks
- if cache_config.cache_dtype == "auto":
- self.dtype = model_config.dtype
- else:
- self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
- # Get attention backend.
- self.attn_backend = get_attn_backend(
- self.model_config.get_head_size(),
- self.model_config.get_sliding_window(),
- self.model_config.dtype,
- cache_config.cache_dtype,
- self.block_size,
- self.model_config.is_attention_free(),
- )
- # Initialize the cache.
- self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks)
- def _allocate_kv_cache(
- self,
- num_blocks: int,
- ) -> List[torch.Tensor]:
- """Allocates KV cache on CPU."""
- kv_cache_shape = self.attn_backend.get_kv_cache_shape(
- num_blocks, self.block_size, self.num_heads, self.head_size)
- kv_cache: List[torch.Tensor] = []
- for _ in range(self.num_layers):
- kv_cache.append(
- torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu"))
- return kv_cache
- def swap_in(self, src_to_dst: Dict[int, int]) -> None:
- raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
- def swap_out(self, src_to_dst: Dict[int, int]) -> None:
- raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
- def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
- self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts)
- @staticmethod
- def get_cache_block_size(
- block_size: int,
- cache_dtype: str,
- model_config: ModelConfig,
- parallel_config: ParallelConfig,
- ) -> int:
- head_size = model_config.get_head_size()
- num_heads = model_config.get_num_kv_heads(parallel_config)
- num_layers = model_config.get_num_layers(parallel_config)
- key_cache_block = block_size * num_heads * head_size
- value_cache_block = key_cache_block
- total = num_layers * (key_cache_block + value_cache_block)
- if cache_dtype == "auto":
- dtype = model_config.dtype
- else:
- dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
- dtype_size = torch.tensor([], dtype=dtype).element_size()
- return dtype_size * total
- class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
- """A worker class that executes (a partition of) the model on a CPU socket.
- Each worker is associated with a single CPU socket. The worker is
- responsible for maintaining the KV cache and executing the model on the
- CPU. In case of distributed inference, each worker is assigned a partition
- of the model.
- """
- def __init__(
- self,
- model_config: ModelConfig,
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- device_config: DeviceConfig,
- cache_config: CacheConfig,
- load_config: LoadConfig,
- local_rank: int,
- rank: int,
- distributed_init_method: str,
- lora_config: Optional[LoRAConfig] = None,
- kv_cache_dtype: Optional[str] = "auto",
- prompt_adapter_config: Optional[PromptAdapterConfig] = None,
- is_driver_worker: bool = False,
- ) -> None:
- self.model_config = model_config
- self.parallel_config = parallel_config
- self.scheduler_config = scheduler_config
- self.device_config = device_config
- self.cache_config = cache_config
- self.load_config = load_config
- self.local_rank = local_rank
- self.rank = rank
- self.distributed_init_method = distributed_init_method
- self.lora_config = lora_config
- self.prompt_adapter_config = prompt_adapter_config
- self.is_driver_worker = is_driver_worker
- if self.is_driver_worker:
- assert self.rank == 0, "The driver worker must have rank 0."
- if self.model_config.trust_remote_code:
- # note: lazy import to avoid importing torch before initializing
- from aphrodite.common.utils import init_cached_hf_modules
- init_cached_hf_modules()
- # Setup OpenMP threads affinity.
- omp_cpuids = APHRODITE_CPU_OMP_THREADS_BIND
- if omp_cpuids == "all":
- self.local_omp_cpuid = "all"
- else:
- self.local_omp_cpuid = omp_cpuids.split("|")[rank]
- self.model_runner: CPUModelRunner = CPUModelRunner(
- model_config,
- parallel_config,
- scheduler_config,
- device_config,
- cache_config,
- load_config=self.load_config,
- lora_config=self.lora_config,
- kv_cache_dtype=kv_cache_dtype,
- prompt_adapter_config=self.prompt_adapter_config,
- is_driver_worker=is_driver_worker)
- # Uninitialized cache engine. Will be initialized by
- # initialize_cache.
- self.cache_engine: List[CPUCacheEngine]
- self.cpu_cache: List[List[torch.Tensor]]
- def init_device(self) -> None:
- if self.local_omp_cpuid != "all":
- ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
- logger.info(ret)
- self.init_distributed_environment()
- # Set random seed.
- set_random_seed(self.model_config.seed)
- def load_model(self):
- self.model_runner.load_model()
- def determine_num_available_blocks(self) -> Tuple[int, int]:
- """Determine the number of blocks available for the KV cache.
- This determines how many KV blocks can fit into the configured CPU
- KV cache space.
- Note that since Aphrodite assumes a block resides on GPU if it can be
- modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0.
- This allows us to reuse the scheduler of Aphrodite without generalizing
- it to different devices.
- """
- # For CPU device, the block number will be calculated based on the
- # cpu_kvcache_space.
- cache_block_size = self.get_cache_block_size_bytes()
- num_cpu_blocks = int(self.cache_config.cpu_kvcache_space_bytes //
- cache_block_size)
- num_cpu_blocks = max(num_cpu_blocks, 0)
- # NOTE: To reuse the cache management procedure,
- # use cpu cache as 'gpu cache'.
- num_gpu_blocks = num_cpu_blocks
- num_cpu_blocks = 0
- return num_gpu_blocks, num_cpu_blocks
- def initialize_cache(self, num_gpu_blocks: int,
- num_cpu_blocks: int) -> None:
- """Initialize the KV cache. Currently, swappable CPU memory is not
- supported.
- Since this worker does not support GPUs, we use the num_gpu_blocks to
- determine how many non-swappable CPU blocks to allocate.
- """
- assert (num_cpu_blocks == 0
- ), f"{type(self)} does not support swappable cache"
- # NOTE: To reuse the cache management procedure,
- # use cpu cache as 'gpu cache'.
- num_cpu_blocks = num_gpu_blocks
- self._validate_num_cpu_blocks(num_cpu_blocks)
- self.cache_config.num_gpu_blocks = num_cpu_blocks
- self.cache_config.num_cpu_blocks = 0
- # Initialize the cache.
- self._init_cache_engine()
- def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None:
- """Raise errors if the num_cpu_blocks is invalid.
- """
- if num_cpu_blocks <= 0:
- raise ValueError(
- "No available memory for the cache blocks. "
- "Try increasing `APHRODITE_CPU_KVCACHE_SPACE` when "
- "initializing the engine.")
- max_seq_len = self.cache_config.block_size * num_cpu_blocks
- if self.model_config.max_model_len > max_seq_len:
- raise ValueError(
- f"The model's max seq len ({self.model_config.max_model_len}) "
- "is larger than the maximum number of tokens that can be "
- f"stored in KV cache ({max_seq_len}). Try increasing "
- "`APHRODITE_CPU_KVCACHE_SPACE` or decreasing `max_model_len` "
- "when initializing the engine.")
- def _init_cache_engine(self) -> None:
- self.cache_engine = [
- CPUCacheEngine(self.cache_config, self.model_config,
- self.parallel_config, self.device_config)
- for _ in range(self.parallel_config.pipeline_parallel_size)
- ]
- self.cpu_cache = [
- self.cache_engine[ve].cpu_cache
- for ve in range(self.parallel_config.pipeline_parallel_size)
- ]
- self.model_runner.block_size = self.cache_engine[0].block_size
- assert all(
- self.cpu_cache[ve] is not None
- for ve in range(self.parallel_config.pipeline_parallel_size))
- # Populate the cache to warmup the memory
- for ve in range(self.parallel_config.pipeline_parallel_size):
- for layer_cache in self.cpu_cache[ve]:
- layer_cache.fill_(0)
- @property
- def do_metadata_broadcast(self) -> bool:
- return self.parallel_config.tensor_parallel_size > 1
- @property
- def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
- return self.cpu_cache
- def execute_worker(
- self,
- worker_input: WorkerInput,
- ) -> None:
- if (worker_input.blocks_to_copy is not None
- and worker_input.blocks_to_copy.numel() > 0):
- self.cache_engine[worker_input.virtual_engine].copy(
- worker_input.blocks_to_copy)
- @torch.inference_mode()
- def prepare_worker_input(
- self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
- assert execute_model_req is not None
- virtual_engine = execute_model_req.virtual_engine
- num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
- blocks_to_copy = execute_model_req.blocks_to_copy
- blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
- device="cpu",
- dtype=torch.int64).view(-1, 2)
- assert len(execute_model_req.blocks_to_swap_in) == 0
- assert len(execute_model_req.blocks_to_swap_out) == 0
- return WorkerInput(
- num_seq_groups=num_seq_groups,
- blocks_to_copy=blocks_to_copy,
- virtual_engine=virtual_engine,
- )
- def init_distributed_environment(self) -> None:
- """Initialize the distributed environment."""
- parallel_config = self.parallel_config
- rank = self.rank
- distributed_init_method = self.distributed_init_method
- init_distributed_environment(
- world_size=parallel_config.world_size,
- rank=rank,
- distributed_init_method=distributed_init_method,
- backend="gloo",
- )
- # A small all_reduce for warmup.
- torch.distributed.all_reduce(torch.zeros(1).cpu())
- ensure_model_parallel_initialized(
- parallel_config.tensor_parallel_size,
- parallel_config.pipeline_parallel_size)
- def get_cache_block_size_bytes(self) -> int:
- """Return the size in bytes of a single KV cache block.
- """
- return CPUCacheEngine.get_cache_block_size(
- self.cache_config.block_size, self.cache_config.cache_dtype,
- self.model_config, self.parallel_config)
|