llm.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
  1. from contextlib import contextmanager
  2. from typing import ClassVar, List, Optional, Sequence, Union, cast, overload
  3. from tqdm import tqdm
  4. from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
  5. from aphrodite.common.inputs import (PromptInputs, PromptStrictInputs,
  6. TextPrompt, TextTokensPrompt,
  7. TokensPrompt, parse_and_batch_prompt)
  8. from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
  9. from aphrodite.common.pooling_params import PoolingParams
  10. from aphrodite.common.sampling_params import SamplingParams
  11. from aphrodite.common.utils import Counter, deprecate_kwargs
  12. from aphrodite.engine.aphrodite_engine import AphroditeEngine
  13. from aphrodite.engine.args_tools import EngineArgs
  14. from aphrodite.lora.request import LoRARequest
  15. from aphrodite.transformers_utils.tokenizer import get_cached_tokenizer
  16. class LLM:
  17. """An LLM for generating texts from given prompts and sampling parameters.
  18. This class includes a tokenizer, a language model (possibly distributed
  19. across multiple GPUs), and GPU memory space allocated for intermediate
  20. states (aka KV cache). Given a batch of prompts and sampling parameters,
  21. this class generates texts from the model, using an intelligent batching
  22. mechanism and efficient memory management.
  23. NOTE: This class is intended to be used for offline inference. For online
  24. serving, use the :class:`~aphrodite.AsyncAphrodite` class instead.
  25. NOTE: For the comprehensive list of arguments, see
  26. :class:`~aphrodite.EngineArgs`.
  27. Args:
  28. model: The name or path of a HuggingFace Transformers model.
  29. tokenizer: The name or path of a HuggingFace Transformers tokenizer.
  30. tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
  31. if available, and "slow" will always use the slow tokenizer.
  32. skip_tokenizer_init: If true, skip initialization of tokenizer and
  33. detokenizer. Expect valid prompt_token_ids and None for prompt
  34. from the input.
  35. trust_remote_code: Trust remote code (e.g., from HuggingFace) when
  36. downloading the model and tokenizer.
  37. tensor_parallel_size: The number of GPUs to use for distributed
  38. execution with tensor parallelism.
  39. dtype: The data type for the model weights and activations. Currently,
  40. we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
  41. the `torch_dtype` attribute specified in the model config file.
  42. However, if the `torch_dtype` in the config is `float32`, we will
  43. use `float16` instead.
  44. quantization: The method used to quantize the model weights. Currently,
  45. we support "awq", "gptq", "squeezellm", and "fp8" (experimental).
  46. If None, we first check the `quantization_config` attribute in the
  47. model config file. If that is None, we assume the model weights are
  48. not quantized and use `dtype` to determine the data type of
  49. the weights.
  50. revision: The specific model version to use. It can be a branch name,
  51. a tag name, or a commit id.
  52. tokenizer_revision: The specific tokenizer version to use. It can be a
  53. branch name, a tag name, or a commit id.
  54. seed: The seed to initialize the random number generator for sampling.
  55. gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
  56. reserve for the model weights, activations, and KV cache. Higher
  57. values will increase the KV cache size and thus improve the model's
  58. throughput. However, if the value is too high, it may cause out-of-
  59. memory (OOM) errors.
  60. swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
  61. This can be used for temporarily storing the states of the requests
  62. when their `best_of` sampling parameters are larger than 1. If all
  63. requests will have `best_of=1`, you can safely set this to 0.
  64. Otherwise, too small values may cause out-of-memory (OOM) errors.
  65. enforce_eager: Whether to enforce eager execution. If True, we will
  66. disable CUDA graph and always execute the model in eager mode.
  67. If False, we will use CUDA graph and eager execution in hybrid.
  68. max_context_len_to_capture: Maximum context len covered by CUDA graphs.
  69. When a sequence has context length larger than this, we fall back
  70. to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
  71. max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
  72. When a sequence has context length larger than this, we fall back
  73. to eager mode.
  74. disable_custom_all_reduce: See ParallelConfig
  75. """
  76. DEPRECATE_LEGACY: ClassVar[bool] = False
  77. """A flag to toggle whether to deprecate the legacy generate/encode API."""
  78. @classmethod
  79. @contextmanager
  80. def deprecate_legacy_api(cls):
  81. cls.DEPRECATE_LEGACY = True
  82. yield
  83. cls.DEPRECATE_LEGACY = False
  84. def __init__(
  85. self,
  86. model: str,
  87. tokenizer: Optional[str] = None,
  88. tokenizer_mode: str = "auto",
  89. skip_tokenizer_init: bool = False,
  90. trust_remote_code: bool = False,
  91. tensor_parallel_size: int = 1,
  92. dtype: str = "auto",
  93. quantization: Optional[str] = None,
  94. revision: Optional[str] = None,
  95. tokenizer_revision: Optional[str] = None,
  96. seed: int = 0,
  97. gpu_memory_utilization: float = 0.9,
  98. swap_space: int = 4,
  99. enforce_eager: bool = False,
  100. max_context_len_to_capture: Optional[int] = None,
  101. max_seq_len_to_capture: int = 8192,
  102. disable_custom_all_reduce: bool = False,
  103. **kwargs,
  104. ) -> None:
  105. if "disable_log_stats" not in kwargs:
  106. kwargs["disable_log_stats"] = True
  107. engine_args = EngineArgs(
  108. model=model,
  109. tokenizer=tokenizer,
  110. tokenizer_mode=tokenizer_mode,
  111. skip_tokenizer_init=skip_tokenizer_init,
  112. trust_remote_code=trust_remote_code,
  113. tensor_parallel_size=tensor_parallel_size,
  114. dtype=dtype,
  115. quantization=quantization,
  116. revision=revision,
  117. tokenizer_revision=tokenizer_revision,
  118. seed=seed,
  119. gpu_memory_utilization=gpu_memory_utilization,
  120. swap_space=swap_space,
  121. enforce_eager=enforce_eager,
  122. max_context_len_to_capture=max_context_len_to_capture,
  123. max_seq_len_to_capture=max_seq_len_to_capture,
  124. disable_custom_all_reduce=disable_custom_all_reduce,
  125. **kwargs,
  126. )
  127. self.llm_engine = AphroditeEngine.from_engine_args(engine_args)
  128. self.request_counter = Counter()
  129. def get_tokenizer(
  130. self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
  131. return self.llm_engine.tokenizer.tokenizer
  132. def set_tokenizer(
  133. self,
  134. tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
  135. ) -> None:
  136. # While CachedTokenizer is dynamic, have no choice but
  137. # compare class name. Misjudgment will arise from
  138. # user-defined tokenizer started with 'Cached'
  139. if tokenizer.__class__.__name__.startswith("Cached"):
  140. self.llm_engine.tokenizer.tokenizer = tokenizer
  141. else:
  142. self.llm_engine.tokenizer.tokenizer = get_cached_tokenizer(
  143. tokenizer)
  144. @overload # LEGACY: single (prompt + optional token ids)
  145. def generate(
  146. self,
  147. prompts: str,
  148. sampling_params: Optional[Union[SamplingParams,
  149. List[SamplingParams]]] = None,
  150. prompt_token_ids: Optional[List[int]] = None,
  151. use_tqdm: bool = True,
  152. lora_request: Optional[LoRARequest] = None,
  153. ) -> List[RequestOutput]:
  154. ...
  155. @overload # LEGACY: multi (prompt + optional token ids)
  156. def generate(
  157. self,
  158. prompts: List[str],
  159. sampling_params: Optional[Union[SamplingParams,
  160. List[SamplingParams]]] = None,
  161. prompt_token_ids: Optional[List[List[int]]] = None,
  162. use_tqdm: bool = True,
  163. lora_request: Optional[LoRARequest] = None,
  164. ) -> List[RequestOutput]:
  165. ...
  166. @overload # LEGACY: single (token ids + optional prompt)
  167. def generate(
  168. self,
  169. prompts: Optional[str] = None,
  170. sampling_params: Optional[Union[SamplingParams,
  171. List[SamplingParams]]] = None,
  172. *,
  173. prompt_token_ids: List[int],
  174. use_tqdm: bool = True,
  175. lora_request: Optional[LoRARequest] = None,
  176. ) -> List[RequestOutput]:
  177. ...
  178. @overload # LEGACY: multi (token ids + optional prompt)
  179. def generate(
  180. self,
  181. prompts: Optional[List[str]] = None,
  182. sampling_params: Optional[Union[SamplingParams,
  183. List[SamplingParams]]] = None,
  184. *,
  185. prompt_token_ids: List[List[int]],
  186. use_tqdm: bool = True,
  187. lora_request: Optional[LoRARequest] = None,
  188. ) -> List[RequestOutput]:
  189. ...
  190. @overload # LEGACY: single or multi token ids [pos-only]
  191. def generate(
  192. self,
  193. prompts: None,
  194. sampling_params: None,
  195. prompt_token_ids: Union[List[int], List[List[int]]],
  196. use_tqdm: bool = True,
  197. lora_request: Optional[LoRARequest] = None,
  198. ) -> List[RequestOutput]:
  199. ...
  200. @overload
  201. def generate(
  202. self,
  203. inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
  204. /, # We may enable `inputs` keyword after removing the old API
  205. *,
  206. sampling_params: Optional[Union[SamplingParams,
  207. Sequence[SamplingParams]]] = None,
  208. use_tqdm: bool = True,
  209. lora_request: Optional[LoRARequest] = None,
  210. ) -> List[RequestOutput]:
  211. ...
  212. @deprecate_kwargs("prompts",
  213. "prompt_token_ids",
  214. is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
  215. additional_message="Please use the 'inputs' parameter "
  216. "instead.")
  217. def generate(
  218. self,
  219. prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
  220. Optional[Union[str, List[str]]]] = None,
  221. sampling_params: Optional[Union[SamplingParams,
  222. Sequence[SamplingParams]]] = None,
  223. prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
  224. use_tqdm: bool = True,
  225. lora_request: Optional[LoRARequest] = None,
  226. ) -> List[RequestOutput]:
  227. """Generates the completions for the input prompts.
  228. NOTE: This class automatically batches the given prompts, considering
  229. the memory constraint. For the best performance, put all of your prompts
  230. into a single list and pass it to this method.
  231. Args:
  232. inputs: A list of inputs to generate completions for.
  233. sampling_params: The sampling parameters for text generation. If
  234. None, we use the default sampling parameters.
  235. When it is a single value, it is applied to every prompt.
  236. When it is a list, the list must have the same length as the
  237. prompts and it is paired one by one with the prompt.
  238. use_tqdm: Whether to use tqdm to display the progress bar.
  239. lora_request: LoRA request to use for generation, if any.
  240. Returns:
  241. A list of `RequestOutput` objects containing the
  242. generated completions in the same order as the input prompts.
  243. """
  244. if self.llm_engine.model_config.embedding_mode:
  245. raise ValueError(
  246. "LLM.generate() is only supported for generation models "
  247. "(XForCausalLM).")
  248. if prompt_token_ids is not None:
  249. inputs = self._convert_v1_inputs(
  250. prompts=cast(Optional[Union[str, List[str]]], prompts),
  251. prompt_token_ids=prompt_token_ids,
  252. )
  253. else:
  254. inputs = cast(
  255. Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
  256. prompts)
  257. if sampling_params is None:
  258. # Use default sampling params.
  259. sampling_params = SamplingParams()
  260. self._validate_and_add_requests(
  261. inputs=inputs,
  262. params=sampling_params,
  263. lora_request=lora_request,
  264. )
  265. outputs = self._run_engine(use_tqdm=use_tqdm)
  266. return AphroditeEngine.validate_outputs(outputs, RequestOutput)
  267. @overload # LEGACY: single (prompt + optional token ids)
  268. def encode(
  269. self,
  270. prompts: str,
  271. pooling_params: Optional[Union[PoolingParams,
  272. Sequence[PoolingParams]]] = None,
  273. prompt_token_ids: Optional[List[int]] = None,
  274. use_tqdm: bool = True,
  275. lora_request: Optional[LoRARequest] = None,
  276. ) -> List[EmbeddingRequestOutput]:
  277. ...
  278. @overload # LEGACY: multi (prompt + optional token ids)
  279. def encode(
  280. self,
  281. prompts: List[str],
  282. pooling_params: Optional[Union[PoolingParams,
  283. Sequence[PoolingParams]]] = None,
  284. prompt_token_ids: Optional[List[List[int]]] = None,
  285. use_tqdm: bool = True,
  286. lora_request: Optional[LoRARequest] = None,
  287. ) -> List[EmbeddingRequestOutput]:
  288. ...
  289. @overload # LEGACY: single (token ids + optional prompt)
  290. def encode(
  291. self,
  292. prompts: Optional[str] = None,
  293. pooling_params: Optional[Union[PoolingParams,
  294. Sequence[PoolingParams]]] = None,
  295. *,
  296. prompt_token_ids: List[int],
  297. use_tqdm: bool = True,
  298. lora_request: Optional[LoRARequest] = None,
  299. ) -> List[EmbeddingRequestOutput]:
  300. ...
  301. @overload # LEGACY: multi (token ids + optional prompt)
  302. def encode(
  303. self,
  304. prompts: Optional[List[str]] = None,
  305. pooling_params: Optional[Union[PoolingParams,
  306. Sequence[PoolingParams]]] = None,
  307. *,
  308. prompt_token_ids: List[List[int]],
  309. use_tqdm: bool = True,
  310. lora_request: Optional[LoRARequest] = None,
  311. ) -> List[EmbeddingRequestOutput]:
  312. ...
  313. @overload # LEGACY: single or multi token ids [pos-only]
  314. def encode(
  315. self,
  316. prompts: None,
  317. pooling_params: None,
  318. prompt_token_ids: Union[List[int], List[List[int]]],
  319. use_tqdm: bool = True,
  320. lora_request: Optional[LoRARequest] = None,
  321. ) -> List[EmbeddingRequestOutput]:
  322. ...
  323. @overload
  324. def encode(
  325. self,
  326. inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
  327. /, # We may enable `inputs` keyword after removing the old API
  328. *,
  329. pooling_params: Optional[Union[PoolingParams,
  330. Sequence[PoolingParams]]] = None,
  331. use_tqdm: bool = True,
  332. lora_request: Optional[LoRARequest] = None,
  333. ) -> List[EmbeddingRequestOutput]:
  334. ...
  335. @deprecate_kwargs("prompts",
  336. "prompt_token_ids",
  337. is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
  338. additional_message="Please use the 'inputs' parameter "
  339. "instead.")
  340. def encode(
  341. self,
  342. prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
  343. Optional[Union[str, List[str]]]] = None,
  344. pooling_params: Optional[Union[PoolingParams,
  345. Sequence[PoolingParams]]] = None,
  346. prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
  347. use_tqdm: bool = True,
  348. lora_request: Optional[LoRARequest] = None,
  349. ) -> List[EmbeddingRequestOutput]:
  350. """Generates the completions for the input prompts.
  351. NOTE: This class automatically batches the given prompts, considering
  352. the memory constraint. For the best performance, put all of your prompts
  353. into a single list and pass it to this method.
  354. Args:
  355. inputs: The inputs to the LLM. You may pass a sequence of inputs for
  356. batch inference. See
  357. :class:`~aphrodite.inputs.PromptStrictInputs`
  358. for more details about the format of each input.
  359. pooling_params: The pooling parameters for pooling. If None, we
  360. use the default pooling parameters.
  361. use_tqdm: Whether to use tqdm to display the progress bar.
  362. lora_request: LoRA request to use for generation, if any.
  363. Returns:
  364. A list of `EmbeddingRequestOutput` objects containing the
  365. generated embeddings in the same order as the input prompts.
  366. """
  367. if not self.llm_engine.model_config.embedding_mode:
  368. raise ValueError(
  369. "LLM.encode() is only supported for embedding models (XModel)."
  370. )
  371. if prompt_token_ids is not None:
  372. inputs = self._convert_v1_inputs(
  373. prompts=cast(Optional[Union[str, List[str]]], prompts),
  374. prompt_token_ids=prompt_token_ids,
  375. )
  376. else:
  377. inputs = cast(
  378. Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
  379. prompts)
  380. if pooling_params is None:
  381. # Use default pooling params.
  382. pooling_params = PoolingParams()
  383. self._validate_and_add_requests(
  384. inputs=inputs,
  385. params=pooling_params,
  386. lora_request=lora_request,
  387. )
  388. outputs = self._run_engine(use_tqdm=use_tqdm)
  389. return AphroditeEngine.validate_outputs(outputs,
  390. EmbeddingRequestOutput)
  391. # LEGACY
  392. def _convert_v1_inputs(
  393. self,
  394. prompts: Optional[Union[str, List[str]]],
  395. prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
  396. ):
  397. # skip_tokenizer_init is now checked in engine
  398. if prompts is not None:
  399. prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
  400. if prompt_token_ids is not None:
  401. prompt_token_ids = [
  402. p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
  403. ]
  404. num_requests = None
  405. if prompts is not None:
  406. num_requests = len(prompts)
  407. if prompt_token_ids is not None:
  408. if (num_requests is not None
  409. and num_requests != len(prompt_token_ids)):
  410. raise ValueError("The lengths of prompts and prompt_token_ids "
  411. "must be the same.")
  412. num_requests = len(prompt_token_ids)
  413. if num_requests is None:
  414. raise ValueError("Either prompts or prompt_token_ids must be "
  415. "provided.")
  416. inputs: List[PromptInputs] = []
  417. for i in range(num_requests):
  418. if prompts is not None:
  419. if prompt_token_ids is not None:
  420. item = TextTokensPrompt(
  421. prompt=prompts[i],
  422. prompt_token_ids=prompt_token_ids[i])
  423. else:
  424. item = TextPrompt(prompt=prompts[i])
  425. else:
  426. if prompt_token_ids is not None:
  427. item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
  428. else:
  429. raise AssertionError
  430. inputs.append(item)
  431. return inputs
  432. def _validate_and_add_requests(
  433. self,
  434. inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
  435. params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
  436. Sequence[PoolingParams]],
  437. lora_request: Optional[LoRARequest],
  438. ) -> None:
  439. if isinstance(inputs, (str, dict)):
  440. # Convert a single prompt to a list.
  441. inputs = [inputs]
  442. num_requests = len(inputs)
  443. if isinstance(params, list) and len(params) != num_requests:
  444. raise ValueError("The lengths of prompts and params "
  445. "must be the same.")
  446. # Add requests to the engine.
  447. for i, request_inputs in enumerate(inputs):
  448. self._add_request(
  449. request_inputs,
  450. params[i] if isinstance(params, Sequence) else params,
  451. lora_request=lora_request,
  452. )
  453. def _add_request(
  454. self,
  455. inputs: PromptInputs,
  456. params: Union[SamplingParams, PoolingParams],
  457. lora_request: Optional[LoRARequest] = None,
  458. ) -> None:
  459. request_id = str(next(self.request_counter))
  460. self.llm_engine.add_request(request_id,
  461. inputs,
  462. params,
  463. lora_request=lora_request)
  464. def _run_engine(
  465. self, *, use_tqdm: bool
  466. ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
  467. # Initialize tqdm.
  468. if use_tqdm:
  469. num_requests = self.llm_engine.get_num_unfinished_requests()
  470. pbar = tqdm(
  471. total=num_requests,
  472. desc="Processed prompts",
  473. dynamic_ncols=True,
  474. postfix=f"Generation Speed: {0:.2f} toks/s",
  475. )
  476. # Run the engine.
  477. outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
  478. total_toks = 0
  479. while self.llm_engine.has_unfinished_requests():
  480. step_outputs = self.llm_engine.step()
  481. for output in step_outputs:
  482. if output.finished:
  483. outputs.append(output)
  484. if use_tqdm:
  485. if isinstance(output, RequestOutput):
  486. # Calculate tokens only for RequestOutput
  487. total_toks += sum(
  488. len(stp.token_ids) for stp in output.outputs)
  489. spd = total_toks / pbar.format_dict["elapsed"]
  490. pbar.postfix = f"Generation Speed: {spd:.2f} toks/s"
  491. pbar.update(1)
  492. if use_tqdm:
  493. pbar.close()
  494. # Sort the outputs by request ID.
  495. # This is necessary because some requests may be finished earlier than
  496. # its previous requests.
  497. return sorted(outputs, key=lambda x: int(x.request_id))