aphrodite_engine.py 45 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058
  1. import copy
  2. from collections import defaultdict
  3. import os
  4. import time
  5. from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
  6. Union)
  7. from loguru import logger
  8. import aphrodite
  9. from aphrodite.lora.request import LoRARequest
  10. from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig,
  11. SchedulerConfig, LoRAConfig, DeviceConfig)
  12. from aphrodite.processing.scheduler import Scheduler, SchedulerOutputs
  13. from aphrodite.engine.args_tools import EngineArgs
  14. from aphrodite.engine.metrics import StatLogger, Stats
  15. from aphrodite.engine.ray_tools import (RayWorkerAphrodite, initialize_cluster,
  16. ray)
  17. from aphrodite.common.logger import setup_logger
  18. from aphrodite.common.outputs import RequestOutput
  19. from aphrodite.common.sampling_params import SamplingParams
  20. from aphrodite.common.sequence import (SamplerOutput, Sequence, SequenceGroup,
  21. SequenceGroupOutput, SequenceOutput,
  22. SequenceStatus, Logprob)
  23. from aphrodite.transformers_utils.tokenizer import (detokenize_incrementally,
  24. TokenizerGroup)
  25. from aphrodite.common.utils import (Counter, set_cuda_visible_devices, get_ip,
  26. get_open_port)
  27. if ray:
  28. from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
  29. if TYPE_CHECKING:
  30. from ray.util.placement_group import PlacementGroup
  31. _LOCAL_LOGGING_INTERVAL_SEC = 5
  32. class AphroditeEngine:
  33. """An LLM engine that receives requests and generates texts.
  34. This is the main class for the Aphrodite engine. It receives requests
  35. from clients and generates texts from the LLM. It includes a tokenizer, a
  36. language model (possibly distributed across multiple GPUs), and GPU memory
  37. space allocated for intermediate states (aka KV cache). This class utilizes
  38. iteration-level scheduling and efficient memory management to maximize the
  39. serving throughput.
  40. The `LLM` class wraps this class for offline batched inference and the
  41. `AsyncAphrodite` class wraps this class for online serving.
  42. NOTE: The config arguments are derived from the `EngineArgs` class. For the
  43. comprehensive list of arguments, see `EngineArgs`.
  44. Args:
  45. model_config: The configuration related to the LLM model.
  46. cache_config: The configuration related to the KV cache memory
  47. management.
  48. parallel_config: The configuration related to distributed execution.
  49. scheduler_config: The configuration related to the request scheduler.
  50. device_config: The configuration related to the device.
  51. lora_config: The configuration related to LoRA.
  52. placement_group: Ray placement group for distributed execution.
  53. Required for distributed execution.
  54. log_stats: Whether to log statistics.
  55. """
  56. def __init__(
  57. self,
  58. model_config: ModelConfig,
  59. cache_config: CacheConfig,
  60. parallel_config: ParallelConfig,
  61. scheduler_config: SchedulerConfig,
  62. device_config: DeviceConfig,
  63. lora_config: Optional[LoRAConfig],
  64. placement_group: Optional["PlacementGroup"],
  65. log_stats: bool,
  66. ) -> None:
  67. logger.info(
  68. f"Initializing the Aphrodite Engine (v{aphrodite.__version__}) "
  69. "with the following config:\n"
  70. f"Model = {model_config.model!r}\n"
  71. f"DataType = {model_config.dtype}\n"
  72. f"Model Load Format = {model_config.load_format}\n"
  73. f"Number of GPUs = {parallel_config.tensor_parallel_size}\n"
  74. f"Disable Custom All-Reduce = "
  75. f"{parallel_config.disable_custom_all_reduce}\n"
  76. f"Quantization Format = {model_config.quantization}\n"
  77. f"Context Length = {model_config.max_model_len}\n"
  78. f"Enforce Eager Mode = {model_config.enforce_eager}\n"
  79. f"KV Cache Data Type = {cache_config.cache_dtype}\n"
  80. f"KV Cache Params Path = {cache_config.cache_quant_params_path}\n"
  81. f"Device = {device_config.device}")
  82. # TODO: Print more configs in debug mode.
  83. self.model_config = model_config
  84. self.cache_config = cache_config
  85. self.lora_config = lora_config
  86. self.parallel_config = parallel_config
  87. self.scheduler_config = scheduler_config
  88. self.device_config = device_config
  89. self.log_stats = log_stats
  90. self._verify_args()
  91. self._init_tokenizer()
  92. self.seq_counter = Counter()
  93. # Create the parallel GPU workers.
  94. if self.parallel_config.worker_use_ray:
  95. # Disable Ray usage stats collection.
  96. ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
  97. if ray_usage != "1":
  98. os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
  99. self._init_workers_ray(placement_group)
  100. else:
  101. self._init_workers()
  102. # Profile the memory usage and initialize the cache.
  103. self._init_cache()
  104. # Create the scheduler.
  105. self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
  106. # Metric Logging.
  107. if self.log_stats:
  108. self.stat_logger = StatLogger(
  109. local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
  110. labels=dict(model_name=model_config.model))
  111. def get_tokenizer_for_seq(self, sequence: Sequence):
  112. return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
  113. def _init_workers(self):
  114. # Lazy import the Worker to avoid importing torch.cuda/xformers
  115. # before CUDA_VISIBLE_DEVICES is set in the Worker
  116. # pylint: disable=import-outside-toplevel
  117. from aphrodite.task_handler.worker import Worker
  118. assert self.parallel_config.world_size == 1, (
  119. "Ray is required if parallel_config.world_size > 1.")
  120. self.workers: List[Worker] = []
  121. distributed_init_method = f"tcp://{get_ip()}:{get_open_port()}"
  122. self.driver_worker = Worker(
  123. self.model_config,
  124. self.parallel_config,
  125. self.scheduler_config,
  126. self.device_config,
  127. local_rank=0,
  128. rank=0,
  129. distributed_init_method=distributed_init_method,
  130. lora_config=self.lora_config,
  131. kv_cache_dtype=self.cache_config.cache_dtype,
  132. kv_quant_params_path=(self.cache_config.cache_quant_params_path),
  133. is_driver_worker=True,
  134. )
  135. self._run_workers("init_model")
  136. self._run_workers("load_model")
  137. def _init_tokenizer(self, **tokenizer_init_kwargs):
  138. init_kwargs = dict(
  139. enable_lora=bool(self.lora_config),
  140. max_num_seqs=self.scheduler_config.max_num_seqs,
  141. max_input_length=None,
  142. tokenizer_mode=self.model_config.tokenizer_mode,
  143. trust_remote_code=self.model_config.trust_remote_code,
  144. revision=self.model_config.tokenizer_revision)
  145. init_kwargs.update(tokenizer_init_kwargs)
  146. self.tokenizer: TokenizerGroup = TokenizerGroup(
  147. self.model_config.tokenizer, **init_kwargs)
  148. def _init_workers_ray(self, placement_group: "PlacementGroup",
  149. **ray_remote_kwargs):
  150. if self.parallel_config.tensor_parallel_size == 1:
  151. num_gpus = self.cache_config.gpu_memory_utilization
  152. else:
  153. num_gpus = 1
  154. self.driver_dummy_worker: RayWorkerAphrodite = None
  155. self.workers: List[RayWorkerAphrodite] = []
  156. driver_ip = get_ip()
  157. for bundle_id, bundle in enumerate(placement_group.bundle_specs):
  158. if not bundle.get("GPU", 0):
  159. continue
  160. scheduling_strategy = PlacementGroupSchedulingStrategy(
  161. placement_group=placement_group,
  162. placement_group_capture_child_tasks=True,
  163. placement_group_bundle_index=bundle_id,
  164. )
  165. worker = ray.remote(
  166. num_cpus=0,
  167. num_gpus=num_gpus,
  168. scheduling_strategy=scheduling_strategy,
  169. **ray_remote_kwargs,
  170. )(RayWorkerAphrodite).remote(self.model_config.trust_remote_code)
  171. worker_ip = ray.get(worker.get_node_ip.remote())
  172. if worker_ip == driver_ip and self.driver_dummy_worker is None:
  173. # If the worker is on the same node as the driver, we use it
  174. # as the resource holder for the driver process.
  175. self.driver_dummy_worker = worker
  176. else:
  177. self.workers.append(worker)
  178. if self.driver_dummy_worker is None:
  179. raise ValueError(
  180. "Ray does not allocate any GPUs on the driver node. Consider "
  181. "adjusting the Ray placement group or running the driver on a "
  182. "GPU node.")
  183. driver_node_id, driver_gpu_ids = ray.get(
  184. self.driver_dummy_worker.get_node_and_gpu_ids.remote())
  185. worker_node_and_gpu_ids = ray.get(
  186. [worker.get_node_and_gpu_ids.remote() for worker in self.workers])
  187. node_workers = defaultdict(list)
  188. node_gpus = defaultdict(list)
  189. node_workers[driver_node_id].append(0)
  190. node_gpus[driver_node_id].extend(driver_gpu_ids)
  191. for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids,
  192. start=1):
  193. node_workers[node_id].append(i)
  194. node_gpus[node_id].extend(gpu_ids)
  195. for node_id, gpu_ids in node_gpus.items():
  196. node_gpus[node_id] = sorted(gpu_ids)
  197. # Set CUDA_VISIBLE_DEVICES for the driver.
  198. set_cuda_visible_devices(node_gpus[driver_node_id])
  199. for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
  200. worker.set_cuda_visible_devices.remote(node_gpus[node_id])
  201. distributed_init_method = f"tcp://{driver_ip}:{get_open_port()}"
  202. # Lazy import the Worker to avoid importing torch.cuda/xformers
  203. # before CUDA_VISIBLE_DEVICES is set in the Worker
  204. # pylint: disable=import-outside-toplevel
  205. from aphrodite.task_handler.worker import Worker
  206. # Initialize torch distributed process group for the workers.
  207. model_config = copy.deepcopy(self.model_config)
  208. parallel_config = copy.deepcopy(self.parallel_config)
  209. scheduler_config = copy.deepcopy(self.scheduler_config)
  210. device_config = copy.deepcopy(self.device_config)
  211. for rank, (worker, (node_id,
  212. _)) in enumerate(zip(self.workers,
  213. worker_node_and_gpu_ids),
  214. start=1):
  215. local_rank = node_workers[node_id].index(rank)
  216. worker.init_worker.remote(
  217. lambda rank=rank, local_rank=local_rank: Worker(
  218. model_config,
  219. parallel_config,
  220. scheduler_config,
  221. device_config,
  222. local_rank,
  223. rank,
  224. distributed_init_method,
  225. lora_config=self.lora_config,
  226. kv_cache_dtype=self.cache_config.cache_dtype,
  227. kv_quant_params_path=
  228. (self.cache_config.cache_quant_params_path),
  229. ))
  230. driver_rank = 0
  231. driver_local_rank = node_workers[driver_node_id].index(driver_rank)
  232. self.driver_worker = Worker(
  233. model_config,
  234. parallel_config,
  235. scheduler_config,
  236. device_config,
  237. driver_local_rank,
  238. driver_rank,
  239. distributed_init_method,
  240. lora_config=self.lora_config,
  241. kv_cache_dtype=self.cache_config.cache_dtype,
  242. kv_quant_params_path=(self.cache_config.cache_quant_params_path),
  243. is_driver_worker=True,
  244. )
  245. self._run_workers("init_model", cupy_port=get_open_port())
  246. self._run_workers(
  247. "load_model",
  248. max_concurrent_workers=self.parallel_config.
  249. max_parallel_loading_workers,
  250. )
  251. def _verify_args(self) -> None:
  252. self.model_config.verify_with_parallel_config(self.parallel_config)
  253. self.cache_config.verify_with_parallel_config(self.parallel_config)
  254. if self.lora_config:
  255. self.lora_config.verify_with_model_config(self.model_config)
  256. self.lora_config.verify_with_scheduler_config(
  257. self.scheduler_config)
  258. def _init_cache(self) -> None:
  259. # ruff: noqa: E501
  260. """Profiles the memory usage and initializes the KV cache.
  261. The engine will first conduct a profiling of the existing memory usage.
  262. Then, it calculate the maximum possible number of GPU and CPU blocks
  263. that can be allocated with the remaining free memory.
  264. More details can be found in the
  265. # pylint: disable=line-too-long
  266. :meth:`~aphrodite.task_handler.worker.Worker.profile_num_available_blocks` method
  267. from class :class:`~aphrodite.task_handler.Worker`.
  268. Afterwards, as there may be multiple workers,
  269. we take the minimum number of blocks across all workers
  270. to ensure this can be applied to all of them.
  271. Finally, the engine will initialize the KV cache
  272. with the calculated number of blocks.
  273. .. tip::
  274. You may limit the usage of GPU memory
  275. by adjusting the `gpu_memory_utilization` parameters.
  276. """
  277. # Get the maximum number of blocks that can be allocated on GPU and CPU.
  278. num_blocks = self._run_workers(
  279. "profile_num_available_blocks",
  280. block_size=self.cache_config.block_size,
  281. gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
  282. cpu_swap_space=self.cache_config.swap_space_bytes,
  283. cache_dtype=self.cache_config.cache_dtype,
  284. )
  285. # Since we use a shared centralized controller, we take the minimum
  286. # number of blocks across all workers to make sure all the memory
  287. # operators can be applied to all workers.
  288. num_gpu_blocks = min(b[0] for b in num_blocks)
  289. num_cpu_blocks = min(b[1] for b in num_blocks)
  290. # FIXME: Change to debug log.
  291. logger.info(f"# GPU blocks: {num_gpu_blocks}, "
  292. f"# CPU blocks: {num_cpu_blocks}")
  293. logger.info(
  294. f"Minimum concurrency: {num_gpu_blocks * self.cache_config.block_size / self.scheduler_config.max_model_len:.2f}x"
  295. )
  296. if num_gpu_blocks <= 0:
  297. raise ValueError("No available memory for the cache blocks. "
  298. "Try increasing `gpu_memory_utilization` when "
  299. "initializing the engine.")
  300. max_seq_len = self.cache_config.block_size * num_gpu_blocks
  301. logger.info(f"Maximum sequence length allowed in the cache: "
  302. f"{max_seq_len}")
  303. if self.model_config.max_model_len > max_seq_len:
  304. raise ValueError(
  305. f"The model's max seq len ({self.model_config.max_model_len}) "
  306. "is larger than the maximum number of tokens that can be "
  307. f"stored in KV cache ({max_seq_len}). Try increasing "
  308. "`gpu_memory_utilization` or decreasing `max_model_len` when "
  309. "initializing the engine.")
  310. self.cache_config.num_gpu_blocks = num_gpu_blocks
  311. self.cache_config.num_cpu_blocks = num_cpu_blocks
  312. # Initialize the cache.
  313. self._run_workers("init_cache_engine", cache_config=self.cache_config)
  314. # Warm up the model. This includes capturing the model into CUDA graph
  315. # if enforce_eager is False.
  316. self._run_workers("warm_up_model")
  317. @classmethod
  318. def from_engine_args(cls, engine_args: EngineArgs) -> "AphroditeEngine":
  319. """Creates an LLM engine from the engine arguments."""
  320. # Create the engine configs.
  321. engine_configs = engine_args.create_engine_configs()
  322. parallel_config = engine_configs[2]
  323. # Initialize the cluster.
  324. placement_group = initialize_cluster(parallel_config)
  325. # Create the LLM engine.
  326. engine = cls(*engine_configs,
  327. placement_group,
  328. log_stats=not engine_args.disable_log_stats)
  329. return engine
  330. def encode_request(
  331. self,
  332. request_id: str,
  333. prompt: Optional[str],
  334. prompt_token_ids: Optional[List[int]] = None,
  335. lora_request: Optional[LoRARequest] = None,
  336. ):
  337. if prompt_token_ids is None:
  338. assert prompt is not None
  339. prompt_token_ids = self.tokenizer.encode(request_id=request_id,
  340. prompt=prompt,
  341. lora_request=lora_request)
  342. return prompt_token_ids
  343. def add_request(
  344. self,
  345. request_id: str,
  346. prompt: Optional[str],
  347. sampling_params: SamplingParams,
  348. prompt_token_ids: Optional[List[int]] = None,
  349. arrival_time: Optional[float] = None,
  350. lora_request: Optional[LoRARequest] = None,
  351. ) -> None:
  352. """Add a request to the engine's request pool.
  353. The request is added to the request pool and will be processed by the
  354. scheduler as `engine.step()` is called. The exact scheduling policy is
  355. determined by the scheduler.
  356. Args:
  357. request_id: The unique ID of the request.
  358. prompt: The prompt string. Can be None if prompt_token_ids is
  359. provided.
  360. sampling_params: The sampling parameters for text generation.
  361. prompt_token_ids: The token IDs of the prompt. If None, we
  362. use the tokenizer to convert the prompts to token IDs.
  363. arrival_time: The arrival time of the request. If None, we use
  364. the current monotonic time.
  365. Details:
  366. - Set arrival_time to the current time if it is None.
  367. - Set prompt_token_ids to the encoded prompt if it is None.
  368. - Create `best_of` number of :class:`~aphrodite.Sequence` objects.
  369. - Create a :class:`~aphrodite.SequenceGroup` object
  370. from the list of :class:`~aphrodite.Sequence`.
  371. - Add the :class:`~aphrodite.SequenceGroup` object to the scheduler.
  372. Example:
  373. >>> # initialize engine
  374. >>> engine = AphroditeEngine.from_engine_args(engine_args)
  375. >>> # set request arguments
  376. >>> example_prompt = "Who is the president of the United States?"
  377. >>> sampling_params = SamplingParams(temperature=0.0)
  378. >>> request_id = 0
  379. >>>
  380. >>> # add the request to the engine
  381. >>> engine.add_request(
  382. >>> str(request_id),
  383. >>> example_prompt,
  384. >>> SamplingParams(temperature=0.0))
  385. >>> # continue the request processing
  386. >>> ...
  387. """
  388. if lora_request is not None and not self.lora_config:
  389. raise ValueError(f"Got lora_request {lora_request} but LoRA is "
  390. "not enabled!")
  391. max_log_probs = self.get_model_config().max_log_probs
  392. if (sampling_params.logprobs
  393. and sampling_params.logprobs > max_log_probs) or (
  394. sampling_params.prompt_logprobs
  395. and sampling_params.prompt_logprobs > max_log_probs):
  396. raise ValueError(f"Cannot request more than "
  397. f"{max_log_probs} logprobs.")
  398. if arrival_time is None:
  399. arrival_time = time.monotonic()
  400. prompt_token_ids = self.encode_request(
  401. request_id=request_id,
  402. prompt=prompt,
  403. prompt_token_ids=prompt_token_ids,
  404. lora_request=lora_request)
  405. # Create the sequences.
  406. block_size = self.cache_config.block_size
  407. seq_id = next(self.seq_counter)
  408. seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
  409. lora_request)
  410. # Create the sequence group.
  411. seq_group = SequenceGroup(request_id, [seq], sampling_params,
  412. arrival_time, lora_request)
  413. # Add the sequence group to the scheduler.
  414. self.scheduler.add_seq_group(seq_group)
  415. def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
  416. """Aborts a request(s) with the given ID.
  417. Args:
  418. request_id: The ID(s) of the request to abort.
  419. Details:
  420. - Refer to the
  421. :meth:`~aphrodite.processing.scheduler.Scheduler.abort_seq_group`
  422. from class :class:`~aphrodite.processing.scheduler.Scheduler`.
  423. Example:
  424. >>> # initialize engine and add a request with request_id
  425. >>> request_id = str(0)
  426. >>> # abort the request
  427. >>> engine.abort_request(request_id)
  428. """
  429. self.scheduler.abort_seq_group(request_id)
  430. def get_model_config(self) -> ModelConfig:
  431. """Gets the model configuration."""
  432. return self.model_config
  433. def get_num_unfinished_requests(self) -> int:
  434. """Gets the number of unfinished requests."""
  435. return self.scheduler.get_num_unfinished_seq_groups()
  436. def has_unfinished_requests(self) -> bool:
  437. """Returns True if there are unfinished requests."""
  438. return self.scheduler.has_unfinished_seqs()
  439. def _check_beam_search_early_stopping(
  440. self,
  441. early_stopping: Union[bool, str],
  442. sampling_params: SamplingParams,
  443. best_running_seq: Sequence,
  444. current_worst_seq: Sequence,
  445. ) -> bool:
  446. assert sampling_params.use_beam_search
  447. length_penalty = sampling_params.length_penalty
  448. if early_stopping is True:
  449. return True
  450. current_worst_score = (current_worst_seq.get_beam_search_score(
  451. length_penalty=length_penalty,
  452. eos_token_id=self.get_tokenizer_for_seq(
  453. current_worst_seq).eos_token_id))
  454. if early_stopping is False:
  455. highest_attainable_score = (best_running_seq.get_beam_search_score(
  456. length_penalty=length_penalty,
  457. eos_token_id=self.get_tokenizer_for_seq(
  458. best_running_seq).eos_token_id))
  459. else:
  460. assert early_stopping == "never"
  461. if length_penalty > 0.0:
  462. # If length_penalty > 0.0, beam search will prefer longer
  463. # sequences. The highest attainable score calculation is
  464. # based on the longest possible sequence length in this case.
  465. max_possible_length = max(
  466. best_running_seq.get_prompt_len() +
  467. sampling_params.max_tokens,
  468. self.scheduler_config.max_model_len)
  469. highest_attainable_score = (
  470. best_running_seq.get_beam_search_score(
  471. length_penalty=length_penalty,
  472. eos_token_id=self.get_tokenizer_for_seq(
  473. best_running_seq).eos_token_id,
  474. seq_len=max_possible_length))
  475. else:
  476. # Otherwise, beam search will prefer shorter sequences. The
  477. # highest attainable score calculation is based on the current
  478. # sequence length.
  479. highest_attainable_score = (
  480. best_running_seq.get_beam_search_score(
  481. length_penalty=length_penalty,
  482. eos_token_id=self.get_tokenizer_for_seq(
  483. best_running_seq).eos_token_id))
  484. return current_worst_score >= highest_attainable_score
  485. def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
  486. outputs: SequenceGroupOutput) -> None:
  487. # Process prompt logprobs
  488. prompt_logprobs = outputs.prompt_logprobs
  489. if prompt_logprobs is not None:
  490. seq_group.prompt_logprobs = prompt_logprobs
  491. # Process samples
  492. samples = outputs.samples
  493. parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
  494. existing_finished_seqs = seq_group.get_finished_seqs()
  495. parent_child_dict = {
  496. parent_seq.seq_id: []
  497. for parent_seq in parent_seqs
  498. }
  499. for sample in samples:
  500. parent_child_dict[sample.parent_seq_id].append(sample)
  501. # List of (child, parent)
  502. child_seqs: List[Tuple[Sequence, Sequence]] = []
  503. # Process the child samples for each parent sequence
  504. for parent in parent_seqs:
  505. child_samples: List[SequenceOutput] = parent_child_dict[
  506. parent.seq_id]
  507. if len(child_samples) == 0:
  508. # This parent sequence has no children samples. Remove
  509. # the parent sequence from the sequence group since it will
  510. # not be used in the future iterations.
  511. parent.status = SequenceStatus.FINISHED_ABORTED
  512. seq_group.remove(parent.seq_id)
  513. self.scheduler.free_seq(parent)
  514. continue
  515. # Fork the parent sequence if there are multiple child samples.
  516. for child_sample in child_samples[:-1]:
  517. new_child_seq_id = next(self.seq_counter)
  518. child = parent.fork(new_child_seq_id)
  519. child.append_token_id(child_sample.output_token,
  520. child_sample.logprobs)
  521. child.persistent_data = child_sample.persistent_data
  522. child_seqs.append((child, parent))
  523. # Continue the parent sequence for the last child sample.
  524. # We reuse the parent sequence here to reduce redundant memory
  525. # copies, especially when using non-beam search sampling methods.
  526. last_child_sample = child_samples[-1]
  527. parent.append_token_id(last_child_sample.output_token,
  528. last_child_sample.logprobs)
  529. parent.persistent_data = last_child_sample.persistent_data
  530. child_seqs.append((parent, parent))
  531. for seq, _ in child_seqs:
  532. self._decode_sequence(seq, seq_group.sampling_params)
  533. self._check_stop(seq, seq_group.sampling_params)
  534. # Non-beam search case
  535. if not seq_group.sampling_params.use_beam_search:
  536. # For newly created child sequences, add them to the sequence group
  537. # and fork them in block manager if they are not finished.
  538. for seq, parent in child_seqs:
  539. if seq is not parent:
  540. seq_group.add(seq)
  541. if not seq.is_finished():
  542. self.scheduler.fork_seq(parent, seq)
  543. # Free the finished and selected parent sequences' memory in block
  544. # manager. Keep them in the sequence group as candidate output.
  545. # NOTE: we need to fork the new sequences before freeing the
  546. # old sequences.
  547. for seq, parent in child_seqs:
  548. if seq is parent and seq.is_finished():
  549. self.scheduler.free_seq(seq)
  550. return
  551. # Beam search case
  552. # Select the child sequences to keep in the sequence group.
  553. selected_child_seqs = []
  554. unselected_child_seqs = []
  555. beam_width = seq_group.sampling_params.best_of
  556. length_penalty = seq_group.sampling_params.length_penalty
  557. # Select the newly finished sequences with the highest scores
  558. # to replace existing finished sequences.
  559. # Tuple of (seq, parent, is_new)
  560. existing_finished_seqs = [(seq, None, False)
  561. for seq in existing_finished_seqs]
  562. new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
  563. if seq.is_finished()]
  564. all_finished_seqs = existing_finished_seqs + new_finished_seqs
  565. # Sort the finished sequences by their scores.
  566. all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
  567. length_penalty=length_penalty,
  568. eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id),
  569. reverse=True)
  570. for seq, parent, is_new in all_finished_seqs[:beam_width]:
  571. if is_new:
  572. # A newly generated child sequence finishes and has a high
  573. # score, so we will add it into the sequence group.
  574. selected_child_seqs.append((seq, parent))
  575. for seq, parent, is_new in all_finished_seqs[beam_width:]:
  576. if is_new:
  577. # A newly generated child sequence finishes but has a low
  578. # score, so we will not add it into the sequence group.
  579. # Additionally, if this sequence is a continuation of a
  580. # parent sequence, we will need remove the parent sequence
  581. # from the sequence group.
  582. unselected_child_seqs.append((seq, parent))
  583. else:
  584. # An existing finished sequence has a low score, so we will
  585. # remove it from the sequence group.
  586. seq_group.remove(seq.seq_id)
  587. # select the top beam_width sequences from the running
  588. # sequences for the next iteration to continue the beam
  589. # search.
  590. running_child_seqs = [(seq, parent) for seq, parent in child_seqs
  591. if not seq.is_finished()]
  592. # Sort the running sequences by their scores.
  593. running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
  594. length_penalty=length_penalty,
  595. eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id),
  596. reverse=True)
  597. # Check if we can stop the beam search.
  598. if len(running_child_seqs) == 0:
  599. # No running sequences, stop the beam search.
  600. stop_beam_search = True
  601. elif len(all_finished_seqs) < beam_width:
  602. # Not enough finished sequences, continue the beam search.
  603. stop_beam_search = False
  604. else:
  605. # Check the early stopping criteria
  606. best_running_seq = running_child_seqs[0][0]
  607. current_worst_seq = all_finished_seqs[beam_width - 1][0]
  608. stop_beam_search = self._check_beam_search_early_stopping(
  609. seq_group.sampling_params.early_stopping,
  610. seq_group.sampling_params, best_running_seq, current_worst_seq)
  611. if stop_beam_search:
  612. # Stop the beam search and remove all the running sequences from
  613. # the sequence group.
  614. unselected_child_seqs.extend(running_child_seqs)
  615. else:
  616. # Continue the beam search and select the top beam_width sequences
  617. # to continue the beam search.
  618. selected_child_seqs.extend(running_child_seqs[:beam_width])
  619. # The remaining running sequences will not be used in the next
  620. # iteration. Again, if these sequences are continuations of
  621. # parent sequences, we will need to remove the parent sequences
  622. # from the sequence group.
  623. unselected_child_seqs.extend(running_child_seqs[beam_width:])
  624. # For newly created child sequences, add them to the sequence group
  625. # and fork them in block manager if they are not finished.
  626. for seq, parent in selected_child_seqs:
  627. if seq is not parent:
  628. seq_group.add(seq)
  629. if not seq.is_finished():
  630. self.scheduler.fork_seq(parent, seq)
  631. # Free the finished and selected parent sequences' memory in block
  632. # manager. Keep them in the sequence group as candidate output.
  633. for seq, parent in selected_child_seqs:
  634. if seq is parent and seq.is_finished():
  635. self.scheduler.free_seq(seq)
  636. # Remove the unselected parent sequences from the sequence group and
  637. # free their memory in block manager.
  638. for seq, parent in unselected_child_seqs:
  639. if seq is parent:
  640. # Remove the parent sequence if it is not selected for next
  641. # iteration
  642. seq_group.remove(seq.seq_id)
  643. self.scheduler.free_seq(seq)
  644. def _process_model_outputs(
  645. self, output: SamplerOutput,
  646. scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
  647. # Update the scheduled sequence groups with the model outputs.
  648. scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
  649. # If prefix caching is enabled, mark all blocks in the sequence groups
  650. # as completed so that future requests don't attempt to recompute them
  651. if self.cache_config.context_shift:
  652. for seq_group in scheduled_seq_groups:
  653. self.scheduler.mark_blocks_as_computed(seq_group)
  654. for seq_group, outputs in zip(scheduled_seq_groups, output):
  655. self._process_sequence_group_outputs(seq_group, outputs)
  656. # Free the finished sequence groups.
  657. self.scheduler.free_finished_seq_groups()
  658. # Create the outputs.
  659. request_outputs: List[RequestOutput] = []
  660. for seq_group in scheduled_seq_groups:
  661. request_output = RequestOutput.from_seq_group(seq_group)
  662. request_outputs.append(request_output)
  663. for seq_group in scheduler_outputs.ignored_seq_groups:
  664. request_output = RequestOutput.from_seq_group(seq_group)
  665. request_outputs.append(request_output)
  666. # Log stats.
  667. if self.log_stats:
  668. self.stat_logger.log(self._get_stats(scheduler_outputs))
  669. return request_outputs
  670. def step(self) -> List[RequestOutput]:
  671. """Performs one decoding iteration and returns newly generated results.
  672. .. figure:: https://i.imgur.com/sv2HssD.png
  673. :alt: Overview of the step function
  674. :align: center
  675. Overview of the step function.
  676. Details:
  677. - Step 1: Schedules the sequences to be executed in the next
  678. iteration and the token blocks to be swapped in/out/copy.
  679. - Depending on the scheduling policy,
  680. sequences may be `preempted/reordered`.
  681. - A Sequence Group (SG) refer to a group of sequences
  682. that are generated from the same prompt.
  683. - Step 2: Calls the workers to execute the model.
  684. - Step 3: Processes the model output. This mainly includes:
  685. - Decodes the relevant outputs.
  686. - Updates the scheduled sequence groups with model outputs
  687. based on its `sampling parameters` (`use_beam_search` or not).
  688. - Frees the finished sequence groups.
  689. - Finally, it creates and returns the newly generated results.
  690. Example:
  691. >>> # Please see the example/ folder for more detailed examples.
  692. >>>
  693. >>> # initialize engine and request arguments
  694. >>> engine = AphroditeEngine.from_engine_args(engine_args)
  695. >>> example_inputs = [(0, "What is LLM?",
  696. >>> SamplingParams(temperature=0.0))]
  697. >>>
  698. >>> # Start the engine with an event loop
  699. >>> while True:
  700. >>> if example_inputs:
  701. >>> req_id, prompt, sampling_params = example_inputs.pop(0)
  702. >>> engine.add_request(str(req_id), prompt, sampling_params)
  703. >>>
  704. >>> # continue the request processing
  705. >>> request_outputs = engine.step()
  706. >>> for request_output in request_outputs:
  707. >>> if request_output.finished:
  708. >>> # return or show the request output
  709. >>>
  710. >>> if not (engine.has_unfinished_requests() or example_inputs):
  711. >>> break
  712. """
  713. seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
  714. if not scheduler_outputs.is_empty():
  715. # Execute the model.
  716. all_outputs = self._run_workers(
  717. "execute_model",
  718. driver_kwargs={
  719. "seq_group_metadata_list": seq_group_metadata_list,
  720. "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
  721. "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
  722. "blocks_to_copy": scheduler_outputs.blocks_to_copy,
  723. })
  724. # Only the driver worker returns the sampling results.
  725. output = all_outputs[0]
  726. else:
  727. output = []
  728. return self._process_model_outputs(output, scheduler_outputs)
  729. def do_log_stats(self) -> None:
  730. """Forced log when no requests active."""
  731. if self.log_stats:
  732. self.stat_logger.log(self._get_stats(scheduler_outputs=None))
  733. def _get_stats(self,
  734. scheduler_outputs: Optional[SchedulerOutputs]) -> Stats:
  735. """Get Stats to be Logged to Prometheus."""
  736. now = time.monotonic()
  737. # KV Cache Usage in %.
  738. num_total_gpu = self.cache_config.num_gpu_blocks
  739. num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks()
  740. gpu_cache_usage = 1.0 - (num_free_gpu / num_total_gpu)
  741. num_total_cpu = self.cache_config.num_cpu_blocks
  742. cpu_cache_usage = 0.
  743. if num_total_cpu > 0:
  744. num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
  745. )
  746. cpu_cache_usage = 1.0 - (num_free_cpu / num_total_cpu)
  747. # Scheduler State
  748. num_running = len(self.scheduler.running)
  749. num_swapped = len(self.scheduler.swapped)
  750. num_waiting = len(self.scheduler.waiting)
  751. # Iteration stats if we have scheduler output.
  752. num_prompt_tokens = 0
  753. num_generation_tokens = 0
  754. time_to_first_tokens = []
  755. time_per_output_tokens = []
  756. time_e2e_requests = []
  757. if scheduler_outputs is not None:
  758. prompt_run = scheduler_outputs.prompt_run
  759. # Number of Tokens.
  760. if prompt_run:
  761. num_prompt_tokens = sum(
  762. len(seq_group.prompt_token_ids)
  763. for seq_group in scheduler_outputs.scheduled_seq_groups)
  764. num_generation_tokens = sum(
  765. seq_group.num_seqs()
  766. for seq_group in scheduler_outputs.scheduled_seq_groups)
  767. else:
  768. num_generation_tokens = scheduler_outputs.num_batched_tokens
  769. # Latency Timings.
  770. time_last_iters = []
  771. for seq_group in scheduler_outputs.scheduled_seq_groups:
  772. # Time since last token. (n.b. updates seq_group.metrics.last_token_time)
  773. time_last_iters.append(seq_group.get_last_latency(now))
  774. # Time since arrival for all finished requests.
  775. if seq_group.is_finished():
  776. time_e2e_requests.append(now -
  777. seq_group.metrics.arrival_time)
  778. time_to_first_tokens = time_last_iters if prompt_run else []
  779. time_per_output_tokens = [] if prompt_run else time_last_iters
  780. return Stats(
  781. now=now,
  782. num_running=num_running,
  783. num_swapped=num_swapped,
  784. num_waiting=num_waiting,
  785. gpu_cache_usage=gpu_cache_usage,
  786. cpu_cache_usage=cpu_cache_usage,
  787. num_prompt_tokens=num_prompt_tokens,
  788. num_generation_tokens=num_generation_tokens,
  789. time_to_first_tokens=time_to_first_tokens,
  790. time_per_output_tokens=time_per_output_tokens,
  791. time_e2e_requests=time_e2e_requests,
  792. )
  793. def _decode_logprobs(self, seq: Sequence, prms: SamplingParams,
  794. logprobs: Dict[int, Logprob],
  795. all_input_ids: List[int]) -> None:
  796. if not logprobs:
  797. return
  798. for token_id, sample_logprob in logprobs.items():
  799. if (sample_logprob.decoded_token is None and token_id != -1):
  800. all_input_ids_with_logprob = all_input_ids[:-1] + [token_id]
  801. # pylint: disable=unused-variable
  802. _, new_text, prefix_offset, read_offset = detokenize_incrementally(
  803. self.get_tokenizer_for_seq(seq),
  804. all_input_ids=all_input_ids_with_logprob,
  805. prev_tokens=seq.tokens,
  806. prefix_offset=seq.prefix_offset,
  807. read_offset=seq.read_offset,
  808. skip_special_tokens=prms.skip_special_tokens,
  809. spaces_between_special_tokens=prms.
  810. spaces_between_special_tokens,
  811. )
  812. sample_logprob.decoded_token = new_text
  813. def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
  814. """Decodes the new token for a sequence."""
  815. all_input_ids = seq.get_token_ids()
  816. self._decode_logprobs(seq, prms, seq.output_logprobs[-1],
  817. all_input_ids)
  818. (new_tokens, new_output_text, prefix_offset,
  819. read_offset) = detokenize_incrementally(
  820. self.get_tokenizer_for_seq(seq),
  821. all_input_ids=all_input_ids,
  822. prev_tokens=seq.tokens,
  823. prefix_offset=seq.prefix_offset,
  824. read_offset=seq.read_offset,
  825. skip_special_tokens=prms.skip_special_tokens,
  826. spaces_between_special_tokens=prms.spaces_between_special_tokens,
  827. )
  828. if seq.tokens is None:
  829. seq.tokens = new_tokens
  830. else:
  831. seq.tokens.extend(new_tokens)
  832. seq.prefix_offset = prefix_offset
  833. seq.read_offset = read_offset
  834. seq.output_text += new_output_text
  835. def _check_stop(self, seq: Sequence,
  836. sampling_params: SamplingParams) -> None:
  837. """Stop the finished sequences."""
  838. for stop_str in sampling_params.stop:
  839. if seq.output_text.endswith(stop_str):
  840. self._finalize_sequence(seq, sampling_params, stop_str)
  841. seq.status = SequenceStatus.FINISHED_STOPPED
  842. return
  843. if seq.get_last_token_id() in sampling_params.stop_token_ids:
  844. stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens(
  845. seq.get_last_token_id())
  846. self._finalize_sequence(seq, sampling_params, stop_str)
  847. seq.status = SequenceStatus.FINISHED_STOPPED
  848. return
  849. # Check if the sequence has reached max_model_len.
  850. if seq.get_len() > self.scheduler_config.max_model_len:
  851. seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
  852. return
  853. # Check if the sequence has reached max_tokens.
  854. if seq.get_output_len() == sampling_params.max_tokens:
  855. seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
  856. return
  857. # Check if the sequence has generated the EOS token.
  858. if ((not sampling_params.ignore_eos) and seq.get_last_token_id()
  859. == self.get_tokenizer_for_seq(seq).eos_token_id):
  860. seq.status = SequenceStatus.FINISHED_STOPPED
  861. return
  862. def _finalize_sequence(self, seq: Sequence,
  863. sampling_params: SamplingParams,
  864. stop_string: str) -> None:
  865. if not sampling_params.include_stop_str_in_output and stop_string:
  866. # Truncate the output text so that the stop string is
  867. # not included in the output.
  868. seq.output_text = seq.output_text[:-len(stop_string)]
  869. def add_lora(self, lora_request: LoRARequest) -> bool:
  870. assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
  871. return self._run_workers(
  872. "add_lora",
  873. lora_request=lora_request,
  874. )
  875. def remove_lora(self, lora_id: int) -> bool:
  876. assert lora_id > 0, "lora_id must be greater than 0."
  877. return self._run_workers(
  878. "remove_lora",
  879. lora_id=lora_id,
  880. )
  881. def list_loras(self) -> List[int]:
  882. return self._run_workers("list_loras")
  883. def _run_workers(
  884. self,
  885. method: str,
  886. *args,
  887. driver_args: Optional[List[Any]] = None,
  888. driver_kwargs: Optional[Dict[str, Any]] = None,
  889. max_concurrent_workers: Optional[int] = None,
  890. **kwargs,
  891. ) -> Any:
  892. """Runs the given method on all workers."""
  893. if max_concurrent_workers:
  894. raise NotImplementedError(
  895. "max_concurrent_workers is not supported yet.")
  896. # Start the ray workers first.
  897. ray_worker_outputs = [
  898. worker.execute_method.remote(method, *args, **kwargs)
  899. for worker in self.workers
  900. ]
  901. if driver_args is None:
  902. driver_args = args
  903. if driver_kwargs is None:
  904. driver_kwargs = kwargs
  905. # Start the driver worker after all the ray workers.
  906. driver_worker_output = getattr(self.driver_worker,
  907. method)(*driver_args, **driver_kwargs)
  908. # Get the results of the ray workers.
  909. if self.workers:
  910. ray_worker_outputs = ray.get(ray_worker_outputs)
  911. return [driver_worker_output] + ray_worker_outputs
  912. def check_health(self) -> None:
  913. """Raises an error if engine is unhealthy."""
  914. self._check_if_any_actor_is_dead()
  915. def _check_if_any_actor_is_dead(self):
  916. if not self.parallel_config.worker_use_ray:
  917. return
  918. if not self.workers:
  919. return
  920. dead_actors = []
  921. for actor in self.workers:
  922. actor_state = ray.state.actors(actor._ray_actor_id.hex())
  923. if actor_state["State"] == "DEAD":
  924. dead_actors.append(actor)
  925. if dead_actors:
  926. raise RuntimeError("At least one Worker is dead. "
  927. f"Dead workers: {dead_actors}")
  928. setup_logger()