from typing import Optional import torch from transformers import PretrainedConfig from aphrodite.common.logger import init_logger from aphrodite.transformers_utils.config import get_config from aphrodite.common.utils import get_cpu_memory logger = init_logger(__name__) _GB = 1 << 30 class ModelConfig: """Configuration for the model. Args: model: Name or path of the huggingface model to use. tokenizer: Name or path of the huggingface tokenizer to use. tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if available, and "slow" will always use the slow tokenizer. trust_remote_code: Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer. download_dir: Directory to download and load the weights, default to the default cache directory of huggingface. load_format: The format of the model weights to load: "auto" will try to load the weights in the safetensors format and fall back to the pytorch bin format if safetensors format is not available. "pt" will load the weights in the pytorch bin format. "safetensors" will load the weights in the safetensors format. "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading. "dummy" will initialize the weights with random values, which is mainly for profiling. dtype: Data type for model weights and activations. The "auto" option will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. seed: Random seed for reproducibility. revision: The specific model version to use. It can be a branch name, a tag name, or a commit ID. If unspecified, will use the default version. max_model_len: Maximum length of a sequence (including prompt and output). If None, will be derived from the model. """ def __init__( self, model: str, tokenizer: str, tokenizer_mode: str, trust_remote_code: bool, download_dir: Optional[str], load_format: str, dtype: str, seed: int, revision: Optional[str], max_model_len: Optional[int] = None, ) -> None: self.model = model self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode self.trust_remote_code = trust_remote_code self.download_dir = download_dir self.load_format = load_format self.seed = seed self.revision = revision self.hf_config = get_config(model, trust_remote_code, revision) self.dtype = _get_and_verify_dtype(self.hf_config, dtype) self._verify_load_format() self._verify_tokenizer_mode() self.max_model_len = None if max_model_len is not None: derived_max_model_len = self.get_max_model_len() if max_model_len > derived_max_model_len: logger.warning( f"User-specified max_model_len ({max_model_len}) is " f"greater than the model's max length ({derived_max_model_len}). " f"Make sure the value is correct and within the model's ctxlen.") self.max_model_len = max_model_len def _verify_load_format(self) -> None: load_format = self.load_format.lower() if load_format not in [ "auto", "pt", "safetensors", "npcache", "dummy" ]: raise ValueError( f"Unknown load format: {self.load_format}. Must be one of " "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.") self.load_format = load_format def _verify_tokenizer_mode(self) -> None: tokenizer_mode = self.tokenizer_mode.lower() if tokenizer_mode not in ["auto", "slow"]: raise ValueError( f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " "either 'auto' or 'slow'.") self.tokenizer_mode = tokenizer_mode def verify_with_parallel_config( self, parallel_config: "ParallelConfig", ) -> None: total_num_attention_heads = self.hf_config.num_attention_heads tensor_parallel_size = parallel_config.tensor_parallel_size if total_num_attention_heads % tensor_parallel_size != 0: raise ValueError( f"Total number of attention heads ({total_num_attention_heads})" " must be divisible by tensor parallel size " f"({tensor_parallel_size}).") total_num_hidden_layers = self.hf_config.num_hidden_layers pipeline_parallel_size = parallel_config.pipeline_parallel_size if total_num_hidden_layers % pipeline_parallel_size != 0: raise ValueError( f"Total number of hidden layers ({total_num_hidden_layers}) " "must be divisible by pipeline parallel size " f"({pipeline_parallel_size}).") def get_hidden_size(self) -> int: return self.hf_config.hidden_size def get_head_size(self) -> int: # FIXME: This may not be true for all models. return self.hf_config.hidden_size // self.hf_config.num_attention_heads def get_num_heads(self, parallel_config: "ParallelConfig") -> int: new_decoder_arch_falcon = ( self.hf_config.model_type == "falcon" and getattr(self.hf_config, "new_decoder_architecture", False)) if not new_decoder_arch_falcon and getattr(self.hf_config, "multi_query", False): # Multi-query attention, only one KV head. return 1 if getattr(self.hf_config, "n_head_kv", None) is not None: return (self.hf_config.n_head_kv // parallel_config.tensor_parallel_size) if getattr(self.hf_config, "num_key_value_heads", None) is not None: return (self.hf_config.num_key_value_heads // parallel_config.tensor_parallel_size) total_num_attention_heads = self.hf_config.num_attention_heads return total_num_attention_heads // parallel_config.tensor_parallel_size def get_max_model_len(self) -> int: if self.max_model_len is not None: return self.max_model_len max_model_len = float("inf") possible_keys = [ "max_position_embeddings", "n_positions", "max_seq_len", "max_sequence_length", "max_seq_length", "seq_len", ] for key in possible_keys: max_len_key = getattr(self.hf_config, key, None) if max_len_key is not None: max_model_len = min(max_model_len, max_len_key) return max_model_len def get_num_layers(self, parallel_config: "ParallelConfig") -> int: total_num_hidden_layers = self.hf_config.num_hidden_layers return total_num_hidden_layers // parallel_config.pipeline_parallel_size class CacheConfig: """Configuration for the KV cache. Args: block_size: Size of a cache block in number of tokens. gpu_memory_utilization: Fraction of GPU memory to use for the Aphrodite execution. swap_space: Size of the CPU swap space per GPU (in GiB). """ def __init__( self, block_size: int, gpu_memory_utilization: float, swap_space: int, ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization self.swap_space_bytes = swap_space * _GB self._verify_args() # Will be set after profiling. self.num_gpu_blocks = None self.num_cpu_blocks = None def _verify_args(self) -> None: if self.gpu_memory_utilization > 1.0: raise ValueError( "GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.") def verify_with_parallel_config( self, parallel_config: "ParallelConfig", ) -> None: total_cpu_memory = get_cpu_memory() # FIXME: Here, it is assumed that the GPUs in a tensor parallel # group are in the same node. However, the GPUs may span multiple nodes. num_gpus_per_node = parallel_config.tensor_parallel_size cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of " f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is " "allocated for the swap space.") if cpu_memory_usage > 0.7 * total_cpu_memory: raise ValueError("Too large swap space. " + msg) elif cpu_memory_usage > 0.4 * total_cpu_memory: logger.warning("Possibly too large swap space. " + msg) class ParallelConfig: """Configuration for the distributed execution. Args: pipeline_parallel_size: Number of pipeline parallel groups. tensor_parallel_size: Number of tensor parallel groups. worker_use_ray: Whether to use Ray for model workers. Will be set to True if either pipeline_parallel_size or tensor_parallel_size is greater than 1. """ def __init__( self, pipeline_parallel_size: int, tensor_parallel_size: int, worker_use_ray: bool, ) -> None: self.pipeline_parallel_size = pipeline_parallel_size self.tensor_parallel_size = tensor_parallel_size self.worker_use_ray = worker_use_ray self.world_size = pipeline_parallel_size * tensor_parallel_size if self.world_size > 1: self.worker_use_ray = True self._verify_args() def _verify_args(self) -> None: if self.pipeline_parallel_size > 1: raise NotImplementedError( "Pipeline parallelism is not supported yet.") class SchedulerConfig: """Scheduler configuration. Args: max_num_batched_tokens: Maximum number of tokens to be processed in a single iteration. max_num_seqs: Maximum number of sequences to be processed in a single iteration. max_model_len: Maximum length of a sequence (including prompt and generated text). """ def __init__(self, max_num_batched_tokens: int, max_num_seqs: int, max_model_len: int) -> None: self.max_num_batched_tokens = max_num_batched_tokens self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len _STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.float16, "float16": torch.float16, "float": torch.float32, "float32": torch.float32, "bfloat16": torch.bfloat16, } def _get_and_verify_dtype( config: PretrainedConfig, dtype: str, ) -> torch.dtype: # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct # because config.torch_dtype can be None. config_dtype = getattr(config, "torch_dtype", None) if config_dtype is None: config_dtype = torch.float32 dtype = dtype.lower() if dtype == "auto": if config_dtype == torch.float32: # Following the common practice, we use float16 for float32 models. torch_dtype = torch.float16 else: torch_dtype = config_dtype else: if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: raise ValueError(f"Unknown dtype: {dtype}") torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] # Verify the dtype. if torch_dtype != config_dtype: if torch_dtype == torch.float32: # Upcasting to float32 is allowed. pass elif config_dtype == torch.float32: # Downcasting from float32 to float16 or bfloat16 is allowed. pass else: # Casting between float16 and bfloat16 is allowed with a warning. logger.warning(f"Casting {config_dtype} to {torch_dtype}.") # Check if the GPU supports the dtype. if torch_dtype == torch.bfloat16: compute_capability = torch.cuda.get_device_capability() if compute_capability[0] < 8: gpu_name = torch.cuda.get_device_name() raise ValueError( "Bfloat16 is only supported on GPUs with compute capability " f"of at least 8.0. Your {gpu_name} GPU has compute capability " f"{compute_capability[0]}.{compute_capability[1]}.") return torch_dtype