config.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712
  1. from typing import Optional, Union, ClassVar
  2. from dataclasses import dataclass
  3. import os
  4. from packaging.version import Version
  5. from loguru import logger
  6. import torch
  7. from transformers import PretrainedConfig
  8. from aphrodite.transformers_utils.config import get_config
  9. from aphrodite.common.utils import (get_cpu_memory, is_hip,
  10. get_nvcc_cuda_version)
  11. _GB = 1 << 30
  12. class ModelConfig:
  13. """Configuration for the model.
  14. Args:
  15. model: Name or path of the huggingface model to use.
  16. tokenizer: Name or path of the huggingface tokenizer to use.
  17. tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
  18. available, and "slow" will always use the slow tokenizer.
  19. trust_remote_code: Trust remote code (e.g., from HuggingFace) when
  20. downloading the model and tokenizer.
  21. download_dir: Directory to download and load the weights, default to the
  22. default cache directory of huggingface.
  23. load_format: The format of the model weights to load:
  24. "auto" will try to load the weights in the safetensors format and
  25. fall back to the pytorch bin format if safetensors format is
  26. not available.
  27. "pt" will load the weights in the pytorch bin format.
  28. "safetensors" will load the weights in the safetensors format.
  29. "npcache" will load the weights in pytorch format and store
  30. a numpy cache to speed up the loading.
  31. "dummy" will initialize the weights with random values, which is
  32. mainly for profiling.
  33. dtype: Data type for model weights and activations. The "auto" option
  34. will use FP16 precision for FP32 and FP16 models, and BF16 precision
  35. for BF16 models.
  36. seed: Random seed for reproducibility.
  37. revision: The specific model version to use. It can be a branch name,
  38. a tag name, or a commit id. If unspecified, will use the default
  39. version.
  40. tokenizer_revision: The specific tokenizer version to use. It can be a
  41. branch name, a tag name, or a commit id. If unspecified, will use
  42. the default version.
  43. max_model_len: Maximum length of a sequence (including prompt and
  44. output). If None, will be derived from the model.
  45. quantization: Quantization method that was used to quantize the model
  46. weights. If None, we assume the model weights are not quantized.
  47. load_in_4bit: Whether to load the FP16 model in bitsandbytes 4bit
  48. format. Works with AWQ models as well as FP16.
  49. load_in_8bit: Whether to load the FP16 model in 8bit format. Slower
  50. than load_in_smooth in terms of throughput.
  51. load_in_smooth: Whether to load the FP16 model in smoothquant format.
  52. enforce_eager: Whether to enforce eager execution. If True, we will
  53. disable CUDA graph and always execute the model in eager mode.
  54. If False, we will use CUDA graph and eager execution in hybrid.
  55. max_context_len_to_capture: Maximum context len covered by CUDA graphs.
  56. When a sequence has context length larger than this, we fall back
  57. to eager mode.
  58. """
  59. def __init__(
  60. self,
  61. model: str,
  62. tokenizer: str,
  63. tokenizer_mode: str,
  64. trust_remote_code: bool,
  65. download_dir: Optional[str],
  66. load_format: str,
  67. dtype: str,
  68. seed: int,
  69. revision: Optional[str] = None,
  70. tokenizer_revision: Optional[str] = None,
  71. max_model_len: Optional[int] = None,
  72. quantization: Optional[str] = None,
  73. load_in_4bit: bool = False,
  74. load_in_8bit: bool = False,
  75. load_in_smooth: bool = False,
  76. enforce_eager: bool = False,
  77. max_context_len_to_capture: Optional[int] = None,
  78. max_log_probs: int = 10,
  79. ) -> None:
  80. self.model = model
  81. self.tokenizer = tokenizer
  82. self.tokenizer_mode = tokenizer_mode
  83. self.trust_remote_code = trust_remote_code
  84. self.download_dir = download_dir
  85. self.load_format = load_format
  86. self.seed = seed
  87. self.revision = revision
  88. self.tokenizer_revision = tokenizer_revision
  89. self.quantization = quantization
  90. self.load_in_4bit = load_in_4bit
  91. self.load_in_8bit = load_in_8bit
  92. self.load_in_smooth = load_in_smooth
  93. self.enforce_eager = enforce_eager
  94. self.max_context_len_to_capture = max_context_len_to_capture
  95. self.max_log_probs = max_log_probs
  96. if os.environ.get("APHRODITE_USE_MODELSCOPE",
  97. "False").lower() == "true":
  98. # download model from ModelScope hub,
  99. # lazy import so that modelscope is not required for normal use.
  100. from modelscope.hub.snapshot_download import snapshot_download # pylint: disable=C
  101. model_path = snapshot_download(model_id=model,
  102. cache_dir=download_dir,
  103. revision=revision)
  104. self.model = model_path
  105. self.download_dir = model_path
  106. self.tokenizer = model_path
  107. self.hf_config = get_config(self.model, trust_remote_code, revision)
  108. self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
  109. self.max_model_len = _get_and_verify_max_len(self.hf_config,
  110. max_model_len)
  111. self._verify_load_format()
  112. self._verify_tokenizer_mode()
  113. self._verify_quantization()
  114. self._verify_cuda_graph()
  115. def _verify_load_format(self) -> None:
  116. load_format = self.load_format.lower()
  117. supported_load_format = [
  118. "auto", "pt", "safetensors", "npcache", "dummy"
  119. ]
  120. rocm_not_supported_load_format = []
  121. if load_format not in supported_load_format:
  122. raise ValueError(
  123. f"Unknown load format: {self.load_format}. Must be one of "
  124. "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
  125. if is_hip() and load_format in rocm_not_supported_load_format:
  126. rocm_supported_load_format = [
  127. f for f in supported_load_format
  128. if (f not in rocm_not_supported_load_format)
  129. ]
  130. raise ValueError(
  131. f"load format \'{load_format}\' is not supported in ROCm. "
  132. f"Supported load format are "
  133. f"{rocm_supported_load_format}")
  134. # TODO: Remove this check once HF updates the pt weights of Mixtral.
  135. architectures = getattr(self.hf_config, "architectures", [])
  136. if "MixtralForCausalLM" in architectures and load_format == "pt":
  137. raise ValueError(
  138. "Currently, the 'pt' format is not supported for Mixtral. "
  139. "Please use the 'safetensors' format instead. ")
  140. self.load_format = load_format
  141. def _verify_tokenizer_mode(self) -> None:
  142. tokenizer_mode = self.tokenizer_mode.lower()
  143. if tokenizer_mode not in ["auto", "slow"]:
  144. raise ValueError(
  145. f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
  146. "either 'auto' or 'slow'.")
  147. self.tokenizer_mode = tokenizer_mode
  148. def _verify_quantization(self) -> None:
  149. supported_quantization = [
  150. "aqlm", "awq", "bnb", "exl2", "gguf", "gptq", "quip", "squeezellm",
  151. "marlin"
  152. ]
  153. rocm_not_supported_quantization = ["aqlm", "awq", "bnb", "quip"]
  154. if self.quantization is not None:
  155. self.quantization = self.quantization.lower()
  156. if self.model.endswith("gguf"):
  157. if self.quantization is None:
  158. self.quantization = "gguf"
  159. elif self.quantization != "gguf":
  160. raise ValueError(
  161. f"GGUF file cannot be used in ({self.quantization}).")
  162. # Parse quantization method from the HF model config, if available.
  163. hf_quant_config = getattr(self.hf_config, "quantization_config", None)
  164. if hf_quant_config is not None:
  165. hf_quant_method = str(hf_quant_config["quant_method"]).lower()
  166. # If the GPTQ model is serialized in marlin format, use marlin.
  167. if (hf_quant_method == "gptq"
  168. and "is_marlin_format" in hf_quant_config
  169. and hf_quant_config["is_marlin_format"]):
  170. hf_quant_method = "marlin"
  171. if self.quantization is None:
  172. self.quantization = hf_quant_method
  173. elif self.quantization != hf_quant_method:
  174. raise ValueError(
  175. "Quantization method specified in the model config "
  176. f"({hf_quant_method}) does not match the quantization "
  177. f"method specified in the `quantization` argument "
  178. f"({self.quantization}).")
  179. if self.load_in_4bit:
  180. # the kernels seem to not work with 4bit weight_only
  181. if torch.cuda.get_device_capability(0)[0] < 8:
  182. raise ValueError(
  183. "load_in_4bit quantization is not supported on GPUs with "
  184. "compute capability less than 8.0.")
  185. if self.quantization is None:
  186. self.quantization = "bnb"
  187. self.hf_config.quantization_config = {
  188. "bits": 4,
  189. "quant_mode": "weight_only",
  190. "quant_method": "bnb",
  191. "group_size": 128,
  192. "zero_point": True,
  193. "from_float": True
  194. }
  195. elif self.quantization == "awq":
  196. logger.warning("AWQ model is being loaded in 4bit bnb format.")
  197. self.quantization = "bnb"
  198. self.hf_config.quantization_config = {
  199. "zero_point": True,
  200. "q_group_size": 128,
  201. "w_bit": 4,
  202. "version": "gemm"
  203. }
  204. elif self.quantization != "bnb":
  205. raise ValueError("4bit quantization is not supported in "
  206. f"{self.quantization}.")
  207. if self.load_in_8bit:
  208. if self.quantization is None:
  209. self.quantization = "bnb"
  210. elif self.quantization != "bnb":
  211. raise ValueError("8bit quantization is not supported in "
  212. f"{self.quantization}.")
  213. self.hf_config.quantization_config = {
  214. "bits": 8,
  215. "quant_mode": "llm_int8",
  216. "quant_method": "bnb",
  217. "group_size": 128,
  218. "zero_point": True,
  219. "from_float": True
  220. }
  221. self.enforce_eager = True
  222. if self.load_in_smooth:
  223. if self.quantization is None:
  224. self.quantization = "bnb"
  225. elif self.quantization != "bnb":
  226. raise ValueError("Smooth quantization is not supported in "
  227. f"{self.quantization}.")
  228. self.hf_config.quantization_config = {
  229. "bits": 8,
  230. "quant_mode": "smoothquant",
  231. "quant_method": "bnb",
  232. "group_size": 128,
  233. "zero_point": True,
  234. "from_float": True
  235. }
  236. self.enforce_eager = True
  237. if self.quantization is not None:
  238. if self.quantization not in supported_quantization:
  239. raise ValueError(
  240. f"Unknown quantization method: {self.quantization}. Must "
  241. f"be one of {supported_quantization}.")
  242. if is_hip(
  243. ) and self.quantization in rocm_not_supported_quantization:
  244. raise ValueError(
  245. f"{self.quantization} quantization is currently not "
  246. "supported in ROCm.")
  247. if self.quantization != "marlin":
  248. logger.warning(
  249. f"{self.quantization} quantization is not fully "
  250. "optimized yet. The speed can be slower than "
  251. "non-quantized models.")
  252. def _verify_cuda_graph(self) -> None:
  253. if self.max_context_len_to_capture is None:
  254. self.max_context_len_to_capture = self.max_model_len
  255. self.max_context_len_to_capture = min(self.max_context_len_to_capture,
  256. self.max_model_len)
  257. def verify_with_parallel_config(
  258. self,
  259. parallel_config: "ParallelConfig",
  260. ) -> None:
  261. total_num_attention_heads = self.hf_config.num_attention_heads
  262. tensor_parallel_size = parallel_config.tensor_parallel_size
  263. if total_num_attention_heads % tensor_parallel_size != 0:
  264. raise ValueError(
  265. f"Total number of attention heads ({total_num_attention_heads})"
  266. " must be divisible by tensor parallel size "
  267. f"({tensor_parallel_size}).")
  268. total_num_hidden_layers = self.hf_config.num_hidden_layers
  269. pipeline_parallel_size = parallel_config.pipeline_parallel_size
  270. if total_num_hidden_layers % pipeline_parallel_size != 0:
  271. raise ValueError(
  272. f"Total number of hidden layers ({total_num_hidden_layers}) "
  273. "must be divisible by pipeline parallel size "
  274. f"({pipeline_parallel_size}).")
  275. def get_sliding_window(self) -> Optional[int]:
  276. return getattr(self.hf_config, "sliding_window", None)
  277. def get_vocab_size(self) -> int:
  278. return self.hf_config.vocab_size
  279. def get_hidden_size(self) -> int:
  280. return self.hf_config.hidden_size
  281. def get_head_size(self) -> int:
  282. if hasattr(self.hf_config, "head_dim"):
  283. return self.hf_config.head_dim
  284. # FIXME: This may not be true for all models.
  285. return self.hf_config.hidden_size // self.hf_config.num_attention_heads
  286. def get_total_num_kv_heads(self) -> int:
  287. """Returns the total number of KV heads."""
  288. # For GPTBigCode & Falcon:
  289. # NOTE: for falcon, when new_decoder_architecture is True, the
  290. # multi_query flag is ignored and we use n_head_kv for the number of
  291. # KV heads.
  292. falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
  293. new_decoder_arch_falcon = (
  294. self.hf_config.model_type in falcon_model_types
  295. and getattr(self.hf_config, "new_decoder_architecture", False))
  296. if not new_decoder_arch_falcon and getattr(self.hf_config,
  297. "multi_query", False):
  298. # Multi-query attention, only one KV head.
  299. # Currently, tensor parallelism is not supported in this case.
  300. return 1
  301. attributes = [
  302. # For Falcon:
  303. "n_head_kv",
  304. "num_kv_heads",
  305. # For LLaMA-2:
  306. "num_key_value_heads",
  307. # For ChatGLM:
  308. "multi_query_group_num",
  309. ]
  310. for attr in attributes:
  311. num_kv_heads = getattr(self.hf_config, attr, None)
  312. if num_kv_heads is not None:
  313. return num_kv_heads
  314. # For non-grouped-query attention models, the number of KV heads is
  315. # equal to the number of attention heads.
  316. return self.hf_config.num_attention_heads
  317. def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
  318. """Returns the number of KV heads per GPU."""
  319. total_num_kv_heads = self.get_total_num_kv_heads()
  320. # If tensor parallelism is used, we divide the number of KV heads by
  321. # the tensor parallel size. We will replicate the KV heads in the
  322. # case where the number of KV heads is smaller than the tensor
  323. # parallel size so each GPU has at least one KV head.
  324. return max(1,
  325. total_num_kv_heads // parallel_config.tensor_parallel_size)
  326. def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
  327. total_num_hidden_layers = self.hf_config.num_hidden_layers
  328. return total_num_hidden_layers // parallel_config.pipeline_parallel_size
  329. class CacheConfig:
  330. """Configuration for the KV cache.
  331. Args:
  332. block_size: Size of a cache block in number of tokens.
  333. gpu_memory_utilization: Fraction of GPU memory to use for the
  334. Aphrodite execution.
  335. swap_space: Size of the CPU swap space per GPU (in GiB).
  336. cache_dtype: Data Type for KV cache storage.
  337. cache_quant_params_path: Path to the scales and zero points
  338. of KV cache quantization when cache_dtype is int8.
  339. """
  340. def __init__(
  341. self,
  342. block_size: int,
  343. gpu_memory_utilization: float,
  344. swap_space: int,
  345. cache_dtype: str,
  346. cache_quant_params_path: Optional[str] = None,
  347. sliding_window: Optional[int] = None,
  348. context_shift: bool = False,
  349. ) -> None:
  350. self.block_size = block_size
  351. self.gpu_memory_utilization = gpu_memory_utilization
  352. self.swap_space_bytes = swap_space * _GB
  353. self.cache_dtype = cache_dtype
  354. self.sliding_window = sliding_window
  355. self.cache_quant_params_path = cache_quant_params_path
  356. self.context_shift = context_shift
  357. self._verify_args()
  358. self._verify_cache_dtype()
  359. # Will be set after profiling.
  360. self.num_gpu_blocks = None
  361. self.num_cpu_blocks = None
  362. def _verify_args(self) -> None:
  363. if self.gpu_memory_utilization > 1.0:
  364. raise ValueError(
  365. "GPU memory utilization must be less than 1.0. Got "
  366. f"{self.gpu_memory_utilization}.")
  367. def _verify_cache_dtype(self) -> None:
  368. if self.cache_dtype in ["auto", "int8"]:
  369. pass
  370. elif self.cache_dtype == "fp8_e5m2":
  371. nvcc_cuda_version = get_nvcc_cuda_version()
  372. if nvcc_cuda_version < Version("11.8"):
  373. raise ValueError(
  374. "FP8 is not supported when cuda version is lower than "
  375. "11.8. If you think you have the correct cuda version, "
  376. "please make sure you've properly exported CUDA_HOME.")
  377. device_name = torch.cuda.get_device_name()
  378. if "AMD" in device_name:
  379. raise NotImplementedError(
  380. "FP8_E5M2 KV Cache on AMD GPU has not been supported yet.")
  381. logger.info(
  382. "Using fp8_e5m2 data type to store kv cache. It reduces "
  383. "the GPU memory footprint and boosts the performance. "
  384. "But it may cause slight accuracy drop. "
  385. "Currently we only support fp8 without scaling factors and "
  386. "make e5m2 as a default format.")
  387. else:
  388. raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
  389. def verify_with_parallel_config(
  390. self,
  391. parallel_config: "ParallelConfig",
  392. ) -> None:
  393. total_cpu_memory = get_cpu_memory()
  394. # FIXME: Here, it is assumed that the GPUs in a tensor parallel
  395. # group are in the same node. However, the GPUs may span multiple nodes.
  396. num_gpus_per_node = parallel_config.tensor_parallel_size
  397. cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
  398. msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of "
  399. f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
  400. "allocated for the swap space.")
  401. if cpu_memory_usage > 0.7 * total_cpu_memory:
  402. raise ValueError("Too large swap space. " + msg)
  403. elif cpu_memory_usage > 0.4 * total_cpu_memory:
  404. logger.warning("Possibly too large swap space. " + msg)
  405. class ParallelConfig:
  406. """Configuration for the distributed execution.
  407. Args:
  408. pipeline_parallel_size: Number of pipeline parallel groups.
  409. tensor_parallel_size: Number of tensor parallel groups.
  410. worker_use_ray: Whether to use Ray for model workers. Will be set to
  411. True if either pipeline_parallel_size or tensor_parallel_size is
  412. greater than 1.
  413. disable_custom_all_reduce: Disable the custom all-reduce kernel and
  414. fall back to NCCL.
  415. """
  416. def __init__(
  417. self,
  418. pipeline_parallel_size: int,
  419. tensor_parallel_size: int,
  420. worker_use_ray: bool,
  421. max_parallel_loading_workers: Optional[int] = None,
  422. disable_custom_all_reduce: bool = False,
  423. ) -> None:
  424. self.pipeline_parallel_size = pipeline_parallel_size
  425. self.tensor_parallel_size = tensor_parallel_size
  426. self.worker_use_ray = worker_use_ray
  427. self.max_parallel_loading_workers = max_parallel_loading_workers
  428. self.disable_custom_all_reduce = disable_custom_all_reduce
  429. self.world_size = pipeline_parallel_size * tensor_parallel_size
  430. if self.world_size > 1:
  431. self.worker_use_ray = True
  432. self._verify_args()
  433. def _verify_args(self) -> None:
  434. if self.pipeline_parallel_size > 1:
  435. raise NotImplementedError(
  436. "Pipeline parallelism is not supported yet.")
  437. if is_hip():
  438. self.disable_custom_all_reduce = True
  439. logger.info(
  440. "Disabled the custom all-reduce kernel because it is not "
  441. "supported on AMD GPUs.")
  442. elif self.pipeline_parallel_size > 1:
  443. self.disable_custom_all_reduce = True
  444. logger.info(
  445. "Disabled the custom all-reduce kernel because it is not "
  446. "supported with pipeline parallelism.")
  447. class SchedulerConfig:
  448. """Scheduler configuration.
  449. Args:
  450. max_num_batched_tokens: Maximum number of tokens to be processed in
  451. a single iteration.
  452. max_num_seqs: Maximum number of sequences to be processed in a single
  453. iteration.
  454. max_model_len: Maximum length of a sequence (including prompt
  455. and generated text).
  456. max_paddings: Maximum number of paddings to be added to a batch.
  457. """
  458. def __init__(
  459. self,
  460. max_num_batched_tokens: Optional[int],
  461. max_num_seqs: int,
  462. max_model_len: int,
  463. max_paddings: int,
  464. ) -> None:
  465. if max_num_batched_tokens is not None:
  466. self.max_num_batched_tokens = max_num_batched_tokens
  467. else:
  468. # If max_model_len is too short, use 2048 as the default value for
  469. # higher throughput.
  470. self.max_num_batched_tokens = max(max_model_len, 2048)
  471. self.max_num_seqs = max_num_seqs
  472. self.max_model_len = max_model_len
  473. self.max_paddings = max_paddings
  474. self._verify_args()
  475. def _verify_args(self) -> None:
  476. if self.max_num_batched_tokens < self.max_model_len:
  477. raise ValueError(
  478. f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
  479. f"smaller than max_model_len ({self.max_model_len}). "
  480. "This effectively limits the maximum sequence length to "
  481. "max_num_batched_tokens and makes Aphrodite reject longer "
  482. "sequences. Please increase max_num_batched_tokens or "
  483. "decrease max_model_len.")
  484. if self.max_num_batched_tokens < self.max_num_seqs:
  485. raise ValueError(
  486. f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
  487. "be greater than or equal to max_num_seqs "
  488. f"({self.max_num_seqs}).")
  489. class DeviceConfig:
  490. def __init__(self, device: str = "cuda") -> None:
  491. self.device = torch.device(device)
  492. @dataclass
  493. class LoRAConfig:
  494. max_lora_rank: int
  495. max_loras: int
  496. max_cpu_loras: Optional[int] = None
  497. lora_dtype: Optional[torch.dtype] = None
  498. lora_extra_vocab_size: int = 256
  499. # This is a constant.
  500. lora_vocab_padding_size: ClassVar[int] = 256
  501. def __post_init__(self):
  502. # Keep this in sync with kernels/punica/bgmv/bgmv_config.h
  503. possible_max_ranks = (8, 16, 32, 64)
  504. possible_lora_extra_vocab_size = (0, 256, 512)
  505. if self.max_lora_rank not in possible_max_ranks:
  506. raise ValueError(
  507. f"max_lora_rank ({self.max_lora_rank}) must be one of "
  508. f"{possible_max_ranks}.")
  509. if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
  510. raise ValueError(
  511. f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
  512. f"must be one of {possible_lora_extra_vocab_size}.")
  513. if self.max_loras < 1:
  514. raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
  515. if self.max_cpu_loras is None:
  516. self.max_cpu_loras = self.max_loras
  517. elif self.max_cpu_loras < self.max_loras:
  518. raise ValueError(
  519. f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
  520. f"max_num_seqs ({self.max_loras})")
  521. def verify_with_model_config(self, model_config: ModelConfig):
  522. if self.lora_dtype in (None, "auto"):
  523. self.lora_dtype = model_config.dtype
  524. elif isinstance(self.lora_dtype, str):
  525. self.lora_dtype = getattr(torch, self.lora_dtype)
  526. if (model_config.quantization is not None
  527. and model_config.quantization == "gguf"):
  528. raise ValueError("LoRA is not supported with GGUF quantization.")
  529. def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
  530. if scheduler_config.max_num_batched_tokens > 65528:
  531. raise ValueError(
  532. "Due to limitations of the custom LoRA CUDA kernel, "
  533. "max_num_batched_tokens must be <= 65528 when "
  534. "LoRA is enabled.")
  535. _STR_DTYPE_TO_TORCH_DTYPE = {
  536. "half": torch.float16,
  537. "float16": torch.float16,
  538. "float": torch.float32,
  539. "float32": torch.float32,
  540. "bfloat16": torch.bfloat16,
  541. }
  542. _ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"]
  543. def _get_and_verify_dtype(
  544. config: PretrainedConfig,
  545. dtype: Union[str, torch.dtype],
  546. ) -> torch.dtype:
  547. # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
  548. # because config.torch_dtype can be None.
  549. config_dtype = getattr(config, "torch_dtype", None)
  550. if config_dtype is None:
  551. config_dtype = torch.float32
  552. if isinstance(dtype, str):
  553. dtype = dtype.lower()
  554. if dtype == "auto":
  555. if config_dtype == torch.float32:
  556. # Following the common practice, we use float16 for float32
  557. # models.
  558. torch_dtype = torch.float16
  559. else:
  560. torch_dtype = config_dtype
  561. else:
  562. if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
  563. raise ValueError(f"Unknown dtype: {dtype}")
  564. torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
  565. elif isinstance(dtype, torch.dtype):
  566. torch_dtype = dtype
  567. else:
  568. raise ValueError(f"Unknown dtype: {dtype}")
  569. if is_hip() and torch_dtype == torch.float32:
  570. rocm_supported_dtypes = [
  571. k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
  572. if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
  573. ]
  574. raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. "
  575. f"Supported dtypes are {rocm_supported_dtypes}")
  576. # Verify the dtype.
  577. if torch_dtype != config_dtype:
  578. if torch_dtype == torch.float32:
  579. # Upcasting to float32 is allowed.
  580. pass
  581. elif config_dtype == torch.float32:
  582. # Downcasting from float32 to float16 or bfloat16 is allowed.
  583. pass
  584. else:
  585. # Casting between float16 and bfloat16 is allowed with a warning.
  586. logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
  587. return torch_dtype
  588. def _get_and_verify_max_len(
  589. hf_config: PretrainedConfig,
  590. max_model_len: Optional[int],
  591. ) -> int:
  592. """Get and verify the model's maximum length."""
  593. derived_max_model_len = float("inf")
  594. possible_keys = [
  595. # OPT
  596. "max_position_embeddings",
  597. # GPT-2
  598. "n_positions",
  599. # MPT
  600. "max_seq_len",
  601. # ChatGLM2
  602. "seq_length",
  603. # Others
  604. "max_sequence_length",
  605. "max_seq_length",
  606. "seq_len",
  607. ]
  608. for key in possible_keys:
  609. max_len_key = getattr(hf_config, key, None)
  610. if max_len_key is not None:
  611. derived_max_model_len = min(derived_max_model_len, max_len_key)
  612. if derived_max_model_len == float("inf"):
  613. if max_model_len is not None:
  614. # If max_model_len is specified, we use it.
  615. return max_model_len
  616. default_max_len = 2048
  617. logger.warning(
  618. "The model's config.json does not contain any of the following "
  619. "keys to determine the original maximum length of the model: "
  620. f"{possible_keys}. Assuming the model's maximum length is "
  621. f"{default_max_len}.")
  622. derived_max_model_len = default_max_len
  623. rope_scaling = getattr(hf_config, "rope_scaling", None)
  624. if rope_scaling is not None:
  625. assert "factor" in rope_scaling
  626. scaling_factor = rope_scaling["factor"]
  627. if rope_scaling["type"] == "yarn":
  628. derived_max_model_len = rope_scaling[
  629. "original_max_position_embeddings"]
  630. derived_max_model_len *= scaling_factor
  631. if max_model_len is None:
  632. max_model_len = derived_max_model_len
  633. elif max_model_len > derived_max_model_len:
  634. # hope this works
  635. scaling_factor = max_model_len / derived_max_model_len
  636. hf_config.rope_scaling = {"factor": scaling_factor, "type": "dynamic"}
  637. logger.warning(
  638. f"User-specified max_model_len {max_model_len} is higher than "
  639. f"the original {derived_max_model_len}. "
  640. "Attempting to use RoPE scaling.")
  641. derived_max_model_len = max_model_len
  642. return int(max_model_len)