aphrodite_engine.py 76 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867
  1. import functools
  2. import time
  3. from collections import deque
  4. from contextlib import contextmanager
  5. from dataclasses import dataclass, field
  6. from typing import (TYPE_CHECKING, Any, ClassVar, Deque, Dict, Iterable, List,
  7. Optional)
  8. from typing import Sequence as GenericSequence
  9. from typing import Set, Tuple, Type, Union
  10. import torch
  11. from loguru import logger
  12. from typing_extensions import TypeVar, assert_never
  13. import aphrodite.common.envs as envs
  14. from aphrodite.common.config import (CacheConfig, DecodingConfig, DeviceConfig,
  15. EngineConfig, LoadConfig, LoRAConfig,
  16. ModelConfig, ParallelConfig,
  17. PromptAdapterConfig, SchedulerConfig,
  18. SpeculativeConfig)
  19. from aphrodite.common.logger import setup_logger
  20. from aphrodite.common.outputs import (EmbeddingRequestOutput, RequestOutput,
  21. RequestOutputFactory)
  22. from aphrodite.common.pooling_params import PoolingParams
  23. from aphrodite.common.sampling_params import SamplingParams
  24. from aphrodite.common.sequence import (EmbeddingSequenceGroupOutput,
  25. ExecuteModelRequest, Sequence,
  26. SequenceGroup, SequenceGroupMetadata,
  27. SequenceStatus)
  28. from aphrodite.common.utils import Counter, Device
  29. from aphrodite.engine.args_tools import EngineArgs
  30. from aphrodite.engine.metrics_types import StatLoggerBase, Stats
  31. from aphrodite.engine.output_processor.interfaces import (
  32. SequenceGroupOutputProcessor)
  33. from aphrodite.engine.output_processor.stop_checker import StopChecker
  34. from aphrodite.engine.output_processor.util import (
  35. create_output_by_sequence_group)
  36. from aphrodite.executor.executor_base import ExecutorBase
  37. from aphrodite.executor.ray_utils import initialize_ray_cluster
  38. from aphrodite.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
  39. InputRegistry, LLMInputs, PromptInputs,
  40. SingletonPromptInputs)
  41. from aphrodite.inputs.parse import is_explicit_encoder_decoder_prompt
  42. from aphrodite.lora.request import LoRARequest
  43. from aphrodite.modeling.layers.sampler import SamplerOutput
  44. from aphrodite.multimodal import MultiModalDataDict
  45. from aphrodite.processing.scheduler import (ScheduledSequenceGroup, Scheduler,
  46. SchedulerOutputs)
  47. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  48. from aphrodite.transformers_utils.config import try_get_generation_config
  49. from aphrodite.transformers_utils.detokenizer import Detokenizer
  50. from aphrodite.transformers_utils.tokenizer import AnyTokenizer
  51. from aphrodite.transformers_utils.tokenizer_group import (
  52. BaseTokenizerGroup, init_tokenizer_from_configs)
  53. from aphrodite.version import __version__ as APHRODITE_VERSION
  54. _LOCAL_LOGGING_INTERVAL_SEC = 5
  55. APHRODITE_USE_RAY_SPMD_WORKER = envs.APHRODITE_USE_RAY_SPMD_WORKER
  56. def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
  57. config = try_get_generation_config(
  58. model_config.model,
  59. trust_remote_code=model_config.trust_remote_code,
  60. revision=model_config.revision,
  61. )
  62. if config is None:
  63. return {}
  64. return config.to_diff_dict()
  65. _G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
  66. _O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
  67. PromptComponents = Tuple[Optional[str], List[int],
  68. Optional[MultiModalDataDict]]
  69. DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
  70. Optional[MultiModalDataDict]]
  71. @dataclass
  72. class SchedulerOutputState:
  73. """Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
  74. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
  75. scheduler_outputs: Optional[SchedulerOutputs] = None
  76. allow_async_output_proc: bool = False
  77. last_output: Optional[SamplerOutput] = None
  78. @dataclass
  79. class SchedulerContext:
  80. output_queue: Deque[Tuple[Optional[List[SamplerOutput]],
  81. List[SequenceGroupMetadata],
  82. SchedulerOutputs]] = field(
  83. default_factory=lambda: deque())
  84. request_outputs: List[Union[RequestOutput,
  85. EmbeddingRequestOutput]] = field(
  86. default_factory=lambda: [])
  87. class AphroditeEngine:
  88. """An LLM engine that receives requests and generates texts.
  89. This is the main class for the Aphrodite engine. It receives requests
  90. from clients and generates texts from the LLM. It includes a tokenizer, a
  91. language model (possibly distributed across multiple GPUs), and GPU memory
  92. space allocated for intermediate states (aka KV cache). This class utilizes
  93. iteration-level scheduling and efficient memory management to maximize the
  94. serving throughput.
  95. The `LLM` class wraps this class for offline batched inference and the
  96. `AsyncAphrodite` class wraps this class for online serving.
  97. NOTE: The config arguments are derived from the `EngineArgs` class. For the
  98. comprehensive list of arguments, see `EngineArgs`.
  99. Args:
  100. model_config: The configuration related to the LLM model.
  101. cache_config: The configuration related to the KV cache memory
  102. management.
  103. parallel_config: The configuration related to distributed execution.
  104. scheduler_config: The configuration related to the request scheduler.
  105. device_config: The configuration related to the device.
  106. lora_config (Optional): The configuration related to serving multi-LoRA.
  107. speculative_config (Optional): The configuration related to speculative
  108. decoding.
  109. executor_class: The model executor class for managing distributed
  110. execution.
  111. prompt_adapter_config (Optional): The configuration related to serving
  112. prompt adapters.
  113. log_stats: Whether to log statistics.
  114. """
  115. DO_VALIDATE_OUTPUT: ClassVar[bool] = False
  116. """A flag to toggle whether to validate the type of request output."""
  117. @classmethod
  118. @contextmanager
  119. def enable_output_validation(cls):
  120. cls.DO_VALIDATE_OUTPUT = True
  121. yield
  122. cls.DO_VALIDATE_OUTPUT = False
  123. @classmethod
  124. def validate_output(
  125. cls,
  126. output: object,
  127. output_type: Type[_O],
  128. ) -> _O:
  129. do_validate = cls.DO_VALIDATE_OUTPUT
  130. if ((TYPE_CHECKING or do_validate)
  131. and not isinstance(output, output_type)):
  132. raise TypeError(f"Expected output of type {output_type}, "
  133. f"but found type {type(output)}")
  134. return output
  135. @classmethod
  136. def validate_outputs(
  137. cls,
  138. outputs: GenericSequence[object],
  139. output_type: Type[_O],
  140. ) -> List[_O]:
  141. do_validate = cls.DO_VALIDATE_OUTPUT
  142. outputs_: List[_O]
  143. if TYPE_CHECKING or do_validate:
  144. outputs_ = []
  145. for output in outputs:
  146. if not isinstance(output, output_type):
  147. raise TypeError(f"Expected output of type {output_type}, "
  148. f"but found type {type(output)}")
  149. outputs_.append(output)
  150. else:
  151. outputs_ = outputs
  152. return outputs_
  153. tokenizer: Optional[BaseTokenizerGroup]
  154. def __init__(
  155. self,
  156. model_config: ModelConfig,
  157. cache_config: CacheConfig,
  158. parallel_config: ParallelConfig,
  159. scheduler_config: SchedulerConfig,
  160. device_config: DeviceConfig,
  161. load_config: LoadConfig,
  162. lora_config: Optional[LoRAConfig],
  163. speculative_config: Optional[SpeculativeConfig],
  164. decoding_config: Optional[DecodingConfig],
  165. prompt_adapter_config: Optional[PromptAdapterConfig],
  166. executor_class: Type[ExecutorBase],
  167. log_stats: bool,
  168. stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
  169. input_registry: InputRegistry = INPUT_REGISTRY,
  170. # To improve performance, only final requests outputs may be required.
  171. # If this set to true, then no intermediate outputs will be returned.
  172. step_return_finished_only: bool = False,
  173. ) -> None:
  174. try:
  175. import aphrodite.commit_id
  176. commit_id = True
  177. except ImportError:
  178. commit_id = False
  179. config_dict = {
  180. "Model": model_config.model,
  181. "Speculative Config": speculative_config,
  182. "DataType": model_config.dtype,
  183. "Model Load Format": load_config.load_format,
  184. "Tensor Parallel Size": parallel_config.tensor_parallel_size,
  185. "Pipeline Parallel Size": parallel_config.pipeline_parallel_size,
  186. "Disable Custom All-Reduce":
  187. parallel_config.disable_custom_all_reduce,
  188. "Quantization Format": model_config.quantization,
  189. "Context Length": model_config.max_model_len,
  190. "Enforce Eager Mode": model_config.enforce_eager,
  191. "Prefix Caching": cache_config.enable_prefix_caching,
  192. "KV Cache DataType": cache_config.cache_dtype,
  193. "Device": device_config.device,
  194. "Rope Scaling": model_config.rope_scaling,
  195. "Guided Decoding Backend": decoding_config,
  196. "Scheduler Steps": scheduler_config.num_scheduler_steps,
  197. "Async Output Processing": model_config.use_async_output_proc,
  198. }
  199. logger.info("-" * 85)
  200. if not commit_id:
  201. logger.info(
  202. f"Initializing Aphrodite Engine (v{APHRODITE_VERSION}) "
  203. "with the following config:")
  204. else:
  205. logger.info(f"Initializing Aphrodite Engine (v{APHRODITE_VERSION} "
  206. f"commit {aphrodite.__short_commit__}) with the "
  207. "following config:")
  208. for key, value in config_dict.items():
  209. if value is not None and not ((key == "Model Load Format" or key ==\
  210. "KV Cache DataType") and value == \
  211. "auto"):
  212. logger.info(f"{key} = {value!r}")
  213. logger.info("-" * 85)
  214. # TODO: Print more configs in debug mode.
  215. from aphrodite.plugins import load_general_plugins
  216. load_general_plugins()
  217. self.model_config = model_config
  218. self.cache_config = cache_config
  219. self.lora_config = lora_config
  220. self.parallel_config = parallel_config
  221. self.scheduler_config = scheduler_config
  222. self.device_config = device_config
  223. self.speculative_config = speculative_config
  224. self.load_config = load_config
  225. self.decoding_config = decoding_config or DecodingConfig()
  226. self.prompt_adapter_config = prompt_adapter_config
  227. self.log_stats = log_stats
  228. self.step_return_finished_only = step_return_finished_only
  229. if not self.model_config.skip_tokenizer_init:
  230. self.tokenizer = self._init_tokenizer()
  231. self.detokenizer = Detokenizer(self.tokenizer)
  232. tokenizer_group = self.get_tokenizer_group()
  233. else:
  234. self.tokenizer = None
  235. self.detokenizer = None
  236. tokenizer_group = None
  237. # Ensure that the function doesn't contain a reference to self,
  238. # to avoid engine GC issues
  239. def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
  240. assert tokenizer_group, ("tokenizer_group cannot be None, "
  241. "make sure skip_tokenizer_init is False")
  242. return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
  243. self.seq_counter = Counter()
  244. self.generation_config_fields = _load_generation_config_dict(
  245. model_config)
  246. self.input_registry = input_registry
  247. self.input_processor = input_registry.create_input_processor(
  248. model_config)
  249. self.model_executor = executor_class(
  250. model_config=model_config,
  251. cache_config=cache_config,
  252. parallel_config=parallel_config,
  253. scheduler_config=scheduler_config,
  254. device_config=device_config,
  255. lora_config=lora_config,
  256. speculative_config=speculative_config,
  257. load_config=load_config,
  258. prompt_adapter_config=prompt_adapter_config,
  259. )
  260. if not self.model_config.embedding_mode:
  261. self._initialize_kv_caches()
  262. if self.tokenizer:
  263. # Ping the tokenizer to ensure liveness if it runs in a
  264. # different process.
  265. self.tokenizer.ping()
  266. # Create the scheduler.
  267. # NOTE: the cache_config here have been updated with the numbers of
  268. # GPU and CPU blocks, which are profiled in the distributed executor.
  269. self.scheduler = [
  270. Scheduler(
  271. scheduler_config, cache_config, lora_config,
  272. parallel_config.pipeline_parallel_size,
  273. functools.partial(self._process_model_outputs,
  274. virtual_engine=v_id,
  275. is_async=True)
  276. if model_config.use_async_output_proc else None)
  277. for v_id in range(parallel_config.pipeline_parallel_size)
  278. ]
  279. # Metric Logging.
  280. if self.log_stats:
  281. if stat_loggers is not None:
  282. self.stat_loggers = stat_loggers
  283. else:
  284. # Lazy import for prometheus multiprocessing.
  285. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
  286. # before prometheus_client is imported.
  287. # See https://prometheus.github.io/client_python/multiprocess/
  288. from aphrodite.engine.metrics import (LoggingStatLogger,
  289. PrometheusStatLogger)
  290. self.stat_loggers = {
  291. "logging":
  292. LoggingStatLogger(
  293. local_interval=_LOCAL_LOGGING_INTERVAL_SEC),
  294. "prometheus":
  295. PrometheusStatLogger(
  296. local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
  297. labels=dict(model_name=model_config.served_model_name),
  298. max_model_len=self.model_config.max_model_len),
  299. }
  300. self.stat_loggers["prometheus"].info("cache_config",
  301. self.cache_config)
  302. # Create sequence output processor, e.g. for beam search or
  303. # speculative decoding.
  304. self.output_processor = (
  305. SequenceGroupOutputProcessor.create_output_processor(
  306. self.scheduler_config,
  307. self.detokenizer,
  308. self.scheduler,
  309. self.seq_counter,
  310. get_tokenizer_for_seq,
  311. stop_checker=StopChecker(
  312. self.scheduler_config.max_model_len,
  313. get_tokenizer_for_seq,
  314. ),
  315. ))
  316. self.cached_scheduler_outputs = [
  317. SchedulerOutputState()
  318. for _ in range(self.parallel_config.pipeline_parallel_size)
  319. ]
  320. self.scheduler_contexts = [
  321. SchedulerContext()
  322. for _ in range(self.parallel_config.pipeline_parallel_size)
  323. ]
  324. self.async_callback = [
  325. functools.partial(self._process_model_outputs,
  326. virtual_engine=v_id,
  327. is_async=True)
  328. for v_id in range(self.parallel_config.pipeline_parallel_size)
  329. ]
  330. self.async_callback_multi_step = [
  331. functools.partial(self._process_model_outputs,
  332. virtual_engine=v_id,
  333. is_async=False)
  334. for v_id in range(self.parallel_config.pipeline_parallel_size)
  335. ]
  336. def _initialize_kv_caches(self) -> None:
  337. """Initialize the KV cache in the worker(s).
  338. The workers will determine the number of blocks in both the GPU cache
  339. and the swap CPU cache.
  340. """
  341. num_gpu_blocks, num_cpu_blocks = (
  342. self.model_executor.determine_num_available_blocks())
  343. if self.cache_config.num_gpu_blocks_override is not None:
  344. num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
  345. logger.info(f"Overriding {num_gpu_blocks=} with "
  346. f"{num_gpu_blocks_override=}")
  347. num_gpu_blocks = num_gpu_blocks_override
  348. self.cache_config.num_gpu_blocks = num_gpu_blocks
  349. self.cache_config.num_cpu_blocks = num_cpu_blocks
  350. self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
  351. @classmethod
  352. def _get_executor_cls(cls,
  353. engine_config: EngineConfig) -> Type[ExecutorBase]:
  354. distributed_executor_backend = (
  355. engine_config.parallel_config.distributed_executor_backend)
  356. # Initialize the cluster and specify the executor class.
  357. if isinstance(distributed_executor_backend, type):
  358. if not issubclass(distributed_executor_backend, ExecutorBase):
  359. raise TypeError(
  360. "distributed_executor_backend must be a subclass of "
  361. f"ExecutorBase. Got {distributed_executor_backend}.")
  362. if distributed_executor_backend.uses_ray: # type: ignore
  363. initialize_ray_cluster(engine_config.parallel_config)
  364. executor_class = distributed_executor_backend
  365. elif engine_config.device_config.device_type == "neuron":
  366. from aphrodite.executor.neuron_executor import NeuronExecutor
  367. executor_class = NeuronExecutor
  368. elif engine_config.device_config.device_type == "tpu":
  369. if distributed_executor_backend == "ray":
  370. initialize_ray_cluster(engine_config.parallel_config)
  371. from aphrodite.executor.ray_tpu_executor import RayTPUExecutor
  372. executor_class = RayTPUExecutor
  373. else:
  374. assert distributed_executor_backend is None
  375. from aphrodite.executor.tpu_executor import TPUExecutor
  376. executor_class = TPUExecutor
  377. elif engine_config.device_config.device_type == "cpu":
  378. from aphrodite.executor.cpu_executor import CPUExecutor
  379. executor_class = CPUExecutor
  380. elif engine_config.device_config.device_type == "openvino":
  381. from aphrodite.executor.openvino_executor import OpenVINOExecutor
  382. executor_class = OpenVINOExecutor
  383. elif engine_config.device_config.device_type == "xpu":
  384. if distributed_executor_backend == "ray":
  385. initialize_ray_cluster(engine_config.parallel_config)
  386. from aphrodite.executor.ray_xpu_executor import RayXPUExecutor
  387. executor_class = RayXPUExecutor
  388. elif distributed_executor_backend == "mp":
  389. logger.error(
  390. "Both start methods (spawn and fork) have issues "
  391. "on XPU if you use mp backend, Please try ray instead.")
  392. else:
  393. from aphrodite.executor.xpu_executor import XPUExecutor
  394. executor_class = XPUExecutor
  395. elif distributed_executor_backend == "ray":
  396. initialize_ray_cluster(engine_config.parallel_config)
  397. from aphrodite.executor.ray_gpu_executor import RayGPUExecutor
  398. executor_class = RayGPUExecutor
  399. elif distributed_executor_backend == "mp":
  400. from aphrodite.executor.multiproc_gpu_executor import (
  401. MultiprocessingGPUExecutor)
  402. assert not envs.APHRODITE_USE_RAY_SPMD_WORKER, (
  403. "multiprocessing distributed executor backend does not "
  404. "support APHRODITE_USE_RAY_SPMD_WORKER=1")
  405. executor_class = MultiprocessingGPUExecutor
  406. else:
  407. from aphrodite.executor.gpu_executor import GPUExecutor
  408. executor_class = GPUExecutor
  409. return executor_class
  410. @classmethod
  411. def from_engine_args(
  412. cls,
  413. engine_args: EngineArgs,
  414. stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
  415. ) -> "AphroditeEngine":
  416. """Creates an Aphrodite engine from the engine arguments."""
  417. # Create the engine configs.
  418. engine_config = engine_args.create_engine_config()
  419. executor_class = cls._get_executor_cls(engine_config)
  420. # Create the LLM engine.
  421. engine = cls(
  422. **engine_config.to_dict(),
  423. executor_class=executor_class,
  424. log_stats=not engine_args.disable_log_stats,
  425. stat_loggers=stat_loggers,
  426. )
  427. return engine
  428. def __reduce__(self):
  429. # This is to ensure that the AphroditeEngine is not referenced in
  430. # the closure used to initialize Ray worker actors
  431. raise RuntimeError("AphroditeEngine should not be pickled!")
  432. def __del__(self):
  433. # Shutdown model executor when engine is garbage collected
  434. # Use getattr since __init__ can fail before the field is set
  435. if model_executor := getattr(self, "model_executor", None):
  436. model_executor.shutdown()
  437. MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because "
  438. "skip_tokenizer_init is True")
  439. def get_tokenizer_group(
  440. self,
  441. group_type: Type[_G] = BaseTokenizerGroup,
  442. *,
  443. missing_msg: str = MISSING_TOKENIZER_GROUP_MSG,
  444. ) -> _G:
  445. tokenizer_group = self.tokenizer
  446. if tokenizer_group is None:
  447. raise ValueError(missing_msg)
  448. if not isinstance(tokenizer_group, group_type):
  449. raise TypeError("Invalid type of tokenizer group. "
  450. f"Expected type: {group_type}, but "
  451. f"found type: {type(tokenizer_group)}")
  452. return tokenizer_group
  453. def get_tokenizer(
  454. self,
  455. lora_request: Optional[LoRARequest] = None,
  456. ) -> AnyTokenizer:
  457. return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
  458. def _init_tokenizer(self) -> BaseTokenizerGroup:
  459. return init_tokenizer_from_configs(
  460. model_config=self.model_config,
  461. scheduler_config=self.scheduler_config,
  462. parallel_config=self.parallel_config,
  463. enable_lora=bool(self.lora_config))
  464. def _verify_args(self) -> None:
  465. self.model_config.verify_with_parallel_config(self.parallel_config)
  466. self.cache_config.verify_with_parallel_config(self.parallel_config)
  467. if self.lora_config:
  468. self.lora_config.verify_with_model_config(self.model_config)
  469. self.lora_config.verify_with_scheduler_config(
  470. self.scheduler_config)
  471. if self.prompt_adapter_config:
  472. self.prompt_adapter_config.verify_with_model_config(
  473. self.model_config)
  474. def _get_bos_token_id(self,
  475. lora_request: Optional[LoRARequest] = None
  476. ) -> Optional[int]:
  477. if self.tokenizer is None:
  478. logger.warning("Using None for BOS token id because tokenizer "
  479. "is not initialized")
  480. return None
  481. return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id
  482. def _get_eos_token_id(self,
  483. lora_request: Optional[LoRARequest] = None
  484. ) -> Optional[int]:
  485. if self.tokenizer is None:
  486. logger.warning("Using None for EOS token id because tokenizer "
  487. "is not initialized")
  488. return None
  489. return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
  490. def _get_decoder_start_token_id(self) -> Optional[int]:
  491. '''
  492. Obtain the decoder start token id employed by an encoder/decoder
  493. model. Returns None for non-encoder/decoder models or if the
  494. model config is unavailable.
  495. '''
  496. if not self.is_encoder_decoder_model():
  497. logger.warning("Using None for decoder start token id because "
  498. "this is not an encoder/decoder model.")
  499. return None
  500. if (self.model_config is None or self.model_config.hf_config is None):
  501. logger.warning("Using None for decoder start token id because "
  502. "model config is not available.")
  503. return None
  504. dec_start_token_id = getattr(self.model_config.hf_config,
  505. 'decoder_start_token_id', None)
  506. if dec_start_token_id is None:
  507. logger.warning("Falling back on <BOS> for decoder start token id "
  508. "because decoder start token id is not available.")
  509. dec_start_token_id = self._get_bos_token_id()
  510. return dec_start_token_id
  511. def _add_processed_request(
  512. self,
  513. request_id: str,
  514. processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
  515. params: Union[SamplingParams, PoolingParams],
  516. arrival_time: float,
  517. lora_request: Optional[LoRARequest],
  518. prompt_adapter_request: Optional[PromptAdapterRequest],
  519. ) -> None:
  520. self._validate_model_inputs(processed_inputs)
  521. # Create the sequences.
  522. block_size = self.cache_config.block_size
  523. seq_id = next(self.seq_counter)
  524. eos_token_id = self._get_eos_token_id(lora_request)
  525. seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
  526. lora_request, prompt_adapter_request)
  527. encoder_seq = None
  528. if 'encoder_prompt_token_ids' in processed_inputs:
  529. encoder_seq = Sequence(seq_id,
  530. processed_inputs,
  531. block_size,
  532. eos_token_id,
  533. lora_request,
  534. prompt_adapter_request,
  535. from_decoder_prompt=False)
  536. # Create a SequenceGroup based on SamplingParams or PoolingParams
  537. if isinstance(params, SamplingParams):
  538. seq_group = self._create_sequence_group_with_sampling(
  539. request_id,
  540. seq,
  541. params,
  542. arrival_time=arrival_time,
  543. lora_request=lora_request,
  544. prompt_adapter_request=prompt_adapter_request,
  545. encoder_seq=encoder_seq)
  546. elif isinstance(params, PoolingParams):
  547. seq_group = self._create_sequence_group_with_pooling(
  548. request_id,
  549. seq,
  550. params,
  551. arrival_time=arrival_time,
  552. lora_request=lora_request,
  553. prompt_adapter_request=prompt_adapter_request,
  554. encoder_seq=encoder_seq)
  555. else:
  556. raise ValueError(
  557. "Either SamplingParams or PoolingParams must be provided.")
  558. # Add the sequence group to the scheduler with least unfinished seqs.
  559. costs = [
  560. scheduler.get_num_unfinished_seq_groups()
  561. for scheduler in self.scheduler
  562. ]
  563. min_cost_scheduler = self.scheduler[costs.index(min(costs))]
  564. min_cost_scheduler.add_seq_group(seq_group)
  565. def stop_remote_worker_execution_loop(self) -> None:
  566. self.model_executor.stop_remote_worker_execution_loop()
  567. _LLMInputComponentsType = Tuple[str, List[int]]
  568. def _prepare_decoder_input_ids_for_generation(
  569. self,
  570. decoder_input_ids: Optional[List[int]],
  571. ) -> List[int]:
  572. """
  573. Prepares `decoder_input_ids` for generation with encoder-decoder models.
  574. Based on
  575. https://github.com/huggingface/transformers/blob/
  576. 4037a2b5b1278736e566aec12e169100275545ea/
  577. src/transformers/generation/utils.py
  578. specifically GenerationMixin._prepare_decoder_input_ids_for_generation()
  579. Arguments:
  580. * decoder_input_ids: input token ids to preprocess
  581. Returns:
  582. * Processed token list
  583. """
  584. decoder_start_token_id = self._get_decoder_start_token_id()
  585. assert decoder_start_token_id is not None
  586. if decoder_input_ids is None:
  587. # no decoder prompt input ->
  588. # use decoder_start_token_id as decoder_input_ids
  589. decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
  590. if (len(decoder_input_ids) == 0
  591. or decoder_input_ids[0] != decoder_start_token_id):
  592. decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
  593. return decoder_input_ids
  594. def _tokenize_prompt(
  595. self,
  596. prompt: str,
  597. request_id: str,
  598. lora_request: Optional[LoRARequest],
  599. ) -> List[int]:
  600. '''
  601. Wrapper around application of the model's tokenizer.
  602. Arguments:
  603. * prompt
  604. * request_id
  605. * lora_request
  606. Returns:
  607. * prompt token ids
  608. '''
  609. tokenizer = self.get_tokenizer_group(
  610. missing_msg="prompts must be None if skip_tokenizer_init is True")
  611. return tokenizer.encode(request_id=request_id,
  612. prompt=prompt,
  613. lora_request=lora_request)
  614. def _extract_prompt_components(
  615. self,
  616. inputs: SingletonPromptInputs,
  617. request_id: str,
  618. lora_request: Optional[LoRARequest] = None,
  619. ) -> PromptComponents:
  620. '''
  621. Extract the components of any single encoder or decoder input prompt.
  622. Arguments:
  623. * request_id
  624. * inputs: single encoder or decoder input prompt
  625. * lora_request: this is only valid for decoder prompts
  626. Returns:
  627. * prompt
  628. * prompt_token_ids
  629. * multi_modal_data
  630. '''
  631. if isinstance(inputs, str):
  632. prompt = inputs
  633. prompt_token_ids = self._tokenize_prompt(
  634. prompt,
  635. request_id=request_id,
  636. lora_request=lora_request,
  637. )
  638. multi_modal_data = None
  639. elif isinstance(inputs, dict):
  640. if "prompt_token_ids" in inputs:
  641. prompt = None
  642. prompt_token_ids = inputs["prompt_token_ids"]
  643. else:
  644. # NOTE: This extra assignment is required to pass mypy
  645. prompt = parsed_prompt = inputs["prompt"]
  646. prompt_token_ids = self._tokenize_prompt(
  647. parsed_prompt,
  648. request_id=request_id,
  649. lora_request=lora_request,
  650. )
  651. multi_modal_data = inputs.get("multi_modal_data")
  652. else:
  653. assert_never(inputs)
  654. return prompt, prompt_token_ids, multi_modal_data
  655. def _apply_prompt_adapter(
  656. self,
  657. prompt_token_ids: List[int],
  658. prompt_adapter_request: Optional[PromptAdapterRequest],
  659. ) -> List[int]:
  660. if prompt_adapter_request:
  661. prompt_token_ids = (
  662. [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
  663. + prompt_token_ids)
  664. return prompt_token_ids
  665. def _get_default_enc_dec_decoder_prompt(self) -> List[int]:
  666. '''
  667. Specifically for encoder/decoder models:
  668. generate a default decoder prompt for when
  669. the user specifies only the encoder prompt.
  670. Encoder/decoder models utilize the decoder
  671. prompt in different ways; as new models are
  672. added, it is intended that this function
  673. will be extended to produce differing
  674. default decoder prompts, depending on the
  675. model variety.
  676. Absent a special case, the default behavior
  677. of this method is to mirror the behavior of
  678. the HuggingFace (HF) GenerationMixin for a None
  679. decoder prompt, which is to employ a logit processor
  680. setting to force the first decoded token to be <BOS>.
  681. Here, this behavior is approximated by having the
  682. "default" decoder prompt be <BOS>.
  683. However, it is possible that in the future
  684. other models may have different or more
  685. complex logic for the default decoder prompt.
  686. This motivates having a special helper method
  687. for default decoder prompts.
  688. Returns:
  689. * prompt_token_ids
  690. '''
  691. bos_token_id = self._get_bos_token_id()
  692. assert bos_token_id is not None
  693. return [bos_token_id]
  694. def _build_enc_dec_llm_inputs(
  695. self,
  696. encoder_comps: PromptComponents,
  697. decoder_comps: DecoderPromptComponents,
  698. ) -> EncoderDecoderLLMInputs:
  699. encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
  700. decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps
  701. if encoder_mm_data is not None or decoder_mm_data is not None:
  702. raise ValueError("Multi-modal encoder-decoder models are "
  703. "not supported yet")
  704. decoder_prompt_ids = (
  705. self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids))
  706. return EncoderDecoderLLMInputs(
  707. prompt_token_ids=decoder_prompt_ids,
  708. prompt=decoder_prompt,
  709. encoder_prompt_token_ids=encoder_prompt_ids,
  710. encoder_prompt=encoder_prompt,
  711. )
  712. def _process_encoder_decoder_prompt(
  713. self,
  714. inputs: PromptInputs,
  715. request_id: str,
  716. ) -> EncoderDecoderLLMInputs:
  717. '''
  718. For encoder/decoder models only:
  719. Process an input prompt into an
  720. :class:`EncoderDecoderLLMInputs` instance.
  721. There are two types of input prompts:
  722. singleton prompts which carry only the
  723. encoder prompt, and explicit encoder/decoder
  724. prompts which carry both the encoder and the
  725. decoder prompts as member variables.
  726. This function handles the following scenarios:
  727. * Singleton encoder prompt: extract encoder prompt
  728. token ids & infer default decoder prompt token ids
  729. * Explicit encoder/decoder prompt: extract encoder
  730. and decoder prompt token ids
  731. Note that for Explicit encoder/decoder prompts,
  732. each sub-prompt (encoder or decoder prompt) can
  733. have any possible singleton type; thus this
  734. method relies on helper functions to obtain
  735. token ids for the sub-prompts.
  736. Arguments:
  737. * inputs: an input prompt
  738. * request_id
  739. Returns:
  740. * :class:`EncoderDecoderLLMInputs` instance
  741. '''
  742. encoder_comps: PromptComponents
  743. decoder_comps: DecoderPromptComponents
  744. if is_explicit_encoder_decoder_prompt(inputs):
  745. encoder_comps = self._extract_prompt_components(
  746. inputs["encoder_prompt"],
  747. request_id=request_id,
  748. )
  749. if (decoder_input := inputs["decoder_prompt"]) is None:
  750. decoder_comps = None, None, None
  751. else:
  752. decoder_comps = self._extract_prompt_components(
  753. decoder_input,
  754. request_id=request_id,
  755. )
  756. else:
  757. encoder_comps = self._extract_prompt_components(
  758. inputs,
  759. request_id=request_id,
  760. )
  761. decoder_comps = None, None, None
  762. return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
  763. def _build_decoder_only_llm_inputs(
  764. self,
  765. prompt_comps: PromptComponents,
  766. prompt_adapter_request: Optional[PromptAdapterRequest],
  767. ) -> LLMInputs:
  768. prompt, prompt_token_ids, multi_modal_data = prompt_comps
  769. prompt_token_ids = self._apply_prompt_adapter(
  770. prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
  771. return LLMInputs(prompt_token_ids=prompt_token_ids,
  772. prompt=prompt,
  773. multi_modal_data=multi_modal_data)
  774. def _process_decoder_only_prompt(
  775. self,
  776. inputs: SingletonPromptInputs,
  777. request_id: str,
  778. lora_request: Optional[LoRARequest] = None,
  779. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  780. ) -> LLMInputs:
  781. '''
  782. For decoder-only models:
  783. Process an input prompt into an :class:`LLMInputs` instance.
  784. Arguments:
  785. * inputs: input prompt
  786. * request_id
  787. * lora_request
  788. * prompt_adapter_request
  789. Returns:
  790. * :class:`LLMInputs` instance
  791. '''
  792. prompt_comps = self._extract_prompt_components(
  793. inputs,
  794. request_id=request_id,
  795. lora_request=lora_request,
  796. )
  797. return self._build_decoder_only_llm_inputs(
  798. prompt_comps,
  799. prompt_adapter_request=prompt_adapter_request,
  800. )
  801. def process_model_inputs(
  802. self,
  803. inputs: PromptInputs,
  804. request_id: str,
  805. lora_request: Optional[LoRARequest] = None,
  806. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  807. ) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
  808. if self.is_encoder_decoder_model():
  809. # Encoder-decoder model requires special mapping of
  810. # input prompts to encoder & decoder
  811. model_inputs = self._process_encoder_decoder_prompt(
  812. inputs,
  813. request_id=request_id,
  814. )
  815. else:
  816. if is_explicit_encoder_decoder_prompt(inputs):
  817. raise ValueError("Cannot pass encoder-decoder prompt "
  818. "to decoder-only models")
  819. # Decoder-only operation
  820. model_inputs = self._process_decoder_only_prompt(
  821. inputs,
  822. request_id=request_id,
  823. lora_request=lora_request,
  824. prompt_adapter_request=prompt_adapter_request,
  825. )
  826. return self.input_processor(model_inputs)
  827. def add_request(
  828. self,
  829. request_id: str,
  830. inputs: PromptInputs,
  831. params: Union[SamplingParams, PoolingParams],
  832. arrival_time: Optional[float] = None,
  833. lora_request: Optional[LoRARequest] = None,
  834. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  835. ) -> None:
  836. """Add a request to the engine's request pool.
  837. The request is added to the request pool and will be processed by the
  838. scheduler as `engine.step()` is called. The exact scheduling policy is
  839. determined by the scheduler.
  840. Args:
  841. request_id: The unique ID of the request.
  842. prompt: The prompt string. Can be None if prompt_token_ids is
  843. provided.
  844. params: Parameters for sampling or pooling. SamplingParams
  845. for text generation. PoolingParams for pooling.
  846. prompt_token_ids: The token IDs of the prompt. If None, we
  847. use the tokenizer to convert the prompts to token IDs.
  848. arrival_time: The arrival time of the request. If None, we use
  849. the current monotonic time.
  850. Details:
  851. - Set arrival_time to the current time if it is None.
  852. - Set prompt_token_ids to the encoded prompt if it is None.
  853. - Create `best_of` number of :class:`~aphrodite.common.sequence`
  854. objects.
  855. - Create a :class:`~aphrodite.common.sequenceGroup` object
  856. from the list of :class:`~aphrodite.common.sequence`.
  857. - Add the :class:`~aphrodite.common.sequenceGroup` object to the
  858. scheduler.
  859. Example:
  860. >>> # initialize engine
  861. >>> engine = AphroditeEngine.from_engine_args(engine_args)
  862. >>> # set request arguments
  863. >>> example_prompt = "Who is the president of the United States?"
  864. >>> sampling_params = SamplingParams(temperature=0.0)
  865. >>> request_id = 0
  866. >>>
  867. >>> # add the request to the engine
  868. >>> engine.add_request(
  869. >>> str(request_id),
  870. >>> example_prompt,
  871. >>> SamplingParams(temperature=0.0))
  872. >>> # continue the request processing
  873. >>> ...
  874. """
  875. if lora_request is not None and not self.lora_config:
  876. raise ValueError(f"Got lora_request {lora_request} but LoRA is "
  877. "not enabled!")
  878. if arrival_time is None:
  879. arrival_time = time.time()
  880. processed_inputs = self.process_model_inputs(
  881. inputs,
  882. request_id=request_id,
  883. lora_request=lora_request,
  884. prompt_adapter_request=prompt_adapter_request,
  885. )
  886. self._add_processed_request(
  887. request_id=request_id,
  888. processed_inputs=processed_inputs,
  889. params=params,
  890. arrival_time=arrival_time,
  891. lora_request=lora_request,
  892. prompt_adapter_request=prompt_adapter_request,
  893. )
  894. def _create_sequence_group_with_sampling(
  895. self,
  896. request_id: str,
  897. seq: Sequence,
  898. sampling_params: SamplingParams,
  899. arrival_time: float,
  900. lora_request: Optional[LoRARequest],
  901. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  902. encoder_seq: Optional[Sequence] = None,
  903. ) -> SequenceGroup:
  904. """Creates a SequenceGroup with SamplingParams."""
  905. max_logprobs = self.get_model_config().max_logprobs
  906. if (sampling_params.logprobs
  907. and sampling_params.logprobs > max_logprobs) or (
  908. sampling_params.prompt_logprobs
  909. and sampling_params.prompt_logprobs > max_logprobs):
  910. raise ValueError(f"Cannot request more than "
  911. f"{max_logprobs} logprobs.")
  912. # Defensive copy of SamplingParams, which are used by the sampler,
  913. # this doesn't deep-copy LogitsProcessor objects
  914. sampling_params = sampling_params.clone()
  915. sampling_params.update_from_generation_config(
  916. self.generation_config_fields, seq.eos_token_id)
  917. sampling_params._verify_with_scheduler_config(self.scheduler_config)
  918. # Create the sequence group.
  919. seq_group = SequenceGroup(
  920. request_id=request_id,
  921. seqs=[seq],
  922. arrival_time=arrival_time,
  923. sampling_params=sampling_params,
  924. lora_request=lora_request,
  925. prompt_adapter_request=prompt_adapter_request,
  926. encoder_seq=encoder_seq)
  927. return seq_group
  928. def _create_sequence_group_with_pooling(
  929. self,
  930. request_id: str,
  931. seq: Sequence,
  932. pooling_params: PoolingParams,
  933. arrival_time: float,
  934. lora_request: Optional[LoRARequest],
  935. prompt_adapter_request: Optional[PromptAdapterRequest],
  936. encoder_seq: Optional[Sequence] = None,
  937. ) -> SequenceGroup:
  938. """Creates a SequenceGroup with PoolingParams."""
  939. # Defensive copy of PoolingParams, which are used by the pooler
  940. pooling_params = pooling_params.clone()
  941. # Create the sequence group.
  942. seq_group = SequenceGroup(
  943. request_id=request_id,
  944. seqs=[seq],
  945. arrival_time=arrival_time,
  946. lora_request=lora_request,
  947. pooling_params=pooling_params,
  948. prompt_adapter_request=prompt_adapter_request,
  949. encoder_seq=encoder_seq)
  950. return seq_group
  951. def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
  952. """Aborts a request(s) with the given ID.
  953. Args:
  954. request_id: The ID(s) of the request to abort.
  955. Details:
  956. - Refer to the
  957. :meth:`~aphrodite.processing.scheduler.Scheduler.abort_seq_group`
  958. from class :class:`~aphrodite.processing.scheduler.Scheduler`.
  959. Example:
  960. >>> # initialize engine and add a request with request_id
  961. >>> request_id = str(0)
  962. >>> # abort the request
  963. >>> engine.abort_request(request_id)
  964. """
  965. for scheduler in self.scheduler:
  966. scheduler.abort_seq_group(request_id)
  967. def get_model_config(self) -> ModelConfig:
  968. """Gets the model configuration."""
  969. return self.model_config
  970. def get_parallel_config(self) -> ParallelConfig:
  971. """Gets the parallel configuration."""
  972. return self.parallel_config
  973. def get_decoding_config(self) -> DecodingConfig:
  974. """Gets the decoding configuration."""
  975. return self.decoding_config
  976. def get_scheduler_config(self) -> SchedulerConfig:
  977. """Gets the scheduler configuration."""
  978. return self.scheduler_config
  979. def get_lora_config(self) -> LoRAConfig:
  980. """Gets the LoRA configuration."""
  981. return self.lora_config
  982. def get_num_unfinished_requests(self) -> int:
  983. """Gets the number of unfinished requests."""
  984. return sum(scheduler.get_num_unfinished_seq_groups()
  985. for scheduler in self.scheduler)
  986. def has_unfinished_requests(self) -> bool:
  987. """Returns True if there are unfinished requests."""
  988. return any(scheduler.has_unfinished_seqs()
  989. for scheduler in self.scheduler)
  990. def has_unfinished_requests_for_virtual_engine(
  991. self, virtual_engine: int) -> bool:
  992. """
  993. Returns True if there are unfinished requests for the virtual engine.
  994. """
  995. return self.scheduler[virtual_engine].has_unfinished_seqs()
  996. def _process_sequence_group_outputs(
  997. self,
  998. seq_group: SequenceGroup,
  999. outputs: List[EmbeddingSequenceGroupOutput],
  1000. ) -> None:
  1001. seq_group.embeddings = outputs[0].embeddings
  1002. for seq in seq_group.get_seqs():
  1003. seq.status = SequenceStatus.FINISHED_STOPPED
  1004. return
  1005. def _process_model_outputs(self,
  1006. virtual_engine: int,
  1007. is_async: bool,
  1008. sampler_output: Optional[SamplerOutput] = None,
  1009. is_last_output: bool = False) -> None:
  1010. """Apply the model output to the sequences in the scheduled seq groups.
  1011. virtual_engine: The engine id to operate on
  1012. is_async: Indicates whether this postprocessor runs in
  1013. parallel with the GPU forward pass and is processing
  1014. tokens from the previous step. If this is true, then
  1015. no tokens need to be appended since it is already done
  1016. externally (before the next schedule() call)
  1017. sampler_output: Used with multi-step execution to provide
  1018. sampler_output of each step
  1019. is_last_output: Used with multi-step execution to indicate
  1020. the last step (of each multi-step group)
  1021. Returns RequestOutputs that can be returned to the client.
  1022. """
  1023. now = time.time()
  1024. is_multi_step = sampler_output is not None
  1025. ctx: SchedulerContext = self.scheduler_contexts[virtual_engine]
  1026. if len(ctx.output_queue) == 0:
  1027. return None
  1028. if is_multi_step:
  1029. # Async + multi-step case
  1030. (outputs, seq_group_metadata_list,
  1031. scheduler_outputs) = ctx.output_queue[0]
  1032. assert outputs is None
  1033. outputs = [sampler_output]
  1034. else:
  1035. # Async standard case
  1036. (outputs, seq_group_metadata_list,
  1037. scheduler_outputs) = ctx.output_queue.popleft()
  1038. assert outputs is not None
  1039. # Sanity check
  1040. assert len(seq_group_metadata_list) == len(
  1041. scheduler_outputs.scheduled_seq_groups)
  1042. # Organize outputs by [step][sequence group] instead of
  1043. # [sequence group][step].
  1044. if len(outputs) > 1:
  1045. outputs_by_sequence_group = create_output_by_sequence_group(
  1046. outputs, num_seq_groups=len(seq_group_metadata_list))
  1047. else:
  1048. outputs_by_sequence_group = outputs
  1049. finished_before: List[int] = []
  1050. for i, seq_group_meta in enumerate(seq_group_metadata_list):
  1051. scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
  1052. seq_group = scheduled_seq_group.seq_group
  1053. if seq_group.is_finished():
  1054. finished_before.append(i)
  1055. continue
  1056. if len(outputs) > 1:
  1057. output = outputs_by_sequence_group[i]
  1058. else:
  1059. output = [outputs_by_sequence_group[0][i]]
  1060. if not is_async:
  1061. seq_group.update_num_computed_tokens(
  1062. scheduled_seq_group.token_chunk_size)
  1063. if self.model_config.embedding_mode:
  1064. self._process_sequence_group_outputs(seq_group, output)
  1065. continue
  1066. self.output_processor.process_prompt_logprob(seq_group, output)
  1067. if seq_group_meta.do_sample:
  1068. self.output_processor.process_outputs(seq_group, output,
  1069. is_async)
  1070. # For async + multi-step, free finished seqs and create outputs
  1071. # only on the final step.
  1072. if is_multi_step and not is_last_output:
  1073. return
  1074. for scheduler in self.scheduler:
  1075. scheduler.free_finished_seq_groups()
  1076. # Create the outputs.
  1077. for i, _ in enumerate(seq_group_metadata_list):
  1078. scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
  1079. if not is_multi_step and i in finished_before:
  1080. continue # Avoids double processing
  1081. seq_group = scheduled_seq_group.seq_group
  1082. seq_group.maybe_set_first_token_time(now)
  1083. if (seq_group.is_finished()
  1084. if self.step_return_finished_only else True):
  1085. request_output = RequestOutputFactory.create(seq_group)
  1086. ctx.request_outputs.append(request_output)
  1087. for seq_group in scheduler_outputs.ignored_seq_groups:
  1088. request_output = RequestOutputFactory.create(seq_group)
  1089. ctx.request_outputs.append(request_output)
  1090. # For async + multi-step, do stats only on the last output.
  1091. # Otherwise, do stats if the execution is async
  1092. do_stats = is_multi_step or is_async
  1093. if do_stats:
  1094. # Log stats.
  1095. self.do_log_stats(scheduler_outputs, outputs, finished_before)
  1096. return None
  1097. def _advance_to_next_step(
  1098. self, output: List[SamplerOutput],
  1099. seq_group_metadata_list: List[SequenceGroupMetadata],
  1100. scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
  1101. """Given model output from a single run, append the tokens to the
  1102. sequences. This is normally done inside output processor, but it is
  1103. required if the worker is to perform async forward pass to next step.
  1104. """
  1105. for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
  1106. zip(seq_group_metadata_list, output, scheduled_seq_groups):
  1107. seq_group = scheduled_seq_group.seq_group
  1108. if seq_group.is_finished():
  1109. continue
  1110. seq_group.update_num_computed_tokens(
  1111. seq_group_metadata.token_chunk_size)
  1112. if seq_group_metadata.do_sample:
  1113. assert len(sequence_group_outputs.samples) == 1, (
  1114. "Async output processor expects a single sample"
  1115. " (i.e sampling_params.n == 1 and no "
  1116. "sampling_params.best_of > 1)")
  1117. sample = sequence_group_outputs.samples[0]
  1118. assert len(seq_group.seqs) == 1
  1119. seq = seq_group.seqs[0]
  1120. seq.append_token_id(sample.output_token, sample.logprobs)
  1121. def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
  1122. """Performs one decoding iteration and returns newly generated results.
  1123. .. figure:: https://i.imgur.com/sv2HssD.png
  1124. :alt: Overview of the step function
  1125. :align: center
  1126. Overview of the step function.
  1127. Details:
  1128. - Step 1: Schedules the sequences to be executed in the next
  1129. iteration and the token blocks to be swapped in/out/copy.
  1130. - Depending on the scheduling policy,
  1131. sequences may be `preempted/reordered`.
  1132. - A Sequence Group (SG) refer to a group of sequences
  1133. that are generated from the same prompt.
  1134. - Step 2: Calls the distributed executor to execute the model.
  1135. - Step 3: Processes the model output. This mainly includes:
  1136. - Decodes the relevant outputs.
  1137. - Updates the scheduled sequence groups with model outputs
  1138. based on its `sampling parameters` (`use_beam_search` or not).
  1139. - Frees the finished sequence groups.
  1140. - Finally, it creates and returns the newly generated results.
  1141. Example:
  1142. >>> # Please see the example/ folder for more detailed examples.
  1143. >>>
  1144. >>> # initialize engine and request arguments
  1145. >>> engine = AphroditeEngine.from_engine_args(engine_args)
  1146. >>> example_inputs = [(0, "What is LLM?",
  1147. >>> SamplingParams(temperature=0.0))]
  1148. >>>
  1149. >>> # Start the engine with an event loop
  1150. >>> while True:
  1151. >>> if example_inputs:
  1152. >>> req_id, prompt, sampling_params = example_inputs.pop(0)
  1153. >>> engine.add_request(str(req_id),prompt,sampling_params)
  1154. >>>
  1155. >>> # continue the request processing
  1156. >>> request_outputs = engine.step()
  1157. >>> for request_output in request_outputs:
  1158. >>> if request_output.finished:
  1159. >>> # return or show the request output
  1160. >>>
  1161. >>> if not (engine.has_unfinished_requests() or example_inputs):
  1162. >>> break
  1163. """
  1164. if self.parallel_config.pipeline_parallel_size > 1:
  1165. raise NotImplementedError(
  1166. "Pipeline parallelism is only supported through AsyncAphrodite "
  1167. "as performance will be severely degraded otherwise.")
  1168. # For llm_engine, there is no pipeline parallel support, so the engine
  1169. # used is always 0
  1170. virtual_engine = 0
  1171. # These are cached outputs from previous iterations. None if on first
  1172. # iteration
  1173. cached_outputs = self.cached_scheduler_outputs[virtual_engine]
  1174. seq_group_metadata_list = cached_outputs.seq_group_metadata_list
  1175. scheduler_outputs = cached_outputs.scheduler_outputs
  1176. allow_async_output_proc = cached_outputs.allow_async_output_proc
  1177. # Detect async + multi-step
  1178. use_async_and_multi_step = (self.scheduler_config.is_multi_step
  1179. and allow_async_output_proc)
  1180. ctx = self.scheduler_contexts[virtual_engine]
  1181. # Skip the scheduler if there are any remaining steps in the seq groups.
  1182. # This ensures that the scheduler is only called again when the current
  1183. # batch has completed.
  1184. if not self._has_remaining_steps(seq_group_metadata_list):
  1185. # Clear outputs on scheduler iteration start
  1186. ctx.request_outputs.clear()
  1187. # Schedule iteration
  1188. (seq_group_metadata_list, scheduler_outputs,
  1189. allow_async_output_proc
  1190. ) = self.scheduler[virtual_engine].schedule()
  1191. # Detect async + multi-step
  1192. use_async_and_multi_step = (self.scheduler_config.is_multi_step
  1193. and allow_async_output_proc)
  1194. # Maybe switch from async mode to sync mode
  1195. if not allow_async_output_proc and len(ctx.output_queue) > 0:
  1196. self._process_model_outputs(virtual_engine=virtual_engine,
  1197. is_async=True)
  1198. # For async + multi-step, init the queue
  1199. if use_async_and_multi_step:
  1200. assert len(ctx.output_queue) == 0
  1201. assert seq_group_metadata_list is not None
  1202. ctx.output_queue.append(
  1203. (None, seq_group_metadata_list, scheduler_outputs))
  1204. if (self.scheduler_config.is_multi_step
  1205. and scheduler_outputs.num_lookahead_slots > 0):
  1206. # cache the scheduler outputs for the next iteration if we have
  1207. # lookahead slots
  1208. self._cache_scheduler_outputs_for_multi_step(
  1209. virtual_engine, seq_group_metadata_list, scheduler_outputs,
  1210. allow_async_output_proc)
  1211. assert seq_group_metadata_list is not None
  1212. assert scheduler_outputs is not None
  1213. if not scheduler_outputs.is_empty():
  1214. finished_requests_ids = self.scheduler[
  1215. virtual_engine].get_and_reset_finished_requests_ids()
  1216. # Check if we have a cached last_output from the previous iteration.
  1217. # For supporting PP this is probably the best way to pass the
  1218. # sampled_token_ids, as a separate broadcast over all the PP stages
  1219. # will cause one virtual engine's microbatch to block the pipeline.
  1220. last_sampled_token_ids = \
  1221. self._get_last_sampled_token_ids(virtual_engine)
  1222. execute_model_req = ExecuteModelRequest(
  1223. seq_group_metadata_list=seq_group_metadata_list,
  1224. blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
  1225. blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
  1226. blocks_to_copy=scheduler_outputs.blocks_to_copy,
  1227. num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
  1228. running_queue_size=scheduler_outputs.running_queue_size,
  1229. finished_requests_ids=finished_requests_ids,
  1230. # We use ExecuteModelRequest to pass the last sampled_token_ids
  1231. # to each of the non-last PP stages for in-place prepare_input.
  1232. last_sampled_token_ids=last_sampled_token_ids)
  1233. if allow_async_output_proc:
  1234. async_callback = self.async_callback_multi_step[
  1235. virtual_engine] if use_async_and_multi_step \
  1236. else self.async_callback[virtual_engine]
  1237. execute_model_req.async_callback = async_callback
  1238. execute_model_req.use_async_and_multi_step = \
  1239. use_async_and_multi_step
  1240. output = self.model_executor.execute_model(
  1241. execute_model_req=execute_model_req)
  1242. # We need to do this here so that last step's sampled_token_ids can
  1243. # be passed to the next iteration for PP.
  1244. if self.scheduler_config.is_multi_step:
  1245. self._update_cached_scheduler_output(virtual_engine, output)
  1246. else:
  1247. # Nothing scheduled => If there is pending async postprocessor,
  1248. # then finish it here.
  1249. if not use_async_and_multi_step and len(ctx.output_queue) > 0:
  1250. assert not self.scheduler_config.is_multi_step
  1251. self._process_model_outputs(virtual_engine=virtual_engine,
  1252. is_async=True)
  1253. # No outputs in this case
  1254. output = []
  1255. # Finish the current step for all the sequence groups.
  1256. if self.scheduler_config.is_multi_step:
  1257. for seq_group in seq_group_metadata_list:
  1258. seq_group.finish_step()
  1259. if not self._has_remaining_steps(seq_group_metadata_list):
  1260. # clear the cache if we have finished all the steps.
  1261. if self.scheduler_config.is_multi_step:
  1262. self.cached_scheduler_outputs[0] = SchedulerOutputState()
  1263. if use_async_and_multi_step:
  1264. # For async + multi-step, clear the queue
  1265. ctx.output_queue.clear()
  1266. else:
  1267. # Add results to the output_queue
  1268. # (for async or non-async postprocessing)
  1269. ctx.output_queue.append(
  1270. (output, seq_group_metadata_list, scheduler_outputs))
  1271. if output and allow_async_output_proc:
  1272. assert len(output) == 1, (
  1273. "Multi step decoding does not work "
  1274. "with async output processing.")
  1275. self._advance_to_next_step(
  1276. output[0], seq_group_metadata_list,
  1277. scheduler_outputs.scheduled_seq_groups)
  1278. # Check if need to run the usual non-async path
  1279. if not allow_async_output_proc:
  1280. self._process_model_outputs(virtual_engine=virtual_engine,
  1281. is_async=False)
  1282. # Log stats.
  1283. self.do_log_stats(scheduler_outputs, output)
  1284. else:
  1285. # Multi-step case
  1286. if use_async_and_multi_step:
  1287. return []
  1288. else:
  1289. ctx.request_outputs = []
  1290. if not self.has_unfinished_requests():
  1291. # Drain async postprocessor (if exists)
  1292. if len(ctx.output_queue) > 0:
  1293. assert not self.scheduler_config.is_multi_step
  1294. self._process_model_outputs(virtual_engine=virtual_engine,
  1295. is_async=True)
  1296. assert len(ctx.output_queue) == 0
  1297. # Stop the execute model loop in parallel workers until there are
  1298. # more requests to process. This avoids waiting indefinitely in
  1299. # torch.distributed ops which may otherwise timeout, and unblocks
  1300. # the RPC thread in the workers so that they can process any other
  1301. # queued control plane messages, such as add/remove lora adapters.
  1302. self.model_executor.stop_remote_worker_execution_loop()
  1303. return ctx.request_outputs
  1304. def _has_remaining_steps(
  1305. self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
  1306. ) -> bool:
  1307. if (not self.scheduler_config.is_multi_step
  1308. or not seq_group_metadata_list):
  1309. return False
  1310. # TODO: this is a sanity check for nowto make sure that all the
  1311. # seqs are on the same steps. Eventually we will want to do some sort of
  1312. # dynamic scheduling when doing multi-step decoding.
  1313. ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
  1314. if any([
  1315. seq_group.state.remaining_steps != ref_remaining_steps
  1316. for seq_group in seq_group_metadata_list[1:]
  1317. ]):
  1318. raise AssertionError(("All running sequence groups should "
  1319. "have the same remaining steps."))
  1320. return ref_remaining_steps > 0
  1321. def _cache_scheduler_outputs_for_multi_step(
  1322. self, virtual_engine: int,
  1323. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
  1324. scheduler_outputs: SchedulerOutputs,
  1325. allow_async_output_proc: bool) -> None:
  1326. co = self.cached_scheduler_outputs[virtual_engine]
  1327. co.seq_group_metadata_list = seq_group_metadata_list
  1328. co.scheduler_outputs = scheduler_outputs
  1329. co.allow_async_output_proc = allow_async_output_proc
  1330. co.last_output = None
  1331. def _update_cached_scheduler_output(
  1332. self, virtual_engine: int,
  1333. output: List[Optional[SamplerOutput]]) -> None:
  1334. if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
  1335. and output[0] is not None):
  1336. last_output = output[-1]
  1337. assert last_output is not None
  1338. assert last_output.sampled_token_ids_cpu is not None
  1339. assert last_output.sampled_token_ids is None
  1340. assert last_output.sampled_token_probs is None
  1341. self.cached_scheduler_outputs[
  1342. virtual_engine].last_output = last_output
  1343. def _get_last_sampled_token_ids(
  1344. self, virtual_engine: int) -> Optional[torch.Tensor]:
  1345. cached_last_output = self.cached_scheduler_outputs[
  1346. virtual_engine].last_output
  1347. if (self.scheduler_config.is_multi_step
  1348. and self.parallel_config.pipeline_parallel_size > 1
  1349. and cached_last_output is not None
  1350. and cached_last_output.sampled_token_ids_cpu is not None):
  1351. return cached_last_output.sampled_token_ids_cpu
  1352. return None
  1353. def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
  1354. if logger_name in self.stat_loggers:
  1355. raise KeyError(f"Logger with name {logger_name} already exists.")
  1356. self.stat_loggers[logger_name] = logger
  1357. def remove_logger(self, logger_name: str) -> None:
  1358. if logger_name not in self.stat_loggers:
  1359. raise KeyError(f"Logger with name {logger_name} does not exist.")
  1360. del self.stat_loggers[logger_name]
  1361. def do_log_stats(self,
  1362. scheduler_outputs: Optional[SchedulerOutputs] = None,
  1363. model_output: Optional[List[SamplerOutput]] = None,
  1364. finished_before: Optional[List[int]] = None) -> None:
  1365. """Forced log when no requests active."""
  1366. if self.log_stats:
  1367. stats = self._get_stats(scheduler_outputs, model_output,
  1368. finished_before)
  1369. for loggers in self.stat_loggers.values():
  1370. loggers.log(stats)
  1371. def _get_stats(self,
  1372. scheduler_outputs: Optional[SchedulerOutputs],
  1373. model_output: Optional[List[SamplerOutput]] = None,
  1374. finished_before: Optional[List[int]] = None) -> Stats:
  1375. """Get Stats to be Logged to Prometheus.
  1376. Args:
  1377. scheduler_outputs: Optional, used to populate metrics related to
  1378. the scheduled batch,
  1379. model_output: Optional, used to emit speculative decoding metrics
  1380. which are created by the workers.
  1381. """
  1382. now = time.time()
  1383. # System State
  1384. # Scheduler State
  1385. num_running_sys = sum(
  1386. len(scheduler.running) for scheduler in self.scheduler)
  1387. num_swapped_sys = sum(
  1388. len(scheduler.swapped) for scheduler in self.scheduler)
  1389. num_waiting_sys = sum(
  1390. len(scheduler.waiting) for scheduler in self.scheduler)
  1391. # KV Cache Usage in %
  1392. num_total_gpu = self.cache_config.num_gpu_blocks
  1393. gpu_cache_usage_sys = 0.
  1394. if num_total_gpu: # Guard against both None and 0
  1395. num_free_gpu = sum(
  1396. scheduler.block_manager.get_num_free_gpu_blocks()
  1397. for scheduler in self.scheduler)
  1398. gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
  1399. num_total_cpu = self.cache_config.num_cpu_blocks
  1400. cpu_cache_usage_sys = 0.
  1401. if num_total_cpu: # Guard against both None and 0
  1402. num_free_cpu = sum(
  1403. scheduler.block_manager.get_num_free_cpu_blocks()
  1404. for scheduler in self.scheduler)
  1405. cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
  1406. # Prefix Cache Hit Rate. Note that we always use
  1407. # the cache hit rate of the first virtual engine.
  1408. cpu_prefix_cache_hit_rate = self.scheduler[
  1409. 0].get_prefix_cache_hit_rate(Device.CPU)
  1410. gpu_prefix_cache_hit_rate = self.scheduler[
  1411. 0].get_prefix_cache_hit_rate(Device.GPU)
  1412. # Iteration stats
  1413. num_prompt_tokens_iter = 0
  1414. num_generation_tokens_iter = 0
  1415. time_to_first_tokens_iter: List[float] = []
  1416. time_per_output_tokens_iter: List[float] = []
  1417. num_preemption_iter = (0 if scheduler_outputs is None else
  1418. scheduler_outputs.preempted)
  1419. # Request stats
  1420. # Latency
  1421. time_e2e_requests: List[float] = []
  1422. # Metadata
  1423. num_prompt_tokens_requests: List[int] = []
  1424. num_generation_tokens_requests: List[int] = []
  1425. best_of_requests: List[int] = []
  1426. n_requests: List[int] = []
  1427. finished_reason_requests: List[str] = []
  1428. # NOTE: This loop assumes prefill seq_groups are before
  1429. # decode seq_groups in scheduled_seq_groups.
  1430. if scheduler_outputs is not None:
  1431. # For async postprocessor, already finished sequences need to be
  1432. # not counted (to avoid double counting)
  1433. actual_num_batched_tokens = scheduler_outputs.num_batched_tokens # type: ignore
  1434. num_generation_tokens_from_prefill_groups = 0.
  1435. # NOTE: if scheduler_outputs.num_prefill_groups > 0 and
  1436. # the len of scheduler_outputs.scheduled_seq_groups is !=
  1437. # scheduler_outputs.num_prefill_groups, this means that
  1438. # chunked prefills have been detected.
  1439. for idx, scheduled_seq_group in enumerate(
  1440. scheduler_outputs.scheduled_seq_groups):
  1441. # Skip double logging when using async output proc
  1442. if finished_before and idx in finished_before:
  1443. actual_num_batched_tokens -= 1
  1444. continue
  1445. group_was_prefill = idx < scheduler_outputs.num_prefill_groups
  1446. seq_group = scheduled_seq_group.seq_group
  1447. # NOTE: a seq_group that completed all of its prefill tokens
  1448. # in the last iteration will have seq_group.is_prefill() = False
  1449. # with group_was_prefill = True
  1450. if group_was_prefill:
  1451. # Number of prompt tokens.
  1452. num_prompt_tokens_iter += (
  1453. scheduled_seq_group.token_chunk_size)
  1454. # If the seq_group just finished the prefill state
  1455. # get TTFT.
  1456. if not seq_group.is_prefill():
  1457. latency = seq_group.get_last_latency(now)
  1458. time_to_first_tokens_iter.append(latency)
  1459. # One generation token per finished prefill.
  1460. num_generation_tokens_from_prefill_groups += (
  1461. seq_group.num_seqs())
  1462. else:
  1463. # TPOTs.
  1464. latency = seq_group.get_last_latency(now)
  1465. time_per_output_tokens_iter.append(latency)
  1466. # Because of chunked prefill, we can have a single sequence
  1467. # group that does multiple prompt_runs. To prevent logging
  1468. # the same metadata more than once per request, we standardize
  1469. # on logging request level information for finished requests,
  1470. # which can only happen once.
  1471. if seq_group.is_finished():
  1472. # Latency timings
  1473. time_e2e_requests.append(now -
  1474. seq_group.metrics.arrival_time)
  1475. # Metadata
  1476. num_prompt_tokens_requests.append(
  1477. len(seq_group.prompt_token_ids))
  1478. num_generation_tokens_requests.extend([
  1479. seq.get_output_len()
  1480. for seq in seq_group.get_finished_seqs()
  1481. ])
  1482. if seq_group.sampling_params is not None:
  1483. best_of_requests.append(
  1484. seq_group.sampling_params.best_of)
  1485. n_requests.append(seq_group.sampling_params.n)
  1486. finished_reason_requests.extend([
  1487. SequenceStatus.get_finished_reason(seq.status)
  1488. for seq in seq_group.get_finished_seqs()
  1489. ])
  1490. # Number of generation tokens.
  1491. # num_batched_tokens equals the number of prompt_tokens plus the
  1492. # number of decode_tokens in a single iteration. So,
  1493. # num_generation_tokens = num_batched_tokens - num_prompt_tokens
  1494. # + num_generation_tokens_from_prefill_groups (since we generate
  1495. # one token on prefills on iters where the prefill finishes).
  1496. num_generation_tokens_iter = (
  1497. actual_num_batched_tokens - num_prompt_tokens_iter +
  1498. num_generation_tokens_from_prefill_groups)
  1499. # Spec decode, if enabled, emits specialized metrics from the worker in
  1500. # sampler output.
  1501. if model_output and (model_output[0].spec_decode_worker_metrics
  1502. is not None):
  1503. spec_decode_metrics = model_output[0].spec_decode_worker_metrics
  1504. else:
  1505. spec_decode_metrics = None
  1506. return Stats(
  1507. now=now,
  1508. # System stats
  1509. # Scheduler State
  1510. num_running_sys=num_running_sys,
  1511. num_swapped_sys=num_swapped_sys,
  1512. num_waiting_sys=num_waiting_sys,
  1513. # KV Cache Usage in %
  1514. gpu_cache_usage_sys=gpu_cache_usage_sys,
  1515. cpu_cache_usage_sys=cpu_cache_usage_sys,
  1516. # Prefix Cache Hit Rate
  1517. cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
  1518. gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
  1519. # Iteration stats
  1520. num_prompt_tokens_iter=num_prompt_tokens_iter,
  1521. num_generation_tokens_iter=num_generation_tokens_iter,
  1522. time_to_first_tokens_iter=time_to_first_tokens_iter,
  1523. time_per_output_tokens_iter=time_per_output_tokens_iter,
  1524. spec_decode_metrics=spec_decode_metrics,
  1525. num_preemption_iter=num_preemption_iter,
  1526. # Request stats
  1527. # Latency
  1528. time_e2e_requests=time_e2e_requests,
  1529. # Metadata
  1530. num_prompt_tokens_requests=num_prompt_tokens_requests,
  1531. num_generation_tokens_requests=num_generation_tokens_requests,
  1532. best_of_requests=best_of_requests,
  1533. n_requests=n_requests,
  1534. finished_reason_requests=finished_reason_requests,
  1535. )
  1536. def add_lora(self, lora_request: LoRARequest) -> bool:
  1537. return self.model_executor.add_lora(lora_request)
  1538. def remove_lora(self, lora_id: int) -> bool:
  1539. return self.model_executor.remove_lora(lora_id)
  1540. def list_loras(self) -> Set[int]:
  1541. return self.model_executor.list_loras()
  1542. def pin_lora(self, lora_id: int) -> bool:
  1543. return self.model_executor.pin_lora(lora_id)
  1544. def add_prompt_adapter(
  1545. self, prompt_adapter_request: PromptAdapterRequest) -> bool:
  1546. return self.model_executor.add_prompt_adapter(prompt_adapter_request)
  1547. def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  1548. return self.model_executor.remove_prompt_adapter(prompt_adapter_id)
  1549. def list_prompt_adapters(self) -> List[int]:
  1550. return self.model_executor.list_prompt_adapters()
  1551. def check_health(self) -> None:
  1552. if self.tokenizer:
  1553. self.tokenizer.check_health()
  1554. self.model_executor.check_health()
  1555. def is_encoder_decoder_model(self):
  1556. return self.model_config.is_encoder_decoder_model
  1557. def is_embedding_model(self):
  1558. return self.model_config.is_embedding_model
  1559. def _validate_model_inputs(self, inputs: Union[LLMInputs,
  1560. EncoderDecoderLLMInputs]):
  1561. if self.is_encoder_decoder_model():
  1562. prompt_ids = inputs.get("encoder_prompt_token_ids")
  1563. else:
  1564. prompt_ids = inputs.get("prompt_token_ids")
  1565. if prompt_ids is None or len(prompt_ids) == 0:
  1566. raise ValueError("Prompt cannot be empty")
  1567. if self.model_config.is_multimodal_model:
  1568. max_prompt_len = self.model_config.max_model_len
  1569. if len(prompt_ids) > max_prompt_len:
  1570. raise ValueError(
  1571. f"The prompt (total length {len(prompt_ids)}) is too long "
  1572. f"to fit into the model (context length {max_prompt_len}). "
  1573. "Make sure that `max_model_len` is no smaller than the "
  1574. "number of text tokens plus multimodal tokens. For image "
  1575. "inputs, the number of image tokens depends on the number "
  1576. "of images, and possibly their aspect ratios as well.")
  1577. # TODO: Find out how many placeholder tokens are there so we can
  1578. # check that chunked prefill does not truncate them
  1579. # max_batch_len = self.scheduler_config.max_num_batched_tokens
  1580. setup_logger()