aphrodite_engine.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792
  1. import copy
  2. import os
  3. import time
  4. from functools import partial
  5. from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
  6. import psutil
  7. from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig,
  8. SchedulerConfig)
  9. from aphrodite.processing.scheduler import Scheduler, SchedulerOutputs
  10. from aphrodite.engine.args_tools import EngineArgs
  11. from aphrodite.engine.metrics import record_metrics
  12. from aphrodite.engine.ray_tools import RayWorker, initialize_cluster, ray
  13. from aphrodite.common.logger import init_logger
  14. from aphrodite.common.outputs import RequestOutput
  15. from aphrodite.common.sampling_params import SamplingParams
  16. from aphrodite.common.sequence import (SamplerOutput, Sequence, SequenceGroup,
  17. SequenceGroupOutput, SequenceOutput,
  18. SequenceStatus)
  19. from aphrodite.transformers_utils.tokenizer import (detokenize_incrementally,
  20. get_tokenizer)
  21. from aphrodite.common.utils import Counter
  22. if ray:
  23. from ray.air.util.torch_dist import init_torch_dist_process_group
  24. from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
  25. if TYPE_CHECKING:
  26. from ray.util.placement_group import PlacementGroup
  27. logger = init_logger(__name__)
  28. _LOGGING_INTERVAL_SEC = 5
  29. class AphroditeEngine:
  30. """An LLM engine that receives requests and generates texts.
  31. This is the main class for the Aphrodite engine. It receives requests
  32. from clients and generates texts from the LLM. It includes a tokenizer, a
  33. language model (possibly distributed across multiple GPUs), and GPU memory
  34. space allocated for intermediate states (aka KV cache). This class utilizes
  35. iteration-level scheduling and efficient memory management to maximize the
  36. serving throughput.
  37. The `LLM` class wraps this class for offline batched inference and the
  38. `AsyncAphrodite` class wraps this class for online serving.
  39. NOTE: The config arguments are derived from the `EngineArgs` class. For the
  40. comprehensive list of arguments, see `EngineArgs`.
  41. Args:
  42. model_config: The configuration related to the LLM model.
  43. cache_config: The configuration related to the KV cache memory
  44. management.
  45. parallel_config: The configuration related to distributed execution.
  46. scheduler_config: The configuration related to the request scheduler.
  47. distributed_init_method: The initialization method for distributed
  48. execution. See `torch.distributed.init_process_group` for details.
  49. stage_devices: The list of devices for each stage. Each stage is a list
  50. of (rank, node_resource, device) tuples.
  51. log_stats: Whether to log statistics.
  52. """
  53. def __init__(
  54. self,
  55. model_config: ModelConfig,
  56. cache_config: CacheConfig,
  57. parallel_config: ParallelConfig,
  58. scheduler_config: SchedulerConfig,
  59. distributed_init_method: str,
  60. placement_group: Optional["PlacementGroup"],
  61. log_stats: bool,
  62. ) -> None:
  63. logger.info(
  64. "Initializing the Aphrodite Engine with the following config:\n"
  65. f"Model = {model_config.model!r}\n"
  66. f"Tokenizer = {model_config.tokenizer!r}\n"
  67. f"tokenizer_mode = {model_config.tokenizer_mode}\n"
  68. f"revision = {model_config.revision}\n"
  69. f"trust_remote_code = {model_config.trust_remote_code}\n"
  70. f"DataType = {model_config.dtype}\n"
  71. f"Download Directory = {model_config.download_dir!r}\n"
  72. f"Model Load Format = {model_config.load_format}\n"
  73. f"Number of GPUs = {parallel_config.tensor_parallel_size}\n"
  74. f"Quantization Format = {model_config.quantization}\n"
  75. f"Sampler Seed = {model_config.seed}\n"
  76. f"Context Length = {model_config.max_model_len}\n"
  77. f"Enforce Eager Mode = {model_config.enforce_eager}\n"
  78. f"KV Cache DataType = {cache_config.cache_dtype}\n"
  79. f"Seed = {model_config.seed}")
  80. # TODO: Print more configs in debug mode.
  81. self.model_config = model_config
  82. self.cache_config = cache_config
  83. self.parallel_config = parallel_config
  84. self.scheduler_config = scheduler_config
  85. self.log_stats = log_stats
  86. self._verify_args()
  87. self.tokenizer = get_tokenizer(
  88. model_config.tokenizer,
  89. tokenizer_mode=model_config.tokenizer_mode,
  90. trust_remote_code=model_config.trust_remote_code,
  91. revision=model_config.revision)
  92. self.seq_counter = Counter()
  93. # Create the parallel GPU workers.
  94. if self.parallel_config.worker_use_ray:
  95. # Disable Ray usage stats collection.
  96. ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
  97. if ray_usage != "1":
  98. os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
  99. self._init_workers_ray(placement_group)
  100. else:
  101. self._init_workers(distributed_init_method)
  102. # Profile the memory usage and initialize the cache.
  103. self._init_cache()
  104. # Create the scheduler.
  105. self.scheduler = Scheduler(scheduler_config, cache_config)
  106. # Logging.
  107. self.last_logging_time = 0.0
  108. # List of (timestamp, num_tokens)
  109. self.num_prompt_tokens: List[Tuple[float, int]] = []
  110. # List of (timestamp, num_tokens)
  111. self.num_generation_tokens: List[Tuple[float, int]] = []
  112. def _init_workers(self, distributed_init_method: str):
  113. # Lazy import the Worker to avoid importing torch.cuda/xformers
  114. # before CUDA_VISIBLE_DEVICES is set in the Worker
  115. from aphrodite.task_handler.worker import Worker # pylint: disable=import-outside-toplevel
  116. assert self.parallel_config.world_size == 1, (
  117. "Ray is required if parallel_config.world_size > 1.")
  118. self.workers: List[Worker] = []
  119. worker = Worker(
  120. self.model_config,
  121. self.parallel_config,
  122. self.scheduler_config,
  123. 0,
  124. distributed_init_method,
  125. )
  126. self.workers.append(worker)
  127. self._run_workers(
  128. "init_model",
  129. get_all_outputs=True,
  130. )
  131. self._run_workers(
  132. "load_model",
  133. get_all_outputs=True,
  134. max_concurrent_workers=self.parallel_config.
  135. max_parallel_loading_workers,
  136. )
  137. def _init_workers_ray(self, placement_group: "PlacementGroup",
  138. **ray_remote_kwargs):
  139. # Lazy import the Worker to avoid importing torch.cuda/xformers
  140. # before CUDA_VISIBLE_DEVICES is set in the Worker
  141. from aphrodite.task_handler.worker import Worker # pylint: disable=import-outside-toplevel
  142. self.workers: List[Worker] = []
  143. for bundle in placement_group.bundle_specs:
  144. if not bundle.get("GPU", 0):
  145. continue
  146. if self.parallel_config.tensor_parallel_size == 1:
  147. num_gpus = self.cache_config.gpu_memory_utilization
  148. else:
  149. num_gpus = 1
  150. worker = ray.remote(
  151. num_cpus=0,
  152. num_gpus=num_gpus,
  153. scheduling_strategy=PlacementGroupSchedulingStrategy(
  154. placement_group=placement_group,
  155. placement_group_capture_child_tasks=True),
  156. **ray_remote_kwargs,
  157. )(RayWorker).remote(self.model_config.trust_remote_code)
  158. self.workers.append(worker)
  159. # Initialize torch distributed process group for the workers.
  160. init_torch_dist_process_group(self.workers, backend="nccl")
  161. model_config = copy.deepcopy(self.model_config)
  162. parallel_config = copy.deepcopy(self.parallel_config)
  163. scheduler_config = copy.deepcopy(self.scheduler_config)
  164. self._run_workers("init_worker",
  165. get_all_outputs=True,
  166. worker_init_fn=lambda: Worker(
  167. model_config,
  168. parallel_config,
  169. scheduler_config,
  170. None,
  171. None,
  172. ))
  173. self._run_workers(
  174. "init_model",
  175. get_all_outputs=True,
  176. )
  177. self._run_workers(
  178. "load_model",
  179. get_all_outputs=True,
  180. max_concurrent_workers=self.parallel_config.
  181. max_parallel_loading_workers,
  182. )
  183. # HACK
  184. # After running ray.init(), ray processes affinity is set to (0,1).
  185. # (or whatever the CPU scheduler fancies)
  186. # We however want the actual workers that are being used,
  187. # so we call here since calling after ray.init() and everything else.
  188. # We reassign each ray process by taking the
  189. # modulus of the number of cpu_cores available.
  190. # Issue: https://github.com/PygmalionAI/aphrodite-engine/issues/115
  191. # The solution is similar to the taskset solution linked above.
  192. current_process = psutil.Process()
  193. ray_threads = 0
  194. logical_cores = psutil.cpu_count(logical=True)
  195. physical_cores = psutil.cpu_count(logical=False)
  196. ht_scale = physical_cores / logical_cores
  197. for process in current_process.children(recursive=True):
  198. # process.pid
  199. if "ray::" in process.name():
  200. process.cpu_affinity([ray_threads])
  201. ray_threads += int(1 * ht_scale) if ht_scale > 1.0 else 1
  202. ray_threads = ray_threads % logical_cores
  203. def _verify_args(self) -> None:
  204. self.model_config.verify_with_parallel_config(self.parallel_config)
  205. self.cache_config.verify_with_parallel_config(self.parallel_config)
  206. def _init_cache(self) -> None:
  207. """Profiles the memory usage and initializes the KV cache."""
  208. # Get the maximum number of blocks that can be allocated on GPU and CPU.
  209. num_blocks = self._run_workers(
  210. "profile_num_available_blocks",
  211. get_all_outputs=True,
  212. block_size=self.cache_config.block_size,
  213. gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
  214. cpu_swap_space=self.cache_config.swap_space_bytes,
  215. cache_dtype=self.cache_config.cache_dtype,
  216. )
  217. # Since we use a shared centralized controller, we take the minimum
  218. # number of blocks across all workers to make sure all the memory
  219. # operators can be applied to all workers.
  220. num_gpu_blocks = min(b[0] for b in num_blocks)
  221. num_cpu_blocks = min(b[1] for b in num_blocks)
  222. # FIXME: Change to debug log.
  223. logger.info(f"# GPU blocks: {num_gpu_blocks}, "
  224. f"# CPU blocks: {num_cpu_blocks}")
  225. if num_gpu_blocks <= 0:
  226. raise ValueError("No available memory for the cache blocks. "
  227. "Try increasing `gpu_memory_utilization` when "
  228. "initializing the engine.")
  229. max_seq_len = self.cache_config.block_size * num_gpu_blocks
  230. if self.model_config.max_model_len > max_seq_len:
  231. raise ValueError(
  232. f"The model's max seq len ({self.model_config.max_model_len}) "
  233. "is larger than the maximum number of tokens that can be "
  234. f"stored in KV cache ({max_seq_len}). Try increasing "
  235. "`gpu_memory_utilization` or decreasing `max_model_len` when "
  236. "initializing the engine.")
  237. self.cache_config.num_gpu_blocks = num_gpu_blocks
  238. self.cache_config.num_cpu_blocks = num_cpu_blocks
  239. # Initialize the cache.
  240. self._run_workers("init_cache_engine", cache_config=self.cache_config)
  241. # Warm up the model. This includes capturing the model into CUDA graph
  242. # if enforce_eager is set to False.
  243. self._run_workers("warm_up_model")
  244. @classmethod
  245. def from_engine_args(cls, engine_args: EngineArgs) -> "AphroditeEngine":
  246. """Creates an LLM engine from the engine arguments."""
  247. # Create the engine configs.
  248. engine_configs = engine_args.create_engine_configs()
  249. parallel_config = engine_configs[2]
  250. # Initialize the cluster.
  251. distributed_init_method, placement_group = initialize_cluster(
  252. parallel_config)
  253. # Create the LLM engine.
  254. engine = cls(*engine_configs,
  255. distributed_init_method,
  256. placement_group,
  257. log_stats=not engine_args.disable_log_stats)
  258. return engine
  259. def add_request(
  260. self,
  261. request_id: str,
  262. prompt: Optional[str],
  263. sampling_params: SamplingParams,
  264. prompt_token_ids: Optional[List[int]] = None,
  265. arrival_time: Optional[float] = None,
  266. ) -> None:
  267. """Add a request to the engine's request pool.
  268. The request is added to the request pool and will be processed by the
  269. scheduler as `engine.step()` is called. The exact scheduling policy is
  270. determined by the scheduler.
  271. Args:
  272. request_id: The unique ID of the request.
  273. prompt: The prompt string. Can be None if prompt_token_ids is
  274. provided.
  275. sampling_params: The sampling parameters for text generation.
  276. prompt_token_ids: The token IDs of the prompt. If None, we
  277. use the tokenizer to convert the prompts to token IDs.
  278. arrival_time: The arrival time of the request. If None, we use
  279. the current time.
  280. """
  281. if arrival_time is None:
  282. arrival_time = time.time()
  283. if prompt_token_ids is None:
  284. assert prompt is not None
  285. prompt_token_ids = self.tokenizer.encode(prompt)
  286. # Create the sequences.
  287. block_size = self.cache_config.block_size
  288. seq_id = next(self.seq_counter)
  289. seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
  290. # Create the sequence group.
  291. seq_group = SequenceGroup(request_id, [seq], sampling_params,
  292. arrival_time)
  293. # Add the sequence group to the scheduler.
  294. self.scheduler.add_seq_group(seq_group)
  295. def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
  296. """Aborts a request(s) with the given ID.
  297. Args:
  298. request_id: The ID(s) of the request to abort.
  299. """
  300. self.scheduler.abort_seq_group(request_id)
  301. def get_model_config(self) -> ModelConfig:
  302. """Gets the model configuration."""
  303. return self.model_config
  304. def get_num_unfinished_requests(self) -> int:
  305. """Gets the number of unfinished requests."""
  306. return self.scheduler.get_num_unfinished_seq_groups()
  307. def has_unfinished_requests(self) -> bool:
  308. """Returns True if there are unfinished requests."""
  309. return self.scheduler.has_unfinished_seqs()
  310. def _check_beam_search_early_stopping(
  311. self,
  312. early_stopping: Union[bool, str],
  313. sampling_params: SamplingParams,
  314. best_running_seq: Sequence,
  315. current_worst_seq: Sequence,
  316. ) -> bool:
  317. assert sampling_params.use_beam_search
  318. length_penalty = sampling_params.length_penalty
  319. if early_stopping is True:
  320. return True
  321. current_worst_score = (current_worst_seq.get_beam_search_score(
  322. length_penalty=length_penalty,
  323. eos_token_id=self.tokenizer.eos_token_id))
  324. if early_stopping is False:
  325. highest_attainable_score = (best_running_seq.get_beam_search_score(
  326. length_penalty=length_penalty,
  327. eos_token_id=self.tokenizer.eos_token_id))
  328. else:
  329. assert early_stopping == "never"
  330. if length_penalty > 0.0:
  331. # If length_penalty > 0.0, beam search will prefer longer
  332. # sequences. The highest attainable score calculation is
  333. # based on the longest possible sequence length in this case.
  334. max_possible_length = max(
  335. best_running_seq.get_prompt_len() +
  336. sampling_params.max_tokens,
  337. self.scheduler_config.max_model_len)
  338. highest_attainable_score = (
  339. best_running_seq.get_beam_search_score(
  340. length_penalty=length_penalty,
  341. eos_token_id=self.tokenizer.eos_token_id,
  342. seq_len=max_possible_length))
  343. else:
  344. # Otherwise, beam search will prefer shorter sequences. The
  345. # highest attainable score calculation is based on the current
  346. # sequence length.
  347. highest_attainable_score = (
  348. best_running_seq.get_beam_search_score(
  349. length_penalty=length_penalty,
  350. eos_token_id=self.tokenizer.eos_token_id))
  351. return current_worst_score >= highest_attainable_score
  352. def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
  353. outputs: SequenceGroupOutput) -> None:
  354. # Process prompt logprobs
  355. prompt_logprobs = outputs.prompt_logprobs
  356. if prompt_logprobs is not None:
  357. seq_group.prompt_logprobs = prompt_logprobs
  358. # Process samples
  359. samples = outputs.samples
  360. parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
  361. existing_finished_seqs = seq_group.get_finished_seqs()
  362. parent_child_dict = {
  363. parent_seq.seq_id: []
  364. for parent_seq in parent_seqs
  365. }
  366. for sample in samples:
  367. parent_child_dict[sample.parent_seq_id].append(sample)
  368. # List of (child, parent)
  369. child_seqs: List[Tuple[Sequence, Sequence]] = []
  370. # Process the child samples for each parent sequence
  371. for parent in parent_seqs:
  372. child_samples: List[SequenceOutput] = parent_child_dict[
  373. parent.seq_id]
  374. if len(child_samples) == 0:
  375. # This parent sequence has no children samples. Remove
  376. # the parent sequence from the sequence group since it will
  377. # not be used in the future iterations.
  378. parent.status = SequenceStatus.FINISHED_ABORTED
  379. seq_group.remove(parent.seq_id)
  380. self.scheduler.free_seq(parent)
  381. continue
  382. # Fork the parent sequence if there are multiple child samples.
  383. for child_sample in child_samples[:-1]:
  384. new_child_seq_id = next(self.seq_counter)
  385. child = parent.fork(new_child_seq_id)
  386. child.append_token_id(child_sample.output_token,
  387. child_sample.logprobs)
  388. child.persistent_data = child_sample.persistent_data
  389. child_seqs.append((child, parent))
  390. # Continue the parent sequence for the last child sample.
  391. # We reuse the parent sequence here to reduce redundant memory
  392. # copies, especially when using non-beam search sampling methods.
  393. last_child_sample = child_samples[-1]
  394. parent.append_token_id(last_child_sample.output_token,
  395. last_child_sample.logprobs)
  396. parent.persistent_data = last_child_sample.persistent_data
  397. child_seqs.append((parent, parent))
  398. for seq, _ in child_seqs:
  399. self._decode_sequence(seq, seq_group.sampling_params)
  400. self._check_stop(seq, seq_group.sampling_params)
  401. # Non-beam search case
  402. if not seq_group.sampling_params.use_beam_search:
  403. # For newly created child sequences, add them to the sequence group
  404. # and fork them in block manager if they are not finished.
  405. for seq, parent in child_seqs:
  406. if seq is not parent:
  407. seq_group.add(seq)
  408. if not seq.is_finished():
  409. self.scheduler.fork_seq(parent, seq)
  410. # Free the finished and selected parent sequences' memory in block
  411. # manager. Keep them in the sequence group as candidate output.
  412. # NOTE: we need to fork the new sequences before freeing the
  413. # old sequences.
  414. for seq, parent in child_seqs:
  415. if seq is parent and seq.is_finished():
  416. self.scheduler.free_seq(seq)
  417. return
  418. # Beam search case
  419. # Select the child sequences to keep in the sequence group.
  420. selected_child_seqs = []
  421. unselected_child_seqs = []
  422. beam_width = seq_group.sampling_params.best_of
  423. length_penalty = seq_group.sampling_params.length_penalty
  424. # Select the newly finished sequences with the highest scores
  425. # to replace existing finished sequences.
  426. # Tuple of (seq, parent, is_new)
  427. existing_finished_seqs = [(seq, None, False)
  428. for seq in existing_finished_seqs]
  429. new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
  430. if seq.is_finished()]
  431. all_finished_seqs = existing_finished_seqs + new_finished_seqs
  432. # Sort the finished sequences by their scores.
  433. all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
  434. length_penalty=length_penalty,
  435. eos_token_id=self.tokenizer.eos_token_id),
  436. reverse=True)
  437. for seq, parent, is_new in all_finished_seqs[:beam_width]:
  438. if is_new:
  439. # A newly generated child sequence finishes and has a high
  440. # score, so we will add it into the sequence group.
  441. selected_child_seqs.append((seq, parent))
  442. for seq, parent, is_new in all_finished_seqs[beam_width:]:
  443. if is_new:
  444. # A newly generated child sequence finishes but has a low
  445. # score, so we will not add it into the sequence group.
  446. # Additionally, if this sequence is a continuation of a
  447. # parent sequence, we will need remove the parent sequence
  448. # from the sequence group.
  449. unselected_child_seqs.append((seq, parent))
  450. else:
  451. # An existing finished sequence has a low score, so we will
  452. # remove it from the sequence group.
  453. seq_group.remove(seq.seq_id)
  454. # select the top beam_width sequences from the running
  455. # sequences for the next iteration to continue the beam
  456. # search.
  457. running_child_seqs = [(seq, parent) for seq, parent in child_seqs
  458. if not seq.is_finished()]
  459. # Sort the running sequences by their scores.
  460. running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
  461. length_penalty=length_penalty,
  462. eos_token_id=self.tokenizer.eos_token_id),
  463. reverse=True)
  464. # Check if we can stop the beam search.
  465. if len(running_child_seqs) == 0:
  466. # No running sequences, stop the beam search.
  467. stop_beam_search = True
  468. elif len(all_finished_seqs) < beam_width:
  469. # Not enough finished sequences, continue the beam search.
  470. stop_beam_search = False
  471. else:
  472. # Check the early stopping criteria
  473. best_running_seq = running_child_seqs[0][0]
  474. current_worst_seq = all_finished_seqs[beam_width - 1][0]
  475. stop_beam_search = self._check_beam_search_early_stopping(
  476. seq_group.sampling_params.early_stopping,
  477. seq_group.sampling_params, best_running_seq, current_worst_seq)
  478. if stop_beam_search:
  479. # Stop the beam search and remove all the running sequences from
  480. # the sequence group.
  481. unselected_child_seqs.extend(running_child_seqs)
  482. else:
  483. # Continue the beam search and select the top beam_width sequences
  484. # to continue the beam search.
  485. selected_child_seqs.extend(running_child_seqs[:beam_width])
  486. # The remaining running sequences will not be used in the next
  487. # iteration. Again, if these sequences are continuations of
  488. # parent sequences, we will need to remove the parent sequences
  489. # from the sequence group.
  490. unselected_child_seqs.extend(running_child_seqs[beam_width:])
  491. # For newly created child sequences, add them to the sequence group
  492. # and fork them in block manager if they are not finished.
  493. for seq, parent in selected_child_seqs:
  494. if seq is not parent:
  495. seq_group.add(seq)
  496. if not seq.is_finished():
  497. self.scheduler.fork_seq(parent, seq)
  498. # Free the finished and selected parent sequences' memory in block
  499. # manager. Keep them in the sequence group as candidate output.
  500. for seq, parent in selected_child_seqs:
  501. if seq is parent and seq.is_finished():
  502. self.scheduler.free_seq(seq)
  503. # Remove the unselected parent sequences from the sequence group and
  504. # free their memory in block manager.
  505. for seq, parent in unselected_child_seqs:
  506. if seq is parent:
  507. # Remove the parent sequence if it is not selected for next
  508. # iteration
  509. seq_group.remove(seq.seq_id)
  510. self.scheduler.free_seq(seq)
  511. def _process_model_outputs(
  512. self, output: SamplerOutput,
  513. scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
  514. # Update the scheduled sequence groups with the model outputs.
  515. scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
  516. for seq_group, outputs in zip(scheduled_seq_groups, output):
  517. self._process_sequence_group_outputs(seq_group, outputs)
  518. # Free the finished sequence groups.
  519. self.scheduler.free_finished_seq_groups()
  520. # Create the outputs.
  521. request_outputs: List[RequestOutput] = []
  522. for seq_group in (scheduled_seq_groups +
  523. scheduler_outputs.ignored_seq_groups):
  524. request_output = RequestOutput.from_seq_group(seq_group)
  525. request_outputs.append(request_output)
  526. if self.log_stats:
  527. # Log the system stats.
  528. self._log_system_stats(scheduler_outputs.prompt_run,
  529. scheduler_outputs.num_batched_tokens)
  530. return request_outputs
  531. def step(self) -> List[RequestOutput]:
  532. """Performs one decoding iteration and returns newly generated results.
  533. This function performs one decoding iteration of the engine. It first
  534. schedules the sequences to be executed in the next iteration and the
  535. token blocks to be swapped in/out/copy. Then, it executes the model
  536. and updates the scheduler with the model outputs. Finally, it decodes
  537. the sequences and returns the newly generated results.
  538. """
  539. seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
  540. # Execute the model.
  541. output = self._run_workers(
  542. "execute_model",
  543. seq_group_metadata_list=seq_group_metadata_list,
  544. blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
  545. blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
  546. blocks_to_copy=scheduler_outputs.blocks_to_copy,
  547. ) if not scheduler_outputs.is_empty() else []
  548. return self._process_model_outputs(output, scheduler_outputs)
  549. def _log_system_stats(
  550. self,
  551. prompt_run: bool,
  552. num_batched_tokens: int,
  553. ) -> None:
  554. now = time.time()
  555. # Log the number of batched input tokens.
  556. if prompt_run:
  557. self.num_prompt_tokens.append((now, num_batched_tokens))
  558. else:
  559. self.num_generation_tokens.append((now, num_batched_tokens))
  560. should_log = now - self.last_logging_time >= _LOGGING_INTERVAL_SEC
  561. if not should_log:
  562. return
  563. # Discard the old stats.
  564. self.num_prompt_tokens = [(t, n) for t, n in self.num_prompt_tokens
  565. if now - t < _LOGGING_INTERVAL_SEC]
  566. self.num_generation_tokens = [(t, n)
  567. for t, n in self.num_generation_tokens
  568. if now - t < _LOGGING_INTERVAL_SEC]
  569. if len(self.num_prompt_tokens) > 1:
  570. total_num_tokens = sum(n for _, n in self.num_prompt_tokens[:-1])
  571. window = now - self.num_prompt_tokens[0][0]
  572. avg_prompt_throughput = total_num_tokens / window
  573. else:
  574. avg_prompt_throughput = 0.0
  575. if len(self.num_generation_tokens) > 1:
  576. total_num_tokens = sum(n
  577. for _, n in self.num_generation_tokens[:-1])
  578. window = now - self.num_generation_tokens[0][0]
  579. avg_generation_throughput = total_num_tokens / window
  580. else:
  581. avg_generation_throughput = 0.0
  582. total_num_gpu_blocks = self.cache_config.num_gpu_blocks
  583. num_free_gpu_blocks = (
  584. self.scheduler.block_manager.get_num_free_gpu_blocks())
  585. num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
  586. gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks
  587. total_num_cpu_blocks = self.cache_config.num_cpu_blocks
  588. if total_num_cpu_blocks > 0:
  589. num_free_cpu_blocks = (
  590. self.scheduler.block_manager.get_num_free_cpu_blocks())
  591. num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
  592. cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
  593. else:
  594. cpu_cache_usage = 0.0
  595. record_metrics(
  596. avg_prompt_throughput=avg_prompt_throughput,
  597. avg_generation_throughput=avg_generation_throughput,
  598. scheduler_running=len(self.scheduler.running),
  599. scheduler_swapped=len(self.scheduler.swapped),
  600. scheduler_waiting=len(self.scheduler.waiting),
  601. gpu_cache_usage=gpu_cache_usage,
  602. cpu_cache_usage=cpu_cache_usage,
  603. )
  604. logger.info("Avg prompt throughput: "
  605. f"{avg_prompt_throughput:.1f} tokens/s, "
  606. "Avg generation throughput: "
  607. f"{avg_generation_throughput:.1f} tokens/s, "
  608. f"Running: {len(self.scheduler.running)} reqs, "
  609. f"Swapped: {len(self.scheduler.swapped)} reqs, "
  610. f"Pending: {len(self.scheduler.waiting)} reqs, "
  611. f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
  612. f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
  613. self.last_logging_time = now
  614. def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
  615. """Decodes the new token for a sequence."""
  616. (new_tokens, new_output_text, prefix_offset,
  617. read_offset) = detokenize_incrementally(
  618. self.tokenizer,
  619. all_input_ids=seq.get_token_ids(),
  620. prev_tokens=seq.tokens,
  621. prefix_offset=seq.prefix_offset,
  622. read_offset=seq.read_offset,
  623. skip_special_tokens=prms.skip_special_tokens,
  624. spaces_between_special_tokens=prms.spaces_between_special_tokens)
  625. if seq.tokens is None:
  626. seq.tokens = new_tokens
  627. else:
  628. seq.tokens.extend(new_tokens)
  629. seq.prefix_offset = prefix_offset
  630. seq.read_offset = read_offset
  631. seq.output_text += new_output_text
  632. def _check_stop(self, seq: Sequence,
  633. sampling_params: SamplingParams) -> None:
  634. """Stop the finished sequences."""
  635. for stop_str in sampling_params.stop:
  636. if seq.output_text.endswith(stop_str):
  637. if not sampling_params.include_stop_str_in_output:
  638. # Truncate the output text so that the stop string is
  639. # not included in the output
  640. seq.output_text = seq.output_text[:-len(stop_str)]
  641. seq.status = SequenceStatus.FINISHED_STOPPED
  642. return
  643. if seq.get_last_token_id() in sampling_params.stop_token_ids:
  644. seq.status = SequenceStatus.FINISHED_STOPPED
  645. return
  646. # Check if the sequence has reached max_model_len.
  647. if seq.get_len() > self.scheduler_config.max_model_len:
  648. seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
  649. return
  650. # Check if the sequence has reached max_tokens.
  651. if seq.get_output_len() == sampling_params.max_tokens:
  652. seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
  653. return
  654. # Check if the sequence has generated the EOS token.
  655. if ((not sampling_params.ignore_eos)
  656. and seq.get_last_token_id() == self.tokenizer.eos_token_id):
  657. seq.status = SequenceStatus.FINISHED_STOPPED
  658. return
  659. def _run_workers_in_batch(
  660. self,
  661. workers,
  662. method: str,
  663. *args,
  664. **kwargs,
  665. ):
  666. all_outputs = []
  667. for worker in workers:
  668. if self.parallel_config.worker_use_ray:
  669. executor = partial(worker.execute_method.remote, method)
  670. else:
  671. executor = getattr(worker, method)
  672. output = executor(*args, **kwargs)
  673. all_outputs.append(output)
  674. if self.parallel_config.worker_use_ray:
  675. all_outputs = ray.get(all_outputs)
  676. return all_outputs
  677. def _run_workers(
  678. self,
  679. method: str,
  680. *args,
  681. get_all_outputs: bool = False,
  682. max_concurrent_workers: Optional[int] = None,
  683. **kwargs,
  684. ) -> Any:
  685. """Runs a method on all workers."""
  686. all_outputs = []
  687. if max_concurrent_workers:
  688. work_groups = [
  689. self.workers[i:i + max_concurrent_workers]
  690. for i in range(0, len(self.workers), max_concurrent_workers)
  691. ]
  692. else:
  693. work_groups = [self.workers]
  694. for workers in work_groups:
  695. all_outputs.extend(
  696. self._run_workers_in_batch(workers, method, *args, **kwargs))
  697. if get_all_outputs:
  698. return all_outputs
  699. # Make sure all workers have the same results.
  700. output = all_outputs[0]
  701. for other_output in all_outputs[1:]:
  702. assert output == other_output
  703. return output