|
@@ -1,22 +1,29 @@
|
|
|
import enum
|
|
|
-from typing import TYPE_CHECKING, Optional, Union, ClassVar
|
|
|
-from dataclasses import dataclass, fields
|
|
|
-import os
|
|
|
-from packaging.version import Version
|
|
|
-from loguru import logger
|
|
|
import json
|
|
|
+import os
|
|
|
+from dataclasses import dataclass, field, fields
|
|
|
+from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
|
|
|
|
|
|
import torch
|
|
|
+from loguru import logger
|
|
|
+from packaging.version import Version
|
|
|
from transformers import PretrainedConfig
|
|
|
|
|
|
-from aphrodite.transformers_utils.config import get_config, get_hf_text_config
|
|
|
-from aphrodite.common.utils import (get_cpu_memory, is_cpu, is_hip, is_neuron,
|
|
|
- get_nvcc_cuda_version)
|
|
|
+from aphrodite.common.utils import (get_cpu_memory, get_nvcc_cuda_version,
|
|
|
+ is_cpu, is_hip, is_neuron)
|
|
|
from aphrodite.quantization import QUANTIZATION_METHODS
|
|
|
+from aphrodite.transformers_utils.config import get_config, get_hf_text_config
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
from ray.util.placement_group import PlacementGroup
|
|
|
|
|
|
+ from aphrodite.modeling.model_loader.loader import BaseModelLoader
|
|
|
+
|
|
|
+
|
|
|
+# If true, will load models from ModelScope instead of Hugging Face Hub.
|
|
|
+APHRODITE_USE_MODELSCOPE = os.environ.get("APHRODITE_USE_MODELSCOPE",
|
|
|
+ "False").lower() == "true"
|
|
|
+
|
|
|
_GB = 1 << 30
|
|
|
|
|
|
|
|
@@ -30,18 +37,6 @@ class ModelConfig:
|
|
|
available, and "slow" will always use the slow tokenizer.
|
|
|
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
|
|
downloading the model and tokenizer.
|
|
|
- download_dir: Directory to download and load the weights, default to the
|
|
|
- default cache directory of huggingface.
|
|
|
- load_format: The format of the model weights to load:
|
|
|
- "auto" will try to load the weights in the safetensors format and
|
|
|
- fall back to the pytorch bin format if safetensors format is
|
|
|
- not available.
|
|
|
- "pt" will load the weights in the pytorch bin format.
|
|
|
- "safetensors" will load the weights in the safetensors format.
|
|
|
- "npcache" will load the weights in pytorch format and store
|
|
|
- a numpy cache to speed up the loading.
|
|
|
- "dummy" will initialize the weights with random values, which is
|
|
|
- mainly for profiling.
|
|
|
dtype: Data type for model weights and activations. The "auto" option
|
|
|
will use FP16 precision for FP32 and FP16 models, and BF16 precision
|
|
|
for BF16 models.
|
|
@@ -50,7 +45,7 @@ class ModelConfig:
|
|
|
a tag name, or a commit id. If unspecified, will use the default
|
|
|
version.
|
|
|
code_revision: The specific revision to use for the model code on
|
|
|
- Hugging Face Hub. It can be a branch name, a tag name, or a
|
|
|
+ Hugging Face Hub. It can be a branch name, a tag name, or a
|
|
|
commit id. If unspecified, will use the default version.
|
|
|
tokenizer_revision: The specific tokenizer version to use. It can be a
|
|
|
branch name, a tag name, or a commit id. If unspecified, will use
|
|
@@ -66,8 +61,8 @@ class ModelConfig:
|
|
|
load_in_smooth: Whether to load the FP16 model in smoothquant format.
|
|
|
quantization_param_path: Path to JSON file containing scaling factors.
|
|
|
Used to load KV cache scaling factors into the model when KV cache
|
|
|
- type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
|
|
|
- be used to load activation and weight scaling factors when the
|
|
|
+ type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
|
|
|
+ be used to load activation and weight scaling factors when the
|
|
|
model dtype is FP8_E4M3 on ROCm.
|
|
|
enforce_eager: Whether to enforce eager execution. If True, we will
|
|
|
disable CUDA graph and always execute the model in eager mode.
|
|
@@ -75,6 +70,8 @@ class ModelConfig:
|
|
|
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
|
|
|
When a sequence has context length larger than this, we fall back
|
|
|
to eager mode.
|
|
|
+ skip_tokenizer_init: If true, skip initialization of tokenizer and
|
|
|
+ detokenizer.
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
@@ -83,9 +80,6 @@ class ModelConfig:
|
|
|
tokenizer: str,
|
|
|
tokenizer_mode: str,
|
|
|
trust_remote_code: bool,
|
|
|
- download_dir: Optional[str],
|
|
|
- load_format: str,
|
|
|
- # dtype: str,
|
|
|
dtype: Union[str, torch.dtype],
|
|
|
seed: int,
|
|
|
revision: Optional[str] = None,
|
|
@@ -97,16 +91,15 @@ class ModelConfig:
|
|
|
load_in_8bit: bool = False,
|
|
|
load_in_smooth: bool = False,
|
|
|
quantization_param_path: Optional[str] = None,
|
|
|
- enforce_eager: bool = True,
|
|
|
+ enforce_eager: bool = False,
|
|
|
max_context_len_to_capture: Optional[int] = None,
|
|
|
- max_log_probs: int = 10,
|
|
|
+ max_logprobs: int = 5,
|
|
|
+ skip_tokenizer_init: bool = False,
|
|
|
) -> None:
|
|
|
self.model = model
|
|
|
self.tokenizer = tokenizer
|
|
|
self.tokenizer_mode = tokenizer_mode
|
|
|
self.trust_remote_code = trust_remote_code
|
|
|
- self.download_dir = download_dir
|
|
|
- self.load_format = load_format
|
|
|
self.seed = seed
|
|
|
self.revision = revision
|
|
|
self.code_revision = code_revision
|
|
@@ -118,22 +111,8 @@ class ModelConfig:
|
|
|
self.quantization_param_path = quantization_param_path
|
|
|
self.enforce_eager = enforce_eager
|
|
|
self.max_context_len_to_capture = max_context_len_to_capture
|
|
|
- self.max_log_probs = max_log_probs
|
|
|
-
|
|
|
- if os.environ.get("APHRODITE_USE_MODELSCOPE",
|
|
|
- "False").lower() == "true":
|
|
|
- # download model from ModelScope hub,
|
|
|
- # lazy import so that modelscope is not required for normal use.
|
|
|
- from modelscope.hub.snapshot_download import snapshot_download # pylint: disable=C
|
|
|
- if not os.path.exists(model):
|
|
|
- model_path = snapshot_download(model_id=model,
|
|
|
- cache_dir=download_dir,
|
|
|
- revision=revision)
|
|
|
- else:
|
|
|
- model_path = model
|
|
|
- self.model = model_path
|
|
|
- self.download_dir = model_path
|
|
|
- self.tokenizer = model_path
|
|
|
+ self.max_logprobs = max_logprobs
|
|
|
+ self.skip_tokenizer_init = skip_tokenizer_init
|
|
|
|
|
|
self.hf_config = get_config(self.model, trust_remote_code, revision,
|
|
|
code_revision)
|
|
@@ -141,41 +120,11 @@ class ModelConfig:
|
|
|
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
|
|
self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
|
|
|
max_model_len)
|
|
|
- self._verify_load_format()
|
|
|
- self._verify_tokenizer_mode()
|
|
|
+ if not self.skip_tokenizer_init:
|
|
|
+ self._verify_tokenizer_mode()
|
|
|
self._verify_quantization()
|
|
|
self._verify_cuda_graph()
|
|
|
|
|
|
- def _verify_load_format(self) -> None:
|
|
|
- load_format = self.load_format.lower()
|
|
|
- supported_load_format = [
|
|
|
- "auto", "pt", "safetensors", "npcache", "dummy"
|
|
|
- ]
|
|
|
- rocm_not_supported_load_format = []
|
|
|
- if load_format not in supported_load_format:
|
|
|
- raise ValueError(
|
|
|
- f"Unknown load format: {self.load_format}. Must be one of "
|
|
|
- "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
|
|
|
- if is_hip() and load_format in rocm_not_supported_load_format:
|
|
|
- rocm_supported_load_format = [
|
|
|
- f for f in supported_load_format
|
|
|
- if (f not in rocm_not_supported_load_format)
|
|
|
- ]
|
|
|
- raise ValueError(
|
|
|
- f"load format \'{load_format}\' is not supported in ROCm. "
|
|
|
- f"Supported load format are "
|
|
|
- f"{rocm_supported_load_format}")
|
|
|
-
|
|
|
- # TODO: Remove this check once HF updates the pt weights of Mixtral.
|
|
|
- architectures = getattr(self.hf_config, "architectures", [])
|
|
|
- # architectures can be None instead of []
|
|
|
- if architectures and "MixtralForCausalLM" in architectures \
|
|
|
- and load_format == "pt":
|
|
|
- raise ValueError(
|
|
|
- "Currently, the 'pt' format is not supported for Mixtral. "
|
|
|
- "Please use the 'safetensors' format instead. ")
|
|
|
- self.load_format = load_format
|
|
|
-
|
|
|
def _verify_tokenizer_mode(self) -> None:
|
|
|
tokenizer_mode = self.tokenizer_mode.lower()
|
|
|
if tokenizer_mode not in ["auto", "slow"]:
|
|
@@ -186,33 +135,33 @@ class ModelConfig:
|
|
|
|
|
|
def _verify_quantization(self) -> None:
|
|
|
supported_quantization = [*QUANTIZATION_METHODS]
|
|
|
- rocm_not_supported_quantization = ["aqlm", "awq", "bnb", "quip"]
|
|
|
+ rocm_supported_quantization = ["gptq", "squeezellm"]
|
|
|
if self.quantization is not None:
|
|
|
self.quantization = self.quantization.lower()
|
|
|
|
|
|
- if self.model.endswith("gguf"):
|
|
|
- if self.quantization is None:
|
|
|
- self.quantization = "gguf"
|
|
|
- elif self.quantization != "gguf":
|
|
|
- raise ValueError(
|
|
|
- f"GGUF file cannot be used in ({self.quantization}).")
|
|
|
-
|
|
|
# Parse quantization method from the HF model config, if available.
|
|
|
- hf_quant_config = getattr(self.hf_config, "quantization_config", None)
|
|
|
- if hf_quant_config is not None:
|
|
|
-
|
|
|
- hf_quant_method = str(hf_quant_config["quant_method"]).lower()
|
|
|
- # If the GPTQ model is serialized in marlin format, use marlin.
|
|
|
- if (hf_quant_method == "gptq"
|
|
|
- and "is_marlin_format" in hf_quant_config
|
|
|
- and hf_quant_config["is_marlin_format"]):
|
|
|
- hf_quant_method = "marlin"
|
|
|
+ quant_cfg = getattr(self.hf_config, "quantization_config", None)
|
|
|
+ if quant_cfg is not None:
|
|
|
+ quant_method = quant_cfg.get("quant_method", "").lower()
|
|
|
+ # compat: autogptq >=0.8.0 use checkpoint_format: str
|
|
|
+ # compat: autogptq <=0.7.1 is_marlin_format: bool
|
|
|
+ is_format_marlin = (quant_cfg.get("checkpoint_format") == "marlin"
|
|
|
+ or quant_cfg.get("is_marlin_format", False))
|
|
|
+
|
|
|
+ # Use marlin if the GPTQ model is serialized in marlin format.
|
|
|
+ if quant_method == "gptq" and is_format_marlin:
|
|
|
+ logger.info("The model is serialized in Marlin format. "
|
|
|
+ "Using Marlin kernel.")
|
|
|
+ quant_method = "marlin"
|
|
|
+ if self.quantization == "gptq":
|
|
|
+ self.quantization = quant_method
|
|
|
+
|
|
|
if self.quantization is None:
|
|
|
- self.quantization = hf_quant_method
|
|
|
- elif self.quantization != hf_quant_method:
|
|
|
+ self.quantization = quant_method
|
|
|
+ elif self.quantization != quant_method:
|
|
|
raise ValueError(
|
|
|
"Quantization method specified in the model config "
|
|
|
- f"({hf_quant_method}) does not match the quantization "
|
|
|
+ f"({quant_method}) does not match the quantization "
|
|
|
f"method specified in the `quantization` argument "
|
|
|
f"({self.quantization}).")
|
|
|
if self.load_in_4bit:
|
|
@@ -280,7 +229,7 @@ class ModelConfig:
|
|
|
f"Unknown quantization method: {self.quantization}. Must "
|
|
|
f"be one of {supported_quantization}.")
|
|
|
if is_hip(
|
|
|
- ) and self.quantization in rocm_not_supported_quantization:
|
|
|
+ ) and self.quantization not in rocm_supported_quantization:
|
|
|
raise ValueError(
|
|
|
f"{self.quantization} quantization is currently not "
|
|
|
"supported in ROCm.")
|
|
@@ -317,6 +266,12 @@ class ModelConfig:
|
|
|
f"({pipeline_parallel_size}).")
|
|
|
|
|
|
def get_sliding_window(self) -> Optional[int]:
|
|
|
+ """Get the sliding window size, or None if disabled.
|
|
|
+ """
|
|
|
+
|
|
|
+ # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
|
|
|
+ # addition to sliding window size. We check if that field is present
|
|
|
+ # and if it's False, return None.
|
|
|
if (hasattr(self.hf_text_config, "use_sliding_window")
|
|
|
and not self.hf_text_config.use_sliding_window):
|
|
|
return None
|
|
@@ -329,8 +284,8 @@ class ModelConfig:
|
|
|
return self.hf_text_config.hidden_size
|
|
|
|
|
|
def get_head_size(self) -> int:
|
|
|
- if hasattr(self.hf_config, "head_dim"):
|
|
|
- return self.hf_config.head_dim
|
|
|
+ if hasattr(self.hf_text_config, "head_dim"):
|
|
|
+ return self.hf_text_config.head_dim
|
|
|
# FIXME: This may not be true for all models.
|
|
|
return (self.hf_text_config.hidden_size //
|
|
|
self.hf_text_config.num_attention_heads)
|
|
@@ -397,11 +352,9 @@ class CacheConfig:
|
|
|
gpu_memory_utilization: Fraction of GPU memory to use for the
|
|
|
Aphrodite execution.
|
|
|
swap_space: Size of the CPU swap space per GPU (in GiB).
|
|
|
- cache_dtype: Data Type for KV cache storage.
|
|
|
- cache_quant_params_path: Path to the scales and zero points
|
|
|
- of KV cache quantization when cache_dtype is int8.
|
|
|
- num_gpu_blocks_override: Number of GPU blocks to use. This overrides
|
|
|
- the profiled num_gpu_blocks if specified. Does nothing if None.
|
|
|
+ cache_dtype: Data type for kv cache storage.
|
|
|
+ num_gpu_blocks_override: Number of GPU blocks to use. This overrides the
|
|
|
+ profiled num_gpu_blocks if specified. Does nothing if None.
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
@@ -410,10 +363,9 @@ class CacheConfig:
|
|
|
gpu_memory_utilization: float,
|
|
|
swap_space: int,
|
|
|
cache_dtype: str,
|
|
|
- # cache_quant_params_path: Optional[str] = None,
|
|
|
num_gpu_blocks_override: Optional[int] = None,
|
|
|
sliding_window: Optional[int] = None,
|
|
|
- context_shift: bool = False,
|
|
|
+ enable_prefix_caching: bool = False,
|
|
|
) -> None:
|
|
|
self.block_size = block_size
|
|
|
self.gpu_memory_utilization = gpu_memory_utilization
|
|
@@ -421,8 +373,7 @@ class CacheConfig:
|
|
|
self.num_gpu_blocks_override = num_gpu_blocks_override
|
|
|
self.cache_dtype = cache_dtype
|
|
|
self.sliding_window = sliding_window
|
|
|
- # self.cache_quant_params_path = cache_quant_params_path
|
|
|
- self.context_shift = context_shift
|
|
|
+ self.enable_prefix_caching = enable_prefix_caching
|
|
|
self._verify_args()
|
|
|
self._verify_cache_dtype()
|
|
|
|
|
@@ -443,7 +394,6 @@ class CacheConfig:
|
|
|
|
|
|
def _verify_cache_dtype(self) -> None:
|
|
|
if self.cache_dtype == "auto":
|
|
|
- # if self.cache_dtype in ["auto", "int8"]:
|
|
|
pass
|
|
|
elif self.cache_dtype == "fp8":
|
|
|
if not is_hip():
|
|
@@ -484,10 +434,10 @@ class CacheConfig:
|
|
|
@dataclass
|
|
|
class TokenizerPoolConfig:
|
|
|
"""Configuration for the tokenizer pool.
|
|
|
-
|
|
|
+
|
|
|
Args:
|
|
|
- pool_size: Number of tokenizer instances in the pool.
|
|
|
- pool_type: Type of the tokenizer pool.
|
|
|
+ pool_size: Number of tokenizer workers in the pool.
|
|
|
+ pool_type: Type of the pool.
|
|
|
extra_config: Additional config for the pool.
|
|
|
The way the config will be used depends on the
|
|
|
pool type.
|
|
@@ -498,7 +448,7 @@ class TokenizerPoolConfig:
|
|
|
|
|
|
def __post_init__(self):
|
|
|
if self.pool_type not in ("ray", ):
|
|
|
- raise ValueError(f"Unknown pool type: {self.pool_type}.")
|
|
|
+ raise ValueError(f"Unknown pool type: {self.pool_type}")
|
|
|
if not isinstance(self.extra_config, dict):
|
|
|
raise ValueError("extra_config must be a dictionary.")
|
|
|
|
|
@@ -508,14 +458,15 @@ class TokenizerPoolConfig:
|
|
|
tokenizer_pool_extra_config: Optional[Union[str, dict]]
|
|
|
) -> Optional["TokenizerPoolConfig"]:
|
|
|
"""Create a TokenizerPoolConfig from the given parameters.
|
|
|
-
|
|
|
+
|
|
|
If tokenizer_pool_size is 0, return None.
|
|
|
-
|
|
|
+
|
|
|
Args:
|
|
|
tokenizer_pool_size: Number of tokenizer workers in the pool.
|
|
|
- tokenizer_pool_type: Type of the tokenizer pool.
|
|
|
+ tokenizer_pool_type: Type of the pool.
|
|
|
tokenizer_pool_extra_config: Additional config for the pool.
|
|
|
- The way the config will be used depends on the pool type.
|
|
|
+ The way the config will be used depends on the
|
|
|
+ pool type. This can be a JSON string (will be parsed).
|
|
|
"""
|
|
|
if tokenizer_pool_size:
|
|
|
if isinstance(tokenizer_pool_extra_config, str):
|
|
@@ -532,6 +483,65 @@ class TokenizerPoolConfig:
|
|
|
return tokenizer_pool_config
|
|
|
|
|
|
|
|
|
+class LoadFormat(str, enum.Enum):
|
|
|
+ AUTO = "auto"
|
|
|
+ PT = "pt"
|
|
|
+ SAFETENSORS = "safetensors"
|
|
|
+ NPCACHE = "npcache"
|
|
|
+ DUMMY = "dummy"
|
|
|
+ TENSORIZER = "tensorizer"
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class LoadConfig:
|
|
|
+ """
|
|
|
+ download_dir: Directory to download and load the weights, default to the
|
|
|
+ default cache directory of huggingface.
|
|
|
+ load_format: The format of the model weights to load:
|
|
|
+ "auto" will try to load the weights in the safetensors format and
|
|
|
+ fall back to the pytorch bin format if safetensors format is
|
|
|
+ not available.
|
|
|
+ "pt" will load the weights in the pytorch bin format.
|
|
|
+ "safetensors" will load the weights in the safetensors format.
|
|
|
+ "npcache" will load the weights in pytorch format and store
|
|
|
+ a numpy cache to speed up the loading.
|
|
|
+ "dummy" will initialize the weights with random values, which is
|
|
|
+ mainly for profiling.
|
|
|
+ "tensorizer" will use CoreWeave's tensorizer library for
|
|
|
+ fast weight loading.
|
|
|
+ """
|
|
|
+
|
|
|
+ load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
|
|
|
+ download_dir: Optional[str] = None
|
|
|
+ model_loader_extra_config: Optional[Union[str, dict]] = field(
|
|
|
+ default_factory=dict)
|
|
|
+
|
|
|
+ def __post_init__(self):
|
|
|
+ model_loader_extra_config = self.model_loader_extra_config or {}
|
|
|
+ if isinstance(model_loader_extra_config, str):
|
|
|
+ self.model_loader_extra_config = json.loads(
|
|
|
+ model_loader_extra_config)
|
|
|
+ self._verify_load_format()
|
|
|
+
|
|
|
+ def _verify_load_format(self) -> None:
|
|
|
+ if not isinstance(self.load_format, str):
|
|
|
+ return
|
|
|
+
|
|
|
+ load_format = self.load_format.lower()
|
|
|
+ self.load_format = LoadFormat(load_format)
|
|
|
+
|
|
|
+ rocm_not_supported_load_format: List[str] = []
|
|
|
+ if is_hip() and load_format in rocm_not_supported_load_format:
|
|
|
+ rocm_supported_load_format = [
|
|
|
+ f for f in LoadFormat.__members__
|
|
|
+ if (f not in rocm_not_supported_load_format)
|
|
|
+ ]
|
|
|
+ raise ValueError(
|
|
|
+ f"load format '{load_format}' is not supported in ROCm. "
|
|
|
+ f"Supported load formats are "
|
|
|
+ f"{rocm_supported_load_format}")
|
|
|
+
|
|
|
+
|
|
|
class ParallelConfig:
|
|
|
"""Configuration for the distributed execution.
|
|
|
|
|
@@ -546,7 +556,7 @@ class ParallelConfig:
|
|
|
parallel and large models.
|
|
|
disable_custom_all_reduce: Disable the custom all-reduce kernel and
|
|
|
fall back to NCCL.
|
|
|
- tokenizer_pool_config: Configuration for the tokenizer pool.
|
|
|
+ tokenizer_pool_config: Config for the tokenizer pool.
|
|
|
If None, will use synchronous tokenization.
|
|
|
ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
|
|
|
https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
|
|
@@ -613,12 +623,9 @@ class SchedulerConfig:
|
|
|
decoding to store KV activations of tokens which may or may not be
|
|
|
accepted.
|
|
|
delay_factor: Apply a delay (of delay factor multiplied by previous
|
|
|
- prompt latency) before scheduling the next prompt.
|
|
|
- policy: Policy of sequence scheduling (`fcfs` or `reorder`).
|
|
|
- reorder_window: Allowed reorder window size (in sec) for `reorder`
|
|
|
- policy.
|
|
|
- enable_chunked_prefill: If True, prefill requests can be chunked
|
|
|
- based on the remaining max_num_batched_tokens.
|
|
|
+ prompt latency) before scheduling next prompt.
|
|
|
+ enable_chunked_prefill: If True, prefill requests can be chunked based
|
|
|
+ on the remaining max_num_batched_tokens.
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
@@ -629,8 +636,6 @@ class SchedulerConfig:
|
|
|
use_v2_block_manager: bool = False,
|
|
|
num_lookahead_slots: int = 0,
|
|
|
delay_factor: float = 0.0,
|
|
|
- policy: str = "fcfs",
|
|
|
- reorder_window: float = 0.0,
|
|
|
enable_chunked_prefill: bool = False,
|
|
|
) -> None:
|
|
|
if max_num_batched_tokens is not None:
|
|
@@ -645,14 +650,14 @@ class SchedulerConfig:
|
|
|
self.max_num_batched_tokens = max(max_model_len, 2048)
|
|
|
if enable_chunked_prefill:
|
|
|
logger.info("Chunked prefill is enabled (EXPERIMENTAL).")
|
|
|
+
|
|
|
self.max_num_seqs = max_num_seqs
|
|
|
self.max_model_len = max_model_len
|
|
|
self.use_v2_block_manager = use_v2_block_manager
|
|
|
self.num_lookahead_slots = num_lookahead_slots
|
|
|
self.delay_factor = delay_factor
|
|
|
- self.policy = policy
|
|
|
- self.reorder_window = reorder_window
|
|
|
self.chunked_prefill_enabled = enable_chunked_prefill
|
|
|
+
|
|
|
self._verify_args()
|
|
|
|
|
|
def _verify_args(self) -> None:
|
|
@@ -665,17 +670,13 @@ class SchedulerConfig:
|
|
|
"max_num_batched_tokens and makes Aphrodite reject longer "
|
|
|
"sequences. Please increase max_num_batched_tokens or "
|
|
|
"decrease max_model_len.")
|
|
|
+
|
|
|
if self.max_num_batched_tokens < self.max_num_seqs:
|
|
|
raise ValueError(
|
|
|
f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
|
|
|
"be greater than or equal to max_num_seqs "
|
|
|
f"({self.max_num_seqs}).")
|
|
|
- if self.reorder_window < 0:
|
|
|
- raise ValueError(f"reorder_window ({self.reorder_window}) must "
|
|
|
- "be not be negative.")
|
|
|
- if self.reorder_window != 0 and self.policy != 'reorder':
|
|
|
- raise ValueError("fcfs policy doesn't support reorder_window "
|
|
|
- f"({self.reorder_window}).")
|
|
|
+
|
|
|
if self.num_lookahead_slots < 0:
|
|
|
raise ValueError(
|
|
|
"num_lookahead_slots "
|
|
@@ -688,14 +689,14 @@ class DeviceConfig:
|
|
|
def __init__(self, device: str = "auto") -> None:
|
|
|
if device == "auto":
|
|
|
# Automated device type detection
|
|
|
- if torch.cuda.is_available():
|
|
|
- self.device_type = "cuda"
|
|
|
- elif is_neuron():
|
|
|
+ if is_neuron():
|
|
|
self.device_type = "neuron"
|
|
|
elif is_cpu():
|
|
|
self.device_type = "cpu"
|
|
|
else:
|
|
|
- raise RuntimeError("No supported device detected.")
|
|
|
+ # We don't call torch.cuda.is_available() here to
|
|
|
+ # avoid initializing CUDA before workers are forked
|
|
|
+ self.device_type = "cuda"
|
|
|
else:
|
|
|
# Device type is assigned explicitly
|
|
|
self.device_type = device
|
|
@@ -710,6 +711,7 @@ class DeviceConfig:
|
|
|
|
|
|
class SpeculativeConfig:
|
|
|
"""Configuration for speculative decoding.
|
|
|
+
|
|
|
The configuration is currently specialized to draft-model speculative
|
|
|
decoding with top-1 proposals.
|
|
|
"""
|
|
@@ -724,13 +726,13 @@ class SpeculativeConfig:
|
|
|
speculative_max_model_len: Optional[int],
|
|
|
enable_chunked_prefill: bool,
|
|
|
use_v2_block_manager: bool,
|
|
|
- ngram_prompt_lookup_max: Optional[int],
|
|
|
- ngram_prompt_lookup_min: Optional[int],
|
|
|
) -> Optional["SpeculativeConfig"]:
|
|
|
"""Create a SpeculativeConfig if possible, else return None.
|
|
|
+
|
|
|
This function attempts to create a SpeculativeConfig object based on the
|
|
|
provided parameters. If the necessary conditions are met, it returns an
|
|
|
instance of SpeculativeConfig. Otherwise, it returns None.
|
|
|
+
|
|
|
Args:
|
|
|
target_model_config (ModelConfig): The configuration of the target
|
|
|
model.
|
|
@@ -767,6 +769,7 @@ class SpeculativeConfig:
|
|
|
"Expected both speculative_model and "
|
|
|
"num_speculative_tokens to be provided, but found "
|
|
|
f"{speculative_model=} and {num_speculative_tokens=}.")
|
|
|
+
|
|
|
assert (speculative_model is not None
|
|
|
and num_speculative_tokens is not None)
|
|
|
|
|
@@ -786,55 +789,39 @@ class SpeculativeConfig:
|
|
|
draft_code_revision = None
|
|
|
draft_quantization = None
|
|
|
|
|
|
- if speculative_model == "[ngram]":
|
|
|
- assert (ngram_prompt_lookup_max is not None
|
|
|
- and ngram_prompt_lookup_max > 0)
|
|
|
- if ngram_prompt_lookup_min is None:
|
|
|
- ngram_prompt_lookup_min = 0
|
|
|
- else:
|
|
|
- assert ngram_prompt_lookup_max > ngram_prompt_lookup_min
|
|
|
- draft_model_config = target_model_config
|
|
|
- draft_parallel_config = target_parallel_config
|
|
|
- else:
|
|
|
- ngram_prompt_lookup_max = 0
|
|
|
- ngram_prompt_lookup_min = 0
|
|
|
- draft_model_config = ModelConfig(
|
|
|
- model=speculative_model,
|
|
|
- download_dir=target_model_config.download_dir,
|
|
|
- load_format=target_model_config.load_format,
|
|
|
- tokenizer=target_model_config.tokenizer,
|
|
|
- tokenizer_mode=target_model_config.tokenizer_mode,
|
|
|
- trust_remote_code=target_model_config.trust_remote_code,
|
|
|
- dtype=target_model_config.dtype,
|
|
|
- seed=target_model_config.seed,
|
|
|
- revision=draft_revision,
|
|
|
- code_revision=draft_code_revision,
|
|
|
- tokenizer_revision=target_model_config.tokenizer_revision,
|
|
|
- max_model_len=None,
|
|
|
- quantization=draft_quantization,
|
|
|
- enforce_eager=target_model_config.enforce_eager,
|
|
|
- max_context_len_to_capture=target_model_config.
|
|
|
- max_context_len_to_capture,
|
|
|
- max_log_probs=target_model_config.max_log_probs,
|
|
|
- )
|
|
|
-
|
|
|
- draft_model_config.max_model_len = (
|
|
|
- SpeculativeConfig._maybe_override_draft_max_model_len(
|
|
|
- speculative_max_model_len,
|
|
|
- draft_model_config.max_model_len,
|
|
|
- target_model_config.max_model_len,
|
|
|
- ))
|
|
|
-
|
|
|
- draft_parallel_config = (
|
|
|
- SpeculativeConfig.create_draft_parallel_config(
|
|
|
- target_parallel_config))
|
|
|
+ draft_model_config = ModelConfig(
|
|
|
+ model=speculative_model,
|
|
|
+ tokenizer=target_model_config.tokenizer,
|
|
|
+ tokenizer_mode=target_model_config.tokenizer_mode,
|
|
|
+ trust_remote_code=target_model_config.trust_remote_code,
|
|
|
+ dtype=target_model_config.dtype,
|
|
|
+ seed=target_model_config.seed,
|
|
|
+ revision=draft_revision,
|
|
|
+ code_revision=draft_code_revision,
|
|
|
+ tokenizer_revision=target_model_config.tokenizer_revision,
|
|
|
+ max_model_len=None,
|
|
|
+ quantization=draft_quantization,
|
|
|
+ enforce_eager=target_model_config.enforce_eager,
|
|
|
+ max_context_len_to_capture=target_model_config.
|
|
|
+ max_context_len_to_capture,
|
|
|
+ max_logprobs=target_model_config.max_logprobs,
|
|
|
+ )
|
|
|
+
|
|
|
+ draft_model_config.max_model_len = (
|
|
|
+ SpeculativeConfig._maybe_override_draft_max_model_len(
|
|
|
+ speculative_max_model_len,
|
|
|
+ draft_model_config.max_model_len,
|
|
|
+ target_model_config.max_model_len,
|
|
|
+ ))
|
|
|
+
|
|
|
+ draft_parallel_config = (
|
|
|
+ SpeculativeConfig.create_draft_parallel_config(
|
|
|
+ target_parallel_config))
|
|
|
|
|
|
return SpeculativeConfig(
|
|
|
draft_model_config,
|
|
|
draft_parallel_config,
|
|
|
num_speculative_tokens,
|
|
|
- ngram_prompt_lookup_max,
|
|
|
- ngram_prompt_lookup_min,
|
|
|
)
|
|
|
|
|
|
@staticmethod
|
|
@@ -847,8 +834,10 @@ class SpeculativeConfig:
|
|
|
the draft_max_model_len, but may be the target_max_model_len if it is
|
|
|
less than the draft_max_model_len, or may be speculative_max_model_len
|
|
|
if it is specified.
|
|
|
+
|
|
|
This is necessary so that sequences do not exceed the capacity of the
|
|
|
draft model or the target model.
|
|
|
+
|
|
|
speculative_max_model_len is mainly used for testing that sequences can
|
|
|
skip speculation.
|
|
|
"""
|
|
@@ -874,6 +863,7 @@ class SpeculativeConfig:
|
|
|
def create_draft_parallel_config(
|
|
|
target_parallel_config: ParallelConfig) -> ParallelConfig:
|
|
|
"""Create a parallel config for use by the draft worker.
|
|
|
+
|
|
|
This is mostly a copy of the target parallel config. In the future the
|
|
|
draft worker can have a different parallel strategy, e.g. TP=1.
|
|
|
"""
|
|
@@ -899,10 +889,9 @@ class SpeculativeConfig:
|
|
|
draft_model_config: ModelConfig,
|
|
|
draft_parallel_config: ParallelConfig,
|
|
|
num_speculative_tokens: int,
|
|
|
- ngram_prompt_lookup_max: int,
|
|
|
- ngram_prompt_lookup_min: int,
|
|
|
):
|
|
|
"""Create a SpeculativeConfig object.
|
|
|
+
|
|
|
Args:
|
|
|
draft_model_config: ModelConfig for the draft model.
|
|
|
draft_parallel_config: ParallelConfig for the draft model.
|
|
@@ -912,8 +901,6 @@ class SpeculativeConfig:
|
|
|
self.draft_model_config = draft_model_config
|
|
|
self.draft_parallel_config = draft_parallel_config
|
|
|
self.num_speculative_tokens = num_speculative_tokens
|
|
|
- self.ngram_prompt_lookup_max = ngram_prompt_lookup_max
|
|
|
- self.ngram_prompt_lookup_min = ngram_prompt_lookup_min
|
|
|
|
|
|
self._verify_args()
|
|
|
|
|
@@ -930,16 +917,14 @@ class SpeculativeConfig:
|
|
|
def num_lookahead_slots(self) -> int:
|
|
|
"""The number of additional slots the scheduler should allocate per
|
|
|
step, in addition to the slots allocated for each known token.
|
|
|
+
|
|
|
This is equal to the number of speculative tokens, as each speculative
|
|
|
token must be scored.
|
|
|
"""
|
|
|
return self.num_speculative_tokens
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
- if self.ngram_prompt_lookup_max > 0:
|
|
|
- draft_model = "[ngram]"
|
|
|
- else:
|
|
|
- draft_model = self.draft_model_config.model
|
|
|
+ draft_model = self.draft_model_config.model
|
|
|
num_spec_tokens = self.num_speculative_tokens
|
|
|
return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})"
|
|
|
|
|
@@ -980,9 +965,12 @@ class LoRAConfig:
|
|
|
self.lora_dtype = model_config.dtype
|
|
|
elif isinstance(self.lora_dtype, str):
|
|
|
self.lora_dtype = getattr(torch, self.lora_dtype)
|
|
|
- if (model_config.quantization is not None
|
|
|
- and model_config.quantization == "gguf"):
|
|
|
- raise ValueError("LoRA is not supported with GGUF quantization.")
|
|
|
+ if model_config.quantization and model_config.quantization not in [
|
|
|
+ "awq", "gptq"
|
|
|
+ ]:
|
|
|
+ # TODO support all other quants
|
|
|
+ logger.warning(f"{model_config.quantization} quantization is not "
|
|
|
+ "tested with LoRA yet.")
|
|
|
|
|
|
def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
|
|
|
if scheduler_config.max_num_batched_tokens > 65528:
|
|
@@ -999,11 +987,14 @@ class VisionLanguageConfig:
|
|
|
|
|
|
class ImageInputType(enum.Enum):
|
|
|
"""Image input type into the vision language model.
|
|
|
+
|
|
|
An image roughly goes through the following transformation:
|
|
|
Raw image --> pixel values --> image features --> image embeddings.
|
|
|
+
|
|
|
The difference between different image input types is where the
|
|
|
image encoder (pixel values --> image features) is run.
|
|
|
Different image input types also correspond to different tensor shapes.
|
|
|
+
|
|
|
For example, for Llava, PIXEL_VALUES: (1, 3, 336, 336).
|
|
|
IMAGE_FEATURES: (1, 576, 1024).
|
|
|
"""
|
|
@@ -1075,7 +1066,7 @@ def _get_and_verify_dtype(
|
|
|
k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
|
|
|
if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
|
|
|
]
|
|
|
- raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. "
|
|
|
+ raise ValueError(f"dtype '{dtype}' is not supported in ROCm. "
|
|
|
f"Supported dtypes are {rocm_supported_dtypes}")
|
|
|
|
|
|
# Verify the dtype.
|
|
@@ -1110,16 +1101,20 @@ def _get_and_verify_max_len(
|
|
|
"max_seq_len",
|
|
|
# ChatGLM2
|
|
|
"seq_length",
|
|
|
+ # Command-R
|
|
|
+ "model_max_length",
|
|
|
# Others
|
|
|
"max_sequence_length",
|
|
|
"max_seq_length",
|
|
|
"seq_len",
|
|
|
]
|
|
|
+ max_len_key = None
|
|
|
for key in possible_keys:
|
|
|
- max_len_key = getattr(hf_config, key, None)
|
|
|
- if max_len_key is not None:
|
|
|
- derived_max_model_len = min(derived_max_model_len, max_len_key)
|
|
|
- break
|
|
|
+ max_len = getattr(hf_config, key, None)
|
|
|
+ if max_len is not None:
|
|
|
+ max_len_key = key if max_len < derived_max_model_len \
|
|
|
+ else max_len_key
|
|
|
+ derived_max_model_len = min(derived_max_model_len, max_len)
|
|
|
if derived_max_model_len == float("inf"):
|
|
|
if max_model_len is not None:
|
|
|
# If max_model_len is specified, we use it.
|
|
@@ -1167,7 +1162,7 @@ class DecodingConfig:
|
|
|
valid_guided_backends = ['outlines', 'lm-format-enforcer']
|
|
|
backend = self.guided_decoding_backend
|
|
|
if backend not in valid_guided_backends:
|
|
|
- raise ValueError(f"Invalid guided_decoding_backend '{backend}',"
|
|
|
+ raise ValueError(f"Invalid guided_decoding_backend '{backend},"
|
|
|
f"must be one of {valid_guided_backends}")
|
|
|
|
|
|
|
|
@@ -1182,6 +1177,7 @@ class EngineConfig:
|
|
|
parallel_config: ParallelConfig
|
|
|
scheduler_config: SchedulerConfig
|
|
|
device_config: DeviceConfig
|
|
|
+ load_config: LoadConfig
|
|
|
lora_config: Optional[LoRAConfig]
|
|
|
vision_language_config: Optional[VisionLanguageConfig]
|
|
|
speculative_config: Optional[SpeculativeConfig]
|