config.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614
  1. from typing import Optional, Union, ClassVar
  2. from dataclasses import dataclass
  3. import os
  4. from packaging.version import Version
  5. import torch
  6. from transformers import PretrainedConfig
  7. from aphrodite.common.logger import init_logger
  8. from aphrodite.transformers_utils.config import get_config
  9. from aphrodite.common.utils import (get_cpu_memory, is_hip,
  10. get_nvcc_cuda_version)
  11. logger = init_logger(__name__)
  12. _GB = 1 << 30
  13. class ModelConfig:
  14. """Configuration for the model.
  15. Args:
  16. model: Name or path of the huggingface model to use.
  17. tokenizer: Name or path of the huggingface tokenizer to use.
  18. tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
  19. available, and "slow" will always use the slow tokenizer.
  20. trust_remote_code: Trust remote code (e.g., from HuggingFace) when
  21. downloading the model and tokenizer.
  22. download_dir: Directory to download and load the weights, default to the
  23. default cache directory of huggingface.
  24. load_format: The format of the model weights to load:
  25. "auto" will try to load the weights in the safetensors format and
  26. fall back to the pytorch bin format if safetensors format is
  27. not available.
  28. "pt" will load the weights in the pytorch bin format.
  29. "safetensors" will load the weights in the safetensors format.
  30. "npcache" will load the weights in pytorch format and store
  31. a numpy cache to speed up the loading.
  32. "dummy" will initialize the weights with random values, which is
  33. mainly for profiling.
  34. dtype: Data type for model weights and activations. The "auto" option
  35. will use FP16 precision for FP32 and FP16 models, and BF16 precision
  36. for BF16 models.
  37. seed: Random seed for reproducibility.
  38. revision: The specific model version to use. It can be a branch name,
  39. a tag name, or a commit id. If unspecified, will use the default
  40. version.
  41. tokenizer_revision: The specific tokenizer version to use. It can be a
  42. branch name, a tag name, or a commit id. If unspecified, will use
  43. the default version.
  44. max_model_len: Maximum length of a sequence (including prompt and
  45. output). If None, will be derived from the model.
  46. quantization: Quantization method that was used to quantize the model
  47. weights. If None, we assume the model weights are not quantized.
  48. enforce_eager: Whether to enforce eager execution. If True, we will
  49. disable CUDA graph and always execute the model in eager mode.
  50. If False, we will use CUDA graph and eager execution in hybrid.
  51. max_context_len_to_capture: Maximum context len covered by CUDA graphs.
  52. When a sequence has context length larger than this, we fall back
  53. to eager mode.
  54. """
  55. def __init__(
  56. self,
  57. model: str,
  58. tokenizer: str,
  59. tokenizer_mode: str,
  60. trust_remote_code: bool,
  61. download_dir: Optional[str],
  62. load_format: str,
  63. dtype: Union[str, torch.dtype],
  64. seed: int,
  65. revision: Optional[str] = None,
  66. tokenizer_revision: Optional[str] = None,
  67. max_model_len: Optional[int] = None,
  68. quantization: Optional[str] = None,
  69. enforce_eager: bool = False,
  70. max_context_len_to_capture: Optional[int] = None,
  71. ) -> None:
  72. self.model = model
  73. self.tokenizer = tokenizer
  74. self.tokenizer_mode = tokenizer_mode
  75. self.trust_remote_code = trust_remote_code
  76. self.download_dir = download_dir
  77. self.load_format = load_format
  78. self.seed = seed
  79. self.revision = revision
  80. self.tokenizer_revision = tokenizer_revision
  81. self.quantization = quantization
  82. self.enforce_eager = enforce_eager
  83. self.max_context_len_to_capture = max_context_len_to_capture
  84. if os.environ.get("APHRODITE_USE_MODELSCOPE",
  85. "False").lower() == "true":
  86. # download model from ModelScope hub,
  87. # lazy import so that modelscope is not required for normal use.
  88. from modelscope.hub.snapshot_download import snapshot_download # pylint: disable=C
  89. model_path = snapshot_download(model_id=model,
  90. cache_dir=download_dir,
  91. revision=revision)
  92. self.model = model_path
  93. self.download_dir = model_path
  94. self.tokenizer = model_path
  95. self.hf_config = get_config(self.model, trust_remote_code, revision)
  96. self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
  97. self.max_model_len = _get_and_verify_max_len(self.hf_config,
  98. max_model_len)
  99. self._verify_load_format()
  100. self._verify_tokenizer_mode()
  101. self._verify_quantization()
  102. self._verify_cuda_graph()
  103. def _verify_load_format(self) -> None:
  104. load_format = self.load_format.lower()
  105. supported_load_format = [
  106. "auto", "pt", "safetensors", "npcache", "dummy"
  107. ]
  108. rocm_not_supported_load_format = []
  109. if load_format not in supported_load_format:
  110. raise ValueError(
  111. f"Unknown load format: {self.load_format}. Must be one of "
  112. "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
  113. if is_hip() and load_format in rocm_not_supported_load_format:
  114. rocm_supported_load_format = [
  115. f for f in supported_load_format
  116. if (f not in rocm_not_supported_load_format)
  117. ]
  118. raise ValueError(
  119. f"load format \'{load_format}\' is not supported in ROCm. "
  120. f"Supported load format are "
  121. f"{rocm_supported_load_format}")
  122. # TODO: Remove this check once HF updates the pt weights of Mixtral.
  123. architectures = getattr(self.hf_config, "architectures", [])
  124. if "MixtralForCausalLM" in architectures and load_format == "pt":
  125. raise ValueError(
  126. "Currently, the 'pt' format is not supported for Mixtral. "
  127. "Please use the 'safetensors' format instead. ")
  128. self.load_format = load_format
  129. def _verify_tokenizer_mode(self) -> None:
  130. tokenizer_mode = self.tokenizer_mode.lower()
  131. if tokenizer_mode not in ["auto", "slow"]:
  132. raise ValueError(
  133. f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
  134. "either 'auto' or 'slow'.")
  135. self.tokenizer_mode = tokenizer_mode
  136. def _verify_quantization(self) -> None:
  137. supported_quantization = ["awq", "gguf", "gptq", "quip", "squeezellm"]
  138. rocm_not_supported_quantization = ["awq", "quip"]
  139. if self.quantization is not None:
  140. self.quantization = self.quantization.lower()
  141. # Parse quantization method from the HF model config, if available.
  142. hf_quant_config = getattr(self.hf_config, "quantization_config", None)
  143. if hf_quant_config is not None:
  144. hf_quant_method = str(hf_quant_config["quant_method"]).lower()
  145. if self.quantization is None:
  146. self.quantization = hf_quant_method
  147. elif self.quantization != hf_quant_method:
  148. raise ValueError(
  149. "Quantization method specified in the model config "
  150. f"({hf_quant_method}) does not match the quantization "
  151. f"method specified in the `quantization` argument "
  152. f"({self.quantization}).")
  153. if self.quantization is not None:
  154. if self.quantization not in supported_quantization:
  155. raise ValueError(
  156. f"Unknown quantization method: {self.quantization}. Must "
  157. f"be one of {supported_quantization}.")
  158. if is_hip(
  159. ) and self.quantization in rocm_not_supported_quantization:
  160. raise ValueError(
  161. f"{self.quantization} quantization is currently not "
  162. "supported in ROCm.")
  163. logger.warning(f"{self.quantization} quantization is not fully "
  164. "optimized yet. The speed can be slower than "
  165. "non-quantized models.")
  166. def _verify_cuda_graph(self) -> None:
  167. if self.max_context_len_to_capture is None:
  168. self.max_context_len_to_capture = self.max_model_len
  169. self.max_context_len_to_capture = min(self.max_context_len_to_capture,
  170. self.max_model_len)
  171. def verify_with_parallel_config(
  172. self,
  173. parallel_config: "ParallelConfig",
  174. ) -> None:
  175. total_num_attention_heads = self.hf_config.num_attention_heads
  176. tensor_parallel_size = parallel_config.tensor_parallel_size
  177. if total_num_attention_heads % tensor_parallel_size != 0:
  178. raise ValueError(
  179. f"Total number of attention heads ({total_num_attention_heads})"
  180. " must be divisible by tensor parallel size "
  181. f"({tensor_parallel_size}).")
  182. total_num_hidden_layers = self.hf_config.num_hidden_layers
  183. pipeline_parallel_size = parallel_config.pipeline_parallel_size
  184. if total_num_hidden_layers % pipeline_parallel_size != 0:
  185. raise ValueError(
  186. f"Total number of hidden layers ({total_num_hidden_layers}) "
  187. "must be divisible by pipeline parallel size "
  188. f"({pipeline_parallel_size}).")
  189. def get_sliding_window(self) -> Optional[int]:
  190. return getattr(self.hf_config, "sliding_window", None)
  191. def get_vocab_size(self) -> int:
  192. return self.hf_config.vocab_size
  193. def get_hidden_size(self) -> int:
  194. return self.hf_config.hidden_size
  195. def get_head_size(self) -> int:
  196. if hasattr(self.hf_config, "head_dim"):
  197. return self.hf_config.head_dim
  198. # FIXME: This may not be true for all models.
  199. return self.hf_config.hidden_size // self.hf_config.num_attention_heads
  200. def get_total_num_kv_heads(self) -> int:
  201. """Returns the total number of KV heads."""
  202. # For GPTBigCode & Falcon:
  203. # NOTE: for falcon, when new_decoder_architecture is True, the
  204. # multi_query flag is ignored and we use n_head_kv for the number of
  205. # KV heads.
  206. falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
  207. new_decoder_arch_falcon = (
  208. self.hf_config.model_type in falcon_model_types
  209. and getattr(self.hf_config, "new_decoder_architecture", False))
  210. if not new_decoder_arch_falcon and getattr(self.hf_config,
  211. "multi_query", False):
  212. # Multi-query attention, only one KV head.
  213. # Currently, tensor parallelism is not supported in this case.
  214. return 1
  215. attributes = [
  216. # For Falcon:
  217. "n_head_kv",
  218. "num_kv_heads",
  219. # For LLaMA-2:
  220. "num_key_value_heads",
  221. # For ChatGLM:
  222. "multi_query_group_num",
  223. ]
  224. for attr in attributes:
  225. num_kv_heads = getattr(self.hf_config, attr, None)
  226. if num_kv_heads is not None:
  227. return num_kv_heads
  228. # For non-grouped-query attention models, the number of KV heads is
  229. # equal to the number of attention heads.
  230. return self.hf_config.num_attention_heads
  231. def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
  232. """Returns the number of KV heads per GPU."""
  233. total_num_kv_heads = self.get_total_num_kv_heads()
  234. # If tensor parallelism is used, we divide the number of KV heads by
  235. # the tensor parallel size. We will replicate the KV heads in the
  236. # case where the number of KV heads is smaller than the tensor
  237. # parallel size so each GPU has at least one KV head.
  238. return max(1,
  239. total_num_kv_heads // parallel_config.tensor_parallel_size)
  240. def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
  241. total_num_hidden_layers = self.hf_config.num_hidden_layers
  242. return total_num_hidden_layers // parallel_config.pipeline_parallel_size
  243. class CacheConfig:
  244. """Configuration for the KV cache.
  245. Args:
  246. block_size: Size of a cache block in number of tokens.
  247. gpu_memory_utilization: Fraction of GPU memory to use for the
  248. Aphrodite execution.
  249. swap_space: Size of the CPU swap space per GPU (in GiB).
  250. cache_dtype: Data Type for KV cache storage.
  251. """
  252. def __init__(
  253. self,
  254. block_size: int,
  255. gpu_memory_utilization: float,
  256. swap_space: int,
  257. cache_dtype: str,
  258. sliding_window: Optional[int] = None,
  259. ) -> None:
  260. self.block_size = block_size
  261. self.gpu_memory_utilization = gpu_memory_utilization
  262. self.swap_space_bytes = swap_space * _GB
  263. self.cache_dtype = cache_dtype
  264. self.sliding_window = sliding_window
  265. self._verify_args()
  266. self._verify_cache_dtype()
  267. # Will be set after profiling.
  268. self.num_gpu_blocks = None
  269. self.num_cpu_blocks = None
  270. def _verify_args(self) -> None:
  271. if self.gpu_memory_utilization > 1.0:
  272. raise ValueError(
  273. "GPU memory utilization must be less than 1.0. Got "
  274. f"{self.gpu_memory_utilization}.")
  275. def _verify_cache_dtype(self) -> None:
  276. if self.cache_dtype == "auto":
  277. pass
  278. elif self.cache_dtype == "fp8_e5m2":
  279. nvcc_cuda_version = get_nvcc_cuda_version()
  280. if nvcc_cuda_version < Version("11.8"):
  281. raise ValueError(
  282. "FP8 is not supported when cuda version is lower than "
  283. "11.8. If you think you have the correct cuda version, "
  284. "please make sure you've properly exported CUDA_HOME.")
  285. device_name = torch.cuda.get_device_name()
  286. if "AMD" in device_name:
  287. raise NotImplementedError(
  288. "FP8_E5M2 KV Cache on AMD GPU has not been supported yet.")
  289. logger.info(
  290. "Using fp8_e5m2 data type to store kv cache. It reduces "
  291. "the GPU memory footprint and boosts the performance. "
  292. "But it may cause slight accuracy drop. "
  293. "Currently we only support fp8 without scaling factors and "
  294. "make e5m2 as a default format.")
  295. else:
  296. raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
  297. def verify_with_parallel_config(
  298. self,
  299. parallel_config: "ParallelConfig",
  300. ) -> None:
  301. total_cpu_memory = get_cpu_memory()
  302. # FIXME: Here, it is assumed that the GPUs in a tensor parallel
  303. # group are in the same node. However, the GPUs may span multiple nodes.
  304. num_gpus_per_node = parallel_config.tensor_parallel_size
  305. cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
  306. msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of "
  307. f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
  308. "allocated for the swap space.")
  309. if cpu_memory_usage > 0.7 * total_cpu_memory:
  310. raise ValueError("Too large swap space. " + msg)
  311. elif cpu_memory_usage > 0.4 * total_cpu_memory:
  312. logger.warning("Possibly too large swap space. " + msg)
  313. class ParallelConfig:
  314. """Configuration for the distributed execution.
  315. Args:
  316. pipeline_parallel_size: Number of pipeline parallel groups.
  317. tensor_parallel_size: Number of tensor parallel groups.
  318. worker_use_ray: Whether to use Ray for model workers. Will be set to
  319. True if either pipeline_parallel_size or tensor_parallel_size is
  320. greater than 1.
  321. disable_custom_all_reduce: Disable the custom all-reduce kernel and
  322. fall back to NCCL.
  323. """
  324. def __init__(
  325. self,
  326. pipeline_parallel_size: int,
  327. tensor_parallel_size: int,
  328. worker_use_ray: bool,
  329. max_parallel_loading_workers: Optional[int] = None,
  330. disable_custom_all_reduce: bool = False,
  331. ) -> None:
  332. self.pipeline_parallel_size = pipeline_parallel_size
  333. self.tensor_parallel_size = tensor_parallel_size
  334. self.worker_use_ray = worker_use_ray
  335. self.max_parallel_loading_workers = max_parallel_loading_workers
  336. self.disable_custom_all_reduce = disable_custom_all_reduce
  337. self.world_size = pipeline_parallel_size * tensor_parallel_size
  338. if self.world_size > 1:
  339. self.worker_use_ray = True
  340. self._verify_args()
  341. def _verify_args(self) -> None:
  342. if self.pipeline_parallel_size > 1:
  343. raise NotImplementedError(
  344. "Pipeline parallelism is not supported yet.")
  345. if is_hip():
  346. self.disable_custom_all_reduce = True
  347. logger.info(
  348. "Disabled the custom all-reduce kernel because it is not "
  349. "supported on AMD GPUs.")
  350. elif self.pipeline_parallel_size > 1:
  351. self.disable_custom_all_reduce = True
  352. logger.info(
  353. "Disabled the custom all-reduce kernel because it is not "
  354. "supported with pipeline parallelism.")
  355. class SchedulerConfig:
  356. """Scheduler configuration.
  357. Args:
  358. max_num_batched_tokens: Maximum number of tokens to be processed in
  359. a single iteration.
  360. max_num_seqs: Maximum number of sequences to be processed in a single
  361. iteration.
  362. max_model_len: Maximum length of a sequence (including prompt
  363. and generated text).
  364. max_paddings: Maximum number of paddings to be added to a batch.
  365. """
  366. def __init__(
  367. self,
  368. max_num_batched_tokens: Optional[int],
  369. max_num_seqs: int,
  370. max_model_len: int,
  371. max_paddings: int,
  372. ) -> None:
  373. if max_num_batched_tokens is not None:
  374. self.max_num_batched_tokens = max_num_batched_tokens
  375. else:
  376. # If max_model_len is too short, use 2048 as the default value for
  377. # higher throughput.
  378. self.max_num_batched_tokens = max(max_model_len, 2048)
  379. self.max_num_seqs = max_num_seqs
  380. self.max_model_len = max_model_len
  381. self.max_paddings = max_paddings
  382. self._verify_args()
  383. def _verify_args(self) -> None:
  384. if self.max_num_batched_tokens < self.max_model_len:
  385. raise ValueError(
  386. f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
  387. f"smaller than max_model_len ({self.max_model_len}). "
  388. "This effectively limits the maximum sequence length to "
  389. "max_num_batched_tokens and makes Aphrodite reject longer "
  390. "sequences. Please increase max_num_batched_tokens or "
  391. "decrease max_model_len.")
  392. if self.max_num_batched_tokens < self.max_num_seqs:
  393. raise ValueError(
  394. f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
  395. "be greater than or equal to max_num_seqs "
  396. f"({self.max_num_seqs}).")
  397. @dataclass
  398. class LoRAConfig:
  399. max_lora_rank: int
  400. max_loras: int
  401. max_cpu_loras: Optional[int] = None
  402. lora_dtype: Optional[torch.dtype] = None
  403. lora_extra_vocab_size: int = 256
  404. # This is a constant.
  405. lora_vocab_padding_size: ClassVar[int] = 256
  406. def __post_init__(self):
  407. # Keep this in sync with kernels/punica/bgmv/bgmv_config.h
  408. possible_max_ranks = (8, 16, 32, 64)
  409. possible_lora_extra_vocab_size = (0, 256, 512)
  410. if self.max_lora_rank not in possible_max_ranks:
  411. raise ValueError(
  412. f"max_lora_rank ({self.max_lora_rank}) must be one of "
  413. f"{possible_max_ranks}.")
  414. if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
  415. raise ValueError(
  416. f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
  417. f"must be one of {possible_lora_extra_vocab_size}.")
  418. if self.max_loras < 1:
  419. raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
  420. if self.max_cpu_loras is None:
  421. self.max_cpu_loras = self.max_loras
  422. elif self.max_cpu_loras < self.max_loras:
  423. raise ValueError(
  424. f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
  425. f"max_num_seqs ({self.max_loras})")
  426. def verify_with_model_config(self, model_config: ModelConfig):
  427. if self.lora_dtype in (None, "auto"):
  428. self.lora_dtype = model_config.dtype
  429. elif isinstance(self.lora_dtype, str):
  430. self.lora_dtype = getattr(torch, self.lora_dtype)
  431. if model_config.quantization is not None:
  432. raise ValueError(
  433. "LoRA is not supported with quantized models yet.")
  434. def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
  435. if scheduler_config.max_num_batched_tokens > 65528:
  436. raise ValueError(
  437. "Due to limitations of the custom LoRA CUDA kernel, "
  438. "max_num_batched_tokens must be <= 65528 when "
  439. "LoRA is enabled.")
  440. _STR_DTYPE_TO_TORCH_DTYPE = {
  441. "half": torch.float16,
  442. "float16": torch.float16,
  443. "float": torch.float32,
  444. "float32": torch.float32,
  445. "bfloat16": torch.bfloat16,
  446. }
  447. _ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"]
  448. def _get_and_verify_dtype(
  449. config: PretrainedConfig,
  450. dtype: Union[str, torch.dtype],
  451. ) -> torch.dtype:
  452. # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
  453. # because config.torch_dtype can be None.
  454. config_dtype = getattr(config, "torch_dtype", None)
  455. if config_dtype is None:
  456. config_dtype = torch.float32
  457. if isinstance(dtype, str):
  458. dtype = dtype.lower()
  459. if dtype == "auto":
  460. if config_dtype == torch.float32:
  461. # Following the common practice, we use float16 for float32
  462. # models.
  463. torch_dtype = torch.float16
  464. else:
  465. torch_dtype = config_dtype
  466. else:
  467. if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
  468. raise ValueError(f"Unknown dtype: {dtype}")
  469. torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
  470. elif isinstance(dtype, torch.dtype):
  471. torch_dtype = dtype
  472. else:
  473. raise ValueError(f"Unknown dtype: {dtype}")
  474. if is_hip() and torch_dtype == torch.float32:
  475. rocm_supported_dtypes = [
  476. k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
  477. if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
  478. ]
  479. raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. "
  480. f"Supported dtypes are {rocm_supported_dtypes}")
  481. # Verify the dtype.
  482. if torch_dtype != config_dtype:
  483. if torch_dtype == torch.float32:
  484. # Upcasting to float32 is allowed.
  485. pass
  486. elif config_dtype == torch.float32:
  487. # Downcasting from float32 to float16 or bfloat16 is allowed.
  488. pass
  489. else:
  490. # Casting between float16 and bfloat16 is allowed with a warning.
  491. logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
  492. return torch_dtype
  493. def _get_and_verify_max_len(
  494. hf_config: PretrainedConfig,
  495. max_model_len: Optional[int],
  496. ) -> int:
  497. """Get and verify the model's maximum length."""
  498. derived_max_model_len = float("inf")
  499. possible_keys = [
  500. # OPT
  501. "max_position_embeddings",
  502. # GPT-2
  503. "n_positions",
  504. # MPT
  505. "max_seq_len",
  506. # ChatGLM2
  507. "seq_length",
  508. # Others
  509. "max_sequence_length",
  510. "max_seq_length",
  511. "seq_len",
  512. ]
  513. for key in possible_keys:
  514. max_len_key = getattr(hf_config, key, None)
  515. if max_len_key is not None:
  516. derived_max_model_len = min(derived_max_model_len, max_len_key)
  517. if derived_max_model_len == float("inf"):
  518. if max_model_len is not None:
  519. # If max_model_len is specified, we use it.
  520. return max_model_len
  521. default_max_len = 2048
  522. logger.warning(
  523. "The model's config.json does not contain any of the following "
  524. "keys to determine the original maximum length of the model: "
  525. f"{possible_keys}. Assuming the model's maximum length is "
  526. f"{default_max_len}.")
  527. derived_max_model_len = default_max_len
  528. rope_scaling = getattr(hf_config, "rope_scaling", None)
  529. if rope_scaling is not None:
  530. assert "factor" in rope_scaling
  531. scaling_factor = rope_scaling["factor"]
  532. if rope_scaling["type"] == "yarn":
  533. derived_max_model_len = rope_scaling[
  534. "original_max_position_embeddings"]
  535. derived_max_model_len *= scaling_factor
  536. if max_model_len is None:
  537. max_model_len = derived_max_model_len
  538. elif max_model_len > derived_max_model_len:
  539. # hope this works
  540. scaling_factor = max_model_len / derived_max_model_len
  541. hf_config.rope_scaling = {"factor": scaling_factor, "type": "dynamic"}
  542. logger.warning(
  543. f"User-specified max_model_len {max_model_len} is higher than "
  544. f"the original {derived_max_model_len}. "
  545. "Attempting to use RoPE scaling.")
  546. derived_max_model_len = max_model_len
  547. return int(max_model_len)