aphrodite_engine.py 45 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105
  1. import os
  2. import time
  3. from contextlib import contextmanager
  4. from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional
  5. from typing import Sequence as GenericSequence
  6. from typing import Type, TypeVar, Union
  7. from loguru import logger
  8. from transformers import PreTrainedTokenizer
  9. from aphrodite.common.config import (CacheConfig, DecodingConfig, DeviceConfig,
  10. EngineConfig, LoadConfig, LoRAConfig,
  11. ModelConfig, MultiModalConfig,
  12. ParallelConfig, PromptAdapterConfig,
  13. SchedulerConfig, SpeculativeConfig)
  14. from aphrodite.common.logger import setup_logger
  15. from aphrodite.common.outputs import (EmbeddingRequestOutput, RequestOutput,
  16. RequestOutputFactory)
  17. from aphrodite.common.pooling_params import PoolingParams
  18. from aphrodite.common.sampling_params import SamplingParams
  19. from aphrodite.common.sequence import (EmbeddingSequenceGroupOutput,
  20. ExecuteModelRequest, PoolerOutput,
  21. SamplerOutput, Sequence, SequenceGroup,
  22. SequenceGroupMetadata, SequenceStatus)
  23. from aphrodite.common.utils import Counter
  24. from aphrodite.engine.args_tools import EngineArgs
  25. from aphrodite.engine.metrics import (LoggingStatLogger, PrometheusStatLogger,
  26. StatLoggerBase, Stats)
  27. from aphrodite.engine.output_processor.interfaces import (
  28. SequenceGroupOutputProcessor)
  29. from aphrodite.engine.output_processor.stop_checker import StopChecker
  30. from aphrodite.engine.output_processor.util import (
  31. create_output_by_sequence_group)
  32. from aphrodite.executor.executor_base import ExecutorBase
  33. from aphrodite.executor.ray_utils import initialize_ray_cluster
  34. from aphrodite.inputs import INPUT_REGISTRY, LLMInputs, PromptInputs
  35. from aphrodite.lora.request import LoRARequest
  36. from aphrodite.processing.scheduler import (ScheduledSequenceGroup, Scheduler,
  37. SchedulerOutputs)
  38. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  39. from aphrodite.transformers_utils.config import try_get_generation_config
  40. from aphrodite.transformers_utils.detokenizer import Detokenizer
  41. from aphrodite.transformers_utils.tokenizer_group import (
  42. BaseTokenizerGroup, init_tokenizer_from_configs)
  43. from aphrodite.version import __version__ as APHRODITE_VERSION
  44. _LOCAL_LOGGING_INTERVAL_SEC = 5
  45. APHRODITE_USE_RAY_SPMD_WORKER = bool(
  46. os.getenv("APHRODITE_USE_RAY_SPMD_WORKER", 0))
  47. def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
  48. config = try_get_generation_config(
  49. model_config.model,
  50. trust_remote_code=model_config.trust_remote_code,
  51. revision=model_config.revision,
  52. )
  53. if config is None:
  54. return {}
  55. return config.to_diff_dict()
  56. _O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
  57. class AphroditeEngine:
  58. """An LLM engine that receives requests and generates texts.
  59. This is the main class for the Aphrodite engine. It receives requests
  60. from clients and generates texts from the LLM. It includes a tokenizer, a
  61. language model (possibly distributed across multiple GPUs), and GPU memory
  62. space allocated for intermediate states (aka KV cache). This class utilizes
  63. iteration-level scheduling and efficient memory management to maximize the
  64. serving throughput.
  65. The `LLM` class wraps this class for offline batched inference and the
  66. `AsyncAphrodite` class wraps this class for online serving.
  67. NOTE: The config arguments are derived from the `EngineArgs` class. For the
  68. comprehensive list of arguments, see `EngineArgs`.
  69. Args:
  70. model_config: The configuration related to the LLM model.
  71. cache_config: The configuration related to the KV cache memory
  72. management.
  73. parallel_config: The configuration related to distributed execution.
  74. scheduler_config: The configuration related to the request scheduler.
  75. device_config: The configuration related to the device.
  76. lora_config (Optional): The configuration related to serving multi-LoRA.
  77. multimodal_config (Optional): The configuration related to multimodal
  78. models.
  79. speculative_config (Optional): The configuration related to speculative
  80. decoding.
  81. executor_class: The model executor class for managing distributed
  82. execution.
  83. prompt_adapter_config (Optional): The configuration related to serving
  84. prompt adapters.
  85. log_stats: Whether to log statistics.
  86. """
  87. DO_VALIDATE_OUTPUT: ClassVar[bool] = False
  88. """A flag to toggle whether to validate the type of request output."""
  89. @classmethod
  90. @contextmanager
  91. def enable_output_validation(cls):
  92. cls.DO_VALIDATE_OUTPUT = True
  93. yield
  94. cls.DO_VALIDATE_OUTPUT = False
  95. @classmethod
  96. def validate_output(
  97. cls,
  98. output: object,
  99. output_type: Type[_O],
  100. ) -> _O:
  101. do_validate = cls.DO_VALIDATE_OUTPUT
  102. if ((TYPE_CHECKING or do_validate)
  103. and not isinstance(output, output_type)):
  104. raise TypeError(f"Expected output of type {output_type}, "
  105. f"but found type {type(output)}")
  106. return output
  107. @classmethod
  108. def validate_outputs(
  109. cls,
  110. outputs: GenericSequence[object],
  111. output_type: Type[_O],
  112. ) -> List[_O]:
  113. do_validate = cls.DO_VALIDATE_OUTPUT
  114. outputs_: List[_O]
  115. if TYPE_CHECKING or do_validate:
  116. outputs_ = []
  117. for output in outputs:
  118. if not isinstance(output, output_type):
  119. raise TypeError(f"Expected output of type {output_type}, "
  120. f"but found type {type(output)}")
  121. outputs_.append(output)
  122. else:
  123. outputs_ = outputs
  124. return outputs_
  125. tokenizer: Optional[BaseTokenizerGroup]
  126. def __init__(
  127. self,
  128. model_config: ModelConfig,
  129. cache_config: CacheConfig,
  130. parallel_config: ParallelConfig,
  131. scheduler_config: SchedulerConfig,
  132. device_config: DeviceConfig,
  133. load_config: LoadConfig,
  134. lora_config: Optional[LoRAConfig],
  135. multimodal_config: Optional[MultiModalConfig],
  136. speculative_config: Optional[SpeculativeConfig],
  137. decoding_config: Optional[DecodingConfig],
  138. prompt_adapter_config: Optional[PromptAdapterConfig],
  139. executor_class: Type[ExecutorBase],
  140. log_stats: bool,
  141. stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
  142. ) -> None:
  143. try:
  144. import aphrodite.commit_id
  145. commit_id = True
  146. except ImportError:
  147. commit_id = False
  148. config_dict = {
  149. "Model": model_config.model,
  150. "Speculative Config": speculative_config,
  151. "DataType": model_config.dtype,
  152. "Model Load Format": load_config.load_format,
  153. "Tensor Parallel Size": parallel_config.tensor_parallel_size,
  154. "Pipeline Parallel Size": parallel_config.pipeline_parallel_size,
  155. "Disable Custom All-Reduce":
  156. parallel_config.disable_custom_all_reduce,
  157. "Quantization Format": model_config.quantization,
  158. "Context Length": model_config.max_model_len,
  159. "Enforce Eager Mode": model_config.enforce_eager,
  160. "Prefix Caching": cache_config.enable_prefix_caching,
  161. "KV Cache DataType": cache_config.cache_dtype,
  162. "Device": device_config.device,
  163. "Rope Scaling": model_config.rope_scaling,
  164. "Guided Decoding Backend": decoding_config
  165. }
  166. logger.info("-" * 85)
  167. if not commit_id:
  168. logger.info(
  169. f"Initializing Aphrodite Engine (v{APHRODITE_VERSION}) "
  170. "with the following config:")
  171. else:
  172. logger.info(f"Initializing Aphrodite Engine (v{APHRODITE_VERSION} "
  173. f"commit {aphrodite.__short_commit__}) with the "
  174. "following config:")
  175. for key, value in config_dict.items():
  176. if value is not None and not ((key == "Model Load Format" or key ==\
  177. "KV Cache DataType") and value == \
  178. "auto"):
  179. logger.info(f"{key} = {value!r}")
  180. logger.info("-" * 85)
  181. # TODO: Print more configs in debug mode.
  182. self.model_config = model_config
  183. self.cache_config = cache_config
  184. self.lora_config = lora_config
  185. self.multimodal_config = multimodal_config
  186. self.parallel_config = parallel_config
  187. self.scheduler_config = scheduler_config
  188. self.device_config = device_config
  189. self.speculative_config = speculative_config
  190. self.load_config = load_config
  191. self.decoding_config = decoding_config or DecodingConfig()
  192. self.prompt_adapter_config = prompt_adapter_config
  193. self.log_stats = log_stats
  194. if not self.model_config.skip_tokenizer_init:
  195. self.tokenizer = self._init_tokenizer()
  196. self.detokenizer = Detokenizer(self.tokenizer)
  197. else:
  198. self.tokenizer = None
  199. self.detokenizer = None
  200. self.seq_counter = Counter()
  201. self.generation_config_fields = _load_generation_config_dict(
  202. model_config)
  203. self.input_processor = INPUT_REGISTRY.create_input_processor(
  204. self.model_config)
  205. self.model_executor = executor_class(
  206. model_config=model_config,
  207. cache_config=cache_config,
  208. parallel_config=parallel_config,
  209. scheduler_config=scheduler_config,
  210. device_config=device_config,
  211. lora_config=lora_config,
  212. multimodal_config=multimodal_config,
  213. speculative_config=speculative_config,
  214. load_config=load_config,
  215. prompt_adapter_config=prompt_adapter_config,
  216. )
  217. if not self.model_config.embedding_mode:
  218. self._initialize_kv_caches()
  219. if self.tokenizer:
  220. # Ping the tokenizer to ensure liveness if it runs in a
  221. # different process.
  222. self.tokenizer.ping()
  223. # Create the scheduler.
  224. # NOTE: the cache_config here have been updated with the numbers of
  225. # GPU and CPU blocks, which are profiled in the distributed executor.
  226. self.scheduler = [
  227. Scheduler(scheduler_config, cache_config, lora_config,
  228. parallel_config.pipeline_parallel_size)
  229. for _ in range(parallel_config.pipeline_parallel_size)
  230. ]
  231. # Metric Logging.
  232. if self.log_stats:
  233. if stat_loggers is not None:
  234. self.stat_loggers = stat_loggers
  235. else:
  236. self.stat_loggers = {
  237. "logging":
  238. LoggingStatLogger(
  239. local_interval=_LOCAL_LOGGING_INTERVAL_SEC),
  240. "prometheus":
  241. PrometheusStatLogger(
  242. local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
  243. labels=dict(model_name=model_config.served_model_name),
  244. max_model_len=self.model_config.max_model_len),
  245. }
  246. self.stat_loggers["prometheus"].info("cache_config",
  247. self.cache_config)
  248. # Create sequence output processor, e.g. for beam search or
  249. # speculative decoding.
  250. self.output_processor = (
  251. SequenceGroupOutputProcessor.create_output_processor(
  252. self.scheduler_config,
  253. self.detokenizer,
  254. self.scheduler,
  255. self.seq_counter,
  256. self.get_tokenizer_for_seq,
  257. stop_checker=StopChecker(
  258. self.scheduler_config.max_model_len,
  259. self.get_tokenizer_for_seq,
  260. ),
  261. ))
  262. def _initialize_kv_caches(self) -> None:
  263. """Initialize the KV cache in the worker(s).
  264. The workers will determine the number of blocks in both the GPU cache
  265. and the swap CPU cache.
  266. """
  267. num_gpu_blocks, num_cpu_blocks = (
  268. self.model_executor.determine_num_available_blocks())
  269. if self.cache_config.num_gpu_blocks_override is not None:
  270. num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
  271. logger.info(f"Overriding {num_gpu_blocks=} with "
  272. f"{num_gpu_blocks_override=}")
  273. num_gpu_blocks = num_gpu_blocks_override
  274. self.cache_config.num_gpu_blocks = num_gpu_blocks
  275. self.cache_config.num_cpu_blocks = num_cpu_blocks
  276. self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
  277. @classmethod
  278. def _get_executor_cls(cls,
  279. engine_config: EngineConfig) -> Type[ExecutorBase]:
  280. distributed_executor_backend = (
  281. engine_config.parallel_config.distributed_executor_backend)
  282. # Initialize the cluster and specify the executor class.
  283. if isinstance(distributed_executor_backend, type):
  284. if not issubclass(distributed_executor_backend, ExecutorBase):
  285. raise TypeError(
  286. "distributed_executor_backend must be a subclass of "
  287. f"ExecutorBase. Got {distributed_executor_backend}.")
  288. if distributed_executor_backend.uses_ray: # type: ignore
  289. initialize_ray_cluster(engine_config.parallel_config)
  290. executor_class = distributed_executor_backend
  291. elif engine_config.device_config.device_type == "neuron":
  292. from aphrodite.executor.neuron_executor import NeuronExecutor
  293. executor_class = NeuronExecutor
  294. elif engine_config.device_config.device_type == "tpu":
  295. if distributed_executor_backend == "ray":
  296. initialize_ray_cluster(engine_config.parallel_config)
  297. from aphrodite.executor.ray_tpu_executor import RayTPUExecutor
  298. executor_class = RayTPUExecutor
  299. else:
  300. assert distributed_executor_backend is None
  301. from aphrodite.executor.tpu_executor import TPUExecutor
  302. executor_class = TPUExecutor
  303. elif engine_config.device_config.device_type == "cpu":
  304. from aphrodite.executor.cpu_executor import CPUExecutor
  305. executor_class = CPUExecutor
  306. elif engine_config.device_config.device_type == "openvino":
  307. from aphrodite.executor.openvino_executor import OpenVINOExecutor
  308. executor_class = OpenVINOExecutor
  309. elif engine_config.device_config.device_type == "xpu":
  310. if distributed_executor_backend == "ray":
  311. initialize_ray_cluster(engine_config.parallel_config)
  312. from aphrodite.executor.ray_xpu_executor import RayXPUExecutor
  313. executor_class = RayXPUExecutor
  314. else:
  315. from aphrodite.executor.xpu_executor import XPUExecutor
  316. executor_class = XPUExecutor
  317. elif distributed_executor_backend == "ray":
  318. initialize_ray_cluster(engine_config.parallel_config)
  319. from aphrodite.executor.ray_gpu_executor import RayGPUExecutor
  320. executor_class = RayGPUExecutor
  321. elif distributed_executor_backend == "mp":
  322. from aphrodite.executor.multiproc_gpu_executor import (
  323. MultiprocessingGPUExecutor)
  324. assert not APHRODITE_USE_RAY_SPMD_WORKER, (
  325. "multiprocessing distributed executor backend does not "
  326. "support APHRODITE_USE_RAY_SPMD_WORKER=1")
  327. executor_class = MultiprocessingGPUExecutor
  328. else:
  329. from aphrodite.executor.gpu_executor import GPUExecutor
  330. executor_class = GPUExecutor
  331. return executor_class
  332. @classmethod
  333. def from_engine_args(
  334. cls,
  335. engine_args: EngineArgs,
  336. stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
  337. ) -> "AphroditeEngine":
  338. """Creates an LLM engine from the engine arguments."""
  339. # Create the engine configs.
  340. engine_config = engine_args.create_engine_config()
  341. executor_class = cls._get_executor_cls(engine_config)
  342. # Create the LLM engine.
  343. engine = cls(
  344. **engine_config.to_dict(),
  345. executor_class=executor_class,
  346. log_stats=not engine_args.disable_log_stats,
  347. stat_loggers=stat_loggers,
  348. )
  349. return engine
  350. def __reduce__(self):
  351. # This is to ensure that the AphroditeEngine is not referenced in
  352. # the closure used to initialize Ray worker actors
  353. raise RuntimeError("AphroditeEngine should not be pickled!")
  354. def __del__(self):
  355. # Shutdown the model executor when engine is garbage collected.
  356. # Use getattr since __init__ can fail before the field is set
  357. if model_executor := getattr(self, "model_executor", None):
  358. model_executor.shutdown()
  359. MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because "
  360. "skip_tokenizer_init is True")
  361. def get_tokenizer_group(
  362. self,
  363. fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup:
  364. if self.tokenizer is None:
  365. raise ValueError(fail_msg)
  366. return self.tokenizer
  367. def get_tokenizer(
  368. self,
  369. lora_request: Optional[LoRARequest] = None
  370. ) -> "PreTrainedTokenizer":
  371. return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
  372. def get_tokenizer_for_seq(self,
  373. sequence: Sequence) -> "PreTrainedTokenizer":
  374. return self.get_tokenizer_group().get_lora_tokenizer(
  375. sequence.lora_request)
  376. def _init_tokenizer(self) -> BaseTokenizerGroup:
  377. return init_tokenizer_from_configs(
  378. model_config=self.model_config,
  379. scheduler_config=self.scheduler_config,
  380. parallel_config=self.parallel_config,
  381. enable_lora=bool(self.lora_config))
  382. def _verify_args(self) -> None:
  383. self.model_config.verify_with_parallel_config(self.parallel_config)
  384. self.cache_config.verify_with_parallel_config(self.parallel_config)
  385. if self.lora_config:
  386. self.lora_config.verify_with_model_config(self.model_config)
  387. self.lora_config.verify_with_scheduler_config(
  388. self.scheduler_config)
  389. if self.prompt_adapter_config:
  390. self.prompt_adapter_config.verify_with_model_config(
  391. self.model_config)
  392. def _get_eos_token_id(
  393. self, lora_request: Optional[LoRARequest]) -> Optional[int]:
  394. if self.tokenizer is None:
  395. logger.warning("Using None for EOS token id because tokenizer "
  396. "is not initialized")
  397. return None
  398. return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
  399. def _add_processed_request(
  400. self,
  401. request_id: str,
  402. processed_inputs: LLMInputs,
  403. params: Union[SamplingParams, PoolingParams],
  404. arrival_time: float,
  405. lora_request: Optional[LoRARequest],
  406. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  407. ) -> None:
  408. # Create the sequences.
  409. block_size = self.cache_config.block_size
  410. seq_id = next(self.seq_counter)
  411. eos_token_id = self._get_eos_token_id(lora_request)
  412. seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
  413. lora_request, prompt_adapter_request)
  414. # Create a SequenceGroup based on SamplingParams or PoolingParams
  415. if isinstance(params, SamplingParams):
  416. seq_group = self._create_sequence_group_with_sampling(
  417. request_id,
  418. seq,
  419. params,
  420. arrival_time=arrival_time,
  421. lora_request=lora_request,
  422. prompt_adapter_request=prompt_adapter_request,
  423. )
  424. elif isinstance(params, PoolingParams):
  425. seq_group = self._create_sequence_group_with_pooling(
  426. request_id,
  427. seq,
  428. params,
  429. arrival_time=arrival_time,
  430. lora_request=lora_request,
  431. prompt_adapter_request=prompt_adapter_request,
  432. )
  433. else:
  434. raise ValueError(
  435. "Either SamplingParams or PoolingParams must be provided.")
  436. # Add the sequence group to the scheduler with least unfinished seqs.
  437. costs = [
  438. scheduler.get_num_unfinished_seq_groups()
  439. for scheduler in self.scheduler
  440. ]
  441. min_cost_scheduler = self.scheduler[costs.index(min(costs))]
  442. min_cost_scheduler.add_seq_group(seq_group)
  443. def stop_remote_worker_execution_loop(self) -> None:
  444. self.model_executor.stop_remote_worker_execution_loop()
  445. def process_model_inputs(
  446. self,
  447. request_id: str,
  448. inputs: PromptInputs,
  449. lora_request: Optional[LoRARequest] = None,
  450. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  451. ) -> LLMInputs:
  452. if isinstance(inputs, str):
  453. inputs = {"prompt": inputs}
  454. if "prompt_token_ids" not in inputs:
  455. tokenizer = self.get_tokenizer_group("prompts must be None if "
  456. "skip_tokenizer_init is True")
  457. prompt_token_ids = tokenizer.encode(request_id=request_id,
  458. prompt=inputs["prompt"],
  459. lora_request=lora_request)
  460. else:
  461. prompt_token_ids = inputs["prompt_token_ids"]
  462. if prompt_adapter_request:
  463. prompt_token_ids = \
  464. [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens\
  465. + prompt_token_ids
  466. llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
  467. prompt=inputs.get("prompt"),
  468. multi_modal_data=inputs.get("multi_modal_data"))
  469. return self.input_processor(llm_inputs)
  470. def add_request(
  471. self,
  472. request_id: str,
  473. inputs: PromptInputs,
  474. params: Union[SamplingParams, PoolingParams],
  475. arrival_time: Optional[float] = None,
  476. lora_request: Optional[LoRARequest] = None,
  477. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  478. ) -> None:
  479. """Add a request to the engine's request pool.
  480. The request is added to the request pool and will be processed by the
  481. scheduler as `engine.step()` is called. The exact scheduling policy is
  482. determined by the scheduler.
  483. Args:
  484. request_id: The unique ID of the request.
  485. prompt: The prompt string. Can be None if prompt_token_ids is
  486. provided.
  487. params: Parameters for sampling or pooling. SamplingParams
  488. for text generation. PoolingParams for pooling.
  489. prompt_token_ids: The token IDs of the prompt. If None, we
  490. use the tokenizer to convert the prompts to token IDs.
  491. arrival_time: The arrival time of the request. If None, we use
  492. the current monotonic time.
  493. multi_modal_data: Multi modal data per request.
  494. Details:
  495. - Set arrival_time to the current time if it is None.
  496. - Set prompt_token_ids to the encoded prompt if it is None.
  497. - Create `best_of` number of :class:`~aphrodite.common.sequence`
  498. objects.
  499. - Create a :class:`~aphrodite.common.sequenceGroup` object
  500. from the list of :class:`~aphrodite.common.sequence`.
  501. - Add the :class:`~aphrodite.common.sequenceGroup` object to the
  502. scheduler.
  503. Example:
  504. >>> # initialize engine
  505. >>> engine = AphroditeEngine.from_engine_args(engine_args)
  506. >>> # set request arguments
  507. >>> example_prompt = "Who is the president of the United States?"
  508. >>> sampling_params = SamplingParams(temperature=0.0)
  509. >>> request_id = 0
  510. >>>
  511. >>> # add the request to the engine
  512. >>> engine.add_request(
  513. >>> str(request_id),
  514. >>> example_prompt,
  515. >>> SamplingParams(temperature=0.0))
  516. >>> # continue the request processing
  517. >>> ...
  518. """
  519. if lora_request is not None and not self.lora_config:
  520. raise ValueError(f"Got lora_request {lora_request} but LoRA is "
  521. "not enabled!")
  522. if arrival_time is None:
  523. arrival_time = time.time()
  524. processed_inputs = self.process_model_inputs(
  525. request_id=request_id,
  526. inputs=inputs,
  527. lora_request=lora_request,
  528. prompt_adapter_request=prompt_adapter_request)
  529. self._add_processed_request(
  530. request_id=request_id,
  531. processed_inputs=processed_inputs,
  532. params=params,
  533. arrival_time=arrival_time,
  534. lora_request=lora_request,
  535. prompt_adapter_request=prompt_adapter_request,
  536. )
  537. def _create_sequence_group_with_sampling(
  538. self,
  539. request_id: str,
  540. seq: Sequence,
  541. sampling_params: SamplingParams,
  542. arrival_time: float,
  543. lora_request: Optional[LoRARequest],
  544. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  545. ) -> SequenceGroup:
  546. """Creates a SequenceGroup with SamplingParams."""
  547. max_logprobs = self.get_model_config().max_logprobs
  548. if (sampling_params.logprobs
  549. and sampling_params.logprobs > max_logprobs) or (
  550. sampling_params.prompt_logprobs
  551. and sampling_params.prompt_logprobs > max_logprobs):
  552. raise ValueError(f"Cannot request more than "
  553. f"{max_logprobs} logprobs.")
  554. # Defensive copy of SamplingParams, which are used by the sampler,
  555. # this doesn't deep-copy LogitsProcessor objects
  556. sampling_params = sampling_params.clone()
  557. sampling_params.update_from_generation_config(
  558. self.generation_config_fields, seq.eos_token_id)
  559. # Create the sequence group.
  560. seq_group = SequenceGroup(
  561. request_id=request_id,
  562. seqs=[seq],
  563. arrival_time=arrival_time,
  564. sampling_params=sampling_params,
  565. lora_request=lora_request,
  566. prompt_adapter_request=prompt_adapter_request)
  567. return seq_group
  568. def _create_sequence_group_with_pooling(
  569. self,
  570. request_id: str,
  571. seq: Sequence,
  572. pooling_params: PoolingParams,
  573. arrival_time: float,
  574. lora_request: Optional[LoRARequest],
  575. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  576. ) -> SequenceGroup:
  577. """Creates a SequenceGroup with PoolingParams."""
  578. # Defensive copy of PoolingParams, which are used by the pooler
  579. pooling_params = pooling_params.clone()
  580. # Create the sequence group.
  581. seq_group = SequenceGroup(
  582. request_id=request_id,
  583. seqs=[seq],
  584. arrival_time=arrival_time,
  585. lora_request=lora_request,
  586. pooling_params=pooling_params,
  587. prompt_adapter_request=prompt_adapter_request)
  588. return seq_group
  589. def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
  590. """Aborts a request(s) with the given ID.
  591. Args:
  592. request_id: The ID(s) of the request to abort.
  593. Details:
  594. - Refer to the
  595. :meth:`~aphrodite.processing.scheduler.Scheduler.abort_seq_group`
  596. from class :class:`~aphrodite.processing.scheduler.Scheduler`.
  597. Example:
  598. >>> # initialize engine and add a request with request_id
  599. >>> request_id = str(0)
  600. >>> # abort the request
  601. >>> engine.abort_request(request_id)
  602. """
  603. for scheduler in self.scheduler:
  604. scheduler.abort_seq_group(request_id)
  605. def get_model_config(self) -> ModelConfig:
  606. """Gets the model configuration."""
  607. return self.model_config
  608. def get_parallel_config(self) -> ParallelConfig:
  609. """Gets the parallel configuration."""
  610. return self.parallel_config
  611. def get_decoding_config(self) -> DecodingConfig:
  612. """Gets the decoding configuration."""
  613. return self.decoding_config
  614. def get_scheduler_config(self) -> SchedulerConfig:
  615. """Gets the scheduler configuration."""
  616. return self.scheduler_config
  617. def get_lora_config(self) -> LoRAConfig:
  618. """Gets the LoRA configuration."""
  619. return self.lora_config
  620. def get_num_unfinished_requests(self) -> int:
  621. """Gets the number of unfinished requests."""
  622. return sum(scheduler.get_num_unfinished_seq_groups()
  623. for scheduler in self.scheduler)
  624. def has_unfinished_requests(self) -> bool:
  625. """Returns True if there are unfinished requests."""
  626. return any(scheduler.has_unfinished_seqs()
  627. for scheduler in self.scheduler)
  628. def has_unfinished_requests_for_virtual_engine(
  629. self, virtual_engine: int) -> bool:
  630. """
  631. Returns True if there are unfinished requests for the virtual engine.
  632. """
  633. return self.scheduler[virtual_engine].has_unfinished_seqs()
  634. def _process_sequence_group_outputs(
  635. self,
  636. seq_group: SequenceGroup,
  637. outputs: List[EmbeddingSequenceGroupOutput],
  638. ) -> None:
  639. seq_group.embeddings = outputs[0].embeddings
  640. for seq in seq_group.get_seqs():
  641. seq.status = SequenceStatus.FINISHED_STOPPED
  642. return
  643. def _process_model_outputs(
  644. self,
  645. output: GenericSequence[Union[SamplerOutput, PoolerOutput]],
  646. scheduled_seq_groups: List[ScheduledSequenceGroup],
  647. ignored_seq_groups: List[SequenceGroup],
  648. seq_group_metadata_list: List[SequenceGroupMetadata],
  649. ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
  650. """Apply the model output to the sequences in the scheduled seq groups.
  651. Returns RequestOutputs that can be returned to the client.
  652. """
  653. now = time.time()
  654. # Organize outputs by [sequence group][step] instead of
  655. # [step][sequence group].
  656. output_by_sequence_group = create_output_by_sequence_group(
  657. output, num_seq_groups=len(scheduled_seq_groups))
  658. # Update the scheduled sequence groups with the model outputs.
  659. for scheduled_seq_group, outputs, seq_group_meta in zip(
  660. scheduled_seq_groups, output_by_sequence_group,
  661. seq_group_metadata_list):
  662. seq_group = scheduled_seq_group.seq_group
  663. seq_group.update_num_computed_tokens(
  664. scheduled_seq_group.token_chunk_size)
  665. if self.model_config.embedding_mode:
  666. self._process_sequence_group_outputs(seq_group, outputs)
  667. continue
  668. self.output_processor.process_prompt_logprob(seq_group, outputs)
  669. if seq_group_meta.do_sample:
  670. self.output_processor.process_outputs(seq_group, outputs)
  671. # Free the finished sequence groups.
  672. for scheduler in self.scheduler:
  673. scheduler.free_finished_seq_groups()
  674. # Create the outputs.
  675. request_outputs: List[Union[RequestOutput,
  676. EmbeddingRequestOutput]] = []
  677. for scheduled_seq_group in scheduled_seq_groups:
  678. seq_group = scheduled_seq_group.seq_group
  679. seq_group.maybe_set_first_token_time(now)
  680. request_output = RequestOutputFactory.create(seq_group)
  681. request_outputs.append(request_output)
  682. for seq_group in ignored_seq_groups:
  683. request_output = RequestOutputFactory.create(seq_group)
  684. request_outputs.append(request_output)
  685. return request_outputs
  686. def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
  687. """Performs one decoding iteration and returns newly generated results.
  688. .. figure:: https://i.imgur.com/sv2HssD.png
  689. :alt: Overview of the step function
  690. :align: center
  691. Overview of the step function.
  692. Details:
  693. - Step 1: Schedules the sequences to be executed in the next
  694. iteration and the token blocks to be swapped in/out/copy.
  695. - Depending on the scheduling policy,
  696. sequences may be `preempted/reordered`.
  697. - A Sequence Group (SG) refer to a group of sequences
  698. that are generated from the same prompt.
  699. - Step 2: Calls the distributed executor to execute the model.
  700. - Step 3: Processes the model output. This mainly includes:
  701. - Decodes the relevant outputs.
  702. - Updates the scheduled sequence groups with model outputs
  703. based on its `sampling parameters` (`use_beam_search` or not).
  704. - Frees the finished sequence groups.
  705. - Finally, it creates and returns the newly generated results.
  706. Example:
  707. >>> # Please see the example/ folder for more detailed examples.
  708. >>>
  709. >>> # initialize engine and request arguments
  710. >>> engine = AphroditeEngine.from_engine_args(engine_args)
  711. >>> example_inputs = [(0, "What is LLM?",
  712. >>> SamplingParams(temperature=0.0))]
  713. >>>
  714. >>> # Start the engine with an event loop
  715. >>> while True:
  716. >>> if example_inputs:
  717. >>> req_id, prompt, sampling_params = example_inputs.pop(0)
  718. >>> engine.add_request(str(req_id), prompt, sampling_params)
  719. >>>
  720. >>> # continue the request processing
  721. >>> request_outputs = engine.step()
  722. >>> for request_output in request_outputs:
  723. >>> if request_output.finished:
  724. >>> # return or show the request output
  725. >>>
  726. >>> if not (engine.has_unfinished_requests() or example_inputs):
  727. >>> break
  728. """
  729. if self.parallel_config.pipeline_parallel_size > 1:
  730. raise NotImplementedError(
  731. "Pipeline parallelism is only supported through AsyncAphrodite "
  732. "as performance will be severely degraded otherwise.")
  733. seq_group_metadata_list, scheduler_outputs = self.scheduler[
  734. 0].schedule()
  735. if not scheduler_outputs.is_empty():
  736. finished_requests_ids = self.scheduler[
  737. 0].get_and_reset_finished_requests_ids()
  738. execute_model_req = ExecuteModelRequest(
  739. seq_group_metadata_list=seq_group_metadata_list,
  740. blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
  741. blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
  742. blocks_to_copy=scheduler_outputs.blocks_to_copy,
  743. num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
  744. running_queue_size=scheduler_outputs.running_queue_size,
  745. finished_requests_ids=finished_requests_ids,
  746. )
  747. output = self.model_executor.execute_model(
  748. execute_model_req=execute_model_req)
  749. else:
  750. output = []
  751. request_outputs = self._process_model_outputs(
  752. output, scheduler_outputs.scheduled_seq_groups,
  753. scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
  754. # Log stats.
  755. self.do_log_stats(scheduler_outputs, output)
  756. if not self.has_unfinished_requests():
  757. # Stop the execute model loop in parallel workers until there are
  758. # more requests to process. This avoids waiting indefinitely in
  759. # torch.distributed ops which may otherwise timeout, and unblocks
  760. # the RPC thread in the workers so that they can process any other
  761. # queued control plane messages, such as add/remove lora adapters.
  762. self.model_executor.stop_remote_worker_execution_loop()
  763. return request_outputs
  764. def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
  765. if logger_name in self.stat_loggers:
  766. raise KeyError(f"Logger with name {logger_name} already exists.")
  767. self.stat_loggers[logger_name] = logger
  768. def remove_logger(self, logger_name: str) -> None:
  769. if logger_name not in self.stat_loggers:
  770. raise KeyError(f"Logger with name {logger_name} does not exist.")
  771. del self.stat_loggers[logger_name]
  772. def do_log_stats(
  773. self,
  774. scheduler_outputs: Optional[SchedulerOutputs] = None,
  775. model_output: Optional[List[SamplerOutput]] = None) -> None:
  776. """Forced log when no requests active."""
  777. if self.log_stats:
  778. stats = self._get_stats(scheduler_outputs, model_output)
  779. for loggers in self.stat_loggers.values():
  780. loggers.log(stats)
  781. def _get_stats(
  782. self,
  783. scheduler_outputs: Optional[SchedulerOutputs],
  784. model_output: Optional[List[SamplerOutput]] = None) -> Stats:
  785. """Get Stats to be Logged to Prometheus.
  786. Args:
  787. scheduler_outputs: Optional, used to populate metrics related to
  788. the scheduled batch,
  789. model_output: Optional, used to emit speculative decoding metrics
  790. which are created by the workers.
  791. """
  792. now = time.time()
  793. # System State
  794. # Scheduler State
  795. num_running_sys = sum(
  796. len(scheduler.running) for scheduler in self.scheduler)
  797. num_swapped_sys = sum(
  798. len(scheduler.swapped) for scheduler in self.scheduler)
  799. num_waiting_sys = sum(
  800. len(scheduler.waiting) for scheduler in self.scheduler)
  801. # KV Cache Usage in %
  802. num_total_gpu = self.cache_config.num_gpu_blocks
  803. gpu_cache_usage_sys = 0.
  804. if num_total_gpu is not None:
  805. num_free_gpu = sum(
  806. scheduler.block_manager.get_num_free_gpu_blocks()
  807. for scheduler in self.scheduler)
  808. gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
  809. num_total_cpu = self.cache_config.num_cpu_blocks
  810. cpu_cache_usage_sys = 0.
  811. if num_total_cpu is not None and num_total_cpu > 0:
  812. num_free_cpu = sum(
  813. scheduler.block_manager.get_num_free_cpu_blocks()
  814. for scheduler in self.scheduler)
  815. cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
  816. # Iteration stats
  817. num_prompt_tokens_iter = 0
  818. num_generation_tokens_iter = 0
  819. time_to_first_tokens_iter: List[float] = []
  820. time_per_output_tokens_iter: List[float] = []
  821. num_preemption_iter = (0 if scheduler_outputs is None else
  822. scheduler_outputs.preempted)
  823. # Request stats
  824. # Latency
  825. time_e2e_requests: List[float] = []
  826. # Metadata
  827. num_prompt_tokens_requests: List[int] = []
  828. num_generation_tokens_requests: List[int] = []
  829. best_of_requests: List[int] = []
  830. n_requests: List[int] = []
  831. finished_reason_requests: List[str] = []
  832. # NOTE: This loop assumes prefill seq_groups are before
  833. # decode seq_groups in scheduled_seq_groups.
  834. if scheduler_outputs is not None:
  835. num_generation_tokens_from_prefill_groups = 0.
  836. # NOTE: if scheduler_outputs.num_prefill_groups > 0 and
  837. # the len of scheduler_outputs.scheduled_seq_groups is !=
  838. # scheduler_outputs.num_prefill_groups, this means that
  839. # chunked prefills have been detected.
  840. for idx, scheduled_seq_group in enumerate(
  841. scheduler_outputs.scheduled_seq_groups):
  842. group_was_prefill = idx < scheduler_outputs.num_prefill_groups
  843. seq_group = scheduled_seq_group.seq_group
  844. # NOTE: a seq_group that completed all of its prefill tokens
  845. # in the last iteration will have seq_group.is_prefill() = False
  846. # with group_was_prefill = True
  847. if group_was_prefill:
  848. # Number of prompt tokens.
  849. num_prompt_tokens_iter += (
  850. scheduled_seq_group.token_chunk_size)
  851. # If the seq_group just finished the prefill state
  852. # get TTFT.
  853. if not seq_group.is_prefill():
  854. latency = seq_group.get_last_latency(now)
  855. time_to_first_tokens_iter.append(latency)
  856. # One generation token per finished prefill.
  857. num_generation_tokens_from_prefill_groups += (
  858. seq_group.num_seqs())
  859. else:
  860. # TPOTs.
  861. latency = seq_group.get_last_latency(now)
  862. time_per_output_tokens_iter.append(latency)
  863. # Because of chunked prefill, we can have a single sequence
  864. # group that does multiple prompt_runs. To prevent logging
  865. # the same metadata more than once per request, we standardize
  866. # on logging request level information for finished requests,
  867. # which can only happen once.
  868. if seq_group.is_finished():
  869. # Latency timings
  870. time_e2e_requests.append(now -
  871. seq_group.metrics.arrival_time)
  872. # Metadata
  873. num_prompt_tokens_requests.append(
  874. len(seq_group.prompt_token_ids))
  875. num_generation_tokens_requests.extend([
  876. seq.get_output_len()
  877. for seq in seq_group.get_finished_seqs()
  878. ])
  879. if seq_group.sampling_params is not None:
  880. best_of_requests.append(
  881. seq_group.sampling_params.best_of)
  882. n_requests.append(seq_group.sampling_params.n)
  883. finished_reason_requests.extend([
  884. SequenceStatus.get_finished_reason(seq.status)
  885. for seq in seq_group.get_finished_seqs()
  886. ])
  887. # Number of generation tokens.
  888. # num_batched_tokens equals the number of prompt_tokens plus the
  889. # number of decode_tokens in a single iteration. So,
  890. # num_generation_tokens = num_batched_tokens - num_prompt_tokens
  891. # + num_generation_tokens_from_prefill_groups (since we generate
  892. # one token on prefills on iters where the prefill finishes).
  893. num_generation_tokens_iter = (
  894. scheduler_outputs.num_batched_tokens - num_prompt_tokens_iter +
  895. num_generation_tokens_from_prefill_groups)
  896. # Spec decode, if enabled, emits specialized metrics from the worker in
  897. # sampler output.
  898. if model_output and (model_output[0].spec_decode_worker_metrics
  899. is not None):
  900. spec_decode_metrics = model_output[0].spec_decode_worker_metrics
  901. else:
  902. spec_decode_metrics = None
  903. return Stats(
  904. now=now,
  905. # System stats
  906. # Scheduler State
  907. num_running_sys=num_running_sys,
  908. num_swapped_sys=num_swapped_sys,
  909. num_waiting_sys=num_waiting_sys,
  910. # KV Cache Usage in %
  911. gpu_cache_usage_sys=gpu_cache_usage_sys,
  912. cpu_cache_usage_sys=cpu_cache_usage_sys,
  913. # Iteration stats
  914. num_prompt_tokens_iter=num_prompt_tokens_iter,
  915. num_generation_tokens_iter=num_generation_tokens_iter,
  916. time_to_first_tokens_iter=time_to_first_tokens_iter,
  917. time_per_output_tokens_iter=time_per_output_tokens_iter,
  918. spec_decode_metrics=spec_decode_metrics,
  919. num_preemption_iter=num_preemption_iter,
  920. # Request stats
  921. # Latency
  922. time_e2e_requests=time_e2e_requests,
  923. # Metadata
  924. num_prompt_tokens_requests=num_prompt_tokens_requests,
  925. num_generation_tokens_requests=num_generation_tokens_requests,
  926. best_of_requests=best_of_requests,
  927. n_requests=n_requests,
  928. finished_reason_requests=finished_reason_requests,
  929. )
  930. def add_lora(self, lora_request: LoRARequest) -> bool:
  931. return self.model_executor.add_lora(lora_request)
  932. def remove_lora(self, lora_id: int) -> bool:
  933. return self.model_executor.remove_lora(lora_id)
  934. def list_loras(self) -> List[int]:
  935. return self.model_executor.list_loras()
  936. def pin_lora(self, lora_id: int) -> bool:
  937. return self.model_executor.pin_lora(lora_id)
  938. def add_prompt_adapter(
  939. self, prompt_adapter_request: PromptAdapterRequest) -> bool:
  940. return self.model_executor.add_prompt_adapter(prompt_adapter_request)
  941. def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  942. return self.model_executor.remove_prompt_adapter(prompt_adapter_id)
  943. def list_prompt_adapters(self) -> List[int]:
  944. return self.model_executor.list_prompt_adapters()
  945. def check_health(self) -> None:
  946. if self.tokenizer:
  947. self.tokenizer.check_health()
  948. self.model_executor.check_health()
  949. setup_logger()