1
0

config.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. from typing import Optional
  2. import torch
  3. from transformers import PretrainedConfig
  4. from aphrodite.common.logger import init_logger
  5. from aphrodite.transformers_utils.config import get_config
  6. from aphrodite.common.utils import get_cpu_memory
  7. logger = init_logger(__name__)
  8. _GB = 1 << 30
  9. class ModelConfig:
  10. """Configuration for the model.
  11. Args:
  12. model: Name or path of the huggingface model to use.
  13. tokenizer: Name or path of the huggingface tokenizer to use.
  14. tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
  15. available, and "slow" will always use the slow tokenizer.
  16. trust_remote_code: Trust remote code (e.g., from HuggingFace) when
  17. downloading the model and tokenizer.
  18. download_dir: Directory to download and load the weights, default to the
  19. default cache directory of huggingface.
  20. load_format: The format of the model weights to load:
  21. "auto" will try to load the weights in the safetensors format and
  22. fall back to the pytorch bin format if safetensors format is
  23. not available.
  24. "pt" will load the weights in the pytorch bin format.
  25. "safetensors" will load the weights in the safetensors format.
  26. "npcache" will load the weights in pytorch format and store
  27. a numpy cache to speed up the loading.
  28. "dummy" will initialize the weights with random values, which is
  29. mainly for profiling.
  30. dtype: Data type for model weights and activations. The "auto" option
  31. will use FP16 precision for FP32 and FP16 models, and BF16 precision
  32. for BF16 models.
  33. seed: Random seed for reproducibility.
  34. revision: The specific model version to use. It can be a branch name,
  35. a tag name, or a commit ID. If unspecified, will use the default
  36. version.
  37. max_model_len: Maximum length of a sequence (including prompt and output).
  38. If None, will be derived from the model.
  39. """
  40. def __init__(
  41. self,
  42. model: str,
  43. tokenizer: str,
  44. tokenizer_mode: str,
  45. trust_remote_code: bool,
  46. download_dir: Optional[str],
  47. load_format: str,
  48. dtype: str,
  49. seed: int,
  50. revision: Optional[str],
  51. max_model_len: Optional[int] = None,
  52. ) -> None:
  53. self.model = model
  54. self.tokenizer = tokenizer
  55. self.tokenizer_mode = tokenizer_mode
  56. self.trust_remote_code = trust_remote_code
  57. self.download_dir = download_dir
  58. self.load_format = load_format
  59. self.seed = seed
  60. self.revision = revision
  61. self.hf_config = get_config(model, trust_remote_code, revision)
  62. self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
  63. self._verify_load_format()
  64. self._verify_tokenizer_mode()
  65. self.max_model_len = None
  66. if max_model_len is not None:
  67. derived_max_model_len = self.get_max_model_len()
  68. if max_model_len > derived_max_model_len:
  69. logger.warning(
  70. f"User-specified max_model_len ({max_model_len}) is "
  71. f"greater than the model's max length ({derived_max_model_len}). "
  72. f"Make sure the value is correct and within the model's ctxlen.")
  73. self.max_model_len = max_model_len
  74. def _verify_load_format(self) -> None:
  75. load_format = self.load_format.lower()
  76. if load_format not in [
  77. "auto", "pt", "safetensors", "npcache", "dummy"
  78. ]:
  79. raise ValueError(
  80. f"Unknown load format: {self.load_format}. Must be one of "
  81. "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
  82. self.load_format = load_format
  83. def _verify_tokenizer_mode(self) -> None:
  84. tokenizer_mode = self.tokenizer_mode.lower()
  85. if tokenizer_mode not in ["auto", "slow"]:
  86. raise ValueError(
  87. f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
  88. "either 'auto' or 'slow'.")
  89. self.tokenizer_mode = tokenizer_mode
  90. def verify_with_parallel_config(
  91. self,
  92. parallel_config: "ParallelConfig",
  93. ) -> None:
  94. total_num_attention_heads = self.hf_config.num_attention_heads
  95. tensor_parallel_size = parallel_config.tensor_parallel_size
  96. if total_num_attention_heads % tensor_parallel_size != 0:
  97. raise ValueError(
  98. f"Total number of attention heads ({total_num_attention_heads})"
  99. " must be divisible by tensor parallel size "
  100. f"({tensor_parallel_size}).")
  101. total_num_hidden_layers = self.hf_config.num_hidden_layers
  102. pipeline_parallel_size = parallel_config.pipeline_parallel_size
  103. if total_num_hidden_layers % pipeline_parallel_size != 0:
  104. raise ValueError(
  105. f"Total number of hidden layers ({total_num_hidden_layers}) "
  106. "must be divisible by pipeline parallel size "
  107. f"({pipeline_parallel_size}).")
  108. def get_hidden_size(self) -> int:
  109. return self.hf_config.hidden_size
  110. def get_head_size(self) -> int:
  111. # FIXME: This may not be true for all models.
  112. return self.hf_config.hidden_size // self.hf_config.num_attention_heads
  113. def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
  114. new_decoder_arch_falcon = (
  115. self.hf_config.model_type == "falcon"
  116. and getattr(self.hf_config, "new_decoder_architecture", False))
  117. if not new_decoder_arch_falcon and getattr(self.hf_config,
  118. "multi_query", False):
  119. # Multi-query attention, only one KV head.
  120. return 1
  121. if getattr(self.hf_config, "n_head_kv", None) is not None:
  122. return (self.hf_config.n_head_kv //
  123. parallel_config.tensor_parallel_size)
  124. if getattr(self.hf_config, "num_key_value_heads", None) is not None:
  125. return (self.hf_config.num_key_value_heads //
  126. parallel_config.tensor_parallel_size)
  127. total_num_attention_heads = self.hf_config.num_attention_heads
  128. return total_num_attention_heads // parallel_config.tensor_parallel_size
  129. def get_max_model_len(self) -> int:
  130. if self.max_model_len is not None:
  131. return self.max_model_len
  132. max_model_len = float("inf")
  133. possible_keys = [
  134. "max_position_embeddings",
  135. "n_positions",
  136. "max_seq_len",
  137. "max_sequence_length",
  138. "max_seq_length",
  139. "seq_len",
  140. ]
  141. for key in possible_keys:
  142. max_len_key = getattr(self.hf_config, key, None)
  143. if max_len_key is not None:
  144. max_model_len = min(max_model_len, max_len_key)
  145. return max_model_len
  146. def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
  147. total_num_hidden_layers = self.hf_config.num_hidden_layers
  148. return total_num_hidden_layers // parallel_config.pipeline_parallel_size
  149. class CacheConfig:
  150. """Configuration for the KV cache.
  151. Args:
  152. block_size: Size of a cache block in number of tokens.
  153. gpu_memory_utilization: Fraction of GPU memory to use for the
  154. Aphrodite execution.
  155. swap_space: Size of the CPU swap space per GPU (in GiB).
  156. """
  157. def __init__(
  158. self,
  159. block_size: int,
  160. gpu_memory_utilization: float,
  161. swap_space: int,
  162. ) -> None:
  163. self.block_size = block_size
  164. self.gpu_memory_utilization = gpu_memory_utilization
  165. self.swap_space_bytes = swap_space * _GB
  166. self._verify_args()
  167. # Will be set after profiling.
  168. self.num_gpu_blocks = None
  169. self.num_cpu_blocks = None
  170. def _verify_args(self) -> None:
  171. if self.gpu_memory_utilization > 1.0:
  172. raise ValueError(
  173. "GPU memory utilization must be less than 1.0. Got "
  174. f"{self.gpu_memory_utilization}.")
  175. def verify_with_parallel_config(
  176. self,
  177. parallel_config: "ParallelConfig",
  178. ) -> None:
  179. total_cpu_memory = get_cpu_memory()
  180. # FIXME: Here, it is assumed that the GPUs in a tensor parallel
  181. # group are in the same node. However, the GPUs may span multiple nodes.
  182. num_gpus_per_node = parallel_config.tensor_parallel_size
  183. cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
  184. msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of "
  185. f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
  186. "allocated for the swap space.")
  187. if cpu_memory_usage > 0.7 * total_cpu_memory:
  188. raise ValueError("Too large swap space. " + msg)
  189. elif cpu_memory_usage > 0.4 * total_cpu_memory:
  190. logger.warning("Possibly too large swap space. " + msg)
  191. class ParallelConfig:
  192. """Configuration for the distributed execution.
  193. Args:
  194. pipeline_parallel_size: Number of pipeline parallel groups.
  195. tensor_parallel_size: Number of tensor parallel groups.
  196. worker_use_ray: Whether to use Ray for model workers. Will be set to
  197. True if either pipeline_parallel_size or tensor_parallel_size is
  198. greater than 1.
  199. """
  200. def __init__(
  201. self,
  202. pipeline_parallel_size: int,
  203. tensor_parallel_size: int,
  204. worker_use_ray: bool,
  205. ) -> None:
  206. self.pipeline_parallel_size = pipeline_parallel_size
  207. self.tensor_parallel_size = tensor_parallel_size
  208. self.worker_use_ray = worker_use_ray
  209. self.world_size = pipeline_parallel_size * tensor_parallel_size
  210. if self.world_size > 1:
  211. self.worker_use_ray = True
  212. self._verify_args()
  213. def _verify_args(self) -> None:
  214. if self.pipeline_parallel_size > 1:
  215. raise NotImplementedError(
  216. "Pipeline parallelism is not supported yet.")
  217. class SchedulerConfig:
  218. """Scheduler configuration.
  219. Args:
  220. max_num_batched_tokens: Maximum number of tokens to be processed in
  221. a single iteration.
  222. max_num_seqs: Maximum number of sequences to be processed in a single
  223. iteration.
  224. max_model_len: Maximum length of a sequence (including prompt
  225. and generated text).
  226. """
  227. def __init__(self, max_num_batched_tokens: int, max_num_seqs: int,
  228. max_model_len: int) -> None:
  229. self.max_num_batched_tokens = max_num_batched_tokens
  230. self.max_num_seqs = max_num_seqs
  231. self.max_model_len = max_model_len
  232. _STR_DTYPE_TO_TORCH_DTYPE = {
  233. "half": torch.float16,
  234. "float16": torch.float16,
  235. "float": torch.float32,
  236. "float32": torch.float32,
  237. "bfloat16": torch.bfloat16,
  238. }
  239. def _get_and_verify_dtype(
  240. config: PretrainedConfig,
  241. dtype: str,
  242. ) -> torch.dtype:
  243. # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
  244. # because config.torch_dtype can be None.
  245. config_dtype = getattr(config, "torch_dtype", None)
  246. if config_dtype is None:
  247. config_dtype = torch.float32
  248. dtype = dtype.lower()
  249. if dtype == "auto":
  250. if config_dtype == torch.float32:
  251. # Following the common practice, we use float16 for float32 models.
  252. torch_dtype = torch.float16
  253. else:
  254. torch_dtype = config_dtype
  255. else:
  256. if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
  257. raise ValueError(f"Unknown dtype: {dtype}")
  258. torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
  259. # Verify the dtype.
  260. if torch_dtype != config_dtype:
  261. if torch_dtype == torch.float32:
  262. # Upcasting to float32 is allowed.
  263. pass
  264. elif config_dtype == torch.float32:
  265. # Downcasting from float32 to float16 or bfloat16 is allowed.
  266. pass
  267. else:
  268. # Casting between float16 and bfloat16 is allowed with a warning.
  269. logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
  270. # Check if the GPU supports the dtype.
  271. if torch_dtype == torch.bfloat16:
  272. compute_capability = torch.cuda.get_device_capability()
  273. if compute_capability[0] < 8:
  274. gpu_name = torch.cuda.get_device_name()
  275. raise ValueError(
  276. "Bfloat16 is only supported on GPUs with compute capability "
  277. f"of at least 8.0. Your {gpu_name} GPU has compute capability "
  278. f"{compute_capability[0]}.{compute_capability[1]}.")
  279. return torch_dtype