llm.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. from typing import List, Optional, Union
  2. from tqdm import tqdm
  3. from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
  4. from aphrodite.engine.args_tools import EngineArgs
  5. from aphrodite.engine.aphrodite_engine import AphroditeEngine
  6. from aphrodite.common.outputs import RequestOutput
  7. from aphrodite.common.sampling_params import SamplingParams
  8. from aphrodite.common.utils import Counter
  9. class LLM:
  10. """An LLM for generating texts from given prompts and sampling parameters.
  11. This class includes a tokenizer, a language model (possibly distributed
  12. across multiple GPUs), and GPU memory space allocated for intermediate
  13. states (aka KV cache). Given a batch of prompts and sampling parameters,
  14. this class generates texts from the model, using an intelligent batching
  15. mechanism and efficient memory management.
  16. NOTE: This class is intended to be used for offline inference. For online
  17. serving, use the `AsyncLLMEngine` class instead.
  18. NOTE: For the comprehensive list of arguments, see `EngineArgs`.
  19. Args:
  20. model: The name or path of a HuggingFace Transformers model.
  21. tokenizer: The name or path of a HuggingFace Transformers tokenizer.
  22. tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
  23. if available, and "slow" will always use the slow tokenizer.
  24. trust_remote_code: Trust remote code (e.g., from HuggingFace) when
  25. downloading the model and tokenizer.
  26. tensor_parallel_size: The number of GPUs to use for distributed
  27. execution with tensor parallelism.
  28. dtype: The data type for the model weights and activations. Currently,
  29. we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
  30. the `torch_dtype` attribute specified in the model config file.
  31. However, if the `torch_dtype` in the config is `float32`, we will
  32. use `float16` instead.
  33. quantization: The method used to quantize the model weights. Currently,
  34. we support "awq". If None, we assume the model weights are not
  35. quantized and use `dtype` to determine the data type of the weights.
  36. revision: The specific model version to use. It can be a branch name,
  37. a tag name, or a commit id.
  38. seed: The seed to initialize the random number generator for sampling.
  39. gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
  40. reserve for the model weights, activations, and KV cache. Higher
  41. values will increase the KV cache size and thus improve the model's
  42. throughput. However, if the value is too high, it may cause out-of-
  43. memory (OOM) errors.
  44. swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
  45. This can be used for temporarily storing the states of the requests
  46. when their `best_of` sampling parameters are larger than 1. If all
  47. requests will have `best_of=1`, you can safely set this to 0.
  48. Otherwise, too small values may cause out-of-memory (OOM) errors.
  49. """
  50. def __init__(
  51. self,
  52. model: str,
  53. tokenizer: Optional[str] = None,
  54. tokenizer_mode: str = "auto",
  55. trust_remote_code: bool = False,
  56. tensor_parallel_size: int = 1,
  57. dtype: str = "auto",
  58. quantization: Optional[str] = None,
  59. revision: Optional[str] = None,
  60. seed: int = 0,
  61. gpu_memory_utilization: float = 0.9,
  62. swap_space: int = 4,
  63. **kwargs,
  64. ) -> None:
  65. if "disable_log_stats" not in kwargs:
  66. kwargs["disable_log_stats"] = True
  67. engine_args = EngineArgs(
  68. model=model,
  69. tokenizer=tokenizer,
  70. tokenizer_mode=tokenizer_mode,
  71. trust_remote_code=trust_remote_code,
  72. tensor_parallel_size=tensor_parallel_size,
  73. dtype=dtype,
  74. quantization=quantization,
  75. revision=revision,
  76. seed=seed,
  77. gpu_memory_utilization=gpu_memory_utilization,
  78. swap_space=swap_space,
  79. **kwargs,
  80. )
  81. self.llm_engine = AphroditeEngine.from_engine_args(engine_args)
  82. self.request_counter = Counter()
  83. def get_tokenizer(
  84. self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
  85. return self.llm_engine.tokenizer
  86. def set_tokenizer(
  87. self,
  88. tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
  89. ) -> None:
  90. self.llm_engine.tokenizer = tokenizer
  91. def generate(
  92. self,
  93. prompts: Optional[Union[str, List[str]]] = None,
  94. sampling_params: Optional[SamplingParams] = None,
  95. prompt_token_ids: Optional[List[List[int]]] = None,
  96. use_tqdm: bool = True,
  97. ) -> List[RequestOutput]:
  98. """Generates the completions for the input prompts.
  99. NOTE: This class automatically batches the given prompts, considering
  100. the memory constraint. For the best performance, put all of your prompts
  101. into a single list and pass it to this method.
  102. Args:
  103. prompts: A list of prompts to generate completions for.
  104. sampling_params: The sampling parameters for text generation. If
  105. None, we use the default sampling parameters.
  106. prompt_token_ids: A list of token IDs for the prompts. If None, we
  107. use the tokenizer to convert the prompts to token IDs.
  108. use_tqdm: Whether to use tqdm to display the progress bar.
  109. Returns:
  110. A list of `RequestOutput` objects containing the generated
  111. completions in the same order as the input prompts.
  112. """
  113. if prompts is None and prompt_token_ids is None:
  114. raise ValueError("Either prompts or prompt_token_ids must be "
  115. "provided.")
  116. if isinstance(prompts, str):
  117. # Convert a single prompt to a list.
  118. prompts = [prompts]
  119. if prompts is not None and prompt_token_ids is not None:
  120. if len(prompts) != len(prompt_token_ids):
  121. raise ValueError("The lengths of prompts and prompt_token_ids "
  122. "must be the same.")
  123. if sampling_params is None:
  124. # Use default sampling params.
  125. sampling_params = SamplingParams()
  126. # Add requests to the engine.
  127. if prompts is not None:
  128. num_requests = len(prompts)
  129. else:
  130. num_requests = len(prompt_token_ids)
  131. for i in range(num_requests):
  132. prompt = prompts[i] if prompts is not None else None
  133. if prompt_token_ids is None:
  134. token_ids = None
  135. else:
  136. token_ids = prompt_token_ids[i]
  137. self._add_request(prompt, sampling_params, token_ids)
  138. return self._run_engine(use_tqdm)
  139. def _add_request(
  140. self,
  141. prompt: Optional[str],
  142. sampling_params: SamplingParams,
  143. prompt_token_ids: Optional[List[int]],
  144. ) -> None:
  145. request_id = str(next(self.request_counter))
  146. self.llm_engine.add_request(request_id, prompt, sampling_params,
  147. prompt_token_ids)
  148. def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
  149. # Initialize tqdm.
  150. if use_tqdm:
  151. num_requests = self.llm_engine.get_num_unfinished_requests()
  152. pbar = tqdm(total=num_requests, desc="Processed prompts")
  153. # Run the engine.
  154. outputs: List[RequestOutput] = []
  155. while self.llm_engine.has_unfinished_requests():
  156. step_outputs = self.llm_engine.step()
  157. for output in step_outputs:
  158. if output.finished:
  159. outputs.append(output)
  160. if use_tqdm:
  161. pbar.update(1)
  162. if use_tqdm:
  163. pbar.close()
  164. # Sort the outputs by request ID.
  165. # This is necessary because some requests may be finished earlier than
  166. # its previous requests.
  167. outputs = sorted(outputs, key=lambda x: int(x.request_id))
  168. return outputs