aphrodite_engine.py 66 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565
  1. import time
  2. from collections import deque
  3. from contextlib import contextmanager
  4. from dataclasses import dataclass
  5. from functools import partial
  6. from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
  7. Iterable, List, NamedTuple, Optional)
  8. from typing import Sequence as GenericSequence
  9. from typing import Set, Type, Union
  10. import torch
  11. from loguru import logger
  12. from typing_extensions import TypeVar
  13. import aphrodite.common.envs as envs
  14. from aphrodite.common.config import (CacheConfig, DecodingConfig, DeviceConfig,
  15. EngineConfig, LoadConfig, LoRAConfig,
  16. ModelConfig, ParallelConfig,
  17. PromptAdapterConfig, SchedulerConfig,
  18. SpeculativeConfig)
  19. from aphrodite.common.logger import setup_logger
  20. from aphrodite.common.outputs import (EmbeddingRequestOutput, RequestOutput,
  21. RequestOutputFactory)
  22. from aphrodite.common.pooling_params import PoolingParams
  23. from aphrodite.common.sampling_params import RequestOutputKind, SamplingParams
  24. from aphrodite.common.sequence import (EmbeddingSequenceGroupOutput,
  25. ExecuteModelRequest, Sequence,
  26. SequenceGroup, SequenceGroupMetadata,
  27. SequenceStatus)
  28. from aphrodite.common.utils import Counter, Device, weak_bind
  29. from aphrodite.engine.args_tools import EngineArgs
  30. from aphrodite.engine.metrics_types import StatLoggerBase, Stats
  31. from aphrodite.engine.output_processor.interfaces import (
  32. SequenceGroupOutputProcessor)
  33. from aphrodite.engine.output_processor.stop_checker import StopChecker
  34. from aphrodite.engine.output_processor.util import (
  35. create_output_by_sequence_group)
  36. from aphrodite.executor.executor_base import ExecutorBase
  37. from aphrodite.executor.ray_utils import initialize_ray_cluster
  38. from aphrodite.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
  39. InputRegistry, LLMInputs, PromptType)
  40. from aphrodite.inputs.preprocess import InputPreprocessor
  41. from aphrodite.lora.request import LoRARequest
  42. from aphrodite.modeling.layers.sampler import SamplerOutput
  43. from aphrodite.processing.scheduler import (ScheduledSequenceGroup, Scheduler,
  44. SchedulerOutputs)
  45. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  46. from aphrodite.transformers_utils.config import try_get_generation_config
  47. from aphrodite.transformers_utils.detokenizer import Detokenizer
  48. from aphrodite.transformers_utils.tokenizer import AnyTokenizer
  49. from aphrodite.transformers_utils.tokenizer_group import (
  50. BaseTokenizerGroup, init_tokenizer_from_configs)
  51. from aphrodite.version import __version__ as APHRODITE_VERSION
  52. _LOCAL_LOGGING_INTERVAL_SEC = 5
  53. APHRODITE_USE_RAY_SPMD_WORKER = envs.APHRODITE_USE_RAY_SPMD_WORKER
  54. def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
  55. config = try_get_generation_config(
  56. model_config.model,
  57. trust_remote_code=model_config.trust_remote_code,
  58. revision=model_config.revision,
  59. )
  60. if config is None:
  61. return {}
  62. return config.to_diff_dict()
  63. _G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
  64. _O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
  65. @dataclass
  66. class SchedulerOutputState:
  67. """Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
  68. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
  69. scheduler_outputs: Optional[SchedulerOutputs] = None
  70. allow_async_output_proc: bool = False
  71. last_output: Optional[SamplerOutput] = None
  72. class OutputData(NamedTuple):
  73. outputs: List[SamplerOutput]
  74. seq_group_metadata_list: List[SequenceGroupMetadata]
  75. scheduler_outputs: SchedulerOutputs
  76. is_async: bool
  77. is_last_step: bool
  78. skip: List[int]
  79. class SchedulerContext:
  80. def __init__(self, multi_step_stream_outputs: bool = False):
  81. self.output_queue: Deque[OutputData] = deque()
  82. self.request_outputs: List[Union[RequestOutput,
  83. EmbeddingRequestOutput]] = []
  84. self.seq_group_metadata_list: Optional[
  85. List[SequenceGroupMetadata]] = None
  86. self.scheduler_outputs: Optional[SchedulerOutputs] = None
  87. self.multi_step_stream_outputs: bool = multi_step_stream_outputs
  88. def append_output(self, outputs: List[SamplerOutput],
  89. seq_group_metadata_list: List[SequenceGroupMetadata],
  90. scheduler_outputs: SchedulerOutputs, is_async: bool,
  91. is_last_step: bool):
  92. self.output_queue.append(
  93. OutputData(outputs=outputs,
  94. seq_group_metadata_list=seq_group_metadata_list,
  95. scheduler_outputs=scheduler_outputs,
  96. is_async=is_async,
  97. is_last_step=is_last_step,
  98. skip=[]))
  99. class AphroditeEngine:
  100. """An LLM engine that receives requests and generates texts.
  101. This is the main class for the Aphrodite engine. It receives requests
  102. from clients and generates texts from the LLM. It includes a tokenizer, a
  103. language model (possibly distributed across multiple GPUs), and GPU memory
  104. space allocated for intermediate states (aka KV cache). This class utilizes
  105. iteration-level scheduling and efficient memory management to maximize the
  106. serving throughput.
  107. The `LLM` class wraps this class for offline batched inference and the
  108. `AsyncAphrodite` class wraps this class for online serving.
  109. NOTE: The config arguments are derived from the `EngineArgs` class. For the
  110. comprehensive list of arguments, see `EngineArgs`.
  111. Args:
  112. model_config: The configuration related to the LLM model.
  113. cache_config: The configuration related to the KV cache memory
  114. management.
  115. parallel_config: The configuration related to distributed execution.
  116. scheduler_config: The configuration related to the request scheduler.
  117. device_config: The configuration related to the device.
  118. lora_config (Optional): The configuration related to serving multi-LoRA.
  119. speculative_config (Optional): The configuration related to speculative
  120. decoding.
  121. executor_class: The model executor class for managing distributed
  122. execution.
  123. prompt_adapter_config (Optional): The configuration related to serving
  124. prompt adapters.
  125. log_stats: Whether to log statistics.
  126. """
  127. DO_VALIDATE_OUTPUT: ClassVar[bool] = False
  128. """A flag to toggle whether to validate the type of request output."""
  129. @classmethod
  130. @contextmanager
  131. def enable_output_validation(cls):
  132. cls.DO_VALIDATE_OUTPUT = True
  133. yield
  134. cls.DO_VALIDATE_OUTPUT = False
  135. @classmethod
  136. def validate_output(
  137. cls,
  138. output: object,
  139. output_type: Type[_O],
  140. ) -> _O:
  141. do_validate = cls.DO_VALIDATE_OUTPUT
  142. if ((TYPE_CHECKING or do_validate)
  143. and not isinstance(output, output_type)):
  144. raise TypeError(f"Expected output of type {output_type}, "
  145. f"but found type {type(output)}")
  146. return output
  147. @classmethod
  148. def validate_outputs(
  149. cls,
  150. outputs: GenericSequence[object],
  151. output_type: Type[_O],
  152. ) -> List[_O]:
  153. do_validate = cls.DO_VALIDATE_OUTPUT
  154. outputs_: List[_O]
  155. if TYPE_CHECKING or do_validate:
  156. outputs_ = []
  157. for output in outputs:
  158. if not isinstance(output, output_type):
  159. raise TypeError(f"Expected output of type {output_type}, "
  160. f"but found type {type(output)}")
  161. outputs_.append(output)
  162. else:
  163. outputs_ = outputs
  164. return outputs_
  165. tokenizer: Optional[BaseTokenizerGroup]
  166. def __init__(
  167. self,
  168. model_config: ModelConfig,
  169. cache_config: CacheConfig,
  170. parallel_config: ParallelConfig,
  171. scheduler_config: SchedulerConfig,
  172. device_config: DeviceConfig,
  173. load_config: LoadConfig,
  174. lora_config: Optional[LoRAConfig],
  175. speculative_config: Optional[SpeculativeConfig],
  176. decoding_config: Optional[DecodingConfig],
  177. prompt_adapter_config: Optional[PromptAdapterConfig],
  178. executor_class: Type[ExecutorBase],
  179. log_stats: bool,
  180. stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
  181. input_registry: InputRegistry = INPUT_REGISTRY,
  182. use_cached_outputs: bool = False,
  183. ) -> None:
  184. try:
  185. import aphrodite.commit_id
  186. commit_id = True
  187. except ImportError:
  188. commit_id = False
  189. config_dict = {
  190. "Model": model_config.model,
  191. "Speculative Config": speculative_config,
  192. "DataType": model_config.dtype,
  193. "Model Load Format": load_config.load_format,
  194. "Tensor Parallel Size": parallel_config.tensor_parallel_size,
  195. "Pipeline Parallel Size": parallel_config.pipeline_parallel_size,
  196. "Disable Custom All-Reduce":
  197. parallel_config.disable_custom_all_reduce,
  198. "Quantization Format": model_config.quantization,
  199. "Context Length": model_config.max_model_len,
  200. "Enforce Eager Mode": model_config.enforce_eager,
  201. "Prefix Caching": cache_config.enable_prefix_caching,
  202. "KV Cache DataType": cache_config.cache_dtype,
  203. "Device": device_config.device,
  204. "Rope Scaling": model_config.rope_scaling,
  205. "Guided Decoding Backend": decoding_config,
  206. "Scheduler Steps": scheduler_config.num_scheduler_steps,
  207. "Async Output Processing": model_config.use_async_output_proc,
  208. }
  209. logger.info("-" * 85)
  210. if not commit_id:
  211. logger.info(
  212. f"Initializing Aphrodite Engine (v{APHRODITE_VERSION}) "
  213. "with the following config:")
  214. else:
  215. logger.info(f"Initializing Aphrodite Engine (v{APHRODITE_VERSION} "
  216. f"commit {aphrodite.__short_commit__}) with the "
  217. "following config:")
  218. for key, value in config_dict.items():
  219. if value is not None and not ((key == "Model Load Format" or key ==\
  220. "KV Cache DataType") and value == \
  221. "auto"):
  222. logger.info(f"{key} = {value!r}")
  223. logger.info("-" * 85)
  224. # TODO: Print more configs in debug mode.
  225. from aphrodite.plugins import load_general_plugins
  226. load_general_plugins()
  227. self.model_config = model_config
  228. self.cache_config = cache_config
  229. self.lora_config = lora_config
  230. self.parallel_config = parallel_config
  231. self.scheduler_config = scheduler_config
  232. self.device_config = device_config
  233. self.speculative_config = speculative_config
  234. self.load_config = load_config
  235. self.decoding_config = decoding_config or DecodingConfig()
  236. self.prompt_adapter_config = prompt_adapter_config
  237. self.log_stats = log_stats
  238. self.use_cached_outputs = use_cached_outputs
  239. if not self.model_config.skip_tokenizer_init:
  240. self.tokenizer = self._init_tokenizer()
  241. self.detokenizer = Detokenizer(self.tokenizer)
  242. tokenizer_group = self.get_tokenizer_group()
  243. else:
  244. self.tokenizer = None
  245. self.detokenizer = None
  246. tokenizer_group = None
  247. # Ensure that the function doesn't contain a reference to self,
  248. # to avoid engine GC issues
  249. def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
  250. assert tokenizer_group, ("tokenizer_group cannot be None, "
  251. "make sure skip_tokenizer_init is False")
  252. return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
  253. self.seq_counter = Counter()
  254. self.generation_config_fields = _load_generation_config_dict(
  255. model_config)
  256. self.input_preprocessor = InputPreprocessor(model_config,
  257. self.tokenizer)
  258. self.input_registry = input_registry
  259. self.input_processor = input_registry.create_input_processor(
  260. model_config)
  261. self.model_executor = executor_class(
  262. model_config=model_config,
  263. cache_config=cache_config,
  264. parallel_config=parallel_config,
  265. scheduler_config=scheduler_config,
  266. device_config=device_config,
  267. lora_config=lora_config,
  268. speculative_config=speculative_config,
  269. load_config=load_config,
  270. prompt_adapter_config=prompt_adapter_config,
  271. )
  272. if not self.model_config.embedding_mode:
  273. self._initialize_kv_caches()
  274. if self.tokenizer:
  275. # Ping the tokenizer to ensure liveness if it runs in a
  276. # different process.
  277. self.tokenizer.ping()
  278. self.cached_scheduler_outputs = [
  279. SchedulerOutputState()
  280. for _ in range(self.parallel_config.pipeline_parallel_size)
  281. ]
  282. self.scheduler_contexts = [
  283. SchedulerContext(multi_step_stream_outputs=self.scheduler_config.
  284. multi_step_stream_outputs)
  285. for _ in range(self.parallel_config.pipeline_parallel_size)
  286. ]
  287. if model_config.use_async_output_proc:
  288. process_model_outputs = weak_bind(self._process_model_outputs)
  289. self.async_callbacks = [
  290. partial(process_model_outputs,
  291. ctx=self.scheduler_contexts[v_id])
  292. for v_id in range(self.parallel_config.pipeline_parallel_size)
  293. ]
  294. else:
  295. self.async_callbacks = []
  296. # Currently used by AsyncLLMEngine to ensure quick append
  297. # of request outputs to asyncio queues
  298. self.process_request_outputs_callback: Optional[Callable] = None
  299. # Create the scheduler.
  300. # NOTE: the cache_config here have been updated with the numbers of
  301. # GPU and CPU blocks, which are profiled in the distributed executor.
  302. self.scheduler = [
  303. Scheduler(
  304. scheduler_config, cache_config, lora_config,
  305. parallel_config.pipeline_parallel_size,
  306. self.async_callbacks[v_id]
  307. if model_config.use_async_output_proc else None)
  308. for v_id in range(parallel_config.pipeline_parallel_size)
  309. ]
  310. # Metric Logging.
  311. if self.log_stats:
  312. if stat_loggers is not None:
  313. self.stat_loggers = stat_loggers
  314. else:
  315. # Lazy import for prometheus multiprocessing.
  316. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
  317. # before prometheus_client is imported.
  318. # See https://prometheus.github.io/client_python/multiprocess/
  319. from aphrodite.engine.metrics import (LoggingStatLogger,
  320. PrometheusStatLogger)
  321. self.stat_loggers = {
  322. "logging":
  323. LoggingStatLogger(
  324. local_interval=_LOCAL_LOGGING_INTERVAL_SEC),
  325. "prometheus":
  326. PrometheusStatLogger(
  327. local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
  328. labels=dict(model_name=model_config.served_model_name),
  329. max_model_len=self.model_config.max_model_len),
  330. }
  331. self.stat_loggers["prometheus"].info("cache_config",
  332. self.cache_config)
  333. # Create sequence output processor, e.g. for beam search or
  334. # speculative decoding.
  335. self.output_processor = (
  336. SequenceGroupOutputProcessor.create_output_processor(
  337. self.scheduler_config,
  338. self.detokenizer,
  339. self.scheduler,
  340. self.seq_counter,
  341. get_tokenizer_for_seq,
  342. stop_checker=StopChecker(
  343. self.scheduler_config.max_model_len,
  344. get_tokenizer_for_seq,
  345. ),
  346. ))
  347. def _initialize_kv_caches(self) -> None:
  348. """Initialize the KV cache in the worker(s).
  349. The workers will determine the number of blocks in both the GPU cache
  350. and the swap CPU cache.
  351. """
  352. num_gpu_blocks, num_cpu_blocks = (
  353. self.model_executor.determine_num_available_blocks())
  354. if self.cache_config.num_gpu_blocks_override is not None:
  355. num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
  356. logger.info(f"Overriding {num_gpu_blocks=} with "
  357. f"{num_gpu_blocks_override=}")
  358. num_gpu_blocks = num_gpu_blocks_override
  359. self.cache_config.num_gpu_blocks = num_gpu_blocks
  360. self.cache_config.num_cpu_blocks = num_cpu_blocks
  361. self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
  362. @classmethod
  363. def _get_executor_cls(cls,
  364. engine_config: EngineConfig) -> Type[ExecutorBase]:
  365. distributed_executor_backend = (
  366. engine_config.parallel_config.distributed_executor_backend)
  367. # Initialize the cluster and specify the executor class.
  368. if isinstance(distributed_executor_backend, type):
  369. if not issubclass(distributed_executor_backend, ExecutorBase):
  370. raise TypeError(
  371. "distributed_executor_backend must be a subclass of "
  372. f"ExecutorBase. Got {distributed_executor_backend}.")
  373. if distributed_executor_backend.uses_ray: # type: ignore
  374. initialize_ray_cluster(engine_config.parallel_config)
  375. executor_class = distributed_executor_backend
  376. elif engine_config.device_config.device_type == "neuron":
  377. from aphrodite.executor.neuron_executor import NeuronExecutor
  378. executor_class = NeuronExecutor
  379. elif engine_config.device_config.device_type == "tpu":
  380. if distributed_executor_backend == "ray":
  381. initialize_ray_cluster(engine_config.parallel_config)
  382. from aphrodite.executor.ray_tpu_executor import RayTPUExecutor
  383. executor_class = RayTPUExecutor
  384. else:
  385. assert distributed_executor_backend is None
  386. from aphrodite.executor.tpu_executor import TPUExecutor
  387. executor_class = TPUExecutor
  388. elif engine_config.device_config.device_type == "cpu":
  389. from aphrodite.executor.cpu_executor import CPUExecutor
  390. executor_class = CPUExecutor
  391. elif engine_config.device_config.device_type == "openvino":
  392. from aphrodite.executor.openvino_executor import OpenVINOExecutor
  393. executor_class = OpenVINOExecutor
  394. elif engine_config.device_config.device_type == "xpu":
  395. if distributed_executor_backend == "ray":
  396. initialize_ray_cluster(engine_config.parallel_config)
  397. from aphrodite.executor.ray_xpu_executor import RayXPUExecutor
  398. executor_class = RayXPUExecutor
  399. elif distributed_executor_backend == "mp":
  400. logger.error(
  401. "Both start methods (spawn and fork) have issues "
  402. "on XPU if you use mp backend, Please try ray instead.")
  403. else:
  404. from aphrodite.executor.xpu_executor import XPUExecutor
  405. executor_class = XPUExecutor
  406. elif distributed_executor_backend == "ray":
  407. initialize_ray_cluster(engine_config.parallel_config)
  408. from aphrodite.executor.ray_gpu_executor import RayGPUExecutor
  409. executor_class = RayGPUExecutor
  410. elif distributed_executor_backend == "mp":
  411. from aphrodite.executor.multiproc_gpu_executor import (
  412. MultiprocessingGPUExecutor)
  413. assert not envs.APHRODITE_USE_RAY_SPMD_WORKER, (
  414. "multiprocessing distributed executor backend does not "
  415. "support APHRODITE_USE_RAY_SPMD_WORKER=1")
  416. executor_class = MultiprocessingGPUExecutor
  417. else:
  418. from aphrodite.executor.gpu_executor import GPUExecutor
  419. executor_class = GPUExecutor
  420. return executor_class
  421. @classmethod
  422. def from_engine_args(
  423. cls,
  424. engine_args: EngineArgs,
  425. stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
  426. ) -> "AphroditeEngine":
  427. """Creates an Aphrodite engine from the engine arguments."""
  428. # Create the engine configs.
  429. engine_config = engine_args.create_engine_config()
  430. executor_class = cls._get_executor_cls(engine_config)
  431. # Create the LLM engine.
  432. engine = cls(
  433. **engine_config.to_dict(),
  434. executor_class=executor_class,
  435. log_stats=not engine_args.disable_log_stats,
  436. stat_loggers=stat_loggers,
  437. )
  438. return engine
  439. def __reduce__(self):
  440. # This is to ensure that the AphroditeEngine is not referenced in
  441. # the closure used to initialize Ray worker actors
  442. raise RuntimeError("AphroditeEngine should not be pickled!")
  443. def __del__(self):
  444. # Shutdown model executor when engine is garbage collected
  445. # Use getattr since __init__ can fail before the field is set
  446. if model_executor := getattr(self, "model_executor", None):
  447. model_executor.shutdown()
  448. def get_tokenizer_group(
  449. self,
  450. group_type: Type[_G] = BaseTokenizerGroup,
  451. ) -> _G:
  452. tokenizer_group = self.tokenizer
  453. if tokenizer_group is None:
  454. raise ValueError("Unable to get tokenizer because "
  455. "skip_tokenizer_init is True")
  456. if not isinstance(tokenizer_group, group_type):
  457. raise TypeError("Invalid type of tokenizer group. "
  458. f"Expected type: {group_type}, but "
  459. f"found type: {type(tokenizer_group)}")
  460. return tokenizer_group
  461. def get_tokenizer(
  462. self,
  463. lora_request: Optional[LoRARequest] = None,
  464. ) -> AnyTokenizer:
  465. return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
  466. def _init_tokenizer(self) -> BaseTokenizerGroup:
  467. return init_tokenizer_from_configs(
  468. model_config=self.model_config,
  469. scheduler_config=self.scheduler_config,
  470. parallel_config=self.parallel_config,
  471. enable_lora=bool(self.lora_config))
  472. def _verify_args(self) -> None:
  473. self.model_config.verify_with_parallel_config(self.parallel_config)
  474. self.cache_config.verify_with_parallel_config(self.parallel_config)
  475. if self.lora_config:
  476. self.lora_config.verify_with_model_config(self.model_config)
  477. self.lora_config.verify_with_scheduler_config(
  478. self.scheduler_config)
  479. if self.prompt_adapter_config:
  480. self.prompt_adapter_config.verify_with_model_config(
  481. self.model_config)
  482. def _add_processed_request(
  483. self,
  484. request_id: str,
  485. processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
  486. params: Union[SamplingParams, PoolingParams],
  487. arrival_time: float,
  488. lora_request: Optional[LoRARequest],
  489. prompt_adapter_request: Optional[PromptAdapterRequest],
  490. ) -> None:
  491. self._validate_model_inputs(processed_inputs)
  492. # Create the sequences.
  493. block_size = self.cache_config.block_size
  494. seq_id = next(self.seq_counter)
  495. eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
  496. seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
  497. lora_request, prompt_adapter_request)
  498. encoder_seq = None
  499. if 'encoder_prompt_token_ids' in processed_inputs:
  500. encoder_seq = Sequence(seq_id,
  501. processed_inputs,
  502. block_size,
  503. eos_token_id,
  504. lora_request,
  505. prompt_adapter_request,
  506. from_decoder_prompt=False)
  507. # Create a SequenceGroup based on SamplingParams or PoolingParams
  508. if isinstance(params, SamplingParams):
  509. seq_group = self._create_sequence_group_with_sampling(
  510. request_id,
  511. seq,
  512. params,
  513. arrival_time=arrival_time,
  514. lora_request=lora_request,
  515. prompt_adapter_request=prompt_adapter_request,
  516. encoder_seq=encoder_seq)
  517. elif isinstance(params, PoolingParams):
  518. seq_group = self._create_sequence_group_with_pooling(
  519. request_id,
  520. seq,
  521. params,
  522. arrival_time=arrival_time,
  523. lora_request=lora_request,
  524. prompt_adapter_request=prompt_adapter_request,
  525. encoder_seq=encoder_seq)
  526. else:
  527. raise ValueError(
  528. "Either SamplingParams or PoolingParams must be provided.")
  529. # Add the sequence group to the scheduler with least unfinished seqs.
  530. costs = [
  531. scheduler.get_num_unfinished_seq_groups()
  532. for scheduler in self.scheduler
  533. ]
  534. min_cost_scheduler = self.scheduler[costs.index(min(costs))]
  535. min_cost_scheduler.add_seq_group(seq_group)
  536. def stop_remote_worker_execution_loop(self) -> None:
  537. self.model_executor.stop_remote_worker_execution_loop()
  538. def add_request(
  539. self,
  540. request_id: str,
  541. prompt: PromptType,
  542. params: Union[SamplingParams, PoolingParams],
  543. arrival_time: Optional[float] = None,
  544. lora_request: Optional[LoRARequest] = None,
  545. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  546. ) -> None:
  547. """Add a request to the engine's request pool.
  548. The request is added to the request pool and will be processed by the
  549. scheduler as `engine.step()` is called. The exact scheduling policy is
  550. determined by the scheduler.
  551. Args:
  552. request_id: The unique ID of the request.
  553. prompt: The prompt to the LLM. See
  554. :class:`~aphrodite.common.inputs.PromptType`
  555. for more details about the format of each input.
  556. params: Parameters for sampling or pooling. SamplingParams
  557. for text generation. PoolingParams for pooling.
  558. prompt_token_ids: The token IDs of the prompt. If None, we
  559. use the tokenizer to convert the prompts to token IDs.
  560. arrival_time: The arrival time of the request. If None, we use
  561. the current monotonic time.
  562. Details:
  563. - Set arrival_time to the current time if it is None.
  564. - Set prompt_token_ids to the encoded prompt if it is None.
  565. - Create `best_of` number of :class:`~aphrodite.common.sequence`
  566. objects.
  567. - Create a :class:`~aphrodite.common.sequenceGroup` object
  568. from the list of :class:`~aphrodite.common.sequence`.
  569. - Add the :class:`~aphrodite.common.sequenceGroup` object to the
  570. scheduler.
  571. Example:
  572. >>> # initialize engine
  573. >>> engine = AphroditeEngine.from_engine_args(engine_args)
  574. >>> # set request arguments
  575. >>> example_prompt = "Who is the president of the United States?"
  576. >>> sampling_params = SamplingParams(temperature=0.0)
  577. >>> request_id = 0
  578. >>>
  579. >>> # add the request to the engine
  580. >>> engine.add_request(
  581. >>> str(request_id),
  582. >>> example_prompt,
  583. >>> SamplingParams(temperature=0.0))
  584. >>> # continue the request processing
  585. >>> ...
  586. """
  587. if lora_request is not None and not self.lora_config:
  588. raise ValueError(f"Got lora_request {lora_request} but LoRA is "
  589. "not enabled!")
  590. if arrival_time is None:
  591. arrival_time = time.time()
  592. preprocessed_inputs = self.input_preprocessor.preprocess(
  593. prompt,
  594. request_id=request_id,
  595. lora_request=lora_request,
  596. prompt_adapter_request=prompt_adapter_request,
  597. )
  598. processed_inputs = self.input_processor(preprocessed_inputs)
  599. self._add_processed_request(
  600. request_id=request_id,
  601. processed_inputs=processed_inputs,
  602. params=params,
  603. arrival_time=arrival_time,
  604. lora_request=lora_request,
  605. prompt_adapter_request=prompt_adapter_request,
  606. )
  607. def _create_sequence_group_with_sampling(
  608. self,
  609. request_id: str,
  610. seq: Sequence,
  611. sampling_params: SamplingParams,
  612. arrival_time: float,
  613. lora_request: Optional[LoRARequest],
  614. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  615. encoder_seq: Optional[Sequence] = None,
  616. ) -> SequenceGroup:
  617. """Creates a SequenceGroup with SamplingParams."""
  618. max_logprobs = self.get_model_config().max_logprobs
  619. if (sampling_params.logprobs
  620. and sampling_params.logprobs > max_logprobs) or (
  621. sampling_params.prompt_logprobs
  622. and sampling_params.prompt_logprobs > max_logprobs):
  623. raise ValueError(f"Cannot request more than "
  624. f"{max_logprobs} logprobs.")
  625. # Defensive copy of SamplingParams, which are used by the sampler,
  626. # this doesn't deep-copy LogitsProcessor objects
  627. sampling_params = sampling_params.clone()
  628. sampling_params.update_from_generation_config(
  629. self.generation_config_fields, seq.eos_token_id)
  630. sampling_params._verify_with_scheduler_config(self.scheduler_config)
  631. # Create the sequence group.
  632. seq_group = SequenceGroup(
  633. request_id=request_id,
  634. seqs=[seq],
  635. arrival_time=arrival_time,
  636. sampling_params=sampling_params,
  637. lora_request=lora_request,
  638. prompt_adapter_request=prompt_adapter_request,
  639. encoder_seq=encoder_seq)
  640. return seq_group
  641. def _create_sequence_group_with_pooling(
  642. self,
  643. request_id: str,
  644. seq: Sequence,
  645. pooling_params: PoolingParams,
  646. arrival_time: float,
  647. lora_request: Optional[LoRARequest],
  648. prompt_adapter_request: Optional[PromptAdapterRequest],
  649. encoder_seq: Optional[Sequence] = None,
  650. ) -> SequenceGroup:
  651. """Creates a SequenceGroup with PoolingParams."""
  652. # Defensive copy of PoolingParams, which are used by the pooler
  653. pooling_params = pooling_params.clone()
  654. # Create the sequence group.
  655. seq_group = SequenceGroup(
  656. request_id=request_id,
  657. seqs=[seq],
  658. arrival_time=arrival_time,
  659. lora_request=lora_request,
  660. pooling_params=pooling_params,
  661. prompt_adapter_request=prompt_adapter_request,
  662. encoder_seq=encoder_seq)
  663. return seq_group
  664. def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
  665. """Aborts a request(s) with the given ID.
  666. Args:
  667. request_id: The ID(s) of the request to abort.
  668. Details:
  669. - Refer to the
  670. :meth:`~aphrodite.processing.scheduler.Scheduler.abort_seq_group`
  671. from class :class:`~aphrodite.processing.scheduler.Scheduler`.
  672. Example:
  673. >>> # initialize engine and add a request with request_id
  674. >>> request_id = str(0)
  675. >>> # abort the request
  676. >>> engine.abort_request(request_id)
  677. """
  678. for scheduler in self.scheduler:
  679. scheduler.abort_seq_group(request_id)
  680. def get_model_config(self) -> ModelConfig:
  681. """Gets the model configuration."""
  682. return self.model_config
  683. def get_parallel_config(self) -> ParallelConfig:
  684. """Gets the parallel configuration."""
  685. return self.parallel_config
  686. def get_decoding_config(self) -> DecodingConfig:
  687. """Gets the decoding configuration."""
  688. return self.decoding_config
  689. def get_scheduler_config(self) -> SchedulerConfig:
  690. """Gets the scheduler configuration."""
  691. return self.scheduler_config
  692. def get_lora_config(self) -> LoRAConfig:
  693. """Gets the LoRA configuration."""
  694. return self.lora_config
  695. def get_num_unfinished_requests(self) -> int:
  696. """Gets the number of unfinished requests."""
  697. return sum(scheduler.get_num_unfinished_seq_groups()
  698. for scheduler in self.scheduler)
  699. def has_unfinished_requests(self) -> bool:
  700. """Returns True if there are unfinished requests."""
  701. return any(scheduler.has_unfinished_seqs()
  702. for scheduler in self.scheduler)
  703. def has_unfinished_requests_for_virtual_engine(
  704. self, virtual_engine: int) -> bool:
  705. """
  706. Returns True if there are unfinished requests for the virtual engine.
  707. """
  708. return self.scheduler[virtual_engine].has_unfinished_seqs()
  709. @staticmethod
  710. def _process_sequence_group_outputs(
  711. seq_group: SequenceGroup,
  712. outputs: List[EmbeddingSequenceGroupOutput],
  713. ) -> None:
  714. seq_group.embeddings = outputs[0].embeddings
  715. for seq in seq_group.get_seqs():
  716. seq.status = SequenceStatus.FINISHED_STOPPED
  717. return
  718. def _process_model_outputs(self,
  719. ctx: SchedulerContext,
  720. request_id: Optional[str] = None) -> None:
  721. """Apply the model output to the sequences in the scheduled seq groups
  722. and return responses.
  723. ctx: The virtual engine context to work on
  724. request_id: If provided, then only this request is going to be processed
  725. """
  726. now = time.time()
  727. if len(ctx.output_queue) == 0:
  728. return None
  729. # Get pending async postprocessor
  730. if request_id:
  731. # When we process only one request, no pop is required
  732. # (since later we will process all of the rest)
  733. (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
  734. is_last_step, skip) = ctx.output_queue[0]
  735. else:
  736. (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
  737. is_last_step, skip) = ctx.output_queue.popleft()
  738. # Sanity check
  739. assert len(seq_group_metadata_list) == len(
  740. scheduler_outputs.scheduled_seq_groups)
  741. # Organize outputs by [step][sequence group] instead of
  742. # [sequence group][step].
  743. if len(outputs) > 1:
  744. outputs_by_sequence_group = create_output_by_sequence_group(
  745. outputs, num_seq_groups=len(seq_group_metadata_list))
  746. else:
  747. outputs_by_sequence_group = outputs
  748. # Determine the requests we need to operate on
  749. if request_id:
  750. indices = []
  751. for i, seq_group_meta in enumerate(seq_group_metadata_list):
  752. if seq_group_meta.request_id == request_id:
  753. assert i not in skip # Cannot be called twice
  754. indices.append(i)
  755. break
  756. # If the request_id was not found, then it means that
  757. # this is a new request that has no pending async
  758. # postprocessor
  759. if not indices:
  760. return
  761. else:
  762. indices = range(len(seq_group_metadata_list)) # type: ignore
  763. finished_before: List[int] = []
  764. finished_now: List[int] = []
  765. for i in indices:
  766. if i in skip:
  767. continue
  768. seq_group_meta = seq_group_metadata_list[i]
  769. scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
  770. seq_group = scheduled_seq_group.seq_group
  771. if seq_group.is_finished():
  772. finished_before.append(i)
  773. continue
  774. if len(outputs) > 1:
  775. output = outputs_by_sequence_group[i]
  776. else:
  777. output = [outputs_by_sequence_group[0][i]]
  778. if not is_async:
  779. seq_group.update_num_computed_tokens(
  780. scheduled_seq_group.token_chunk_size)
  781. if self.model_config.embedding_mode:
  782. self._process_sequence_group_outputs(seq_group, output)
  783. else:
  784. self.output_processor.process_prompt_logprob(seq_group, output)
  785. if seq_group_meta.do_sample:
  786. self.output_processor.process_outputs(
  787. seq_group, output, is_async)
  788. if seq_group.is_finished():
  789. finished_now.append(i)
  790. # Generate outputs for the requests that finished this iteration
  791. for i in finished_now:
  792. scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
  793. seq_group = scheduled_seq_group.seq_group
  794. seq_group.maybe_set_first_token_time(now)
  795. request_output = RequestOutputFactory.create(
  796. seq_group, use_cache=self.use_cached_outputs)
  797. if request_output:
  798. ctx.request_outputs.append(request_output)
  799. # When we process a single request, we skip it for the next time,
  800. # and invoke the request output callback (if there was final output)
  801. if request_id:
  802. assert len(indices) == 1
  803. skip.append(indices[0])
  804. if (finished_now
  805. and self.process_request_outputs_callback is not None):
  806. self.process_request_outputs_callback(ctx.request_outputs)
  807. ctx.request_outputs.clear()
  808. return
  809. # Free currently finished requests
  810. if finished_now:
  811. for scheduler in self.scheduler:
  812. scheduler.free_finished_seq_groups()
  813. # For multi-step without streaming, don't create outputs each iteration
  814. if not is_last_step and not ctx.multi_step_stream_outputs:
  815. # Immediately process request outputs here (if callback is given)
  816. if (finished_now
  817. and self.process_request_outputs_callback is not None):
  818. self.process_request_outputs_callback(ctx.request_outputs)
  819. ctx.request_outputs.clear()
  820. return
  821. # Create the outputs
  822. for i in indices:
  823. if i in skip or i in finished_before or i in finished_now:
  824. continue # Avoids double processing
  825. scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
  826. seq_group = scheduled_seq_group.seq_group
  827. seq_group.maybe_set_first_token_time(now)
  828. request_output = RequestOutputFactory.create(
  829. seq_group, use_cache=self.use_cached_outputs)
  830. if request_output:
  831. ctx.request_outputs.append(request_output)
  832. # For multi-step with streaming, create outputs each iteration
  833. if not is_last_step and ctx.multi_step_stream_outputs:
  834. # Immediately process request outputs here (if callback is given)
  835. if self.process_request_outputs_callback is not None:
  836. self.process_request_outputs_callback(ctx.request_outputs)
  837. ctx.request_outputs.clear()
  838. return
  839. for seq_group in scheduler_outputs.ignored_seq_groups:
  840. params = seq_group.sampling_params
  841. if params is not None and params.output_kind == (
  842. RequestOutputKind.DELTA) and not seq_group.is_finished():
  843. continue
  844. request_output = RequestOutputFactory.create(
  845. seq_group, use_cache=self.use_cached_outputs)
  846. if request_output:
  847. ctx.request_outputs.append(request_output)
  848. # Immediately process request outputs here (if callback is given)
  849. if (ctx.request_outputs
  850. and self.process_request_outputs_callback is not None):
  851. self.process_request_outputs_callback(ctx.request_outputs)
  852. ctx.request_outputs.clear()
  853. # For async case, we need to record the stats here.
  854. # For non-async case, the stats are done in the
  855. # LLMEngine/AsyncLLMEngine directly
  856. if is_async:
  857. # Log stats.
  858. self.do_log_stats(scheduler_outputs, outputs, finished_before,
  859. skip)
  860. return None
  861. def _advance_to_next_step(
  862. self, output: List[SamplerOutput],
  863. seq_group_metadata_list: List[SequenceGroupMetadata],
  864. scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
  865. """Given model output from a single run, append the tokens to the
  866. sequences. This is normally done inside output processor, but it is
  867. required if the worker is to perform async forward pass to next step.
  868. """
  869. for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \
  870. zip(seq_group_metadata_list, output, scheduled_seq_groups):
  871. seq_group = scheduled_seq_group.seq_group
  872. if seq_group.is_finished():
  873. continue
  874. seq_group.update_num_computed_tokens(
  875. seq_group_metadata.token_chunk_size)
  876. if seq_group_metadata.do_sample:
  877. assert len(sequence_group_outputs.samples) == 1, (
  878. "Async output processor expects a single sample"
  879. " (i.e sampling_params.n == 1 and no "
  880. "sampling_params.best_of > 1)")
  881. sample = sequence_group_outputs.samples[0]
  882. assert len(seq_group.seqs) == 1
  883. seq = seq_group.seqs[0]
  884. seq.append_token_id(sample.output_token, sample.logprobs)
  885. def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
  886. """Performs one decoding iteration and returns newly generated results.
  887. .. figure:: https://i.imgur.com/sv2HssD.png
  888. :alt: Overview of the step function
  889. :align: center
  890. Overview of the step function.
  891. Details:
  892. - Step 1: Schedules the sequences to be executed in the next
  893. iteration and the token blocks to be swapped in/out/copy.
  894. - Depending on the scheduling policy,
  895. sequences may be `preempted/reordered`.
  896. - A Sequence Group (SG) refer to a group of sequences
  897. that are generated from the same prompt.
  898. - Step 2: Calls the distributed executor to execute the model.
  899. - Step 3: Processes the model output. This mainly includes:
  900. - Decodes the relevant outputs.
  901. - Updates the scheduled sequence groups with model outputs
  902. based on its `sampling parameters` (`use_beam_search` or not).
  903. - Frees the finished sequence groups.
  904. - Finally, it creates and returns the newly generated results.
  905. Example:
  906. >>> # Please see the example/ folder for more detailed examples.
  907. >>>
  908. >>> # initialize engine and request arguments
  909. >>> engine = AphroditeEngine.from_engine_args(engine_args)
  910. >>> example_inputs = [(0, "What is LLM?",
  911. >>> SamplingParams(temperature=0.0))]
  912. >>>
  913. >>> # Start the engine with an event loop
  914. >>> while True:
  915. >>> if example_inputs:
  916. >>> req_id, prompt, sampling_params = example_inputs.pop(0)
  917. >>> engine.add_request(str(req_id),prompt,sampling_params)
  918. >>>
  919. >>> # continue the request processing
  920. >>> request_outputs = engine.step()
  921. >>> for request_output in request_outputs:
  922. >>> if request_output.finished:
  923. >>> # return or show the request output
  924. >>>
  925. >>> if not (engine.has_unfinished_requests() or example_inputs):
  926. >>> break
  927. """
  928. if self.parallel_config.pipeline_parallel_size > 1:
  929. raise NotImplementedError(
  930. "Pipeline parallelism is only supported through AsyncAphrodite "
  931. "as performance will be severely degraded otherwise.")
  932. # For llm_engine, there is no pipeline parallel support, so the engine
  933. # used is always 0
  934. virtual_engine = 0
  935. # These are cached outputs from previous iterations. None if on first
  936. # iteration
  937. cached_outputs = self.cached_scheduler_outputs[virtual_engine]
  938. seq_group_metadata_list = cached_outputs.seq_group_metadata_list
  939. scheduler_outputs = cached_outputs.scheduler_outputs
  940. allow_async_output_proc = cached_outputs.allow_async_output_proc
  941. ctx = self.scheduler_contexts[virtual_engine]
  942. # Clear outputs for each new scheduler iteration
  943. ctx.request_outputs.clear()
  944. # Skip the scheduler if there are any remaining steps in the seq groups.
  945. # This ensures that the scheduler is only called again when the current
  946. # batch has completed.
  947. if not self._has_remaining_steps(seq_group_metadata_list):
  948. # Schedule iteration
  949. (seq_group_metadata_list, scheduler_outputs,
  950. allow_async_output_proc
  951. ) = self.scheduler[virtual_engine].schedule()
  952. ctx.seq_group_metadata_list = seq_group_metadata_list
  953. ctx.scheduler_outputs = scheduler_outputs
  954. # Maybe switch from async mode to sync mode
  955. if not allow_async_output_proc and len(ctx.output_queue) > 0:
  956. self._process_model_outputs(ctx=ctx)
  957. if (self.scheduler_config.is_multi_step
  958. and scheduler_outputs.num_lookahead_slots > 0):
  959. # cache the scheduler outputs for the next iteration if we have
  960. # lookahead slots
  961. self._cache_scheduler_outputs_for_multi_step(
  962. virtual_engine, seq_group_metadata_list, scheduler_outputs,
  963. allow_async_output_proc)
  964. assert seq_group_metadata_list is not None
  965. assert scheduler_outputs is not None
  966. if not scheduler_outputs.is_empty():
  967. finished_requests_ids = self.scheduler[
  968. virtual_engine].get_and_reset_finished_requests_ids()
  969. # Check if we have a cached last_output from the previous iteration.
  970. # For supporting PP this is probably the best way to pass the
  971. # sampled_token_ids, as a separate broadcast over all the PP stages
  972. # will cause one virtual engine's microbatch to block the pipeline.
  973. last_sampled_token_ids = \
  974. self._get_last_sampled_token_ids(virtual_engine)
  975. execute_model_req = ExecuteModelRequest(
  976. seq_group_metadata_list=seq_group_metadata_list,
  977. blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
  978. blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
  979. blocks_to_copy=scheduler_outputs.blocks_to_copy,
  980. num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
  981. running_queue_size=scheduler_outputs.running_queue_size,
  982. finished_requests_ids=finished_requests_ids,
  983. # We use ExecuteModelRequest to pass the last sampled_token_ids
  984. # to each of the non-last PP stages for in-place prepare_input.
  985. last_sampled_token_ids=last_sampled_token_ids)
  986. if allow_async_output_proc:
  987. execute_model_req.async_callback = self.async_callbacks[
  988. virtual_engine]
  989. outputs = self.model_executor.execute_model(
  990. execute_model_req=execute_model_req)
  991. # We need to do this here so that last step's sampled_token_ids can
  992. # be passed to the next iteration for PP.
  993. if self.scheduler_config.is_multi_step:
  994. self._update_cached_scheduler_output(virtual_engine, outputs)
  995. else:
  996. # Nothing scheduled => If there is pending async postprocessor,
  997. # then finish it here.
  998. if len(ctx.output_queue) > 0:
  999. self._process_model_outputs(ctx=ctx)
  1000. # No outputs in this case
  1001. outputs = []
  1002. # Finish the current step for all the sequence groups.
  1003. if self.scheduler_config.is_multi_step:
  1004. for seq_group in seq_group_metadata_list:
  1005. seq_group.finish_step()
  1006. if not self._has_remaining_steps(seq_group_metadata_list):
  1007. # clear the cache if we have finished all the steps.
  1008. if self.scheduler_config.is_multi_step:
  1009. self.cached_scheduler_outputs[0] = SchedulerOutputState()
  1010. # Add results to the output_queue
  1011. ctx.append_output(outputs=outputs,
  1012. seq_group_metadata_list=seq_group_metadata_list,
  1013. scheduler_outputs=scheduler_outputs,
  1014. is_async=allow_async_output_proc,
  1015. is_last_step=True)
  1016. if outputs and allow_async_output_proc:
  1017. assert len(outputs) == 1, (
  1018. "Async postprocessor expects only a single output set")
  1019. self._advance_to_next_step(
  1020. outputs[0], seq_group_metadata_list,
  1021. scheduler_outputs.scheduled_seq_groups)
  1022. # Check if need to run the usual non-async path
  1023. if not allow_async_output_proc:
  1024. self._process_model_outputs(ctx=ctx)
  1025. # Log stats.
  1026. self.do_log_stats(scheduler_outputs, outputs)
  1027. else:
  1028. # Multi-step case
  1029. return ctx.request_outputs
  1030. if not self.has_unfinished_requests():
  1031. # Drain async postprocessor (if exists)
  1032. if len(ctx.output_queue) > 0:
  1033. self._process_model_outputs(ctx=ctx)
  1034. assert len(ctx.output_queue) == 0
  1035. # Stop the execute model loop in parallel workers until there are
  1036. # more requests to process. This avoids waiting indefinitely in
  1037. # torch.distributed ops which may otherwise timeout, and unblocks
  1038. # the RPC thread in the workers so that they can process any other
  1039. # queued control plane messages, such as add/remove lora adapters.
  1040. logger.debug("Stopping remote worker execution loop.")
  1041. self.model_executor.stop_remote_worker_execution_loop()
  1042. return ctx.request_outputs
  1043. def _has_remaining_steps(
  1044. self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
  1045. ) -> bool:
  1046. if (not self.scheduler_config.is_multi_step
  1047. or not seq_group_metadata_list):
  1048. return False
  1049. # TODO: this is a sanity check for nowto make sure that all the
  1050. # seqs are on the same steps. Eventually we will want to do some sort of
  1051. # dynamic scheduling when doing multi-step decoding.
  1052. ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
  1053. if any([
  1054. seq_group.state.remaining_steps != ref_remaining_steps
  1055. for seq_group in seq_group_metadata_list[1:]
  1056. ]):
  1057. raise AssertionError(("All running sequence groups should "
  1058. "have the same remaining steps."))
  1059. return ref_remaining_steps > 0
  1060. def _cache_scheduler_outputs_for_multi_step(
  1061. self, virtual_engine: int,
  1062. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
  1063. scheduler_outputs: SchedulerOutputs,
  1064. allow_async_output_proc: bool) -> None:
  1065. co = self.cached_scheduler_outputs[virtual_engine]
  1066. co.seq_group_metadata_list = seq_group_metadata_list
  1067. co.scheduler_outputs = scheduler_outputs
  1068. co.allow_async_output_proc = allow_async_output_proc
  1069. co.last_output = None
  1070. def _update_cached_scheduler_output(
  1071. self, virtual_engine: int,
  1072. output: List[Optional[SamplerOutput]]) -> None:
  1073. if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
  1074. and output[0] is not None):
  1075. last_output = output[-1]
  1076. assert last_output is not None
  1077. assert last_output.sampled_token_ids_cpu is not None
  1078. assert last_output.sampled_token_ids is None
  1079. assert last_output.sampled_token_probs is None
  1080. self.cached_scheduler_outputs[
  1081. virtual_engine].last_output = last_output
  1082. def _get_last_sampled_token_ids(
  1083. self, virtual_engine: int) -> Optional[torch.Tensor]:
  1084. cached_last_output = self.cached_scheduler_outputs[
  1085. virtual_engine].last_output
  1086. if (self.scheduler_config.is_multi_step
  1087. and self.parallel_config.pipeline_parallel_size > 1
  1088. and cached_last_output is not None
  1089. and cached_last_output.sampled_token_ids_cpu is not None):
  1090. return cached_last_output.sampled_token_ids_cpu
  1091. return None
  1092. def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
  1093. if not self.log_stats:
  1094. raise RuntimeError(
  1095. "Stat logging is disabled. Set `disable_log_stats=False` "
  1096. "argument to enable.")
  1097. if logger_name in self.stat_loggers:
  1098. raise KeyError(f"Logger with name {logger_name} already exists.")
  1099. self.stat_loggers[logger_name] = logger
  1100. def remove_logger(self, logger_name: str) -> None:
  1101. if not self.log_stats:
  1102. raise RuntimeError(
  1103. "Stat logging is disabled. Set `disable_log_stats=False` "
  1104. "argument to enable.")
  1105. if logger_name not in self.stat_loggers:
  1106. raise KeyError(f"Logger with name {logger_name} does not exist.")
  1107. del self.stat_loggers[logger_name]
  1108. def do_log_stats(self,
  1109. scheduler_outputs: Optional[SchedulerOutputs] = None,
  1110. model_output: Optional[List[SamplerOutput]] = None,
  1111. finished_before: Optional[List[int]] = None,
  1112. skip: Optional[List[int]] = None) -> None:
  1113. """Forced log when no requests active."""
  1114. if self.log_stats:
  1115. stats = self._get_stats(scheduler_outputs, model_output,
  1116. finished_before, skip)
  1117. for loggers in self.stat_loggers.values():
  1118. loggers.log(stats)
  1119. def _get_stats(self,
  1120. scheduler_outputs: Optional[SchedulerOutputs],
  1121. model_output: Optional[List[SamplerOutput]] = None,
  1122. finished_before: Optional[List[int]] = None,
  1123. skip: Optional[List[int]] = None) -> Stats:
  1124. """Get Stats to be Logged to Prometheus.
  1125. Args:
  1126. scheduler_outputs: Optional, used to populate metrics related to
  1127. the scheduled batch,
  1128. model_output: Optional, used to emit speculative decoding metrics
  1129. which are created by the workers.
  1130. finished_before: Optional, indices of sequences that were finished
  1131. before. These sequences will be ignored.
  1132. skip: Optional, indices of sequences that were preempted. These
  1133. sequences will be ignored.
  1134. """
  1135. now = time.time()
  1136. # System State
  1137. # Scheduler State
  1138. num_running_sys = sum(
  1139. len(scheduler.running) for scheduler in self.scheduler)
  1140. num_swapped_sys = sum(
  1141. len(scheduler.swapped) for scheduler in self.scheduler)
  1142. num_waiting_sys = sum(
  1143. len(scheduler.waiting) for scheduler in self.scheduler)
  1144. # KV Cache Usage in %
  1145. num_total_gpu = self.cache_config.num_gpu_blocks
  1146. gpu_cache_usage_sys = 0.
  1147. if num_total_gpu: # Guard against both None and 0
  1148. num_free_gpu = sum(
  1149. scheduler.block_manager.get_num_free_gpu_blocks()
  1150. for scheduler in self.scheduler)
  1151. gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
  1152. num_total_cpu = self.cache_config.num_cpu_blocks
  1153. cpu_cache_usage_sys = 0.
  1154. if num_total_cpu: # Guard against both None and 0
  1155. num_free_cpu = sum(
  1156. scheduler.block_manager.get_num_free_cpu_blocks()
  1157. for scheduler in self.scheduler)
  1158. cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
  1159. # Prefix Cache Hit Rate. Note that we always use
  1160. # the cache hit rate of the first virtual engine.
  1161. cpu_prefix_cache_hit_rate = self.scheduler[
  1162. 0].get_prefix_cache_hit_rate(Device.CPU)
  1163. gpu_prefix_cache_hit_rate = self.scheduler[
  1164. 0].get_prefix_cache_hit_rate(Device.GPU)
  1165. # Iteration stats
  1166. num_prompt_tokens_iter = 0
  1167. num_generation_tokens_iter = 0
  1168. time_to_first_tokens_iter: List[float] = []
  1169. time_per_output_tokens_iter: List[float] = []
  1170. num_preemption_iter = (0 if scheduler_outputs is None else
  1171. scheduler_outputs.preempted)
  1172. # Request stats
  1173. # Latency
  1174. time_e2e_requests: List[float] = []
  1175. # Metadata
  1176. num_prompt_tokens_requests: List[int] = []
  1177. num_generation_tokens_requests: List[int] = []
  1178. best_of_requests: List[int] = []
  1179. n_requests: List[int] = []
  1180. finished_reason_requests: List[str] = []
  1181. # NOTE: This loop assumes prefill seq_groups are before
  1182. # decode seq_groups in scheduled_seq_groups.
  1183. if scheduler_outputs is not None:
  1184. # For async postprocessor, already finished sequences need to be
  1185. # not counted (to avoid double counting)
  1186. actual_num_batched_tokens = scheduler_outputs.num_batched_tokens # type: ignore
  1187. num_generation_tokens_from_prefill_groups = 0.
  1188. # NOTE: if scheduler_outputs.num_prefill_groups > 0 and
  1189. # the len of scheduler_outputs.scheduled_seq_groups is !=
  1190. # scheduler_outputs.num_prefill_groups, this means that
  1191. # chunked prefills have been detected.
  1192. for idx, scheduled_seq_group in enumerate(
  1193. scheduler_outputs.scheduled_seq_groups):
  1194. # Skip double logging when using async output proc
  1195. if finished_before and idx in finished_before:
  1196. actual_num_batched_tokens -= 1
  1197. continue
  1198. # Currently, skip == preempted sequences, so we need to skip
  1199. # their log stats
  1200. if skip and idx in skip:
  1201. continue
  1202. group_was_prefill = idx < scheduler_outputs.num_prefill_groups
  1203. seq_group = scheduled_seq_group.seq_group
  1204. # NOTE: a seq_group that completed all of its prefill tokens
  1205. # in the last iteration will have seq_group.is_prefill() = False
  1206. # with group_was_prefill = True
  1207. if group_was_prefill:
  1208. # Number of prompt tokens.
  1209. num_prompt_tokens_iter += (
  1210. scheduled_seq_group.token_chunk_size)
  1211. # If the seq_group just finished the prefill state
  1212. # get TTFT.
  1213. if not seq_group.is_prefill():
  1214. latency = seq_group.get_last_latency(now)
  1215. time_to_first_tokens_iter.append(latency)
  1216. # One generation token per finished prefill.
  1217. num_generation_tokens_from_prefill_groups += (
  1218. seq_group.num_seqs())
  1219. else:
  1220. # TPOTs.
  1221. latency = seq_group.get_last_latency(now)
  1222. time_per_output_tokens_iter.append(latency)
  1223. # Because of chunked prefill, we can have a single sequence
  1224. # group that does multiple prompt_runs. To prevent logging
  1225. # the same metadata more than once per request, we standardize
  1226. # on logging request level information for finished requests,
  1227. # which can only happen once.
  1228. if seq_group.is_finished():
  1229. # Latency timings
  1230. time_e2e_requests.append(now -
  1231. seq_group.metrics.arrival_time)
  1232. # Metadata
  1233. num_prompt_tokens_requests.append(
  1234. len(seq_group.prompt_token_ids))
  1235. num_generation_tokens_requests.extend([
  1236. seq.get_output_len()
  1237. for seq in seq_group.get_finished_seqs()
  1238. ])
  1239. if seq_group.sampling_params is not None:
  1240. best_of_requests.append(
  1241. seq_group.sampling_params.best_of)
  1242. n_requests.append(seq_group.sampling_params.n)
  1243. finished_reason_requests.extend([
  1244. SequenceStatus.get_finished_reason(seq.status)
  1245. for seq in seq_group.get_finished_seqs()
  1246. ])
  1247. # Number of generation tokens.
  1248. # num_batched_tokens equals the number of prompt_tokens plus the
  1249. # number of decode_tokens in a single iteration. So,
  1250. # num_generation_tokens = num_batched_tokens - num_prompt_tokens
  1251. # + num_generation_tokens_from_prefill_groups (since we generate
  1252. # one token on prefills on iters where the prefill finishes).
  1253. num_generation_tokens_iter = (
  1254. actual_num_batched_tokens - num_prompt_tokens_iter +
  1255. num_generation_tokens_from_prefill_groups)
  1256. # Spec decode, if enabled, emits specialized metrics from the worker in
  1257. # sampler output.
  1258. if model_output and (model_output[0].spec_decode_worker_metrics
  1259. is not None):
  1260. spec_decode_metrics = model_output[0].spec_decode_worker_metrics
  1261. else:
  1262. spec_decode_metrics = None
  1263. return Stats(
  1264. now=now,
  1265. # System stats
  1266. # Scheduler State
  1267. num_running_sys=num_running_sys,
  1268. num_swapped_sys=num_swapped_sys,
  1269. num_waiting_sys=num_waiting_sys,
  1270. # KV Cache Usage in %
  1271. gpu_cache_usage_sys=gpu_cache_usage_sys,
  1272. cpu_cache_usage_sys=cpu_cache_usage_sys,
  1273. # Prefix Cache Hit Rate
  1274. cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
  1275. gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
  1276. # Iteration stats
  1277. num_prompt_tokens_iter=num_prompt_tokens_iter,
  1278. num_generation_tokens_iter=num_generation_tokens_iter,
  1279. time_to_first_tokens_iter=time_to_first_tokens_iter,
  1280. time_per_output_tokens_iter=time_per_output_tokens_iter,
  1281. spec_decode_metrics=spec_decode_metrics,
  1282. num_preemption_iter=num_preemption_iter,
  1283. # Request stats
  1284. # Latency
  1285. time_e2e_requests=time_e2e_requests,
  1286. # Metadata
  1287. num_prompt_tokens_requests=num_prompt_tokens_requests,
  1288. num_generation_tokens_requests=num_generation_tokens_requests,
  1289. best_of_requests=best_of_requests,
  1290. n_requests=n_requests,
  1291. finished_reason_requests=finished_reason_requests,
  1292. )
  1293. def add_lora(self, lora_request: LoRARequest) -> bool:
  1294. return self.model_executor.add_lora(lora_request)
  1295. def remove_lora(self, lora_id: int) -> bool:
  1296. return self.model_executor.remove_lora(lora_id)
  1297. def list_loras(self) -> Set[int]:
  1298. return self.model_executor.list_loras()
  1299. def pin_lora(self, lora_id: int) -> bool:
  1300. return self.model_executor.pin_lora(lora_id)
  1301. def add_prompt_adapter(
  1302. self, prompt_adapter_request: PromptAdapterRequest) -> bool:
  1303. return self.model_executor.add_prompt_adapter(prompt_adapter_request)
  1304. def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  1305. return self.model_executor.remove_prompt_adapter(prompt_adapter_id)
  1306. def list_prompt_adapters(self) -> List[int]:
  1307. return self.model_executor.list_prompt_adapters()
  1308. def check_health(self) -> None:
  1309. if self.tokenizer:
  1310. self.tokenizer.check_health()
  1311. self.model_executor.check_health()
  1312. def shutdown(self) -> None:
  1313. self.model_executor.stop_remote_worker_execution_loop()
  1314. if hasattr(self, 'tokenizer') and self.tokenizer is not None:
  1315. self.tokenizer = None
  1316. if hasattr(self, 'scheduler'):
  1317. self.scheduler.clear()
  1318. if hasattr(self, 'cached_scheduler_outputs'):
  1319. self.cached_scheduler_outputs.clear()
  1320. if hasattr(self, 'scheduler_contexts'):
  1321. self.scheduler_contexts.clear()
  1322. if hasattr(self, 'stat_loggers'):
  1323. self.stat_loggers.clear()
  1324. if hasattr(self, 'model_executor'):
  1325. self.model_executor.shutdown()
  1326. def is_encoder_decoder_model(self):
  1327. return self.input_preprocessor.is_encoder_decoder_model()
  1328. def is_embedding_model(self):
  1329. return self.model_config.is_embedding_model
  1330. def _validate_model_inputs(self, inputs: Union[LLMInputs,
  1331. EncoderDecoderLLMInputs]):
  1332. if self.is_encoder_decoder_model():
  1333. prompt_ids = inputs.get("encoder_prompt_token_ids")
  1334. else:
  1335. prompt_ids = inputs.get("prompt_token_ids")
  1336. if prompt_ids is None or len(prompt_ids) == 0:
  1337. raise ValueError("Prompt cannot be empty")
  1338. if self.model_config.is_multimodal_model:
  1339. max_prompt_len = self.model_config.max_model_len
  1340. if len(prompt_ids) > max_prompt_len:
  1341. raise ValueError(
  1342. f"The prompt (total length {len(prompt_ids)}) is too long "
  1343. f"to fit into the model (context length {max_prompt_len}). "
  1344. "Make sure that `max_model_len` is no smaller than the "
  1345. "number of text tokens plus multimodal tokens. For image "
  1346. "inputs, the number of image tokens depends on the number "
  1347. "of images, and possibly their aspect ratios as well.")
  1348. # TODO: Find out how many placeholder tokens are there so we can
  1349. # check that chunked prefill does not truncate them
  1350. # max_batch_len = self.scheduler_config.max_num_batched_tokens
  1351. setup_logger()