1
0

aphrodite_engine.py 61 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518
  1. import os
  2. import time
  3. from contextlib import contextmanager
  4. from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional
  5. from typing import Sequence as GenericSequence
  6. from typing import Tuple, Type, TypeVar, Union
  7. from loguru import logger
  8. from transformers import PreTrainedTokenizer
  9. from typing_extensions import assert_never
  10. from aphrodite.common.config import (CacheConfig, CFGConfig, DecodingConfig,
  11. DeviceConfig, EngineConfig, LoadConfig,
  12. LoRAConfig, ModelConfig, ParallelConfig,
  13. PromptAdapterConfig, SchedulerConfig,
  14. SpeculativeConfig)
  15. from aphrodite.common.logger import setup_logger
  16. from aphrodite.common.outputs import (EmbeddingRequestOutput, RequestOutput,
  17. RequestOutputFactory)
  18. from aphrodite.common.pooling_params import PoolingParams
  19. from aphrodite.common.sampling_params import SamplingParams
  20. from aphrodite.common.sequence import (EmbeddingSequenceGroupOutput,
  21. ExecuteModelRequest, PoolerOutput,
  22. SamplerOutput, Sequence, SequenceGroup,
  23. SequenceGroupMetadata, SequenceStatus)
  24. from aphrodite.common.utils import Counter, Device
  25. from aphrodite.engine.args_tools import EngineArgs
  26. from aphrodite.engine.metrics_types import StatLoggerBase, Stats
  27. from aphrodite.engine.output_processor.interfaces import (
  28. SequenceGroupOutputProcessor)
  29. from aphrodite.engine.output_processor.stop_checker import StopChecker
  30. from aphrodite.engine.output_processor.util import (
  31. create_output_by_sequence_group)
  32. from aphrodite.executor.executor_base import ExecutorBase
  33. from aphrodite.executor.ray_utils import initialize_ray_cluster
  34. from aphrodite.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
  35. InputRegistry, LLMInputs, PromptInputs,
  36. SingletonPromptInputs)
  37. from aphrodite.inputs.parse import is_explicit_encoder_decoder_prompt
  38. from aphrodite.lora.request import LoRARequest
  39. from aphrodite.multimodal import MultiModalDataDict
  40. from aphrodite.processing.scheduler import (ScheduledSequenceGroup, Scheduler,
  41. SchedulerOutputs)
  42. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  43. from aphrodite.transformers_utils.config import try_get_generation_config
  44. from aphrodite.transformers_utils.detokenizer import Detokenizer
  45. from aphrodite.transformers_utils.tokenizer_group import (
  46. BaseTokenizerGroup, init_tokenizer_from_configs)
  47. from aphrodite.version import __version__ as APHRODITE_VERSION
  48. _LOCAL_LOGGING_INTERVAL_SEC = 5
  49. APHRODITE_USE_RAY_SPMD_WORKER = bool(
  50. os.getenv("APHRODITE_USE_RAY_SPMD_WORKER", 0))
  51. def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
  52. config = try_get_generation_config(
  53. model_config.model,
  54. trust_remote_code=model_config.trust_remote_code,
  55. revision=model_config.revision,
  56. )
  57. if config is None:
  58. return {}
  59. return config.to_diff_dict()
  60. _O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
  61. PromptComponents = Tuple[Optional[str], List[int],
  62. Optional[MultiModalDataDict],
  63. Optional[None], Optional[None]]
  64. DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
  65. Optional[MultiModalDataDict],
  66. Optional[None], Optional[None]]
  67. class AphroditeEngine:
  68. """An LLM engine that receives requests and generates texts.
  69. This is the main class for the Aphrodite engine. It receives requests
  70. from clients and generates texts from the LLM. It includes a tokenizer, a
  71. language model (possibly distributed across multiple GPUs), and GPU memory
  72. space allocated for intermediate states (aka KV cache). This class utilizes
  73. iteration-level scheduling and efficient memory management to maximize the
  74. serving throughput.
  75. The `LLM` class wraps this class for offline batched inference and the
  76. `AsyncAphrodite` class wraps this class for online serving.
  77. NOTE: The config arguments are derived from the `EngineArgs` class. For the
  78. comprehensive list of arguments, see `EngineArgs`.
  79. Args:
  80. model_config: The configuration related to the LLM model.
  81. cache_config: The configuration related to the KV cache memory
  82. management.
  83. parallel_config: The configuration related to distributed execution.
  84. scheduler_config: The configuration related to the request scheduler.
  85. device_config: The configuration related to the device.
  86. lora_config (Optional): The configuration related to serving multi-LoRA.
  87. speculative_config (Optional): The configuration related to speculative
  88. decoding.
  89. executor_class: The model executor class for managing distributed
  90. execution.
  91. prompt_adapter_config (Optional): The configuration related to serving
  92. prompt adapters.
  93. log_stats: Whether to log statistics.
  94. """
  95. DO_VALIDATE_OUTPUT: ClassVar[bool] = False
  96. """A flag to toggle whether to validate the type of request output."""
  97. @classmethod
  98. @contextmanager
  99. def enable_output_validation(cls):
  100. cls.DO_VALIDATE_OUTPUT = True
  101. yield
  102. cls.DO_VALIDATE_OUTPUT = False
  103. @classmethod
  104. def validate_output(
  105. cls,
  106. output: object,
  107. output_type: Type[_O],
  108. ) -> _O:
  109. do_validate = cls.DO_VALIDATE_OUTPUT
  110. if ((TYPE_CHECKING or do_validate)
  111. and not isinstance(output, output_type)):
  112. raise TypeError(f"Expected output of type {output_type}, "
  113. f"but found type {type(output)}")
  114. return output
  115. @classmethod
  116. def validate_outputs(
  117. cls,
  118. outputs: GenericSequence[object],
  119. output_type: Type[_O],
  120. ) -> List[_O]:
  121. do_validate = cls.DO_VALIDATE_OUTPUT
  122. outputs_: List[_O]
  123. if TYPE_CHECKING or do_validate:
  124. outputs_ = []
  125. for output in outputs:
  126. if not isinstance(output, output_type):
  127. raise TypeError(f"Expected output of type {output_type}, "
  128. f"but found type {type(output)}")
  129. outputs_.append(output)
  130. else:
  131. outputs_ = outputs
  132. return outputs_
  133. tokenizer: Optional[BaseTokenizerGroup]
  134. def __init__(
  135. self,
  136. model_config: ModelConfig,
  137. cache_config: CacheConfig,
  138. parallel_config: ParallelConfig,
  139. scheduler_config: SchedulerConfig,
  140. device_config: DeviceConfig,
  141. load_config: LoadConfig,
  142. lora_config: Optional[LoRAConfig],
  143. speculative_config: Optional[SpeculativeConfig],
  144. decoding_config: Optional[DecodingConfig],
  145. prompt_adapter_config: Optional[PromptAdapterConfig],
  146. cfg_config: Optional[CFGConfig],
  147. executor_class: Type[ExecutorBase],
  148. log_stats: bool,
  149. stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
  150. input_registry: InputRegistry = INPUT_REGISTRY,
  151. ) -> None:
  152. try:
  153. import aphrodite.commit_id
  154. commit_id = True
  155. except ImportError:
  156. commit_id = False
  157. config_dict = {
  158. "Model": model_config.model,
  159. "Speculative Config": speculative_config,
  160. "CFG Config": cfg_config,
  161. "DataType": model_config.dtype,
  162. "Model Load Format": load_config.load_format,
  163. "Tensor Parallel Size": parallel_config.tensor_parallel_size,
  164. "Pipeline Parallel Size": parallel_config.pipeline_parallel_size,
  165. "Disable Custom All-Reduce":
  166. parallel_config.disable_custom_all_reduce,
  167. "Quantization Format": model_config.quantization,
  168. "Context Length": model_config.max_model_len,
  169. "Enforce Eager Mode": model_config.enforce_eager,
  170. "Prefix Caching": cache_config.enable_prefix_caching,
  171. "KV Cache DataType": cache_config.cache_dtype,
  172. "Device": device_config.device,
  173. "Rope Scaling": model_config.rope_scaling,
  174. "Guided Decoding Backend": decoding_config
  175. }
  176. logger.info("-" * 85)
  177. if not commit_id:
  178. logger.info(
  179. f"Initializing Aphrodite Engine (v{APHRODITE_VERSION}) "
  180. "with the following config:")
  181. else:
  182. logger.info(f"Initializing Aphrodite Engine (v{APHRODITE_VERSION} "
  183. f"commit {aphrodite.__short_commit__}) with the "
  184. "following config:")
  185. for key, value in config_dict.items():
  186. if value is not None and not ((key == "Model Load Format" or key ==\
  187. "KV Cache DataType") and value == \
  188. "auto"):
  189. logger.info(f"{key} = {value!r}")
  190. logger.info("-" * 85)
  191. # TODO: Print more configs in debug mode.
  192. from aphrodite.plugins import load_general_plugins
  193. load_general_plugins()
  194. self.model_config = model_config
  195. self.cache_config = cache_config
  196. self.lora_config = lora_config
  197. self.parallel_config = parallel_config
  198. self.scheduler_config = scheduler_config
  199. self.device_config = device_config
  200. self.speculative_config = speculative_config
  201. self.load_config = load_config
  202. self.decoding_config = decoding_config or DecodingConfig()
  203. self.prompt_adapter_config = prompt_adapter_config
  204. self.cfg_config = cfg_config
  205. self.log_stats = log_stats
  206. if not self.model_config.skip_tokenizer_init:
  207. self.tokenizer = self._init_tokenizer()
  208. self.detokenizer = Detokenizer(self.tokenizer)
  209. tokenizer_group = self.get_tokenizer_group()
  210. else:
  211. self.tokenizer = None
  212. self.detokenizer = None
  213. tokenizer_group = None
  214. # Ensure that the function doesn't contain a reference to self,
  215. # to avoid engine GC issues
  216. def get_tokenizer_for_seq(sequence: Sequence) -> PreTrainedTokenizer:
  217. assert tokenizer_group, ("tokenizer_group cannot be None, "
  218. "make sure skip_tokenizer_init is False")
  219. return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
  220. self.seq_counter = Counter()
  221. self.generation_config_fields = _load_generation_config_dict(
  222. model_config)
  223. self.input_registry = input_registry
  224. self.input_processor = input_registry.create_input_processor(
  225. model_config)
  226. self.model_executor = executor_class(
  227. model_config=model_config,
  228. cache_config=cache_config,
  229. parallel_config=parallel_config,
  230. scheduler_config=scheduler_config,
  231. device_config=device_config,
  232. lora_config=lora_config,
  233. speculative_config=speculative_config,
  234. load_config=load_config,
  235. prompt_adapter_config=prompt_adapter_config,
  236. cfg_config=cfg_config,
  237. )
  238. if not self.model_config.embedding_mode:
  239. self._initialize_kv_caches()
  240. if self.tokenizer:
  241. # Ping the tokenizer to ensure liveness if it runs in a
  242. # different process.
  243. self.tokenizer.ping()
  244. # Create the scheduler.
  245. # NOTE: the cache_config here have been updated with the numbers of
  246. # GPU and CPU blocks, which are profiled in the distributed executor.
  247. self.scheduler = [
  248. Scheduler(scheduler_config, cache_config, lora_config,
  249. parallel_config.pipeline_parallel_size)
  250. for _ in range(parallel_config.pipeline_parallel_size)
  251. ]
  252. # Metric Logging.
  253. if self.log_stats:
  254. if stat_loggers is not None:
  255. self.stat_loggers = stat_loggers
  256. else:
  257. # Lazy import for prometheus multiprocessing.
  258. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
  259. # before prometheus_client is imported.
  260. # See https://prometheus.github.io/client_python/multiprocess/
  261. from aphrodite.engine.metrics import (LoggingStatLogger,
  262. PrometheusStatLogger)
  263. self.stat_loggers = {
  264. "logging":
  265. LoggingStatLogger(
  266. local_interval=_LOCAL_LOGGING_INTERVAL_SEC),
  267. "prometheus":
  268. PrometheusStatLogger(
  269. local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
  270. labels=dict(model_name=model_config.served_model_name),
  271. max_model_len=self.model_config.max_model_len),
  272. }
  273. self.stat_loggers["prometheus"].info("cache_config",
  274. self.cache_config)
  275. # Create sequence output processor, e.g. for beam search or
  276. # speculative decoding.
  277. self.output_processor = (
  278. SequenceGroupOutputProcessor.create_output_processor(
  279. self.scheduler_config,
  280. self.detokenizer,
  281. self.scheduler,
  282. self.seq_counter,
  283. get_tokenizer_for_seq,
  284. stop_checker=StopChecker(
  285. self.scheduler_config.max_model_len,
  286. get_tokenizer_for_seq,
  287. ),
  288. ))
  289. def _initialize_kv_caches(self) -> None:
  290. """Initialize the KV cache in the worker(s).
  291. The workers will determine the number of blocks in both the GPU cache
  292. and the swap CPU cache.
  293. """
  294. num_gpu_blocks, num_cpu_blocks = (
  295. self.model_executor.determine_num_available_blocks())
  296. if self.cache_config.num_gpu_blocks_override is not None:
  297. num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
  298. logger.info(f"Overriding {num_gpu_blocks=} with "
  299. f"{num_gpu_blocks_override=}")
  300. num_gpu_blocks = num_gpu_blocks_override
  301. self.cache_config.num_gpu_blocks = num_gpu_blocks
  302. self.cache_config.num_cpu_blocks = num_cpu_blocks
  303. self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
  304. @classmethod
  305. def _get_executor_cls(cls,
  306. engine_config: EngineConfig) -> Type[ExecutorBase]:
  307. distributed_executor_backend = (
  308. engine_config.parallel_config.distributed_executor_backend)
  309. # Initialize the cluster and specify the executor class.
  310. if isinstance(distributed_executor_backend, type):
  311. if not issubclass(distributed_executor_backend, ExecutorBase):
  312. raise TypeError(
  313. "distributed_executor_backend must be a subclass of "
  314. f"ExecutorBase. Got {distributed_executor_backend}.")
  315. if distributed_executor_backend.uses_ray: # type: ignore
  316. initialize_ray_cluster(engine_config.parallel_config)
  317. executor_class = distributed_executor_backend
  318. elif engine_config.device_config.device_type == "neuron":
  319. from aphrodite.executor.neuron_executor import NeuronExecutor
  320. executor_class = NeuronExecutor
  321. elif engine_config.device_config.device_type == "tpu":
  322. if distributed_executor_backend == "ray":
  323. initialize_ray_cluster(engine_config.parallel_config)
  324. from aphrodite.executor.ray_tpu_executor import RayTPUExecutor
  325. executor_class = RayTPUExecutor
  326. else:
  327. assert distributed_executor_backend is None
  328. from aphrodite.executor.tpu_executor import TPUExecutor
  329. executor_class = TPUExecutor
  330. elif engine_config.device_config.device_type == "cpu":
  331. from aphrodite.executor.cpu_executor import CPUExecutor
  332. executor_class = CPUExecutor
  333. elif engine_config.device_config.device_type == "openvino":
  334. from aphrodite.executor.openvino_executor import OpenVINOExecutor
  335. executor_class = OpenVINOExecutor
  336. elif engine_config.device_config.device_type == "xpu":
  337. if distributed_executor_backend == "ray":
  338. initialize_ray_cluster(engine_config.parallel_config)
  339. from aphrodite.executor.ray_xpu_executor import RayXPUExecutor
  340. executor_class = RayXPUExecutor
  341. else:
  342. from aphrodite.executor.xpu_executor import XPUExecutor
  343. executor_class = XPUExecutor
  344. elif distributed_executor_backend == "ray":
  345. initialize_ray_cluster(engine_config.parallel_config)
  346. from aphrodite.executor.ray_gpu_executor import RayGPUExecutor
  347. executor_class = RayGPUExecutor
  348. elif distributed_executor_backend == "mp":
  349. from aphrodite.executor.multiproc_gpu_executor import (
  350. MultiprocessingGPUExecutor)
  351. assert not APHRODITE_USE_RAY_SPMD_WORKER, (
  352. "multiprocessing distributed executor backend does not "
  353. "support APHRODITE_USE_RAY_SPMD_WORKER=1")
  354. executor_class = MultiprocessingGPUExecutor
  355. else:
  356. from aphrodite.executor.gpu_executor import GPUExecutor
  357. executor_class = GPUExecutor
  358. return executor_class
  359. @classmethod
  360. def from_engine_args(
  361. cls,
  362. engine_args: EngineArgs,
  363. stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
  364. ) -> "AphroditeEngine":
  365. """Creates an LLM engine from the engine arguments."""
  366. # Create the engine configs.
  367. engine_config = engine_args.create_engine_config()
  368. executor_class = cls._get_executor_cls(engine_config)
  369. # Create the LLM engine.
  370. engine = cls(
  371. **engine_config.to_dict(),
  372. executor_class=executor_class,
  373. log_stats=not engine_args.disable_log_stats,
  374. stat_loggers=stat_loggers,
  375. )
  376. return engine
  377. def __reduce__(self):
  378. # This is to ensure that the AphroditeEngine is not referenced in
  379. # the closure used to initialize Ray worker actors
  380. raise RuntimeError("AphroditeEngine should not be pickled!")
  381. def __del__(self):
  382. # Shutdown the model executor when engine is garbage collected.
  383. # Use getattr since __init__ can fail before the field is set
  384. if model_executor := getattr(self, "model_executor", None):
  385. model_executor.shutdown()
  386. MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because "
  387. "skip_tokenizer_init is True")
  388. def get_tokenizer_group(
  389. self,
  390. fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup:
  391. if self.tokenizer is None:
  392. raise ValueError(fail_msg)
  393. return self.tokenizer
  394. def get_tokenizer(
  395. self,
  396. lora_request: Optional[LoRARequest] = None
  397. ) -> "PreTrainedTokenizer":
  398. return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
  399. def _init_tokenizer(self) -> BaseTokenizerGroup:
  400. return init_tokenizer_from_configs(
  401. model_config=self.model_config,
  402. scheduler_config=self.scheduler_config,
  403. parallel_config=self.parallel_config,
  404. enable_lora=bool(self.lora_config))
  405. def _verify_args(self) -> None:
  406. self.model_config.verify_with_parallel_config(self.parallel_config)
  407. self.cache_config.verify_with_parallel_config(self.parallel_config)
  408. if self.lora_config:
  409. self.lora_config.verify_with_model_config(self.model_config)
  410. self.lora_config.verify_with_scheduler_config(
  411. self.scheduler_config)
  412. if self.prompt_adapter_config:
  413. self.prompt_adapter_config.verify_with_model_config(
  414. self.model_config)
  415. def _get_bos_token_id(self,
  416. lora_request: Optional[LoRARequest] = None
  417. ) -> Optional[int]:
  418. if self.tokenizer is None:
  419. logger.warning("Using None for BOS token id because tokenizer "
  420. "is not initialized")
  421. return None
  422. return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id
  423. def _get_eos_token_id(self,
  424. lora_request: Optional[LoRARequest] = None
  425. ) -> Optional[int]:
  426. if self.tokenizer is None:
  427. logger.warning("Using None for EOS token id because tokenizer "
  428. "is not initialized")
  429. return None
  430. return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
  431. def _get_decoder_start_token_id(self) -> Optional[int]:
  432. '''
  433. Obtain the decoder start token id employed by an encoder/decoder
  434. model. Returns None for non-encoder/decoder models or if the
  435. model config is unavailable.
  436. '''
  437. if not self.is_encoder_decoder_model():
  438. logger.warning("Using None for decoder start token id because "
  439. "this is not an encoder/decoder model.")
  440. return None
  441. if (self.model_config is None or self.model_config.hf_config is None):
  442. logger.warning("Using None for decoder start token id because "
  443. "model config is not available.")
  444. return None
  445. dec_start_token_id = getattr(self.model_config.hf_config,
  446. 'decoder_start_token_id', None)
  447. if dec_start_token_id is None:
  448. logger.warning("Falling back on <BOS> for decoder start token id "
  449. "because decoder start token id is not available.")
  450. dec_start_token_id = self._get_bos_token_id()
  451. return dec_start_token_id
  452. def _add_processed_request(
  453. self,
  454. request_id: str,
  455. processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
  456. params: Union[SamplingParams, PoolingParams],
  457. arrival_time: float,
  458. lora_request: Optional[LoRARequest],
  459. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  460. ) -> None:
  461. # Create the sequences.
  462. block_size = self.cache_config.block_size
  463. seq_id = next(self.seq_counter)
  464. eos_token_id = self._get_eos_token_id(lora_request)
  465. seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
  466. lora_request, prompt_adapter_request)
  467. negative_seq = None
  468. if 'negative_prompt_token_ids' in processed_inputs:
  469. negative_seq = Sequence(seq_id,
  470. processed_inputs,
  471. block_size,
  472. eos_token_id,
  473. lora_request,
  474. prompt_adapter_request,
  475. from_negative_prompt=True)
  476. encoder_seq = None
  477. if 'encoder_prompt_token_ids' in processed_inputs:
  478. encoder_seq = Sequence(seq_id,
  479. processed_inputs,
  480. block_size,
  481. eos_token_id,
  482. lora_request,
  483. prompt_adapter_request,
  484. from_decoder_prompt=False)
  485. # Create a SequenceGroup based on SamplingParams or PoolingParams
  486. if isinstance(params, SamplingParams):
  487. seq_group = self._create_sequence_group_with_sampling(
  488. request_id,
  489. seq,
  490. params,
  491. arrival_time=arrival_time,
  492. lora_request=lora_request,
  493. prompt_adapter_request=prompt_adapter_request,
  494. encoder_seq=encoder_seq,
  495. negative_seq=negative_seq,
  496. )
  497. elif isinstance(params, PoolingParams):
  498. seq_group = self._create_sequence_group_with_pooling(
  499. request_id,
  500. seq,
  501. params,
  502. arrival_time=arrival_time,
  503. lora_request=lora_request,
  504. prompt_adapter_request=prompt_adapter_request,
  505. encoder_seq=encoder_seq,
  506. )
  507. else:
  508. raise ValueError(
  509. "Either SamplingParams or PoolingParams must be provided.")
  510. # Add the sequence group to the scheduler with least unfinished seqs.
  511. costs = [
  512. scheduler.get_num_unfinished_seq_groups()
  513. for scheduler in self.scheduler
  514. ]
  515. min_cost_scheduler = self.scheduler[costs.index(min(costs))]
  516. min_cost_scheduler.add_seq_group(seq_group)
  517. def stop_remote_worker_execution_loop(self) -> None:
  518. self.model_executor.stop_remote_worker_execution_loop()
  519. _LLMInputComponentsType = Tuple[str, List[int]]
  520. def _prepare_decoder_input_ids_for_generation(
  521. self,
  522. decoder_input_ids: Optional[List[int]],
  523. ) -> List[int]:
  524. """
  525. Prepares `decoder_input_ids` for generation with encoder-decoder models.
  526. Based on
  527. https://github.com/huggingface/transformers/blob/
  528. 4037a2b5b1278736e566aec12e169100275545ea/
  529. src/transformers/generation/utils.py
  530. specifically GenerationMixin._prepare_decoder_input_ids_for_generation()
  531. Arguments:
  532. * decoder_input_ids: input token ids to preprocess
  533. Returns:
  534. * Processed token list
  535. """
  536. decoder_start_token_id = self._get_decoder_start_token_id()
  537. assert decoder_start_token_id is not None
  538. if decoder_input_ids is None:
  539. # no decoder prompt input ->
  540. # use decoder_start_token_id as decoder_input_ids
  541. decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
  542. if (len(decoder_input_ids) == 0
  543. or decoder_input_ids[0] != decoder_start_token_id):
  544. decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
  545. return decoder_input_ids
  546. def _tokenize_prompt(
  547. self,
  548. prompt: str,
  549. request_id: str,
  550. lora_request: Optional[LoRARequest],
  551. ) -> List[int]:
  552. '''
  553. Wrapper around application of the model's tokenizer.
  554. Arguments:
  555. * prompt
  556. * request_id
  557. * lora_request
  558. Returns:
  559. * prompt token ids
  560. '''
  561. tokenizer = self.get_tokenizer_group("prompts must be None if "
  562. "skip_tokenizer_init is True")
  563. return tokenizer.encode(request_id=request_id,
  564. prompt=prompt,
  565. lora_request=lora_request)
  566. def _extract_prompt_components(
  567. self,
  568. inputs: SingletonPromptInputs,
  569. request_id: str,
  570. lora_request: Optional[LoRARequest] = None,
  571. ) -> PromptComponents:
  572. '''
  573. Extract the components of any single encoder or decoder input prompt.
  574. Arguments:
  575. * request_id
  576. * inputs: single encoder or decoder input prompt
  577. * lora_request: this is only valid for decoder prompts
  578. Returns:
  579. * prompt
  580. * prompt_token_ids
  581. * multi_modal_data
  582. '''
  583. if isinstance(inputs, str):
  584. prompt = inputs
  585. prompt_token_ids = self._tokenize_prompt(
  586. prompt,
  587. request_id=request_id,
  588. lora_request=lora_request,
  589. )
  590. multi_modal_data = None
  591. negative_prompt = None
  592. negative_prompt_token_ids = None
  593. elif isinstance(inputs, dict):
  594. if "prompt_token_ids" in inputs:
  595. prompt = None
  596. prompt_token_ids = inputs["prompt_token_ids"]
  597. else:
  598. # NOTE: This extra assignment is required to pass mypy
  599. prompt = parsed_prompt = inputs["prompt"]
  600. prompt_token_ids = self._tokenize_prompt(
  601. parsed_prompt,
  602. request_id=request_id,
  603. lora_request=lora_request,
  604. )
  605. if "negative_prompt_token_ids" in inputs:
  606. negative_prompt = None
  607. negative_prompt_token_ids = inputs["negative_prompt_token_ids"]
  608. elif "negative_prompt" in inputs:
  609. negative_prompt = parsed_negative_prompt = inputs[
  610. "negative_prompt"]
  611. negative_prompt_token_ids = self._tokenize_prompt(
  612. parsed_negative_prompt,
  613. request_id=request_id,
  614. lora_request=lora_request,
  615. )
  616. else:
  617. negative_prompt = None
  618. negative_prompt_token_ids = None
  619. multi_modal_data = inputs.get("multi_modal_data")
  620. else:
  621. assert_never(inputs)
  622. return (prompt, prompt_token_ids, multi_modal_data,
  623. negative_prompt, negative_prompt_token_ids)
  624. def _apply_prompt_adapter(
  625. self,
  626. prompt_token_ids: List[int],
  627. prompt_adapter_request: Optional[PromptAdapterRequest],
  628. ) -> List[int]:
  629. if prompt_adapter_request:
  630. prompt_token_ids = (
  631. [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
  632. + prompt_token_ids)
  633. return prompt_token_ids
  634. def _get_default_enc_dec_decoder_prompt(self) -> List[int]:
  635. '''
  636. Specifically for encoder/decoder models:
  637. generate a default decoder prompt for when
  638. the user specifies only the encoder prompt.
  639. Encoder/decoder models utilize the decoder
  640. prompt in different ways; as new models are
  641. added, it is intended that this function
  642. will be extended to produce differing
  643. default decoder prompts, depending on the
  644. model variety.
  645. Absent a special case, the default behavior
  646. of this method is to mirror the behavior of
  647. the HuggingFace (HF) GenerationMixin for a None
  648. decoder prompt, which is to employ a logit processor
  649. setting to force the first decoded token to be <BOS>.
  650. Here, this behavior is approximated by having the
  651. "default" decoder prompt be <BOS>.
  652. However, it is possible that in the future
  653. other models may have different or more
  654. complex logic for the default decoder prompt.
  655. This motivates having a special helper method
  656. for default decoder prompts.
  657. Returns:
  658. * prompt_token_ids
  659. '''
  660. bos_token_id = self._get_bos_token_id()
  661. assert bos_token_id is not None
  662. return [bos_token_id]
  663. def _build_enc_dec_llm_inputs(
  664. self,
  665. encoder_comps: PromptComponents,
  666. decoder_comps: DecoderPromptComponents,
  667. ) -> EncoderDecoderLLMInputs:
  668. encoder_prompt, encoder_prompt_ids, encoder_mm_data, \
  669. encoder_negative_prompt, encoder_negative_prompt_ids = encoder_comps
  670. decoder_prompt, decoder_prompt_ids, decoder_mm_data, \
  671. decoder_negative_prompt, decoder_negative_prompt_ids= decoder_comps
  672. if encoder_mm_data is not None or decoder_mm_data is not None:
  673. raise ValueError("Multi-modal encoder-decoder models are "
  674. "not supported yet")
  675. decoder_prompt_ids = (
  676. self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids))
  677. decoder_negative_prompt_ids = (
  678. self._prepare_decoder_input_ids_for_generation(decoder_negative_prompt_ids))
  679. return EncoderDecoderLLMInputs(
  680. prompt_token_ids=decoder_prompt_ids,
  681. prompt=decoder_prompt,
  682. negative_prompt_token_ids=decoder_negative_prompt_ids,
  683. negative_prompt=decoder_negative_prompt,
  684. encoder_prompt_token_ids=encoder_prompt_ids,
  685. encoder_prompt=encoder_prompt,
  686. encoder_negative_prompt_token_ids=encoder_negative_prompt_ids,
  687. encoder_negative_prompt=encoder_negative_prompt,
  688. )
  689. def _process_encoder_decoder_prompt(
  690. self,
  691. inputs: PromptInputs,
  692. request_id: str,
  693. ) -> EncoderDecoderLLMInputs:
  694. '''
  695. For encoder/decoder models only:
  696. Process an input prompt into an
  697. :class:`EncoderDecoderLLMInputs` instance.
  698. There are two types of input prompts:
  699. singleton prompts which carry only the
  700. encoder prompt, and explicit encoder/decoder
  701. prompts which carry both the encoder and the
  702. decoder prompts as member variables.
  703. This function handles the following scenarios:
  704. * Singleton encoder prompt: extract encoder prompt
  705. token ids & infer default decoder prompt token ids
  706. * Explicit encoder/decoder prompt: extract encoder
  707. and decoder prompt token ids
  708. Note that for Explicit encoder/decoder prompts,
  709. each sub-prompt (encoder or decoder prompt) can
  710. have any possible singleton type; thus this
  711. method relies on helper functions to obtain
  712. token ids for the sub-prompts.
  713. Arguments:
  714. * inputs: an input prompt
  715. * request_id
  716. Returns:
  717. * :class:`EncoderDecoderLLMInputs` instance
  718. '''
  719. encoder_comps: PromptComponents
  720. decoder_comps: DecoderPromptComponents
  721. if is_explicit_encoder_decoder_prompt(inputs):
  722. encoder_comps = self._extract_prompt_components(
  723. inputs["encoder_prompt"],
  724. request_id=request_id,
  725. )
  726. if (decoder_input := inputs["decoder_prompt"]) is None:
  727. decoder_comps = None, None, None, None, None
  728. else:
  729. decoder_comps = self._extract_prompt_components(
  730. decoder_input,
  731. request_id=request_id,
  732. )
  733. else:
  734. encoder_comps = self._extract_prompt_components(
  735. inputs,
  736. request_id=request_id,
  737. )
  738. decoder_comps = None, None, None, None, None
  739. return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
  740. def _build_decoder_only_llm_inputs(
  741. self,
  742. prompt_comps: PromptComponents,
  743. prompt_adapter_request: Optional[PromptAdapterRequest],
  744. ) -> LLMInputs:
  745. prompt, prompt_token_ids, multi_modal_data, \
  746. negative_prompt, negative_prompt_token_ids = prompt_comps
  747. prompt_token_ids = self._apply_prompt_adapter(
  748. prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
  749. return LLMInputs(prompt_token_ids=prompt_token_ids,
  750. prompt=prompt,
  751. multi_modal_data=multi_modal_data,
  752. negative_prompt_token_ids=negative_prompt_token_ids,
  753. negative_prompt=negative_prompt)
  754. def _process_decoder_only_prompt(
  755. self,
  756. inputs: SingletonPromptInputs,
  757. request_id: str,
  758. lora_request: Optional[LoRARequest] = None,
  759. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  760. ) -> LLMInputs:
  761. '''
  762. For decoder-only models:
  763. Process an input prompt into an :class:`LLMInputs` instance.
  764. Arguments:
  765. * inputs: input prompt
  766. * request_id
  767. * lora_request
  768. * prompt_adapter_request
  769. Returns:
  770. * :class:`LLMInputs` instance
  771. '''
  772. prompt_comps = self._extract_prompt_components(
  773. inputs,
  774. request_id=request_id,
  775. lora_request=lora_request,
  776. )
  777. return self._build_decoder_only_llm_inputs(
  778. prompt_comps,
  779. prompt_adapter_request=prompt_adapter_request,
  780. )
  781. def process_model_inputs(
  782. self,
  783. inputs: PromptInputs,
  784. request_id: str,
  785. lora_request: Optional[LoRARequest] = None,
  786. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  787. ) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
  788. if self.is_encoder_decoder_model():
  789. # Encoder-decoder model requires special mapping of
  790. # input prompts to encoder & decoder
  791. model_inputs = self._process_encoder_decoder_prompt(
  792. inputs,
  793. request_id=request_id,
  794. )
  795. else:
  796. if is_explicit_encoder_decoder_prompt(inputs):
  797. raise ValueError("Cannot pass encoder-decoder prompt "
  798. "to decoder-only models")
  799. # Decoder-only operation
  800. model_inputs = self._process_decoder_only_prompt(
  801. inputs,
  802. request_id=request_id,
  803. lora_request=lora_request,
  804. prompt_adapter_request=prompt_adapter_request,
  805. )
  806. return self.input_processor(model_inputs)
  807. def add_request(
  808. self,
  809. request_id: str,
  810. inputs: PromptInputs,
  811. params: Union[SamplingParams, PoolingParams],
  812. arrival_time: Optional[float] = None,
  813. lora_request: Optional[LoRARequest] = None,
  814. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  815. ) -> None:
  816. """Add a request to the engine's request pool.
  817. The request is added to the request pool and will be processed by the
  818. scheduler as `engine.step()` is called. The exact scheduling policy is
  819. determined by the scheduler.
  820. Args:
  821. request_id: The unique ID of the request.
  822. prompt: The prompt string. Can be None if prompt_token_ids is
  823. provided.
  824. params: Parameters for sampling or pooling. SamplingParams
  825. for text generation. PoolingParams for pooling.
  826. prompt_token_ids: The token IDs of the prompt. If None, we
  827. use the tokenizer to convert the prompts to token IDs.
  828. arrival_time: The arrival time of the request. If None, we use
  829. the current monotonic time.
  830. multi_modal_data: Multi modal data per request.
  831. Details:
  832. - Set arrival_time to the current time if it is None.
  833. - Set prompt_token_ids to the encoded prompt if it is None.
  834. - Create `best_of` number of :class:`~aphrodite.common.sequence`
  835. objects.
  836. - Create a :class:`~aphrodite.common.sequenceGroup` object
  837. from the list of :class:`~aphrodite.common.sequence`.
  838. - Add the :class:`~aphrodite.common.sequenceGroup` object to the
  839. scheduler.
  840. Example:
  841. >>> # initialize engine
  842. >>> engine = AphroditeEngine.from_engine_args(engine_args)
  843. >>> # set request arguments
  844. >>> example_prompt = "Who is the president of the United States?"
  845. >>> sampling_params = SamplingParams(temperature=0.0)
  846. >>> request_id = 0
  847. >>>
  848. >>> # add the request to the engine
  849. >>> engine.add_request(
  850. >>> str(request_id),
  851. >>> example_prompt,
  852. >>> SamplingParams(temperature=0.0))
  853. >>> # continue the request processing
  854. >>> ...
  855. """
  856. if lora_request is not None and not self.lora_config:
  857. raise ValueError(f"Got lora_request {lora_request} but LoRA is "
  858. "not enabled!")
  859. if arrival_time is None:
  860. arrival_time = time.time()
  861. processed_inputs = self.process_model_inputs(
  862. inputs,
  863. request_id=request_id,
  864. lora_request=lora_request,
  865. prompt_adapter_request=prompt_adapter_request,
  866. )
  867. self._add_processed_request(
  868. request_id=request_id,
  869. processed_inputs=processed_inputs,
  870. params=params,
  871. arrival_time=arrival_time,
  872. lora_request=lora_request,
  873. prompt_adapter_request=prompt_adapter_request,
  874. )
  875. def _create_sequence_group_with_sampling(
  876. self,
  877. request_id: str,
  878. seq: Sequence,
  879. sampling_params: SamplingParams,
  880. arrival_time: float,
  881. lora_request: Optional[LoRARequest],
  882. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  883. encoder_seq: Optional[Sequence] = None,
  884. negative_seq: Optional[Sequence] = None,
  885. ) -> SequenceGroup:
  886. """Creates a SequenceGroup with SamplingParams."""
  887. max_logprobs = self.get_model_config().max_logprobs
  888. if (sampling_params.logprobs
  889. and sampling_params.logprobs > max_logprobs) or (
  890. sampling_params.prompt_logprobs
  891. and sampling_params.prompt_logprobs > max_logprobs):
  892. raise ValueError(f"Cannot request more than "
  893. f"{max_logprobs} logprobs.")
  894. # Defensive copy of SamplingParams, which are used by the sampler,
  895. # this doesn't deep-copy LogitsProcessor objects
  896. sampling_params = sampling_params.clone()
  897. sampling_params.update_from_generation_config(
  898. self.generation_config_fields, seq.eos_token_id)
  899. # Create the sequence group.
  900. seq_group = SequenceGroup(
  901. request_id=request_id,
  902. seqs=[seq],
  903. arrival_time=arrival_time,
  904. sampling_params=sampling_params,
  905. lora_request=lora_request,
  906. prompt_adapter_request=prompt_adapter_request,
  907. encoder_seq=encoder_seq,
  908. negative_seq=negative_seq)
  909. return seq_group
  910. def _create_sequence_group_with_pooling(
  911. self,
  912. request_id: str,
  913. seq: Sequence,
  914. pooling_params: PoolingParams,
  915. arrival_time: float,
  916. lora_request: Optional[LoRARequest],
  917. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  918. encoder_seq: Optional[Sequence] = None,
  919. negative_seq: Optional[Sequence] = None,
  920. ) -> SequenceGroup:
  921. """Creates a SequenceGroup with PoolingParams."""
  922. # Defensive copy of PoolingParams, which are used by the pooler
  923. pooling_params = pooling_params.clone()
  924. # Create the sequence group.
  925. seq_group = SequenceGroup(
  926. request_id=request_id,
  927. seqs=[seq],
  928. arrival_time=arrival_time,
  929. lora_request=lora_request,
  930. pooling_params=pooling_params,
  931. prompt_adapter_request=prompt_adapter_request,
  932. encoder_seq=encoder_seq,
  933. negative_seq=negative_seq)
  934. return seq_group
  935. def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
  936. """Aborts a request(s) with the given ID.
  937. Args:
  938. request_id: The ID(s) of the request to abort.
  939. Details:
  940. - Refer to the
  941. :meth:`~aphrodite.processing.scheduler.Scheduler.abort_seq_group`
  942. from class :class:`~aphrodite.processing.scheduler.Scheduler`.
  943. Example:
  944. >>> # initialize engine and add a request with request_id
  945. >>> request_id = str(0)
  946. >>> # abort the request
  947. >>> engine.abort_request(request_id)
  948. """
  949. for scheduler in self.scheduler:
  950. scheduler.abort_seq_group(request_id)
  951. def get_model_config(self) -> ModelConfig:
  952. """Gets the model configuration."""
  953. return self.model_config
  954. def get_parallel_config(self) -> ParallelConfig:
  955. """Gets the parallel configuration."""
  956. return self.parallel_config
  957. def get_decoding_config(self) -> DecodingConfig:
  958. """Gets the decoding configuration."""
  959. return self.decoding_config
  960. def get_scheduler_config(self) -> SchedulerConfig:
  961. """Gets the scheduler configuration."""
  962. return self.scheduler_config
  963. def get_lora_config(self) -> LoRAConfig:
  964. """Gets the LoRA configuration."""
  965. return self.lora_config
  966. def get_num_unfinished_requests(self) -> int:
  967. """Gets the number of unfinished requests."""
  968. return sum(scheduler.get_num_unfinished_seq_groups()
  969. for scheduler in self.scheduler)
  970. def has_unfinished_requests(self) -> bool:
  971. """Returns True if there are unfinished requests."""
  972. return any(scheduler.has_unfinished_seqs()
  973. for scheduler in self.scheduler)
  974. def has_unfinished_requests_for_virtual_engine(
  975. self, virtual_engine: int) -> bool:
  976. """
  977. Returns True if there are unfinished requests for the virtual engine.
  978. """
  979. return self.scheduler[virtual_engine].has_unfinished_seqs()
  980. def _process_sequence_group_outputs(
  981. self,
  982. seq_group: SequenceGroup,
  983. outputs: List[EmbeddingSequenceGroupOutput],
  984. ) -> None:
  985. seq_group.embeddings = outputs[0].embeddings
  986. for seq in seq_group.get_seqs():
  987. seq.status = SequenceStatus.FINISHED_STOPPED
  988. return
  989. def _process_model_outputs(
  990. self,
  991. output: GenericSequence[Union[SamplerOutput, PoolerOutput]],
  992. scheduled_seq_groups: List[ScheduledSequenceGroup],
  993. ignored_seq_groups: List[SequenceGroup],
  994. seq_group_metadata_list: List[SequenceGroupMetadata],
  995. ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
  996. """Apply the model output to the sequences in the scheduled seq groups.
  997. Returns RequestOutputs that can be returned to the client.
  998. """
  999. now = time.time()
  1000. # Organize outputs by [sequence group][step] instead of
  1001. # [step][sequence group].
  1002. output_by_sequence_group = create_output_by_sequence_group(
  1003. output, num_seq_groups=len(scheduled_seq_groups))
  1004. # Update the scheduled sequence groups with the model outputs.
  1005. for scheduled_seq_group, outputs, seq_group_meta in zip(
  1006. scheduled_seq_groups, output_by_sequence_group,
  1007. seq_group_metadata_list):
  1008. seq_group = scheduled_seq_group.seq_group
  1009. seq_group.update_num_computed_tokens(
  1010. scheduled_seq_group.token_chunk_size)
  1011. if self.model_config.embedding_mode:
  1012. self._process_sequence_group_outputs(seq_group, outputs)
  1013. continue
  1014. self.output_processor.process_prompt_logprob(seq_group, outputs)
  1015. if seq_group_meta.do_sample:
  1016. self.output_processor.process_outputs(seq_group, outputs)
  1017. # Free the finished sequence groups.
  1018. for scheduler in self.scheduler:
  1019. scheduler.free_finished_seq_groups()
  1020. # Create the outputs.
  1021. request_outputs: List[Union[RequestOutput,
  1022. EmbeddingRequestOutput]] = []
  1023. for scheduled_seq_group in scheduled_seq_groups:
  1024. seq_group = scheduled_seq_group.seq_group
  1025. seq_group.maybe_set_first_token_time(now)
  1026. request_output = RequestOutputFactory.create(seq_group)
  1027. request_outputs.append(request_output)
  1028. for seq_group in ignored_seq_groups:
  1029. request_output = RequestOutputFactory.create(seq_group)
  1030. request_outputs.append(request_output)
  1031. return request_outputs
  1032. def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
  1033. """Performs one decoding iteration and returns newly generated results.
  1034. .. figure:: https://i.imgur.com/sv2HssD.png
  1035. :alt: Overview of the step function
  1036. :align: center
  1037. Overview of the step function.
  1038. Details:
  1039. - Step 1: Schedules the sequences to be executed in the next
  1040. iteration and the token blocks to be swapped in/out/copy.
  1041. - Depending on the scheduling policy,
  1042. sequences may be `preempted/reordered`.
  1043. - A Sequence Group (SG) refer to a group of sequences
  1044. that are generated from the same prompt.
  1045. - Step 2: Calls the distributed executor to execute the model.
  1046. - Step 3: Processes the model output. This mainly includes:
  1047. - Decodes the relevant outputs.
  1048. - Updates the scheduled sequence groups with model outputs
  1049. based on its `sampling parameters` (`use_beam_search` or not).
  1050. - Frees the finished sequence groups.
  1051. - Finally, it creates and returns the newly generated results.
  1052. Example:
  1053. >>> # Please see the example/ folder for more detailed examples.
  1054. >>>
  1055. >>> # initialize engine and request arguments
  1056. >>> engine = AphroditeEngine.from_engine_args(engine_args)
  1057. >>> example_inputs = [(0, "What is LLM?",
  1058. >>> SamplingParams(temperature=0.0))]
  1059. >>>
  1060. >>> # Start the engine with an event loop
  1061. >>> while True:
  1062. >>> if example_inputs:
  1063. >>> req_id, prompt, sampling_params = example_inputs.pop(0)
  1064. >>> engine.add_request(str(req_id), prompt, sampling_params)
  1065. >>>
  1066. >>> # continue the request processing
  1067. >>> request_outputs = engine.step()
  1068. >>> for request_output in request_outputs:
  1069. >>> if request_output.finished:
  1070. >>> # return or show the request output
  1071. >>>
  1072. >>> if not (engine.has_unfinished_requests() or example_inputs):
  1073. >>> break
  1074. """
  1075. if self.parallel_config.pipeline_parallel_size > 1:
  1076. raise NotImplementedError(
  1077. "Pipeline parallelism is only supported through AsyncAphrodite "
  1078. "as performance will be severely degraded otherwise.")
  1079. seq_group_metadata_list, scheduler_outputs = self.scheduler[
  1080. 0].schedule()
  1081. if not scheduler_outputs.is_empty():
  1082. finished_requests_ids = self.scheduler[
  1083. 0].get_and_reset_finished_requests_ids()
  1084. execute_model_req = ExecuteModelRequest(
  1085. seq_group_metadata_list=seq_group_metadata_list,
  1086. blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
  1087. blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
  1088. blocks_to_copy=scheduler_outputs.blocks_to_copy,
  1089. num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
  1090. running_queue_size=scheduler_outputs.running_queue_size,
  1091. finished_requests_ids=finished_requests_ids,
  1092. )
  1093. output = self.model_executor.execute_model(
  1094. execute_model_req=execute_model_req)
  1095. else:
  1096. output = []
  1097. request_outputs = self._process_model_outputs(
  1098. output, scheduler_outputs.scheduled_seq_groups,
  1099. scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
  1100. # Log stats.
  1101. self.do_log_stats(scheduler_outputs, output)
  1102. if not self.has_unfinished_requests():
  1103. # Stop the execute model loop in parallel workers until there are
  1104. # more requests to process. This avoids waiting indefinitely in
  1105. # torch.distributed ops which may otherwise timeout, and unblocks
  1106. # the RPC thread in the workers so that they can process any other
  1107. # queued control plane messages, such as add/remove lora adapters.
  1108. self.model_executor.stop_remote_worker_execution_loop()
  1109. return request_outputs
  1110. def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
  1111. if logger_name in self.stat_loggers:
  1112. raise KeyError(f"Logger with name {logger_name} already exists.")
  1113. self.stat_loggers[logger_name] = logger
  1114. def remove_logger(self, logger_name: str) -> None:
  1115. if logger_name not in self.stat_loggers:
  1116. raise KeyError(f"Logger with name {logger_name} does not exist.")
  1117. del self.stat_loggers[logger_name]
  1118. def do_log_stats(
  1119. self,
  1120. scheduler_outputs: Optional[SchedulerOutputs] = None,
  1121. model_output: Optional[List[SamplerOutput]] = None) -> None:
  1122. """Forced log when no requests active."""
  1123. if self.log_stats:
  1124. stats = self._get_stats(scheduler_outputs, model_output)
  1125. for loggers in self.stat_loggers.values():
  1126. loggers.log(stats)
  1127. def _get_stats(
  1128. self,
  1129. scheduler_outputs: Optional[SchedulerOutputs],
  1130. model_output: Optional[List[SamplerOutput]] = None) -> Stats:
  1131. """Get Stats to be Logged to Prometheus.
  1132. Args:
  1133. scheduler_outputs: Optional, used to populate metrics related to
  1134. the scheduled batch,
  1135. model_output: Optional, used to emit speculative decoding metrics
  1136. which are created by the workers.
  1137. """
  1138. now = time.time()
  1139. # System State
  1140. # Scheduler State
  1141. num_running_sys = sum(
  1142. len(scheduler.running) for scheduler in self.scheduler)
  1143. num_swapped_sys = sum(
  1144. len(scheduler.swapped) for scheduler in self.scheduler)
  1145. num_waiting_sys = sum(
  1146. len(scheduler.waiting) for scheduler in self.scheduler)
  1147. # KV Cache Usage in %
  1148. num_total_gpu = self.cache_config.num_gpu_blocks
  1149. gpu_cache_usage_sys = 0.
  1150. if num_total_gpu is not None:
  1151. num_free_gpu = sum(
  1152. scheduler.block_manager.get_num_free_gpu_blocks()
  1153. for scheduler in self.scheduler)
  1154. if not self.model_config.is_attention_free():
  1155. gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
  1156. else:
  1157. gpu_cache_usage_sys = 0.0
  1158. num_total_cpu = self.cache_config.num_cpu_blocks
  1159. cpu_cache_usage_sys = 0.
  1160. if num_total_cpu is not None and num_total_cpu > 0:
  1161. num_free_cpu = sum(
  1162. scheduler.block_manager.get_num_free_cpu_blocks()
  1163. for scheduler in self.scheduler)
  1164. if not self.model_config.is_attention_free():
  1165. cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
  1166. else:
  1167. cpu_cache_usage_sys = 0.0
  1168. # Prefix Cache Hit Rate. Note that we always use
  1169. # the cache hit rate of the first virtual engine.
  1170. cpu_prefix_cache_hit_rate = self.scheduler[
  1171. 0].get_prefix_cache_hit_rate(Device.CPU)
  1172. gpu_prefix_cache_hit_rate = self.scheduler[
  1173. 0].get_prefix_cache_hit_rate(Device.GPU)
  1174. # Iteration stats
  1175. num_prompt_tokens_iter = 0
  1176. num_generation_tokens_iter = 0
  1177. time_to_first_tokens_iter: List[float] = []
  1178. time_per_output_tokens_iter: List[float] = []
  1179. num_preemption_iter = (0 if scheduler_outputs is None else
  1180. scheduler_outputs.preempted)
  1181. # Request stats
  1182. # Latency
  1183. time_e2e_requests: List[float] = []
  1184. # Metadata
  1185. num_prompt_tokens_requests: List[int] = []
  1186. num_generation_tokens_requests: List[int] = []
  1187. best_of_requests: List[int] = []
  1188. n_requests: List[int] = []
  1189. finished_reason_requests: List[str] = []
  1190. # NOTE: This loop assumes prefill seq_groups are before
  1191. # decode seq_groups in scheduled_seq_groups.
  1192. if scheduler_outputs is not None:
  1193. num_generation_tokens_from_prefill_groups = 0.
  1194. # NOTE: if scheduler_outputs.num_prefill_groups > 0 and
  1195. # the len of scheduler_outputs.scheduled_seq_groups is !=
  1196. # scheduler_outputs.num_prefill_groups, this means that
  1197. # chunked prefills have been detected.
  1198. for idx, scheduled_seq_group in enumerate(
  1199. scheduler_outputs.scheduled_seq_groups):
  1200. group_was_prefill = idx < scheduler_outputs.num_prefill_groups
  1201. seq_group = scheduled_seq_group.seq_group
  1202. # NOTE: a seq_group that completed all of its prefill tokens
  1203. # in the last iteration will have seq_group.is_prefill() = False
  1204. # with group_was_prefill = True
  1205. if group_was_prefill:
  1206. # Number of prompt tokens.
  1207. num_prompt_tokens_iter += (
  1208. scheduled_seq_group.token_chunk_size)
  1209. # If the seq_group just finished the prefill state
  1210. # get TTFT.
  1211. if not seq_group.is_prefill():
  1212. latency = seq_group.get_last_latency(now)
  1213. time_to_first_tokens_iter.append(latency)
  1214. # One generation token per finished prefill.
  1215. num_generation_tokens_from_prefill_groups += (
  1216. seq_group.num_seqs())
  1217. else:
  1218. # TPOTs.
  1219. latency = seq_group.get_last_latency(now)
  1220. time_per_output_tokens_iter.append(latency)
  1221. # Because of chunked prefill, we can have a single sequence
  1222. # group that does multiple prompt_runs. To prevent logging
  1223. # the same metadata more than once per request, we standardize
  1224. # on logging request level information for finished requests,
  1225. # which can only happen once.
  1226. if seq_group.is_finished():
  1227. # Latency timings
  1228. time_e2e_requests.append(now -
  1229. seq_group.metrics.arrival_time)
  1230. # Metadata
  1231. num_prompt_tokens_requests.append(
  1232. len(seq_group.prompt_token_ids))
  1233. num_generation_tokens_requests.extend([
  1234. seq.get_output_len()
  1235. for seq in seq_group.get_finished_seqs()
  1236. ])
  1237. if seq_group.sampling_params is not None:
  1238. best_of_requests.append(
  1239. seq_group.sampling_params.best_of)
  1240. n_requests.append(seq_group.sampling_params.n)
  1241. finished_reason_requests.extend([
  1242. SequenceStatus.get_finished_reason(seq.status)
  1243. for seq in seq_group.get_finished_seqs()
  1244. ])
  1245. # Number of generation tokens.
  1246. # num_batched_tokens equals the number of prompt_tokens plus the
  1247. # number of decode_tokens in a single iteration. So,
  1248. # num_generation_tokens = num_batched_tokens - num_prompt_tokens
  1249. # + num_generation_tokens_from_prefill_groups (since we generate
  1250. # one token on prefills on iters where the prefill finishes).
  1251. num_generation_tokens_iter = (
  1252. scheduler_outputs.num_batched_tokens - num_prompt_tokens_iter +
  1253. num_generation_tokens_from_prefill_groups)
  1254. # Spec decode, if enabled, emits specialized metrics from the worker in
  1255. # sampler output.
  1256. if model_output and (model_output[0].spec_decode_worker_metrics
  1257. is not None):
  1258. spec_decode_metrics = model_output[0].spec_decode_worker_metrics
  1259. else:
  1260. spec_decode_metrics = None
  1261. return Stats(
  1262. now=now,
  1263. # System stats
  1264. # Scheduler State
  1265. num_running_sys=num_running_sys,
  1266. num_swapped_sys=num_swapped_sys,
  1267. num_waiting_sys=num_waiting_sys,
  1268. # KV Cache Usage in %
  1269. gpu_cache_usage_sys=gpu_cache_usage_sys,
  1270. cpu_cache_usage_sys=cpu_cache_usage_sys,
  1271. # Prefix Cache Hit Rate
  1272. cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
  1273. gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
  1274. # Iteration stats
  1275. num_prompt_tokens_iter=num_prompt_tokens_iter,
  1276. num_generation_tokens_iter=num_generation_tokens_iter,
  1277. time_to_first_tokens_iter=time_to_first_tokens_iter,
  1278. time_per_output_tokens_iter=time_per_output_tokens_iter,
  1279. spec_decode_metrics=spec_decode_metrics,
  1280. num_preemption_iter=num_preemption_iter,
  1281. # Request stats
  1282. # Latency
  1283. time_e2e_requests=time_e2e_requests,
  1284. # Metadata
  1285. num_prompt_tokens_requests=num_prompt_tokens_requests,
  1286. num_generation_tokens_requests=num_generation_tokens_requests,
  1287. best_of_requests=best_of_requests,
  1288. n_requests=n_requests,
  1289. finished_reason_requests=finished_reason_requests,
  1290. )
  1291. def add_lora(self, lora_request: LoRARequest) -> bool:
  1292. return self.model_executor.add_lora(lora_request)
  1293. def remove_lora(self, lora_id: int) -> bool:
  1294. return self.model_executor.remove_lora(lora_id)
  1295. def list_loras(self) -> List[int]:
  1296. return self.model_executor.list_loras()
  1297. def pin_lora(self, lora_id: int) -> bool:
  1298. return self.model_executor.pin_lora(lora_id)
  1299. def add_prompt_adapter(
  1300. self, prompt_adapter_request: PromptAdapterRequest) -> bool:
  1301. return self.model_executor.add_prompt_adapter(prompt_adapter_request)
  1302. def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  1303. return self.model_executor.remove_prompt_adapter(prompt_adapter_id)
  1304. def list_prompt_adapters(self) -> List[int]:
  1305. return self.model_executor.list_prompt_adapters()
  1306. def check_health(self) -> None:
  1307. if self.tokenizer:
  1308. self.tokenizer.check_health()
  1309. self.model_executor.check_health()
  1310. def is_encoder_decoder_model(self):
  1311. return self.model_config.is_encoder_decoder_model
  1312. def is_embedding_model(self):
  1313. return self.model_config.is_embedding_model
  1314. setup_logger()