config.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777
  1. from typing import Optional, Union, ClassVar
  2. from dataclasses import dataclass
  3. import os
  4. from packaging.version import Version
  5. from loguru import logger
  6. import torch
  7. from transformers import PretrainedConfig
  8. from aphrodite.transformers_utils.config import get_config
  9. from aphrodite.common.utils import (get_cpu_memory, is_hip, is_neuron,
  10. get_nvcc_cuda_version)
  11. _GB = 1 << 30
  12. class ModelConfig:
  13. """Configuration for the model.
  14. Args:
  15. model: Name or path of the huggingface model to use.
  16. tokenizer: Name or path of the huggingface tokenizer to use.
  17. tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
  18. available, and "slow" will always use the slow tokenizer.
  19. trust_remote_code: Trust remote code (e.g., from HuggingFace) when
  20. downloading the model and tokenizer.
  21. download_dir: Directory to download and load the weights, default to the
  22. default cache directory of huggingface.
  23. load_format: The format of the model weights to load:
  24. "auto" will try to load the weights in the safetensors format and
  25. fall back to the pytorch bin format if safetensors format is
  26. not available.
  27. "pt" will load the weights in the pytorch bin format.
  28. "safetensors" will load the weights in the safetensors format.
  29. "npcache" will load the weights in pytorch format and store
  30. a numpy cache to speed up the loading.
  31. "dummy" will initialize the weights with random values, which is
  32. mainly for profiling.
  33. dtype: Data type for model weights and activations. The "auto" option
  34. will use FP16 precision for FP32 and FP16 models, and BF16 precision
  35. for BF16 models.
  36. seed: Random seed for reproducibility.
  37. revision: The specific model version to use. It can be a branch name,
  38. a tag name, or a commit id. If unspecified, will use the default
  39. version.
  40. code_revision: The specific revision to use for the model code on
  41. Hugging Face Hub. It can be a branch name, a tag name, or a
  42. commit id. If unspecified, will use the default version.
  43. tokenizer_revision: The specific tokenizer version to use. It can be a
  44. branch name, a tag name, or a commit id. If unspecified, will use
  45. the default version.
  46. max_model_len: Maximum length of a sequence (including prompt and
  47. output). If None, will be derived from the model.
  48. quantization: Quantization method that was used to quantize the model
  49. weights. If None, we assume the model weights are not quantized.
  50. load_in_4bit: Whether to load the FP16 model in bitsandbytes 4bit
  51. format. Works with AWQ models as well as FP16.
  52. load_in_8bit: Whether to load the FP16 model in 8bit format. Slower
  53. than load_in_smooth in terms of throughput.
  54. load_in_smooth: Whether to load the FP16 model in smoothquant format.
  55. enforce_eager: Whether to enforce eager execution. If True, we will
  56. disable CUDA graph and always execute the model in eager mode.
  57. If False, we will use CUDA graph and eager execution in hybrid.
  58. max_context_len_to_capture: Maximum context len covered by CUDA graphs.
  59. When a sequence has context length larger than this, we fall back
  60. to eager mode.
  61. """
  62. def __init__(
  63. self,
  64. model: str,
  65. tokenizer: str,
  66. tokenizer_mode: str,
  67. trust_remote_code: bool,
  68. download_dir: Optional[str],
  69. load_format: str,
  70. # dtype: str,
  71. dtype: Union[str, torch.dtype],
  72. seed: int,
  73. revision: Optional[str] = None,
  74. code_revision: Optional[str] = None,
  75. tokenizer_revision: Optional[str] = None,
  76. max_model_len: Optional[int] = None,
  77. quantization: Optional[str] = None,
  78. load_in_4bit: bool = False,
  79. load_in_8bit: bool = False,
  80. load_in_smooth: bool = False,
  81. enforce_eager: bool = True,
  82. max_context_len_to_capture: Optional[int] = None,
  83. max_log_probs: int = 10,
  84. ) -> None:
  85. self.model = model
  86. self.tokenizer = tokenizer
  87. self.tokenizer_mode = tokenizer_mode
  88. self.trust_remote_code = trust_remote_code
  89. self.download_dir = download_dir
  90. self.load_format = load_format
  91. self.seed = seed
  92. self.revision = revision
  93. self.code_revision = code_revision
  94. self.tokenizer_revision = tokenizer_revision
  95. self.quantization = quantization
  96. self.load_in_4bit = load_in_4bit
  97. self.load_in_8bit = load_in_8bit
  98. self.load_in_smooth = load_in_smooth
  99. self.enforce_eager = enforce_eager
  100. self.max_context_len_to_capture = max_context_len_to_capture
  101. self.max_log_probs = max_log_probs
  102. if os.environ.get("APHRODITE_USE_MODELSCOPE",
  103. "False").lower() == "true":
  104. # download model from ModelScope hub,
  105. # lazy import so that modelscope is not required for normal use.
  106. from modelscope.hub.snapshot_download import snapshot_download # pylint: disable=C
  107. if not os.path.exists(model):
  108. model_path = snapshot_download(model_id=model,
  109. cache_dir=download_dir,
  110. revision=revision)
  111. else:
  112. model_path = model
  113. self.model = model_path
  114. self.download_dir = model_path
  115. self.tokenizer = model_path
  116. self.hf_config = get_config(self.model, trust_remote_code, revision,
  117. code_revision)
  118. self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
  119. self.max_model_len = _get_and_verify_max_len(self.hf_config,
  120. max_model_len)
  121. self._verify_load_format()
  122. self._verify_tokenizer_mode()
  123. self._verify_quantization()
  124. self._verify_cuda_graph()
  125. def _verify_load_format(self) -> None:
  126. load_format = self.load_format.lower()
  127. supported_load_format = [
  128. "auto", "pt", "safetensors", "npcache", "dummy"
  129. ]
  130. rocm_not_supported_load_format = []
  131. if load_format not in supported_load_format:
  132. raise ValueError(
  133. f"Unknown load format: {self.load_format}. Must be one of "
  134. "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
  135. if is_hip() and load_format in rocm_not_supported_load_format:
  136. rocm_supported_load_format = [
  137. f for f in supported_load_format
  138. if (f not in rocm_not_supported_load_format)
  139. ]
  140. raise ValueError(
  141. f"load format \'{load_format}\' is not supported in ROCm. "
  142. f"Supported load format are "
  143. f"{rocm_supported_load_format}")
  144. # TODO: Remove this check once HF updates the pt weights of Mixtral.
  145. architectures = getattr(self.hf_config, "architectures", [])
  146. if "MixtralForCausalLM" in architectures and load_format == "pt":
  147. raise ValueError(
  148. "Currently, the 'pt' format is not supported for Mixtral. "
  149. "Please use the 'safetensors' format instead. ")
  150. self.load_format = load_format
  151. def _verify_tokenizer_mode(self) -> None:
  152. tokenizer_mode = self.tokenizer_mode.lower()
  153. if tokenizer_mode not in ["auto", "slow"]:
  154. raise ValueError(
  155. f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
  156. "either 'auto' or 'slow'.")
  157. self.tokenizer_mode = tokenizer_mode
  158. def _verify_quantization(self) -> None:
  159. supported_quantization = [
  160. "aqlm", "awq", "bnb", "exl2", "gguf", "gptq", "quip", "squeezellm",
  161. "marlin"
  162. ]
  163. rocm_not_supported_quantization = ["aqlm", "awq", "bnb", "quip"]
  164. if self.quantization is not None:
  165. self.quantization = self.quantization.lower()
  166. if self.model.endswith("gguf"):
  167. if self.quantization is None:
  168. self.quantization = "gguf"
  169. elif self.quantization != "gguf":
  170. raise ValueError(
  171. f"GGUF file cannot be used in ({self.quantization}).")
  172. # Parse quantization method from the HF model config, if available.
  173. hf_quant_config = getattr(self.hf_config, "quantization_config", None)
  174. if hf_quant_config is not None:
  175. hf_quant_method = str(hf_quant_config["quant_method"]).lower()
  176. # If the GPTQ model is serialized in marlin format, use marlin.
  177. if (hf_quant_method == "gptq"
  178. and "is_marlin_format" in hf_quant_config
  179. and hf_quant_config["is_marlin_format"]):
  180. hf_quant_method = "marlin"
  181. if self.quantization is None:
  182. self.quantization = hf_quant_method
  183. elif self.quantization != hf_quant_method:
  184. raise ValueError(
  185. "Quantization method specified in the model config "
  186. f"({hf_quant_method}) does not match the quantization "
  187. f"method specified in the `quantization` argument "
  188. f"({self.quantization}).")
  189. if self.load_in_4bit:
  190. # the kernels seem to not work with 4bit weight_only
  191. if torch.cuda.get_device_capability(0)[0] < 8:
  192. raise ValueError(
  193. "load_in_4bit quantization is not supported on GPUs with "
  194. "compute capability less than 8.0.")
  195. if self.quantization is None:
  196. self.quantization = "bnb"
  197. self.hf_config.quantization_config = {
  198. "bits": 4,
  199. "quant_mode": "weight_only",
  200. "quant_method": "bnb",
  201. "group_size": 128,
  202. "zero_point": True,
  203. "from_float": True
  204. }
  205. elif self.quantization == "awq":
  206. logger.warning("AWQ model is being loaded in 4bit bnb format.")
  207. self.quantization = "bnb"
  208. self.hf_config.quantization_config = {
  209. "zero_point": True,
  210. "q_group_size": 128,
  211. "w_bit": 4,
  212. "version": "gemm"
  213. }
  214. elif self.quantization != "bnb":
  215. raise ValueError("4bit quantization is not supported in "
  216. f"{self.quantization}.")
  217. if self.load_in_8bit:
  218. if self.quantization is None:
  219. self.quantization = "bnb"
  220. elif self.quantization != "bnb":
  221. raise ValueError("8bit quantization is not supported in "
  222. f"{self.quantization}.")
  223. self.hf_config.quantization_config = {
  224. "bits": 8,
  225. "quant_mode": "llm_int8",
  226. "quant_method": "bnb",
  227. "group_size": 128,
  228. "zero_point": True,
  229. "from_float": True
  230. }
  231. self.enforce_eager = True
  232. if self.load_in_smooth:
  233. if self.quantization is None:
  234. self.quantization = "bnb"
  235. elif self.quantization != "bnb":
  236. raise ValueError("Smooth quantization is not supported in "
  237. f"{self.quantization}.")
  238. self.hf_config.quantization_config = {
  239. "bits": 8,
  240. "quant_mode": "smoothquant",
  241. "quant_method": "bnb",
  242. "group_size": 128,
  243. "zero_point": True,
  244. "from_float": True
  245. }
  246. self.enforce_eager = True
  247. if self.quantization is not None:
  248. if self.quantization not in supported_quantization:
  249. raise ValueError(
  250. f"Unknown quantization method: {self.quantization}. Must "
  251. f"be one of {supported_quantization}.")
  252. if is_hip(
  253. ) and self.quantization in rocm_not_supported_quantization:
  254. raise ValueError(
  255. f"{self.quantization} quantization is currently not "
  256. "supported in ROCm.")
  257. if self.quantization != "marlin":
  258. logger.warning(
  259. f"{self.quantization} quantization is not fully "
  260. "optimized yet. The speed can be slower than "
  261. "non-quantized models.")
  262. def _verify_cuda_graph(self) -> None:
  263. if self.max_context_len_to_capture is None:
  264. self.max_context_len_to_capture = self.max_model_len
  265. self.max_context_len_to_capture = min(self.max_context_len_to_capture,
  266. self.max_model_len)
  267. def verify_with_parallel_config(
  268. self,
  269. parallel_config: "ParallelConfig",
  270. ) -> None:
  271. total_num_attention_heads = self.hf_config.num_attention_heads
  272. tensor_parallel_size = parallel_config.tensor_parallel_size
  273. if total_num_attention_heads % tensor_parallel_size != 0:
  274. raise ValueError(
  275. f"Total number of attention heads ({total_num_attention_heads})"
  276. " must be divisible by tensor parallel size "
  277. f"({tensor_parallel_size}).")
  278. total_num_hidden_layers = self.hf_config.num_hidden_layers
  279. pipeline_parallel_size = parallel_config.pipeline_parallel_size
  280. if total_num_hidden_layers % pipeline_parallel_size != 0:
  281. raise ValueError(
  282. f"Total number of hidden layers ({total_num_hidden_layers}) "
  283. "must be divisible by pipeline parallel size "
  284. f"({pipeline_parallel_size}).")
  285. def get_sliding_window(self) -> Optional[int]:
  286. return getattr(self.hf_config, "sliding_window", None)
  287. def get_vocab_size(self) -> int:
  288. return self.hf_config.vocab_size
  289. def get_hidden_size(self) -> int:
  290. return self.hf_config.hidden_size
  291. def get_head_size(self) -> int:
  292. if hasattr(self.hf_config, "head_dim"):
  293. return self.hf_config.head_dim
  294. # FIXME: This may not be true for all models.
  295. return self.hf_config.hidden_size // self.hf_config.num_attention_heads
  296. def get_total_num_kv_heads(self) -> int:
  297. """Returns the total number of KV heads."""
  298. # For GPTBigCode & Falcon:
  299. # NOTE: for falcon, when new_decoder_architecture is True, the
  300. # multi_query flag is ignored and we use n_head_kv for the number of
  301. # KV heads.
  302. falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
  303. new_decoder_arch_falcon = (
  304. self.hf_config.model_type in falcon_model_types
  305. and getattr(self.hf_config, "new_decoder_architecture", False))
  306. if not new_decoder_arch_falcon and getattr(self.hf_config,
  307. "multi_query", False):
  308. # Multi-query attention, only one KV head.
  309. # Currently, tensor parallelism is not supported in this case.
  310. return 1
  311. attributes = [
  312. # For Falcon:
  313. "n_head_kv",
  314. "num_kv_heads",
  315. # For LLaMA-2:
  316. "num_key_value_heads",
  317. # For ChatGLM:
  318. "multi_query_group_num",
  319. ]
  320. for attr in attributes:
  321. num_kv_heads = getattr(self.hf_config, attr, None)
  322. if num_kv_heads is not None:
  323. return num_kv_heads
  324. # For non-grouped-query attention models, the number of KV heads is
  325. # equal to the number of attention heads.
  326. return self.hf_config.num_attention_heads
  327. def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
  328. """Returns the number of KV heads per GPU."""
  329. total_num_kv_heads = self.get_total_num_kv_heads()
  330. # If tensor parallelism is used, we divide the number of KV heads by
  331. # the tensor parallel size. We will replicate the KV heads in the
  332. # case where the number of KV heads is smaller than the tensor
  333. # parallel size so each GPU has at least one KV head.
  334. return max(1,
  335. total_num_kv_heads // parallel_config.tensor_parallel_size)
  336. def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
  337. total_num_hidden_layers = self.hf_config.num_hidden_layers
  338. return total_num_hidden_layers // parallel_config.pipeline_parallel_size
  339. class CacheConfig:
  340. """Configuration for the KV cache.
  341. Args:
  342. block_size: Size of a cache block in number of tokens.
  343. gpu_memory_utilization: Fraction of GPU memory to use for the
  344. Aphrodite execution.
  345. swap_space: Size of the CPU swap space per GPU (in GiB).
  346. cache_dtype: Data Type for KV cache storage.
  347. cache_quant_params_path: Path to the scales and zero points
  348. of KV cache quantization when cache_dtype is int8.
  349. """
  350. def __init__(
  351. self,
  352. block_size: int,
  353. gpu_memory_utilization: float,
  354. swap_space: int,
  355. cache_dtype: str,
  356. # cache_quant_params_path: Optional[str] = None,
  357. sliding_window: Optional[int] = None,
  358. context_shift: bool = False,
  359. ) -> None:
  360. self.block_size = block_size
  361. self.gpu_memory_utilization = gpu_memory_utilization
  362. self.swap_space_bytes = swap_space * _GB
  363. self.cache_dtype = cache_dtype
  364. self.sliding_window = sliding_window
  365. # self.cache_quant_params_path = cache_quant_params_path
  366. self.context_shift = context_shift
  367. self._verify_args()
  368. self._verify_cache_dtype()
  369. # Will be set after profiling.
  370. self.num_gpu_blocks = None
  371. self.num_cpu_blocks = None
  372. def metrics_info(self):
  373. # convert cache_config to dict(key: str, value: str) for prometheus
  374. # metrics info
  375. return {key: str(value) for key, value in self.__dict__.items()}
  376. def _verify_args(self) -> None:
  377. if self.gpu_memory_utilization > 1.0:
  378. raise ValueError(
  379. "GPU memory utilization must be less than 1.0. Got "
  380. f"{self.gpu_memory_utilization}.")
  381. def _verify_cache_dtype(self) -> None:
  382. if self.cache_dtype == "auto":
  383. # if self.cache_dtype in ["auto", "int8"]:
  384. pass
  385. elif self.cache_dtype == "fp8_e5m2":
  386. if is_hip():
  387. raise NotImplementedError(
  388. "FP8_E5M2 KV Cache on AMD GPU has not been supported yet.")
  389. nvcc_cuda_version = get_nvcc_cuda_version()
  390. if nvcc_cuda_version and nvcc_cuda_version < Version("11.8"):
  391. raise ValueError(
  392. "FP8 is not supported when cuda version is lower than 11.8."
  393. )
  394. logger.info(
  395. "Using fp8_e5m2 data type to store kv cache. It reduces "
  396. "the GPU memory footprint and boosts the performance. "
  397. "But it may cause slight accuracy drop. "
  398. "Currently we only support fp8 without scaling factors and "
  399. "use e5m2 as a default format.")
  400. else:
  401. raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
  402. def verify_with_parallel_config(
  403. self,
  404. parallel_config: "ParallelConfig",
  405. ) -> None:
  406. total_cpu_memory = get_cpu_memory()
  407. # FIXME: Here, it is assumed that the GPUs in a tensor parallel
  408. # group are in the same node. However, the GPUs may span multiple nodes.
  409. num_gpus_per_node = parallel_config.tensor_parallel_size
  410. cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
  411. msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of "
  412. f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
  413. "allocated for the swap space.")
  414. if cpu_memory_usage > 0.7 * total_cpu_memory:
  415. raise ValueError("Too large swap space. " + msg)
  416. elif cpu_memory_usage > 0.4 * total_cpu_memory:
  417. logger.warning("Possibly too large swap space. " + msg)
  418. class ParallelConfig:
  419. """Configuration for the distributed execution.
  420. Args:
  421. pipeline_parallel_size: Number of pipeline parallel groups.
  422. tensor_parallel_size: Number of tensor parallel groups.
  423. worker_use_ray: Whether to use Ray for model workers. Will be set to
  424. True if either pipeline_parallel_size or tensor_parallel_size is
  425. greater than 1.
  426. max_parallel_loading_workers: Maximum number of multiple batches
  427. when load model sequentially. To avoid RAM OOM when using tensor
  428. parallel and large models.
  429. disable_custom_all_reduce: Disable the custom all-reduce kernel and
  430. fall back to NCCL.
  431. ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
  432. https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
  433. """
  434. def __init__(
  435. self,
  436. pipeline_parallel_size: int,
  437. tensor_parallel_size: int,
  438. worker_use_ray: bool,
  439. max_parallel_loading_workers: Optional[int] = None,
  440. disable_custom_all_reduce: bool = False,
  441. ray_workers_use_nsight: bool = False,
  442. ) -> None:
  443. self.pipeline_parallel_size = pipeline_parallel_size
  444. if is_neuron():
  445. # For Neuron device support, here we assign TP=1 to avoid sharding
  446. # within Aphrodite directly.
  447. # Transformer-neuronx would take neuron_tp_degree attribute, and
  448. # distribute the workload to multiple NeuronCores.
  449. self.tensor_parallel_size = 1
  450. self.neuron_tp_degree = tensor_parallel_size
  451. else:
  452. self.tensor_parallel_size = tensor_parallel_size
  453. self.worker_use_ray = worker_use_ray
  454. self.max_parallel_loading_workers = max_parallel_loading_workers
  455. self.disable_custom_all_reduce = disable_custom_all_reduce
  456. self.ray_workers_use_nsight = ray_workers_use_nsight
  457. self.world_size = pipeline_parallel_size * self.tensor_parallel_size
  458. # Ray worker is not supported for Neuron backend.
  459. if self.world_size > 1 and not is_neuron():
  460. self.worker_use_ray = True
  461. self._verify_args()
  462. def _verify_args(self) -> None:
  463. if self.pipeline_parallel_size > 1:
  464. raise NotImplementedError(
  465. "Pipeline parallelism is not supported yet.")
  466. if not self.disable_custom_all_reduce and self.world_size > 1:
  467. if is_hip():
  468. self.disable_custom_all_reduce = True
  469. logger.info(
  470. "Disabled the custom all-reduce kernel because it is not "
  471. "supported on AMD GPUs.")
  472. elif self.pipeline_parallel_size > 1:
  473. self.disable_custom_all_reduce = True
  474. logger.info(
  475. "Disabled the custom all-reduce kernel because it is not "
  476. "supported with pipeline parallelism.")
  477. if self.ray_workers_use_nsight and not self.worker_use_ray:
  478. raise ValueError("Unable to use nsight profiling unless workers "
  479. "run with Ray.")
  480. # FIXME: Fix the stability issues and re-enable the custom
  481. # all-reduce kernel.
  482. if not self.disable_custom_all_reduce and self.world_size > 1:
  483. self.disable_custom_all_reduce = True
  484. logger.info(
  485. "Custom all-reduce kernels are temporarily disabled due to "
  486. "stability issues. We will re-enable them once the issues are "
  487. "resolved.")
  488. class SchedulerConfig:
  489. """Scheduler configuration.
  490. Args:
  491. max_num_batched_tokens: Maximum number of tokens to be processed in
  492. a single iteration.
  493. max_num_seqs: Maximum number of sequences to be processed in a single
  494. iteration.
  495. max_model_len: Maximum length of a sequence (including prompt
  496. and generated text).
  497. max_paddings: Maximum number of paddings to be added to a batch.
  498. """
  499. def __init__(
  500. self,
  501. max_num_batched_tokens: Optional[int],
  502. max_num_seqs: int,
  503. max_model_len: int,
  504. max_paddings: int,
  505. ) -> None:
  506. if max_num_batched_tokens is not None:
  507. self.max_num_batched_tokens = max_num_batched_tokens
  508. else:
  509. # If max_model_len is too short, use 2048 as the default value for
  510. # higher throughput.
  511. self.max_num_batched_tokens = max(max_model_len, 2048)
  512. self.max_num_seqs = max_num_seqs
  513. self.max_model_len = max_model_len
  514. self.max_paddings = max_paddings
  515. self._verify_args()
  516. def _verify_args(self) -> None:
  517. if self.max_num_batched_tokens < self.max_model_len:
  518. raise ValueError(
  519. f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
  520. f"smaller than max_model_len ({self.max_model_len}). "
  521. "This effectively limits the maximum sequence length to "
  522. "max_num_batched_tokens and makes Aphrodite reject longer "
  523. "sequences. Please increase max_num_batched_tokens or "
  524. "decrease max_model_len.")
  525. if self.max_num_batched_tokens < self.max_num_seqs:
  526. raise ValueError(
  527. f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
  528. "be greater than or equal to max_num_seqs "
  529. f"({self.max_num_seqs}).")
  530. class DeviceConfig:
  531. def __init__(self, device: str = "auto") -> None:
  532. if device == "auto":
  533. # Automated device type detection
  534. if torch.cuda.is_available():
  535. self.device_type = "cuda"
  536. elif is_neuron():
  537. self.device_type = "neuron"
  538. else:
  539. raise RuntimeError("No supported device detected.")
  540. else:
  541. # Device type is assigned explicitly
  542. self.device_type = device
  543. # Some device types require processing inputs on CPU
  544. if self.device_type in ["neuron"]:
  545. self.device = torch.device("cpu")
  546. else:
  547. # Set device with device type
  548. self.device = torch.device(self.device_type)
  549. @property
  550. def is_neuron(self):
  551. return self.device_type == "neuron"
  552. @dataclass
  553. class LoRAConfig:
  554. max_lora_rank: int
  555. max_loras: int
  556. max_cpu_loras: Optional[int] = None
  557. lora_dtype: Optional[torch.dtype] = None
  558. lora_extra_vocab_size: int = 256
  559. # This is a constant.
  560. lora_vocab_padding_size: ClassVar[int] = 256
  561. def __post_init__(self):
  562. # Keep this in sync with kernels/punica/bgmv/bgmv_config.h
  563. possible_max_ranks = (8, 16, 32, 64)
  564. possible_lora_extra_vocab_size = (0, 256, 512)
  565. if self.max_lora_rank not in possible_max_ranks:
  566. raise ValueError(
  567. f"max_lora_rank ({self.max_lora_rank}) must be one of "
  568. f"{possible_max_ranks}.")
  569. if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
  570. raise ValueError(
  571. f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
  572. f"must be one of {possible_lora_extra_vocab_size}.")
  573. if self.max_loras < 1:
  574. raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
  575. if self.max_cpu_loras is None:
  576. self.max_cpu_loras = self.max_loras
  577. elif self.max_cpu_loras < self.max_loras:
  578. raise ValueError(
  579. f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
  580. f"max_loras ({self.max_loras})")
  581. def verify_with_model_config(self, model_config: ModelConfig):
  582. if self.lora_dtype in (None, "auto"):
  583. self.lora_dtype = model_config.dtype
  584. elif isinstance(self.lora_dtype, str):
  585. self.lora_dtype = getattr(torch, self.lora_dtype)
  586. if (model_config.quantization is not None
  587. and model_config.quantization == "gguf"):
  588. raise ValueError("LoRA is not supported with GGUF quantization.")
  589. def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
  590. if scheduler_config.max_num_batched_tokens > 65528:
  591. raise ValueError(
  592. "Due to limitations of the custom LoRA CUDA kernel, "
  593. "max_num_batched_tokens must be <= 65528 when "
  594. "LoRA is enabled.")
  595. _STR_DTYPE_TO_TORCH_DTYPE = {
  596. "half": torch.float16,
  597. "float16": torch.float16,
  598. "float": torch.float32,
  599. "float32": torch.float32,
  600. "bfloat16": torch.bfloat16,
  601. }
  602. _ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"]
  603. def _get_and_verify_dtype(
  604. config: PretrainedConfig,
  605. dtype: Union[str, torch.dtype],
  606. ) -> torch.dtype:
  607. # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
  608. # because config.torch_dtype can be None.
  609. config_dtype = getattr(config, "torch_dtype", None)
  610. if config_dtype is None:
  611. config_dtype = torch.float32
  612. if isinstance(dtype, str):
  613. dtype = dtype.lower()
  614. if dtype == "auto":
  615. if config_dtype == torch.float32:
  616. # Following the common practice, we use float16 for float32
  617. # models.
  618. torch_dtype = torch.float16
  619. else:
  620. torch_dtype = config_dtype
  621. else:
  622. if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
  623. raise ValueError(f"Unknown dtype: {dtype}")
  624. torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
  625. elif isinstance(dtype, torch.dtype):
  626. torch_dtype = dtype
  627. else:
  628. raise ValueError(f"Unknown dtype: {dtype}")
  629. if is_hip() and torch_dtype == torch.float32:
  630. rocm_supported_dtypes = [
  631. k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
  632. if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
  633. ]
  634. raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. "
  635. f"Supported dtypes are {rocm_supported_dtypes}")
  636. # Verify the dtype.
  637. if torch_dtype != config_dtype:
  638. if torch_dtype == torch.float32:
  639. # Upcasting to float32 is allowed.
  640. pass
  641. elif config_dtype == torch.float32:
  642. # Downcasting from float32 to float16 or bfloat16 is allowed.
  643. pass
  644. else:
  645. # Casting between float16 and bfloat16 is allowed with a warning.
  646. logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
  647. return torch_dtype
  648. def _get_and_verify_max_len(
  649. hf_config: PretrainedConfig,
  650. max_model_len: Optional[int],
  651. ) -> int:
  652. """Get and verify the model's maximum length."""
  653. derived_max_model_len = float("inf")
  654. possible_keys = [
  655. # OPT
  656. "max_position_embeddings",
  657. # GPT-2
  658. "n_positions",
  659. # MPT
  660. "max_seq_len",
  661. # ChatGLM2
  662. "seq_length",
  663. # Others
  664. "max_sequence_length",
  665. "max_seq_length",
  666. "seq_len",
  667. ]
  668. for key in possible_keys:
  669. max_len_key = getattr(hf_config, key, None)
  670. if max_len_key is not None:
  671. derived_max_model_len = min(derived_max_model_len, max_len_key)
  672. if derived_max_model_len == float("inf"):
  673. if max_model_len is not None:
  674. # If max_model_len is specified, we use it.
  675. return max_model_len
  676. default_max_len = 2048
  677. logger.warning(
  678. "The model's config.json does not contain any of the following "
  679. "keys to determine the original maximum length of the model: "
  680. f"{possible_keys}. Assuming the model's maximum length is "
  681. f"{default_max_len}.")
  682. derived_max_model_len = default_max_len
  683. rope_scaling = getattr(hf_config, "rope_scaling", None)
  684. if rope_scaling is not None:
  685. assert "factor" in rope_scaling
  686. scaling_factor = rope_scaling["factor"]
  687. if rope_scaling["type"] == "yarn":
  688. derived_max_model_len = rope_scaling[
  689. "original_max_position_embeddings"]
  690. derived_max_model_len *= scaling_factor
  691. if max_model_len is None:
  692. max_model_len = derived_max_model_len
  693. elif max_model_len > derived_max_model_len:
  694. # hope this works
  695. scaling_factor = max_model_len / derived_max_model_len
  696. hf_config.rope_scaling = {"factor": scaling_factor, "type": "dynamic"}
  697. logger.warning(
  698. f"User-specified max_model_len {max_model_len} is higher than "
  699. f"the original {derived_max_model_len}. "
  700. "Attempting to use RoPE scaling.")
  701. derived_max_model_len = max_model_len
  702. return int(max_model_len)