Browse Source

core: rename `PromptInputs,inputs` -> `PromptType,prompt` (#1080)

* core: rename `PromptInputs,inputs` -> `PromptType,prompt`

* format

* attempt fixing codespell

* codespell attempt 2
AlpinDale 2 months ago
parent
commit
86bf2cc4f3
39 changed files with 850 additions and 329 deletions
  1. 42 38
      aphrodite/endpoints/llm.py
  2. 6 5
      aphrodite/engine/aphrodite_engine.py
  3. 13 14
      aphrodite/engine/async_aphrodite.py
  4. 2 2
      aphrodite/engine/multiprocessing/__init__.py
  5. 11 11
      aphrodite/engine/multiprocessing/client.py
  6. 1 1
      aphrodite/engine/multiprocessing/engine.py
  7. 3 3
      aphrodite/engine/protocol.py
  8. 4 3
      aphrodite/inputs/__init__.py
  9. 25 22
      aphrodite/inputs/data.py
  10. 12 11
      aphrodite/inputs/parse.py
  11. 172 128
      aphrodite/inputs/preprocess.py
  12. 2 2
      pyproject.toml
  13. 9 15
      tests/benchmarks/engine/latency.py
  14. 2 1
      tests/models/decoder_only/audio_language/test_ultravox.py
  15. 2 1
      tests/models/decoder_only/language/test_aqlm.py
  16. 6 3
      tests/models/decoder_only/language/test_big_models.py
  17. 2 1
      tests/models/decoder_only/language/test_danube3_4b.py
  18. 2 1
      tests/models/decoder_only/language/test_granite.py
  19. 13 5
      tests/models/decoder_only/language/test_jamba.py
  20. 2 1
      tests/models/decoder_only/language/test_models.py
  21. 6 2
      tests/models/decoder_only/vision_language/test_blip2.py
  22. 4 4
      tests/models/decoder_only/vision_language/test_chameleon.py
  23. 4 3
      tests/models/decoder_only/vision_language/test_fuyu.py
  24. 3 3
      tests/models/decoder_only/vision_language/test_internvl.py
  25. 4 4
      tests/models/decoder_only/vision_language/test_llava.py
  26. 2 2
      tests/models/decoder_only/vision_language/test_llava_image_embeds.py
  27. 2 2
      tests/models/decoder_only/vision_language/test_llava_next.py
  28. 5 5
      tests/models/decoder_only/vision_language/test_llava_next_video.py
  29. 2 2
      tests/models/decoder_only/vision_language/test_minicpmv.py
  30. 4 4
      tests/models/decoder_only/vision_language/test_paligemma.py
  31. 4 3
      tests/models/decoder_only/vision_language/test_phi3v.py
  32. 3 2
      tests/models/decoder_only/vision_language/test_pixtral.py
  33. 3 2
      tests/models/embedding/language/test_embedding.py
  34. 28 23
      tests/models/encoder_decoder/language/test_bart.py
  35. 0 0
      tests/mq_aphrodite_engine/__init__.py
  36. 68 0
      tests/mq_aphrodite_engine/test_abort.py
  37. 243 0
      tests/mq_aphrodite_engine/test_error_handling.py
  38. 58 0
      tests/mq_aphrodite_engine/test_load.py
  39. 76 0
      tests/mq_aphrodite_engine/utils.py

+ 42 - 38
aphrodite/endpoints/llm.py

@@ -14,7 +14,7 @@ from aphrodite.endpoints.chat_utils import (ChatCompletionMessageParam,
                                             parse_chat_messages)
                                             parse_chat_messages)
 from aphrodite.engine.aphrodite_engine import AphroditeEngine
 from aphrodite.engine.aphrodite_engine import AphroditeEngine
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.engine.args_tools import EngineArgs
-from aphrodite.inputs import PromptInputs, TextPrompt, TokensPrompt
+from aphrodite.inputs import PromptType, TextPrompt, TokensPrompt
 from aphrodite.inputs.parse import parse_and_batch_prompt
 from aphrodite.inputs.parse import parse_and_batch_prompt
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
 from aphrodite.modeling.guided_decoding import (
 from aphrodite.modeling.guided_decoding import (
@@ -254,8 +254,8 @@ class LLM:
     @overload
     @overload
     def generate(
     def generate(
         self,
         self,
-        inputs: Union[PromptInputs, Sequence[PromptInputs]],
-        /,  # We may enable `inputs` keyword after removing the old API
+        prompts: Union[PromptType, Sequence[PromptType]],
+        /,
         *,
         *,
         sampling_params: Optional[Union[SamplingParams,
         sampling_params: Optional[Union[SamplingParams,
                                         Sequence[SamplingParams]]] = None,
                                         Sequence[SamplingParams]]] = None,
@@ -272,7 +272,7 @@ class LLM:
     )
     )
     def generate(
     def generate(
         self,
         self,
-        prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
+        prompts: Union[Union[PromptType, Sequence[PromptType]],
                        Optional[Union[str, List[str]]]] = None,
                        Optional[Union[str, List[str]]]] = None,
         sampling_params: Optional[Union[SamplingParams,
         sampling_params: Optional[Union[SamplingParams,
                                         Sequence[SamplingParams]]] = None,
                                         Sequence[SamplingParams]]] = None,
@@ -290,7 +290,9 @@ class LLM:
         into a single list and pass it to this method.
         into a single list and pass it to this method.
 
 
         Args:
         Args:
-            inputs: A list of inputs to generate completions for.
+            prompts: The prompts to the LLM. You may pass a sequence of prompts
+                for batch inference. See :class:`~aphrodite.inputs.PromptType`
+                for more details about the format of each prompts.
             sampling_params: The sampling parameters for text generation. If
             sampling_params: The sampling parameters for text generation. If
                 None, we use the default sampling parameters.
                 None, we use the default sampling parameters.
                 When it is a single value, it is applied to every prompt.
                 When it is a single value, it is applied to every prompt.
@@ -316,12 +318,13 @@ class LLM:
                 "models (XForCausalLM, XForConditionalGeneration).")
                 "models (XForCausalLM, XForConditionalGeneration).")
 
 
         if prompt_token_ids is not None:
         if prompt_token_ids is not None:
-            inputs = self._convert_v1_inputs(
+            parsed_prompts = self._convert_v1_inputs(
                 prompts=cast(Optional[Union[str, List[str]]], prompts),
                 prompts=cast(Optional[Union[str, List[str]]], prompts),
                 prompt_token_ids=prompt_token_ids,
                 prompt_token_ids=prompt_token_ids,
             )
             )
         else:
         else:
-            inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
+            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
+                                  prompts)
 
 
         if isinstance(guided_options_request, dict):
         if isinstance(guided_options_request, dict):
             if len(guided_options_request) > 1:
             if len(guided_options_request) > 1:
@@ -336,7 +339,7 @@ class LLM:
             sampling_params = SamplingParams()
             sampling_params = SamplingParams()
 
 
         self._validate_and_add_requests(
         self._validate_and_add_requests(
-            inputs=inputs,
+            prompts=parsed_prompts,
             params=sampling_params,
             params=sampling_params,
             lora_request=lora_request,
             lora_request=lora_request,
             prompt_adapter_request=prompt_adapter_request,
             prompt_adapter_request=prompt_adapter_request,
@@ -392,9 +395,9 @@ class LLM:
         conversation, mm_data = parse_chat_messages(messages, model_config,
         conversation, mm_data = parse_chat_messages(messages, model_config,
                                                     tokenizer)
                                                     tokenizer)
 
 
-        prompt: Union[str, List[int]]
+        prompt_data: Union[str, List[int]]
         if isinstance(tokenizer, MistralTokenizer):
         if isinstance(tokenizer, MistralTokenizer):
-            prompt = apply_mistral_chat_template(
+            prompt_data = apply_mistral_chat_template(
                 tokenizer,
                 tokenizer,
                 messages=messages,
                 messages=messages,
                 chat_template=chat_template,
                 chat_template=chat_template,
@@ -402,7 +405,7 @@ class LLM:
                 tools=tools,
                 tools=tools,
             )
             )
         else:
         else:
-            prompt = apply_hf_chat_template(
+            prompt_data = apply_hf_chat_template(
                 tokenizer,
                 tokenizer,
                 conversation=conversation,
                 conversation=conversation,
                 chat_template=chat_template,
                 chat_template=chat_template,
@@ -410,17 +413,17 @@ class LLM:
                 tools=tools,
                 tools=tools,
             )
             )
 
 
-        inputs: PromptInputs
-        if is_list_of(prompt, int):
-            inputs = TokensPrompt(prompt_token_ids=prompt)
+        prompt: PromptType
+        if is_list_of(prompt_data, int):
+            prompt = TokensPrompt(prompt_token_ids=prompt_data)
         else:
         else:
-            inputs = TextPrompt(prompt=prompt)
+            prompt = TextPrompt(prompt=prompt_data)
 
 
         if mm_data is not None:
         if mm_data is not None:
-            inputs["multi_modal_data"] = mm_data
+            prompt["multi_modal_data"] = mm_data
 
 
         return self.generate(
         return self.generate(
-            inputs,
+            prompt,
             sampling_params=sampling_params,
             sampling_params=sampling_params,
             use_tqdm=use_tqdm,
             use_tqdm=use_tqdm,
             lora_request=lora_request,
             lora_request=lora_request,
@@ -490,8 +493,8 @@ class LLM:
     @overload
     @overload
     def encode(
     def encode(
         self,
         self,
-        inputs: Union[PromptInputs, Sequence[PromptInputs]],
-        /,  # We may enable `inputs` keyword after removing the old API
+        prompts: Union[PromptType, Sequence[PromptType]],
+        /,
         *,
         *,
         pooling_params: Optional[Union[PoolingParams,
         pooling_params: Optional[Union[PoolingParams,
                                        Sequence[PoolingParams]]] = None,
                                        Sequence[PoolingParams]]] = None,
@@ -508,7 +511,7 @@ class LLM:
     )
     )
     def encode(
     def encode(
         self,
         self,
-        prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
+        prompts: Union[Union[PromptType, Sequence[PromptType]],
                        Optional[Union[str, List[str]]]] = None,
                        Optional[Union[str, List[str]]]] = None,
         pooling_params: Optional[Union[PoolingParams,
         pooling_params: Optional[Union[PoolingParams,
                                        Sequence[PoolingParams]]] = None,
                                        Sequence[PoolingParams]]] = None,
@@ -524,9 +527,9 @@ class LLM:
         into a single list and pass it to this method.
         into a single list and pass it to this method.
 
 
         Args:
         Args:
-            inputs: The inputs to the LLM. You may pass a sequence of inputs for
-                batch inference. See :class:`~aphrodite.inputs.PromptInputs`
-                for more details about the format of each input.
+            prompts: The prompts to the LLM. You may pass a sequence of prompts
+                for batch inference. See :class:`~aphrodite.inputs.PromptType`
+                for more details about the format of each prompts.
             pooling_params: The pooling parameters for pooling. If None, we
             pooling_params: The pooling parameters for pooling. If None, we
                 use the default pooling parameters.
                 use the default pooling parameters.
             use_tqdm: Whether to use tqdm to display the progress bar.
             use_tqdm: Whether to use tqdm to display the progress bar.
@@ -549,19 +552,20 @@ class LLM:
             )
             )
 
 
         if prompt_token_ids is not None:
         if prompt_token_ids is not None:
-            inputs = self._convert_v1_inputs(
+            parsed_prompts = self._convert_v1_inputs(
                 prompts=cast(Optional[Union[str, List[str]]], prompts),
                 prompts=cast(Optional[Union[str, List[str]]], prompts),
                 prompt_token_ids=prompt_token_ids,
                 prompt_token_ids=prompt_token_ids,
             )
             )
         else:
         else:
-            inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
+            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
+                                  prompts)
 
 
         if pooling_params is None:
         if pooling_params is None:
             # Use default pooling params.
             # Use default pooling params.
             pooling_params = PoolingParams()
             pooling_params = PoolingParams()
 
 
         self._validate_and_add_requests(
         self._validate_and_add_requests(
-            inputs=inputs,
+            prompts=parsed_prompts,
             params=pooling_params,
             params=pooling_params,
             lora_request=lora_request,
             lora_request=lora_request,
             prompt_adapter_request=prompt_adapter_request,
             prompt_adapter_request=prompt_adapter_request,
@@ -599,9 +603,9 @@ class LLM:
             raise ValueError("Either prompts or prompt_token_ids must be "
             raise ValueError("Either prompts or prompt_token_ids must be "
                              "provided.")
                              "provided.")
 
 
-        inputs: List[PromptInputs] = []
+        parsed_prompts: List[PromptType] = []
         for i in range(num_requests):
         for i in range(num_requests):
-            item: PromptInputs
+            item: PromptType
 
 
             if prompts is not None:
             if prompts is not None:
                 item = TextPrompt(prompt=prompts[i])
                 item = TextPrompt(prompt=prompts[i])
@@ -610,24 +614,24 @@ class LLM:
             else:
             else:
                 raise AssertionError
                 raise AssertionError
 
 
-            inputs.append(item)
+            parsed_prompts.append(item)
 
 
-        return inputs
+        return parsed_prompts
 
 
     def _validate_and_add_requests(
     def _validate_and_add_requests(
         self,
         self,
-        inputs: Union[PromptInputs, Sequence[PromptInputs]],
+        prompts: Union[PromptType, Sequence[PromptType]],
         params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
         params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                       Sequence[PoolingParams]],
                       Sequence[PoolingParams]],
         lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
         lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
         prompt_adapter_request: Optional[PromptAdapterRequest],
         prompt_adapter_request: Optional[PromptAdapterRequest],
         guided_options: Optional[GuidedDecodingRequest] = None,
         guided_options: Optional[GuidedDecodingRequest] = None,
     ) -> None:
     ) -> None:
-        if isinstance(inputs, (str, dict)):
+        if isinstance(prompts, (str, dict)):
             # Convert a single prompt to a list.
             # Convert a single prompt to a list.
-            inputs = [inputs]
+            prompts = [prompts]
 
 
-        num_requests = len(inputs)
+        num_requests = len(prompts)
         if isinstance(params, list) and len(params) != num_requests:
         if isinstance(params, list) and len(params) != num_requests:
             raise ValueError("The lengths of prompts and params "
             raise ValueError("The lengths of prompts and params "
                              "must be the same.")
                              "must be the same.")
@@ -644,9 +648,9 @@ class LLM:
                 sp.output_kind = RequestOutputKind.FINAL_ONLY
                 sp.output_kind = RequestOutputKind.FINAL_ONLY
 
 
         # Add requests to the engine.
         # Add requests to the engine.
-        for i, request_inputs in enumerate(inputs):
+        for i, prompt in enumerate(prompts):
             self._add_request(
             self._add_request(
-                request_inputs,
+                prompt,
                 params[i] if isinstance(params, Sequence) else params,
                 params[i] if isinstance(params, Sequence) else params,
                 lora_request=lora_request[i] if isinstance(
                 lora_request=lora_request[i] if isinstance(
                     lora_request, Sequence) else lora_request,
                     lora_request, Sequence) else lora_request,
@@ -655,7 +659,7 @@ class LLM:
 
 
     def _add_request(
     def _add_request(
         self,
         self,
-        inputs: PromptInputs,
+        prompt: PromptType,
         params: Union[SamplingParams, PoolingParams],
         params: Union[SamplingParams, PoolingParams],
         lora_request: Optional[LoRARequest] = None,
         lora_request: Optional[LoRARequest] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@@ -663,7 +667,7 @@ class LLM:
         request_id = str(next(self.request_counter))
         request_id = str(next(self.request_counter))
         self.llm_engine.add_request(
         self.llm_engine.add_request(
             request_id,
             request_id,
-            inputs,
+            prompt,
             params,
             params,
             lora_request=lora_request,
             lora_request=lora_request,
             prompt_adapter_request=prompt_adapter_request,
             prompt_adapter_request=prompt_adapter_request,

+ 6 - 5
aphrodite/engine/aphrodite_engine.py

@@ -38,7 +38,7 @@ from aphrodite.engine.output_processor.util import (
 from aphrodite.executor.executor_base import ExecutorBase
 from aphrodite.executor.executor_base import ExecutorBase
 from aphrodite.executor.ray_utils import initialize_ray_cluster
 from aphrodite.executor.ray_utils import initialize_ray_cluster
 from aphrodite.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
 from aphrodite.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
-                              InputRegistry, LLMInputs, PromptInputs)
+                              InputRegistry, LLMInputs, PromptType)
 from aphrodite.inputs.preprocess import InputPreprocessor
 from aphrodite.inputs.preprocess import InputPreprocessor
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.layers.sampler import SamplerOutput
@@ -613,7 +613,7 @@ class AphroditeEngine:
     def add_request(
     def add_request(
         self,
         self,
         request_id: str,
         request_id: str,
-        inputs: PromptInputs,
+        prompt: PromptType,
         params: Union[SamplingParams, PoolingParams],
         params: Union[SamplingParams, PoolingParams],
         arrival_time: Optional[float] = None,
         arrival_time: Optional[float] = None,
         lora_request: Optional[LoRARequest] = None,
         lora_request: Optional[LoRARequest] = None,
@@ -627,8 +627,9 @@ class AphroditeEngine:
 
 
         Args:
         Args:
             request_id: The unique ID of the request.
             request_id: The unique ID of the request.
-            prompt: The prompt string. Can be None if prompt_token_ids is
-                provided.
+            prompt: The prompt to the LLM. See
+                :class:`~aphrodite.common.inputs.PromptType`
+                for more details about the format of each input.
             params: Parameters for sampling or pooling. SamplingParams
             params: Parameters for sampling or pooling. SamplingParams
                 for text generation. PoolingParams for pooling.
                 for text generation. PoolingParams for pooling.
             prompt_token_ids: The token IDs of the prompt. If None, we
             prompt_token_ids: The token IDs of the prompt. If None, we
@@ -669,7 +670,7 @@ class AphroditeEngine:
             arrival_time = time.time()
             arrival_time = time.time()
 
 
         preprocessed_inputs = self.input_preprocessor.preprocess(
         preprocessed_inputs = self.input_preprocessor.preprocess(
-            inputs,
+            prompt,
             request_id=request_id,
             request_id=request_id,
             lora_request=lora_request,
             lora_request=lora_request,
             prompt_adapter_request=prompt_adapter_request,
             prompt_adapter_request=prompt_adapter_request,

+ 13 - 14
aphrodite/engine/async_aphrodite.py

@@ -24,7 +24,7 @@ from aphrodite.engine.async_timeout import asyncio_timeout
 from aphrodite.engine.metrics_types import StatLoggerBase
 from aphrodite.engine.metrics_types import StatLoggerBase
 from aphrodite.executor.executor_base import ExecutorAsyncBase
 from aphrodite.executor.executor_base import ExecutorAsyncBase
 from aphrodite.executor.ray_utils import initialize_ray_cluster
 from aphrodite.executor.ray_utils import initialize_ray_cluster
-from aphrodite.inputs import PromptInputs
+from aphrodite.inputs import PromptType
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.processing.scheduler import SchedulerOutputs
 from aphrodite.processing.scheduler import SchedulerOutputs
@@ -397,7 +397,7 @@ class _AsyncAphrodite(AphroditeEngine):
     async def add_request_async(
     async def add_request_async(
         self,
         self,
         request_id: str,
         request_id: str,
-        inputs: PromptInputs,
+        prompt: PromptType,
         params: Union[SamplingParams, PoolingParams],
         params: Union[SamplingParams, PoolingParams],
         arrival_time: Optional[float] = None,
         arrival_time: Optional[float] = None,
         lora_request: Optional[LoRARequest] = None,
         lora_request: Optional[LoRARequest] = None,
@@ -411,7 +411,7 @@ class _AsyncAphrodite(AphroditeEngine):
             arrival_time = time.time()
             arrival_time = time.time()
 
 
         preprocessed_inputs = await self.input_preprocessor.preprocess_async(
         preprocessed_inputs = await self.input_preprocessor.preprocess_async(
-            inputs,
+            prompt,
             request_id=request_id,
             request_id=request_id,
             lora_request=lora_request,
             lora_request=lora_request,
             prompt_adapter_request=prompt_adapter_request,
             prompt_adapter_request=prompt_adapter_request,
@@ -766,7 +766,7 @@ class AsyncAphrodite:
     async def add_request(
     async def add_request(
         self,
         self,
         request_id: str,
         request_id: str,
-        inputs: PromptInputs,
+        prompt: PromptType,
         params: Union[SamplingParams, PoolingParams],
         params: Union[SamplingParams, PoolingParams],
         arrival_time: Optional[float] = None,
         arrival_time: Optional[float] = None,
         lora_request: Optional[LoRARequest] = None,
         lora_request: Optional[LoRARequest] = None,
@@ -785,7 +785,7 @@ class AsyncAphrodite:
         stream = self._request_tracker.add_request(
         stream = self._request_tracker.add_request(
             request_id,
             request_id,
             verbose=self.log_requests,
             verbose=self.log_requests,
-            inputs=inputs,
+            prompt=prompt,
             params=params,
             params=params,
             arrival_time=arrival_time or time.time(),
             arrival_time=arrival_time or time.time(),
             lora_request=lora_request,
             lora_request=lora_request,
@@ -795,7 +795,7 @@ class AsyncAphrodite:
 
 
     async def generate(
     async def generate(
         self,
         self,
-        inputs: PromptInputs,
+        prompt: PromptType,
         sampling_params: SamplingParams,
         sampling_params: SamplingParams,
         request_id: str,
         request_id: str,
         lora_request: Optional[LoRARequest] = None,
         lora_request: Optional[LoRARequest] = None,
@@ -808,9 +808,8 @@ class AsyncAphrodite:
         outputs from the AphroditeEngine to the caller.
         outputs from the AphroditeEngine to the caller.
 
 
         Args:
         Args:
-            inputs: The inputs to the LLM. See
-                :class:`~aphrodite.inputs.PromptInputs`
-                for more details about the format of each input.
+            prompt: The prompt to the LLM. See
+                :class:`~aphrodite.inputs.PromptType`
             sampling_params: The sampling parameters of the request.
             sampling_params: The sampling parameters of the request.
             request_id: The unique id of the request.
             request_id: The unique id of the request.
             lora_request: LoRA request to use for generation, if any.
             lora_request: LoRA request to use for generation, if any.
@@ -867,7 +866,7 @@ class AsyncAphrodite:
         """
         """
         async for output in await self.add_request(
         async for output in await self.add_request(
                 request_id,
                 request_id,
-                inputs,
+                prompt,
                 sampling_params,
                 sampling_params,
                 lora_request=lora_request,
                 lora_request=lora_request,
                 prompt_adapter_request=prompt_adapter_request,
                 prompt_adapter_request=prompt_adapter_request,
@@ -876,7 +875,7 @@ class AsyncAphrodite:
 
 
     async def encode(
     async def encode(
         self,
         self,
-        inputs: PromptInputs,
+        prompt: PromptType,
         pooling_params: PoolingParams,
         pooling_params: PoolingParams,
         request_id: str,
         request_id: str,
         lora_request: Optional[LoRARequest] = None,
         lora_request: Optional[LoRARequest] = None,
@@ -888,8 +887,8 @@ class AsyncAphrodite:
         outputs from the AphroditeEngine to the caller.
         outputs from the AphroditeEngine to the caller.
 
 
         Args:
         Args:
-            inputs: The inputs to the LLM. See
-                :class:`~aphrodite.inputs.PromptInputs`
+            prompt: The prompt to the LLM. See
+                :class:`~aphrodite.inputs.PromptType`
                 for more details about the format of each input.
                 for more details about the format of each input.
             pooling_params: The pooling parameters of the request.
             pooling_params: The pooling parameters of the request.
             request_id: The unique id of the request.
             request_id: The unique id of the request.
@@ -942,7 +941,7 @@ class AsyncAphrodite:
         """
         """
         async for output in await self.add_request(
         async for output in await self.add_request(
                 request_id,
                 request_id,
-                inputs,
+                prompt,
                 pooling_params,
                 pooling_params,
                 lora_request=lora_request,
                 lora_request=lora_request,
         ):
         ):

+ 2 - 2
aphrodite/engine/multiprocessing/__init__.py

@@ -5,7 +5,7 @@ from typing import List, Mapping, Optional, Union
 from aphrodite import PoolingParams
 from aphrodite import PoolingParams
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sampling_params import SamplingParams
-from aphrodite.inputs import PromptInputs
+from aphrodite.inputs import PromptType
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 
 
@@ -23,7 +23,7 @@ class MQEngineDeadError(RuntimeError):
 
 
 @dataclass
 @dataclass
 class RPCProcessRequest:
 class RPCProcessRequest:
-    inputs: PromptInputs
+    prompt: PromptType
     params: Union[SamplingParams, PoolingParams]
     params: Union[SamplingParams, PoolingParams]
     request_id: str
     request_id: str
     lora_request: Optional[LoRARequest] = None
     lora_request: Optional[LoRARequest] = None

+ 11 - 11
aphrodite/engine/multiprocessing/client.py

@@ -26,7 +26,7 @@ from aphrodite.engine.multiprocessing import (APHRODITE_RPC_SUCCESS_STR,
                                               RPCProcessRequest,
                                               RPCProcessRequest,
                                               RPCStartupRequest,
                                               RPCStartupRequest,
                                               RPCStartupResponse)
                                               RPCStartupResponse)
-from aphrodite.inputs import PromptInputs
+from aphrodite.inputs import PromptType
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 from aphrodite.transformers_utils.tokenizer_group import (
 from aphrodite.transformers_utils.tokenizer_group import (
@@ -375,7 +375,7 @@ class MQAphroditeEngineClient:
 
 
     def generate(
     def generate(
         self,
         self,
-        inputs: PromptInputs,
+        prompt: PromptType,
         sampling_params: SamplingParams,
         sampling_params: SamplingParams,
         request_id: str,
         request_id: str,
         lora_request: Optional[LoRARequest] = None,
         lora_request: Optional[LoRARequest] = None,
@@ -386,8 +386,8 @@ class MQAphroditeEngineClient:
         request into the waiting queue of the AphroditeEngine and streams the
         request into the waiting queue of the AphroditeEngine and streams the
         outputs from the AphroditeEngine to the caller.
         outputs from the AphroditeEngine to the caller.
         Args:
         Args:
-            inputs: The inputs to the LLM. See
-                :class:`~aphrodite.inputs.PromptInputs`
+            prompt: The prompt to the LLM. See
+                :class:`~aphrodite.inputs.PromptType`
                 for more details about the format of each input.
                 for more details about the format of each input.
             sampling_params: The sampling parameters of the request.
             sampling_params: The sampling parameters of the request.
             request_id: The unique id of the request.
             request_id: The unique id of the request.
@@ -395,12 +395,12 @@ class MQAphroditeEngineClient:
             prompt_adapter_request: Prompt Adapter request to use
             prompt_adapter_request: Prompt Adapter request to use
                                             for generation, if any.
                                             for generation, if any.
         """
         """
-        return self._process_request(inputs, sampling_params, request_id,
+        return self._process_request(prompt, sampling_params, request_id,
                                      lora_request, prompt_adapter_request)
                                      lora_request, prompt_adapter_request)
 
 
     def encode(
     def encode(
         self,
         self,
-        inputs: PromptInputs,
+        prompt: PromptType,
         pooling_params: PoolingParams,
         pooling_params: PoolingParams,
         request_id: str,
         request_id: str,
         lora_request: Optional[LoRARequest] = None,
         lora_request: Optional[LoRARequest] = None,
@@ -410,8 +410,8 @@ class MQAphroditeEngineClient:
         request into the waiting queue of the AphroditeEngine and streams the
         request into the waiting queue of the AphroditeEngine and streams the
         outputs from the AphroditeEngine to the caller.
         outputs from the AphroditeEngine to the caller.
         Args:
         Args:
-            inputs: The inputs to the LLM. See
-                :class:`~aphrodite.inputs.PromptInputs`
+            prompt: The prompt to the LLM. See
+                :class:`~aphrodite.inputs.PromptType`
                 for more details about the format of each input.
                 for more details about the format of each input.
             pooling_params: The pooling parameters of the request.
             pooling_params: The pooling parameters of the request.
             request_id: The unique id of the request.
             request_id: The unique id of the request.
@@ -420,12 +420,12 @@ class MQAphroditeEngineClient:
             The output `EmbeddingRequestOutput` objects from the AphroditeEngine
             The output `EmbeddingRequestOutput` objects from the AphroditeEngine
             for the request.
             for the request.
         """
         """
-        return self._process_request(inputs, pooling_params, request_id,
+        return self._process_request(prompt, pooling_params, request_id,
                                      lora_request)
                                      lora_request)
 
 
     async def _process_request(
     async def _process_request(
         self,
         self,
-        inputs: PromptInputs,
+        prompt: PromptType,
         params: Union[SamplingParams, PoolingParams],
         params: Union[SamplingParams, PoolingParams],
         request_id: str,
         request_id: str,
         lora_request: Optional[LoRARequest] = None,
         lora_request: Optional[LoRARequest] = None,
@@ -457,7 +457,7 @@ class MQAphroditeEngineClient:
 
 
             request_bytes = pickle.dumps(
             request_bytes = pickle.dumps(
                 RPCProcessRequest(
                 RPCProcessRequest(
-                    inputs=inputs,
+                    prompt=prompt,
                     params=params,
                     params=params,
                     request_id=request_id,
                     request_id=request_id,
                     lora_request=lora_request,
                     lora_request=lora_request,

+ 1 - 1
aphrodite/engine/multiprocessing/engine.py

@@ -246,7 +246,7 @@ class MQAphroditeEngine:
         try:
         try:
             self.engine.add_request(
             self.engine.add_request(
                 request_id=request_id,
                 request_id=request_id,
-                inputs=request.inputs,
+                prompt=request.prompt,
                 params=request.params,
                 params=request.params,
                 lora_request=request.lora_request,
                 lora_request=request.lora_request,
                 prompt_adapter_request=request.prompt_adapter_request)
                 prompt_adapter_request=request.prompt_adapter_request)

+ 3 - 3
aphrodite/engine/protocol.py

@@ -6,7 +6,7 @@ from aphrodite.common.config import DecodingConfig, ModelConfig
 from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
 from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
 from aphrodite.common.pooling_params import PoolingParams
 from aphrodite.common.pooling_params import PoolingParams
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sampling_params import SamplingParams
-from aphrodite.inputs.data import PromptInputs
+from aphrodite.inputs.data import PromptType
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.processing.scheduler import SchedulerOutputs
 from aphrodite.processing.scheduler import SchedulerOutputs
@@ -35,7 +35,7 @@ class EngineClient(Protocol):
 
 
     def generate(
     def generate(
         self,
         self,
-        inputs: PromptInputs,
+        prompt: PromptType,
         sampling_params: SamplingParams,
         sampling_params: SamplingParams,
         request_id: str,
         request_id: str,
         lora_request: Optional[LoRARequest] = None,
         lora_request: Optional[LoRARequest] = None,
@@ -46,7 +46,7 @@ class EngineClient(Protocol):
 
 
     def encode(
     def encode(
         self,
         self,
-        inputs: PromptInputs,
+        prompt: PromptType,
         pooling_params: PoolingParams,
         pooling_params: PoolingParams,
         request_id: str,
         request_id: str,
         lora_request: Optional[LoRARequest] = None,
         lora_request: Optional[LoRARequest] = None,

+ 4 - 3
aphrodite/inputs/__init__.py

@@ -1,5 +1,5 @@
 from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
 from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
-                   LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt,
+                   LLMInputs, PromptType, SingletonPrompt, TextPrompt,
                    TokensPrompt, build_explicit_enc_dec_prompt,
                    TokensPrompt, build_explicit_enc_dec_prompt,
                    to_enc_dec_tuple_list, zip_enc_dec_prompts)
                    to_enc_dec_tuple_list, zip_enc_dec_prompts)
 from .registry import InputContext, InputRegistry
 from .registry import InputContext, InputRegistry
@@ -9,6 +9,7 @@ INPUT_REGISTRY = InputRegistry()
 The global :class:`~InputRegistry` which is used by
 The global :class:`~InputRegistry` which is used by
 :class:`~aphrodite.AphroditeEngine`
 :class:`~aphrodite.AphroditeEngine`
 to dispatch data processing according to the target model.
 to dispatch data processing according to the target model.
+
 See also:
 See also:
     :ref:`input_processing_pipeline`
     :ref:`input_processing_pipeline`
 """
 """
@@ -16,8 +17,8 @@ See also:
 __all__ = [
 __all__ = [
     "TextPrompt",
     "TextPrompt",
     "TokensPrompt",
     "TokensPrompt",
-    "PromptInputs",
-    "SingletonPromptInputs",
+    "PromptType",
+    "SingletonPrompt",
     "ExplicitEncoderDecoderPrompt",
     "ExplicitEncoderDecoderPrompt",
     "LLMInputs",
     "LLMInputs",
     "EncoderDecoderLLMInputs",
     "EncoderDecoderLLMInputs",

+ 25 - 22
aphrodite/inputs/data.py

@@ -33,31 +33,34 @@ class TokensPrompt(TypedDict):
     """
     """
 
 
 
 
-SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt]
+SingletonPrompt = Union[str, TextPrompt, TokensPrompt]
 """
 """
 Set of possible schemas for a single LLM input:
 Set of possible schemas for a single LLM input:
+
 - A text prompt (:class:`str` or :class:`TextPrompt`)
 - A text prompt (:class:`str` or :class:`TextPrompt`)
 - A tokenized prompt (:class:`TokensPrompt`)
 - A tokenized prompt (:class:`TokensPrompt`)
+
 Note that "singleton" is as opposed to a data structure
 Note that "singleton" is as opposed to a data structure
 which encapsulates multiple prompts, i.e. of the sort
 which encapsulates multiple prompts, i.e. of the sort
 which may be utilized for encoder/decoder models when
 which may be utilized for encoder/decoder models when
 the user desires to express both the encoder & decoder
 the user desires to express both the encoder & decoder
-prompts explicitly, i.e. ExplicitEncoderDecoderPrompt
-A prompt of type SingletonPromptInputs may be employed
+prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt`
+
+A prompt of type :class:`SingletonPromptType` may be employed
 as (1) input to a decoder-only model, (2) input to
 as (1) input to a decoder-only model, (2) input to
 the encoder of an encoder/decoder model, in the scenario
 the encoder of an encoder/decoder model, in the scenario
 where the decoder-prompt is not specified explicitly, or
 where the decoder-prompt is not specified explicitly, or
 (3) as a member of a larger data structure encapsulating
 (3) as a member of a larger data structure encapsulating
-more than one prompt, i.e. ExplicitEncoderDecoderPrompt
+more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt`
 """
 """
 
 
 _T1_co = TypeVar("_T1_co",
 _T1_co = TypeVar("_T1_co",
-                 bound=SingletonPromptInputs,
-                 default=SingletonPromptInputs,
+                 bound=SingletonPrompt,
+                 default=SingletonPrompt,
                  covariant=True)
                  covariant=True)
 _T2_co = TypeVar("_T2_co",
 _T2_co = TypeVar("_T2_co",
-                 bound=SingletonPromptInputs,
-                 default=SingletonPromptInputs,
+                 bound=SingletonPrompt,
+                 default=SingletonPrompt,
                  covariant=True)
                  covariant=True)
 
 
 
 
@@ -66,16 +69,19 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
     """Represents an encoder/decoder model input prompt,
     """Represents an encoder/decoder model input prompt,
     comprising an explicit encoder prompt and a 
     comprising an explicit encoder prompt and a 
     decoder prompt.
     decoder prompt.
+
     The encoder and decoder prompts, respectively,
     The encoder and decoder prompts, respectively,
     may formatted according to any of the
     may formatted according to any of the
-    SingletonPromptInputs schemas, and are not
+    :class:`SingletonPromptType` schemas, and are not
     required to have the same schema.
     required to have the same schema.
+
     Only the encoder prompt may have multi-modal data.
     Only the encoder prompt may have multi-modal data.
-    Note that an ExplicitEncoderDecoderPrompt may not
+
+    Note that an :class:`ExplicitEncoderDecoderPrompt` may not
     be used as an input to a decoder-only model,
     be used as an input to a decoder-only model,
     and that the `encoder_prompt` and `decoder_prompt`
     and that the `encoder_prompt` and `decoder_prompt`
-    fields of this data structure may not themselves
-    must be SingletonPromptInputs instances.
+    fields of this data structure themselves must be
+    :class:`SingletonPromptType` instances.
     """
     """
 
 
     encoder_prompt: _T1_co
     encoder_prompt: _T1_co
@@ -83,10 +89,11 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
     decoder_prompt: Optional[_T2_co]
     decoder_prompt: Optional[_T2_co]
 
 
 
 
-PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt]
+PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt]
 """
 """
 Set of possible schemas for an LLM input, including
 Set of possible schemas for an LLM input, including
 both decoder-only and encoder/decoder input types:
 both decoder-only and encoder/decoder input types:
+
 - A text prompt (:class:`str` or :class:`TextPrompt`)
 - A text prompt (:class:`str` or :class:`TextPrompt`)
 - A tokenized prompt (:class:`TokensPrompt`)
 - A tokenized prompt (:class:`TokensPrompt`)
 - A single data structure containing both an encoder and a decoder prompt
 - A single data structure containing both an encoder and a decoder prompt
@@ -96,12 +103,11 @@ both decoder-only and encoder/decoder input types:
 
 
 class LLMInputs(TypedDict):
 class LLMInputs(TypedDict):
     """
     """
-    The inputs in :class:`~aphrodite.AphroditeEngine` before they are
+    The inputs in :class:`~vllm.LLMEngine` before they are
     passed to the model executor.
     passed to the model executor.
 
 
     This specifies the data required for decoder-only models.
     This specifies the data required for decoder-only models.
     """
     """
-
     prompt_token_ids: List[int]
     prompt_token_ids: List[int]
     """The token IDs of the prompt."""
     """The token IDs of the prompt."""
 
 
@@ -119,8 +125,9 @@ class LLMInputs(TypedDict):
 
 
 class EncoderDecoderLLMInputs(LLMInputs):
 class EncoderDecoderLLMInputs(LLMInputs):
     """
     """
-    The inputs in :class:`~aphrodite.AphroditeEngine` before they are
+    The inputs in :class:`~vllm.LLMEngine` before they are
     passed to the model executor.
     passed to the model executor.
+
     This specifies the required data for encoder-decoder models.
     This specifies the required data for encoder-decoder models.
     """
     """
     encoder_prompt_token_ids: List[int]
     encoder_prompt_token_ids: List[int]
@@ -133,12 +140,8 @@ class EncoderDecoderLLMInputs(LLMInputs):
     """
     """
 
 
 
 
-_T1 = TypeVar("_T1",
-              bound=SingletonPromptInputs,
-              default=SingletonPromptInputs)
-_T2 = TypeVar("_T2",
-              bound=SingletonPromptInputs,
-              default=SingletonPromptInputs)
+_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt)
+_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)
 
 
 
 
 def build_explicit_enc_dec_prompt(
 def build_explicit_enc_dec_prompt(

+ 12 - 11
aphrodite/inputs/parse.py

@@ -5,7 +5,7 @@ from typing_extensions import TypeIs
 from aphrodite.common.utils import is_list_of
 from aphrodite.common.utils import is_list_of
 
 
 from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
 from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
-                   LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt,
+                   LLMInputs, PromptType, SingletonPrompt, TextPrompt,
                    TokensPrompt)
                    TokensPrompt)
 
 
 
 
@@ -81,22 +81,23 @@ class ParsedTokensPrompt(TypedDict):
 
 
 
 
 def parse_singleton_prompt(
 def parse_singleton_prompt(
-    inputs: SingletonPromptInputs,
+    prompt: SingletonPrompt,
 ) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]:
 ) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]:
-    if isinstance(inputs, str):
-        return ParsedStrPrompt(type="str", content=inputs)
-    elif isinstance(inputs, dict):
-        if "prompt_token_ids" in inputs:
+    if isinstance(prompt, str):
+        return ParsedStrPrompt(type="str", content=prompt)
+    elif isinstance(prompt, dict):
+        if "prompt_token_ids" in prompt:
             return ParsedTokensPrompt(type="tokens",
             return ParsedTokensPrompt(type="tokens",
-                                      content=inputs)  # type: ignore
-        elif "prompt" in inputs:
-            return ParsedTextPrompt(type="text", content=inputs)
+                                      content=prompt)  # type: ignore
+        elif "prompt" in prompt:
+            return ParsedTextPrompt(type="text", content=prompt)
+
     raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt")
     raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt")
 
 
 
 
 def is_explicit_encoder_decoder_prompt(
 def is_explicit_encoder_decoder_prompt(
-        inputs: PromptInputs) -> TypeIs[ExplicitEncoderDecoderPrompt]:
-    return isinstance(inputs, dict) and "encoder_prompt" in inputs
+        prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
+    return isinstance(prompt, dict) and "encoder_prompt" in prompt
 
 
 
 
 def is_valid_encoder_decoder_llm_inputs(
 def is_valid_encoder_decoder_llm_inputs(

+ 172 - 128
aphrodite/inputs/preprocess.py

@@ -9,102 +9,98 @@ from aphrodite.lora.request import LoRARequest
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 from aphrodite.transformers_utils.tokenizer_group import BaseTokenizerGroup
 from aphrodite.transformers_utils.tokenizer_group import BaseTokenizerGroup
 
 
-from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
-                   SingletonPromptInputs)
+from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptType,
+                   SingletonPrompt)
 from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
 from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from aphrodite.multimodal import MultiModalDataDict
     from aphrodite.multimodal import MultiModalDataDict
 
 
 
 
-PromptComponents = Tuple[
-    Optional[str], List[int], Optional["MultiModalDataDict"]
-]
-DecoderPromptComponents = Tuple[
-    Optional[str], Optional[List[int]], Optional["MultiModalDataDict"]
-]
+PromptComponents = Tuple[Optional[str], List[int],
+                         Optional["MultiModalDataDict"]]
+DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
+                                Optional["MultiModalDataDict"]]
 
 
 
 
 class InputPreprocessor:
 class InputPreprocessor:
+
     def __init__(
     def __init__(
         self,
         self,
         model_config: ModelConfig,
         model_config: ModelConfig,
         tokenizer: Optional[BaseTokenizerGroup],
         tokenizer: Optional[BaseTokenizerGroup],
     ) -> None:
     ) -> None:
         super().__init__()
         super().__init__()
+
         self.model_config = model_config
         self.model_config = model_config
         self.tokenizer = tokenizer
         self.tokenizer = tokenizer
 
 
     def get_tokenizer_group(self) -> BaseTokenizerGroup:
     def get_tokenizer_group(self) -> BaseTokenizerGroup:
         if self.tokenizer is None:
         if self.tokenizer is None:
-            raise ValueError(
-                "You cannot pass text prompts when "
-                "`skip_tokenizer_init` is True"
-            )
+            raise ValueError("You cannot pass text prompts when "
+                             "`skip_tokenizer_init` is True")
+
         return self.tokenizer
         return self.tokenizer
 
 
-    def get_bos_token_id(
-        self, lora_request: Optional[LoRARequest] = None
-    ) -> Optional[int]:
+    def get_bos_token_id(self,
+                         lora_request: Optional[LoRARequest] = None
+                         ) -> Optional[int]:
         if self.tokenizer is None:
         if self.tokenizer is None:
-            logger.warning(
-                "Using None for BOS token id because tokenizer "
-                "is not initialized"
-            )
+            logger.warning("Using None for BOS token id because tokenizer "
+                           "is not initialized")
             return None
             return None
+
         return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id
         return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id
 
 
-    def get_eos_token_id(
-        self, lora_request: Optional[LoRARequest] = None
-    ) -> Optional[int]:
+    def get_eos_token_id(self,
+                         lora_request: Optional[LoRARequest] = None
+                         ) -> Optional[int]:
         if self.tokenizer is None:
         if self.tokenizer is None:
-            logger.warning(
-                "Using None for EOS token id because tokenizer "
-                "is not initialized"
-            )
+            logger.warning("Using None for EOS token id because tokenizer "
+                           "is not initialized")
             return None
             return None
+
         return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
         return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
 
 
     def get_decoder_start_token_id(self) -> Optional[int]:
     def get_decoder_start_token_id(self) -> Optional[int]:
-        """
+        '''
         Obtain the decoder start token id employed by an encoder/decoder
         Obtain the decoder start token id employed by an encoder/decoder
         model. Returns None for non-encoder/decoder models or if the
         model. Returns None for non-encoder/decoder models or if the
         model config is unavailable.
         model config is unavailable.
-        """
+        '''
+
         if not self.is_encoder_decoder_model():
         if not self.is_encoder_decoder_model():
-            logger.warning(
-                "Using None for decoder start token id because "
-                "this is not an encoder/decoder model."
-            )
+            logger.warning("Using None for decoder start token id because "
+                           "this is not an encoder/decoder model.")
             return None
             return None
-        if self.model_config is None or self.model_config.hf_config is None:
-            logger.warning(
-                "Using None for decoder start token id because "
-                "model config is not available."
-            )
+
+        if (self.model_config is None or self.model_config.hf_config is None):
+            logger.warning("Using None for decoder start token id because "
+                           "model config is not available.")
             return None
             return None
-        dec_start_token_id = getattr(
-            self.model_config.hf_config, "decoder_start_token_id", None
-        )
+
+        dec_start_token_id = getattr(self.model_config.hf_config,
+                                     'decoder_start_token_id', None)
         if dec_start_token_id is None:
         if dec_start_token_id is None:
-            logger.warning(
-                "Falling back on <BOS> for decoder start token id "
-                "because decoder start token id is not available."
-            )
+            logger.warning("Falling back on <BOS> for decoder start token id "
+                           "because decoder start token id is not available.")
             dec_start_token_id = self.get_bos_token_id()
             dec_start_token_id = self.get_bos_token_id()
+
         return dec_start_token_id
         return dec_start_token_id
 
 
     def _get_default_enc_dec_decoder_prompt(self) -> List[int]:
     def _get_default_enc_dec_decoder_prompt(self) -> List[int]:
-        """
+        '''
         Specifically for encoder/decoder models:
         Specifically for encoder/decoder models:
         generate a default decoder prompt for when
         generate a default decoder prompt for when
         the user specifies only the encoder prompt.
         the user specifies only the encoder prompt.
+
         Encoder/decoder models utilize the decoder
         Encoder/decoder models utilize the decoder
         prompt in different ways; as new models are
         prompt in different ways; as new models are
         added, it is intended that this function
         added, it is intended that this function
         will be extended to produce differing
         will be extended to produce differing
         default decoder prompts, depending on the
         default decoder prompts, depending on the
         model variety.
         model variety.
+
         Absent a special case, the default behavior
         Absent a special case, the default behavior
         of this method is to mirror the behavior of
         of this method is to mirror the behavior of
         the HuggingFace (HF) GenerationMixin for a None
         the HuggingFace (HF) GenerationMixin for a None
@@ -112,14 +108,18 @@ class InputPreprocessor:
         setting to force the first decoded token to be <BOS>.
         setting to force the first decoded token to be <BOS>.
         Here, this behavior is approximated by having the
         Here, this behavior is approximated by having the
         "default" decoder prompt be <BOS>.
         "default" decoder prompt be <BOS>.
+
         However, it is possible that in the future
         However, it is possible that in the future
-        other models may have different or more
+        other models may have different or more 
         complex logic for the default decoder prompt.
         complex logic for the default decoder prompt.
         This motivates having a special helper method
         This motivates having a special helper method
         for default decoder prompts.
         for default decoder prompts.
+
         Returns:
         Returns:
+
         * prompt_token_ids
         * prompt_token_ids
-        """
+        '''
+
         bos_token_id = self.get_bos_token_id()
         bos_token_id = self.get_bos_token_id()
         assert bos_token_id is not None
         assert bos_token_id is not None
         return [bos_token_id]
         return [bos_token_id]
@@ -130,27 +130,36 @@ class InputPreprocessor:
     ) -> List[int]:
     ) -> List[int]:
         """
         """
         Prepares `decoder_input_ids` for generation with encoder-decoder models.
         Prepares `decoder_input_ids` for generation with encoder-decoder models.
+
         Based on
         Based on
+
         https://github.com/huggingface/transformers/blob/
         https://github.com/huggingface/transformers/blob/
         4037a2b5b1278736e566aec12e169100275545ea/
         4037a2b5b1278736e566aec12e169100275545ea/
         src/transformers/generation/utils.py
         src/transformers/generation/utils.py
+
         specifically GenerationMixin._prepare_decoder_input_ids_for_generation()
         specifically GenerationMixin._prepare_decoder_input_ids_for_generation()
+
         Arguments:
         Arguments:
+
         * decoder_input_ids: input token ids to preprocess
         * decoder_input_ids: input token ids to preprocess
+
         Returns:
         Returns:
+
         * Processed token list
         * Processed token list
         """
         """
+
         decoder_start_token_id = self.get_decoder_start_token_id()
         decoder_start_token_id = self.get_decoder_start_token_id()
         assert decoder_start_token_id is not None
         assert decoder_start_token_id is not None
+
         if decoder_input_ids is None:
         if decoder_input_ids is None:
             # no decoder prompt input ->
             # no decoder prompt input ->
             # use decoder_start_token_id as decoder_input_ids
             # use decoder_start_token_id as decoder_input_ids
             decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
             decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
-        if (
-            len(decoder_input_ids) == 0
-            or decoder_input_ids[0] != decoder_start_token_id
-        ):
+
+        if (len(decoder_input_ids) == 0
+                or decoder_input_ids[0] != decoder_start_token_id):
             decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
             decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
+
         return decoder_input_ids
         return decoder_input_ids
 
 
     def _apply_prompt_adapter(
     def _apply_prompt_adapter(
@@ -161,8 +170,8 @@ class InputPreprocessor:
         if prompt_adapter_request:
         if prompt_adapter_request:
             prompt_token_ids = (
             prompt_token_ids = (
                 [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
                 [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
-                + prompt_token_ids
-            )
+                + prompt_token_ids)
+
         return prompt_token_ids
         return prompt_token_ids
 
 
     def _tokenize_prompt(
     def _tokenize_prompt(
@@ -176,9 +185,10 @@ class InputPreprocessor:
         corresponding token IDs.
         corresponding token IDs.
         """
         """
         tokenizer = self.get_tokenizer_group()
         tokenizer = self.get_tokenizer_group()
-        return tokenizer.encode(
-            request_id=request_id, prompt=prompt, lora_request=lora_request
-        )
+
+        return tokenizer.encode(request_id=request_id,
+                                prompt=prompt,
+                                lora_request=lora_request)
 
 
     async def _tokenize_prompt_async(
     async def _tokenize_prompt_async(
         self,
         self,
@@ -188,83 +198,93 @@ class InputPreprocessor:
     ) -> List[int]:
     ) -> List[int]:
         """Async version of :meth:`_tokenize_prompt`."""
         """Async version of :meth:`_tokenize_prompt`."""
         tokenizer = self.get_tokenizer_group()
         tokenizer = self.get_tokenizer_group()
-        return await tokenizer.encode_async(
-            request_id=request_id, prompt=prompt, lora_request=lora_request
-        )
+
+        return await tokenizer.encode_async(request_id=request_id,
+                                            prompt=prompt,
+                                            lora_request=lora_request)
 
 
     def _extract_prompt_components(
     def _extract_prompt_components(
         self,
         self,
-        inputs: SingletonPromptInputs,
+        prompt: SingletonPrompt,
         request_id: str,
         request_id: str,
         lora_request: Optional[LoRARequest] = None,
         lora_request: Optional[LoRARequest] = None,
     ) -> PromptComponents:
     ) -> PromptComponents:
-        """
+        '''
         Extract the components of any single encoder or decoder input prompt.
         Extract the components of any single encoder or decoder input prompt.
+
         Arguments:
         Arguments:
+
         * request_id
         * request_id
-        * inputs: single encoder or decoder input prompt
+        * prompt: single encoder or decoder input prompt
         * lora_request: this is only valid for decoder prompts
         * lora_request: this is only valid for decoder prompts
+
         Returns:
         Returns:
+
         * prompt
         * prompt
         * prompt_token_ids
         * prompt_token_ids
         * multi_modal_data
         * multi_modal_data
-        """
-        parsed = parse_singleton_prompt(inputs)
+        '''
+
+        parsed = parse_singleton_prompt(prompt)
+
         if parsed["type"] == "str":
         if parsed["type"] == "str":
-            prompt = parsed["content"]
+            prompt_text = parsed["content"]
             prompt_token_ids = self._tokenize_prompt(
             prompt_token_ids = self._tokenize_prompt(
-                prompt,
+                prompt_text,
                 request_id=request_id,
                 request_id=request_id,
                 lora_request=lora_request,
                 lora_request=lora_request,
             )
             )
             multi_modal_data = None
             multi_modal_data = None
         elif parsed["type"] == "tokens":
         elif parsed["type"] == "tokens":
-            prompt = None
+            prompt_text = None
             prompt_token_ids = parsed["content"]["prompt_token_ids"]
             prompt_token_ids = parsed["content"]["prompt_token_ids"]
             multi_modal_data = parsed["content"].get("multi_modal_data")
             multi_modal_data = parsed["content"].get("multi_modal_data")
         elif parsed["type"] == "text":
         elif parsed["type"] == "text":
-            prompt = parsed["content"]["prompt"]
+            prompt_text = parsed["content"]["prompt"]
             prompt_token_ids = self._tokenize_prompt(
             prompt_token_ids = self._tokenize_prompt(
-                prompt,
+                prompt_text,
                 request_id=request_id,
                 request_id=request_id,
                 lora_request=lora_request,
                 lora_request=lora_request,
             )
             )
             multi_modal_data = parsed["content"].get("multi_modal_data")
             multi_modal_data = parsed["content"].get("multi_modal_data")
         else:
         else:
             assert_never(parsed)
             assert_never(parsed)
-        return prompt, prompt_token_ids, multi_modal_data
+
+        return prompt_text, prompt_token_ids, multi_modal_data
 
 
     async def _extract_prompt_components_async(
     async def _extract_prompt_components_async(
         self,
         self,
-        inputs: SingletonPromptInputs,
+        prompt: SingletonPrompt,
         request_id: str,
         request_id: str,
         lora_request: Optional[LoRARequest] = None,
         lora_request: Optional[LoRARequest] = None,
     ) -> PromptComponents:
     ) -> PromptComponents:
         """Async version of :meth:`_extract_prompt_components`."""
         """Async version of :meth:`_extract_prompt_components`."""
-        parsed = parse_singleton_prompt(inputs)
+        parsed = parse_singleton_prompt(prompt)
+
         if parsed["type"] == "str":
         if parsed["type"] == "str":
-            prompt = parsed["content"]
+            prompt_text = parsed["content"]
             prompt_token_ids = await self._tokenize_prompt_async(
             prompt_token_ids = await self._tokenize_prompt_async(
-                prompt,
+                prompt_text,
                 request_id=request_id,
                 request_id=request_id,
                 lora_request=lora_request,
                 lora_request=lora_request,
             )
             )
             multi_modal_data = None
             multi_modal_data = None
         elif parsed["type"] == "tokens":
         elif parsed["type"] == "tokens":
-            prompt = None
+            prompt_text = None
             prompt_token_ids = parsed["content"]["prompt_token_ids"]
             prompt_token_ids = parsed["content"]["prompt_token_ids"]
             multi_modal_data = parsed["content"].get("multi_modal_data")
             multi_modal_data = parsed["content"].get("multi_modal_data")
         elif parsed["type"] == "text":
         elif parsed["type"] == "text":
-            prompt = parsed["content"]["prompt"]
+            prompt_text = parsed["content"]["prompt"]
             prompt_token_ids = await self._tokenize_prompt_async(
             prompt_token_ids = await self._tokenize_prompt_async(
-                prompt,
+                prompt_text,
                 request_id=request_id,
                 request_id=request_id,
                 lora_request=lora_request,
                 lora_request=lora_request,
             )
             )
             multi_modal_data = parsed["content"].get("multi_modal_data")
             multi_modal_data = parsed["content"].get("multi_modal_data")
         else:
         else:
             assert_never(parsed)
             assert_never(parsed)
-        return prompt, prompt_token_ids, multi_modal_data
+
+        return prompt_text, prompt_token_ids, multi_modal_data
 
 
     def _build_enc_dec_llm_inputs(
     def _build_enc_dec_llm_inputs(
         self,
         self,
@@ -273,13 +293,14 @@ class InputPreprocessor:
     ) -> EncoderDecoderLLMInputs:
     ) -> EncoderDecoderLLMInputs:
         encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
         encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
         decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps
         decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps
+
         if encoder_mm_data is not None or decoder_mm_data is not None:
         if encoder_mm_data is not None or decoder_mm_data is not None:
-            raise ValueError(
-                "Multi-modal encoder-decoder models are " "not supported yet"
-            )
-        decoder_prompt_ids = self._prepare_decoder_input_ids_for_generation(
-            decoder_prompt_ids
-        )
+            raise ValueError("Multi-modal encoder-decoder models are "
+                             "not supported yet")
+
+        decoder_prompt_ids = (
+            self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids))
+
         return EncoderDecoderLLMInputs(
         return EncoderDecoderLLMInputs(
             prompt_token_ids=decoder_prompt_ids,
             prompt_token_ids=decoder_prompt_ids,
             prompt=decoder_prompt,
             prompt=decoder_prompt,
@@ -289,43 +310,52 @@ class InputPreprocessor:
 
 
     def _process_encoder_decoder_prompt(
     def _process_encoder_decoder_prompt(
         self,
         self,
-        inputs: PromptInputs,
+        prompt: PromptType,
         request_id: str,
         request_id: str,
     ) -> EncoderDecoderLLMInputs:
     ) -> EncoderDecoderLLMInputs:
-        """
+        '''
         For encoder/decoder models only:
         For encoder/decoder models only:
         Process an input prompt into an
         Process an input prompt into an
         :class:`EncoderDecoderLLMInputs` instance.
         :class:`EncoderDecoderLLMInputs` instance.
+
         There are two types of input prompts:
         There are two types of input prompts:
         singleton prompts which carry only the
         singleton prompts which carry only the
         encoder prompt, and explicit encoder/decoder
         encoder prompt, and explicit encoder/decoder
         prompts which carry both the encoder and the
         prompts which carry both the encoder and the
         decoder prompts as member variables.
         decoder prompts as member variables.
+
         This function handles the following scenarios:
         This function handles the following scenarios:
         * Singleton encoder prompt: extract encoder prompt
         * Singleton encoder prompt: extract encoder prompt
           token ids & infer default decoder prompt token ids
           token ids & infer default decoder prompt token ids
         * Explicit encoder/decoder prompt: extract encoder
         * Explicit encoder/decoder prompt: extract encoder
           and decoder prompt token ids
           and decoder prompt token ids
+
         Note that for Explicit encoder/decoder prompts,
         Note that for Explicit encoder/decoder prompts,
         each sub-prompt (encoder or decoder prompt) can
         each sub-prompt (encoder or decoder prompt) can
         have any possible singleton type; thus this
         have any possible singleton type; thus this
         method relies on helper functions to obtain
         method relies on helper functions to obtain
         token ids for the sub-prompts.
         token ids for the sub-prompts.
-
+        
         Arguments:
         Arguments:
-        * inputs: an input prompt
+
+        * prompt: an input prompt
         * request_id
         * request_id
+
         Returns:
         Returns:
+
         * :class:`EncoderDecoderLLMInputs` instance
         * :class:`EncoderDecoderLLMInputs` instance
-        """
+        '''
+
         encoder_comps: PromptComponents
         encoder_comps: PromptComponents
         decoder_comps: DecoderPromptComponents
         decoder_comps: DecoderPromptComponents
-        if is_explicit_encoder_decoder_prompt(inputs):
+
+        if is_explicit_encoder_decoder_prompt(prompt):
             encoder_comps = self._extract_prompt_components(
             encoder_comps = self._extract_prompt_components(
-                inputs["encoder_prompt"],
+                prompt["encoder_prompt"],
                 request_id=request_id,
                 request_id=request_id,
             )
             )
-            if (decoder_input := inputs["decoder_prompt"]) is None:
+
+            if (decoder_input := prompt["decoder_prompt"]) is None:
                 decoder_comps = None, None, None
                 decoder_comps = None, None, None
             else:
             else:
                 decoder_comps = self._extract_prompt_components(
                 decoder_comps = self._extract_prompt_components(
@@ -334,26 +364,30 @@ class InputPreprocessor:
                 )
                 )
         else:
         else:
             encoder_comps = self._extract_prompt_components(
             encoder_comps = self._extract_prompt_components(
-                inputs,
+                prompt,
                 request_id=request_id,
                 request_id=request_id,
             )
             )
+
             decoder_comps = None, None, None
             decoder_comps = None, None, None
+
         return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
         return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
 
 
     async def _process_encoder_decoder_prompt_async(
     async def _process_encoder_decoder_prompt_async(
         self,
         self,
-        inputs: PromptInputs,
+        prompt: PromptType,
         request_id: str,
         request_id: str,
     ) -> EncoderDecoderLLMInputs:
     ) -> EncoderDecoderLLMInputs:
         """Async version of :meth:`_process_encoder_decoder_prompt`."""
         """Async version of :meth:`_process_encoder_decoder_prompt`."""
         encoder_comps: PromptComponents
         encoder_comps: PromptComponents
         decoder_comps: DecoderPromptComponents
         decoder_comps: DecoderPromptComponents
-        if is_explicit_encoder_decoder_prompt(inputs):
+
+        if is_explicit_encoder_decoder_prompt(prompt):
             encoder_task = self._extract_prompt_components_async(
             encoder_task = self._extract_prompt_components_async(
-                inputs["encoder_prompt"],
+                prompt["encoder_prompt"],
                 request_id=request_id,
                 request_id=request_id,
             )
             )
-            if (decoder_input := inputs["decoder_prompt"]) is None:
+
+            if (decoder_input := prompt["decoder_prompt"]) is None:
                 encoder_comps = await encoder_task
                 encoder_comps = await encoder_task
                 decoder_comps = None, None, None
                 decoder_comps = None, None, None
             else:
             else:
@@ -361,15 +395,17 @@ class InputPreprocessor:
                     decoder_input,
                     decoder_input,
                     request_id=request_id,
                     request_id=request_id,
                 )
                 )
+
                 encoder_comps, decoder_comps = await asyncio.gather(
                 encoder_comps, decoder_comps = await asyncio.gather(
-                    encoder_task, decoder_task
-                )
+                    encoder_task, decoder_task)
         else:
         else:
             encoder_comps = await self._extract_prompt_components_async(
             encoder_comps = await self._extract_prompt_components_async(
-                inputs,
+                prompt,
                 request_id=request_id,
                 request_id=request_id,
             )
             )
+
             decoder_comps = None, None, None
             decoder_comps = None, None, None
+
         return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
         return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
 
 
     def _build_decoder_only_llm_inputs(
     def _build_decoder_only_llm_inputs(
@@ -378,38 +414,43 @@ class InputPreprocessor:
         prompt_adapter_request: Optional[PromptAdapterRequest],
         prompt_adapter_request: Optional[PromptAdapterRequest],
     ) -> LLMInputs:
     ) -> LLMInputs:
         prompt, prompt_token_ids, multi_modal_data = prompt_comps
         prompt, prompt_token_ids, multi_modal_data = prompt_comps
+
         prompt_token_ids = self._apply_prompt_adapter(
         prompt_token_ids = self._apply_prompt_adapter(
-            prompt_token_ids, prompt_adapter_request=prompt_adapter_request
-        )
-        return LLMInputs(
-            prompt_token_ids=prompt_token_ids,
-            prompt=prompt,
-            multi_modal_data=multi_modal_data,
-        )
+            prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
+
+        return LLMInputs(prompt_token_ids=prompt_token_ids,
+                         prompt=prompt,
+                         multi_modal_data=multi_modal_data)
 
 
     def _process_decoder_only_prompt(
     def _process_decoder_only_prompt(
         self,
         self,
-        inputs: SingletonPromptInputs,
+        prompt: SingletonPrompt,
         request_id: str,
         request_id: str,
         lora_request: Optional[LoRARequest] = None,
         lora_request: Optional[LoRARequest] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
     ) -> LLMInputs:
     ) -> LLMInputs:
-        """
+        '''
         For decoder-only models:
         For decoder-only models:
         Process an input prompt into an :class:`LLMInputs` instance.
         Process an input prompt into an :class:`LLMInputs` instance.
+
         Arguments:
         Arguments:
-        * inputs: input prompt
+
+        * prompt: input prompt
         * request_id
         * request_id
         * lora_request
         * lora_request
         * prompt_adapter_request
         * prompt_adapter_request
+
         Returns:
         Returns:
+
         * :class:`LLMInputs` instance
         * :class:`LLMInputs` instance
-        """
+        '''
+
         prompt_comps = self._extract_prompt_components(
         prompt_comps = self._extract_prompt_components(
-            inputs,
+            prompt,
             request_id=request_id,
             request_id=request_id,
             lora_request=lora_request,
             lora_request=lora_request,
         )
         )
+
         return self._build_decoder_only_llm_inputs(
         return self._build_decoder_only_llm_inputs(
             prompt_comps,
             prompt_comps,
             prompt_adapter_request=prompt_adapter_request,
             prompt_adapter_request=prompt_adapter_request,
@@ -417,17 +458,18 @@ class InputPreprocessor:
 
 
     async def _process_decoder_only_prompt_async(
     async def _process_decoder_only_prompt_async(
         self,
         self,
-        inputs: SingletonPromptInputs,
+        prompt: SingletonPrompt,
         request_id: str,
         request_id: str,
         lora_request: Optional[LoRARequest] = None,
         lora_request: Optional[LoRARequest] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
     ) -> LLMInputs:
     ) -> LLMInputs:
         """Async version of :meth:`_process_decoder_only_prompt`."""
         """Async version of :meth:`_process_decoder_only_prompt`."""
         prompt_comps = await self._extract_prompt_components_async(
         prompt_comps = await self._extract_prompt_components_async(
-            inputs,
+            prompt,
             request_id=request_id,
             request_id=request_id,
             lora_request=lora_request,
             lora_request=lora_request,
         )
         )
+
         return self._build_decoder_only_llm_inputs(
         return self._build_decoder_only_llm_inputs(
             prompt_comps,
             prompt_comps,
             prompt_adapter_request=prompt_adapter_request,
             prompt_adapter_request=prompt_adapter_request,
@@ -435,7 +477,7 @@ class InputPreprocessor:
 
 
     def preprocess(
     def preprocess(
         self,
         self,
-        inputs: PromptInputs,
+        prompt: PromptType,
         request_id: str,
         request_id: str,
         lora_request: Optional[LoRARequest] = None,
         lora_request: Optional[LoRARequest] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@@ -445,16 +487,17 @@ class InputPreprocessor:
             # Encoder-decoder model requires special mapping of
             # Encoder-decoder model requires special mapping of
             # input prompts to encoder & decoder
             # input prompts to encoder & decoder
             return self._process_encoder_decoder_prompt(
             return self._process_encoder_decoder_prompt(
-                inputs,
+                prompt,
                 request_id=request_id,
                 request_id=request_id,
             )
             )
-        if is_explicit_encoder_decoder_prompt(inputs):
-            raise ValueError(
-                "Cannot pass encoder-decoder prompt " "to decoder-only models"
-            )
+
+        if is_explicit_encoder_decoder_prompt(prompt):
+            raise ValueError("Cannot pass encoder-decoder prompt "
+                             "to decoder-only models")
+
         # Decoder-only operation
         # Decoder-only operation
         return self._process_decoder_only_prompt(
         return self._process_decoder_only_prompt(
-            inputs,
+            prompt,
             request_id=request_id,
             request_id=request_id,
             lora_request=lora_request,
             lora_request=lora_request,
             prompt_adapter_request=prompt_adapter_request,
             prompt_adapter_request=prompt_adapter_request,
@@ -462,7 +505,7 @@ class InputPreprocessor:
 
 
     async def preprocess_async(
     async def preprocess_async(
         self,
         self,
-        inputs: PromptInputs,
+        prompt: PromptType,
         request_id: str,
         request_id: str,
         lora_request: Optional[LoRARequest] = None,
         lora_request: Optional[LoRARequest] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@@ -472,16 +515,17 @@ class InputPreprocessor:
             # Encoder-decoder model requires special mapping of
             # Encoder-decoder model requires special mapping of
             # input prompts to encoder & decoder
             # input prompts to encoder & decoder
             return await self._process_encoder_decoder_prompt_async(
             return await self._process_encoder_decoder_prompt_async(
-                inputs,
+                prompt,
                 request_id=request_id,
                 request_id=request_id,
             )
             )
-        if is_explicit_encoder_decoder_prompt(inputs):
-            raise ValueError(
-                "Cannot pass encoder-decoder prompt " "to decoder-only models"
-            )
+
+        if is_explicit_encoder_decoder_prompt(prompt):
+            raise ValueError("Cannot pass encoder-decoder prompt "
+                             "to decoder-only models")
+
         # Decoder-only operation
         # Decoder-only operation
         return await self._process_decoder_only_prompt_async(
         return await self._process_decoder_only_prompt_async(
-            inputs,
+            prompt,
             request_id=request_id,
             request_id=request_id,
             lora_request=lora_request,
             lora_request=lora_request,
             prompt_adapter_request=prompt_adapter_request,
             prompt_adapter_request=prompt_adapter_request,

+ 2 - 2
pyproject.toml

@@ -46,8 +46,8 @@ ignore = [
 ]
 ]
 
 
 [tool.codespell]
 [tool.codespell]
-ignore-words-list = "dout, te, indicies, ist, subtile, wit, whit, beseige, devlop, serie, vor, holliday, discus, tennant, carin, parma, mor, slac, revered, chanel, sammon, nast, shepard, insead, bloc, clea"
-skip = "./tests/,./aphrodite/endpoints/kobold/klite.embd,./kernels/,./tests/benchmarks/sonnet.txt,./docs/,./tests/lora/data/long_context_test_data.py"
+ignore-words-list = "dout, te, indicies, ist, subtile, wit, whit, beseige, devlop, serie, vor, holliday, discus, tennant, carin, parma, mor, slac, revered, chanel, sammon, nast, shepard, insead, bloc, clea, appy, ser, fter"
+skip = "./tests/,./aphrodite/endpoints/kobold/klite.embd,./kernels/,./tests/benchmarks/sonnet.txt,./docs/,./tests/lora/data/long_context_test_data.py,./tests/models/fixtures/"
 
 
 [tool.isort]
 [tool.isort]
 use_parentheses = true
 use_parentheses = true

+ 9 - 15
tests/benchmarks/engine/latency.py

@@ -12,7 +12,7 @@ from tqdm import tqdm
 from aphrodite import LLM, SamplingParams
 from aphrodite import LLM, SamplingParams
 from aphrodite.common.utils import FlexibleArgumentParser
 from aphrodite.common.utils import FlexibleArgumentParser
 from aphrodite.engine.args_tools import DEVICE_OPTIONS, EngineArgs
 from aphrodite.engine.args_tools import DEVICE_OPTIONS, EngineArgs
-from aphrodite.inputs import PromptInputs
+from aphrodite.inputs import PromptType
 from aphrodite.quantization import QUANTIZATION_METHODS
 from aphrodite.quantization import QUANTIZATION_METHODS
 
 
 
 
@@ -27,8 +27,6 @@ def main(args: argparse.Namespace):
         num_speculative_tokens=args.num_speculative_tokens,
         num_speculative_tokens=args.num_speculative_tokens,
         speculative_draft_tensor_parallel_size=\
         speculative_draft_tensor_parallel_size=\
             args.speculative_draft_tensor_parallel_size,
             args.speculative_draft_tensor_parallel_size,
-        ngram_prompt_lookup_max=args.ngram_prompt_lookup_max,
-        ngram_prompt_lookup_min=args.ngram_prompt_lookup_min,
         tokenizer=args.tokenizer,
         tokenizer=args.tokenizer,
         quantization=args.quantization,
         quantization=args.quantization,
         tensor_parallel_size=args.tensor_parallel_size,
         tensor_parallel_size=args.tensor_parallel_size,
@@ -62,7 +60,7 @@ def main(args: argparse.Namespace):
     dummy_prompt_token_ids = np.random.randint(10000,
     dummy_prompt_token_ids = np.random.randint(10000,
                                                size=(args.batch_size,
                                                size=(args.batch_size,
                                                      args.input_len))
                                                      args.input_len))
-    dummy_inputs: List[PromptInputs] = [{
+    dummy_prompts: List[PromptType] = [{
         "prompt_token_ids": batch
         "prompt_token_ids": batch
     } for batch in dummy_prompt_token_ids.tolist()]
     } for batch in dummy_prompt_token_ids.tolist()]
 
 
@@ -75,13 +73,13 @@ def main(args: argparse.Namespace):
                     ],
                     ],
                     on_trace_ready=torch.profiler.tensorboard_trace_handler(
                     on_trace_ready=torch.profiler.tensorboard_trace_handler(
                         str(profile_dir))) as p:
                         str(profile_dir))) as p:
-                llm.generate(dummy_inputs,
+                llm.generate(dummy_prompts,
                              sampling_params=sampling_params,
                              sampling_params=sampling_params,
                              use_tqdm=False)
                              use_tqdm=False)
             print(p.key_averages())
             print(p.key_averages())
         else:
         else:
             start_time = time.perf_counter()
             start_time = time.perf_counter()
-            llm.generate(dummy_inputs,
+            llm.generate(dummy_prompts,
                          sampling_params=sampling_params,
                          sampling_params=sampling_params,
                          use_tqdm=False)
                          use_tqdm=False)
             end_time = time.perf_counter()
             end_time = time.perf_counter()
@@ -135,8 +133,6 @@ if __name__ == '__main__':
                         '-spec-draft-tp',
                         '-spec-draft-tp',
                         type=int,
                         type=int,
                         default=None)
                         default=None)
-    parser.add_argument('--ngram-prompt-lookup-max', type=int, default=None)
-    parser.add_argument('--ngram-prompt-lookup-min', type=int, default=None)
     parser.add_argument('--tokenizer', type=str, default=None)
     parser.add_argument('--tokenizer', type=str, default=None)
     parser.add_argument('--quantization',
     parser.add_argument('--quantization',
                         '-q',
                         '-q',
@@ -208,13 +204,11 @@ if __name__ == '__main__':
         default=None,
         default=None,
         help=('path to save the pytorch profiler output. Can be visualized '
         help=('path to save the pytorch profiler output. Can be visualized '
               'with ui.perfetto.dev or Tensorboard.'))
               'with ui.perfetto.dev or Tensorboard.'))
-    parser.add_argument(
-        "--device",
-        type=str,
-        default="auto",
-        choices=DEVICE_OPTIONS,
-        help='device type for Aphrodite execution, supporting CUDA, OpenVINO '
-        'and CPU.')
+    parser.add_argument("--device",
+                        type=str,
+                        default="auto",
+                        choices=DEVICE_OPTIONS,
+                        help='device type for vLLM execution')
     parser.add_argument('--block-size',
     parser.add_argument('--block-size',
                         type=int,
                         type=int,
                         default=16,
                         default=16,

+ 2 - 1
tests/models/decoder_only/audio_language/test_ultravox.py

@@ -164,7 +164,8 @@ def run_multi_audio_test(
 def test_models(hf_runner, aphrodite_runner, audio, dtype: str, max_tokens: int,
 def test_models(hf_runner, aphrodite_runner, audio, dtype: str, max_tokens: int,
                 num_logprobs: int) -> None:
                 num_logprobs: int) -> None:
 
 
-    aphrodite_prompt = _get_prompt(1, "Describe the audio above.", APHRODITE_PLACEHOLDER)
+    aphrodite_prompt = _get_prompt(
+        1, "Describe the audio above.", APHRODITE_PLACEHOLDER)
     hf_prompt = _get_prompt(1, "Describe the audio above.", HF_PLACEHOLDER)
     hf_prompt = _get_prompt(1, "Describe the audio above.", HF_PLACEHOLDER)
     run_test(
     run_test(
         hf_runner,
         hf_runner,

+ 2 - 1
tests/models/decoder_only/language/test_aqlm.py

@@ -59,7 +59,8 @@ def test_models(
 
 
     # loop through the prompts to compare against the ground truth generations
     # loop through the prompts to compare against the ground truth generations
     for prompt_idx in range(len(example_prompts)):
     for prompt_idx in range(len(example_prompts)):
-        aphrodite_output_ids, aphrodite_output_str, aphrodite_logprobs = aphrodite_outputs[
+        (aphrodite_output_ids, aphrodite_output_str,
+        aphrodite_logprobs) = aphrodite_outputs[
             prompt_idx]
             prompt_idx]
 
 
         print("Prompt:          ", repr(example_prompts[prompt_idx]))
         print("Prompt:          ", repr(example_prompts[prompt_idx]))

+ 6 - 3
tests/models/decoder_only/language/test_big_models.py

@@ -42,8 +42,10 @@ def test_models(
     with hf_runner(model, dtype=dtype) as hf_model:
     with hf_runner(model, dtype=dtype) as hf_model:
         hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
         hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
 
 
-    with aphrodite_runner(model, dtype=dtype, enforce_eager=True) as aphrodite_model:
-        aphrodite_outputs = aphrodite_model.generate_greedy(example_prompts, max_tokens)
+    with aphrodite_runner(model, dtype=dtype,
+                          enforce_eager=True) as aphrodite_model:
+        aphrodite_outputs = aphrodite_model.generate_greedy(
+            example_prompts, max_tokens)
 
 
     check_outputs_equal(
     check_outputs_equal(
         outputs_0_lst=hf_outputs,
         outputs_0_lst=hf_outputs,
@@ -60,7 +62,8 @@ def test_model_print(
     model: str,
     model: str,
     dtype: str,
     dtype: str,
 ) -> None:
 ) -> None:
-    with aphrodite_runner(model, dtype=dtype, enforce_eager=True) as aphrodite_model:
+    with aphrodite_runner(
+        model, dtype=dtype, enforce_eager=True) as aphrodite_model:
         # This test is for verifying whether the model's extra_repr
         # This test is for verifying whether the model's extra_repr
         # can be printed correctly.
         # can be printed correctly.
         print(aphrodite_model.model.llm_engine.model_executor.driver_worker.
         print(aphrodite_model.model.llm_engine.model_executor.driver_worker.

+ 2 - 1
tests/models/decoder_only/language/test_danube3_4b.py

@@ -28,7 +28,8 @@ def test_models(
         hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
         hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
 
 
     with aphrodite_runner(model, dtype=dtype) as aphrodite_model:
     with aphrodite_runner(model, dtype=dtype) as aphrodite_model:
-        aphrodite_outputs = aphrodite_model.generate_greedy(example_prompts, max_tokens)
+        aphrodite_outputs = aphrodite_model.generate_greedy(
+            example_prompts, max_tokens)
 
 
     check_outputs_equal(
     check_outputs_equal(
         outputs_0_lst=hf_outputs,
         outputs_0_lst=hf_outputs,

+ 2 - 1
tests/models/decoder_only/language/test_granite.py

@@ -1,4 +1,5 @@
-"""Compare the outputs of HF and Aphrodite for Granite models using greedy sampling.
+"""Compare the outputs of HF and Aphrodite for Granite models using greedy
+sampling.
 
 
 Run `pytest tests/models/test_granite.py`.
 Run `pytest tests/models/test_granite.py`.
 """
 """

+ 13 - 5
tests/models/decoder_only/language/test_jamba.py

@@ -26,15 +26,18 @@ def test_models(
         hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
         hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
 
 
     with aphrodite_runner(model, dtype=dtype) as aphrodite_model:
     with aphrodite_runner(model, dtype=dtype) as aphrodite_model:
-        aphrodite_outputs = aphrodite_model.generate_greedy(example_prompts, max_tokens)
+        aphrodite_outputs = aphrodite_model.generate_greedy(
+            example_prompts, max_tokens)
 
 
     for i in range(len(example_prompts)):
     for i in range(len(example_prompts)):
         hf_output_ids, hf_output_str = hf_outputs[i]
         hf_output_ids, hf_output_str = hf_outputs[i]
         aphrodite_output_ids, aphrodite_output_str = aphrodite_outputs[i]
         aphrodite_output_ids, aphrodite_output_str = aphrodite_outputs[i]
         assert hf_output_str == aphrodite_output_str, (
         assert hf_output_str == aphrodite_output_str, (
-            f"Test{i}:\nHF: {hf_output_str!r}\nAPHRODITE: {aphrodite_output_str!r}")
+            f"Test{i}:\nHF: {hf_output_str!r}\nAPHRODITE: "
+            f"{aphrodite_output_str!r}")
         assert hf_output_ids == aphrodite_output_ids, (
         assert hf_output_ids == aphrodite_output_ids, (
-            f"Test{i}:\nHF: {hf_output_ids}\nAPHRODITE: {aphrodite_output_ids}")
+            f"Test{i}:\nHF: {hf_output_ids}\nAPHRODITE: "
+            f"{aphrodite_output_ids}")
 
 
 
 
 @pytest.mark.parametrize("model", MODELS)
 @pytest.mark.parametrize("model", MODELS)
@@ -113,7 +116,8 @@ def test_models_preemption_recompute(
 
 
         aphrodite_model.model.llm_engine.scheduler[
         aphrodite_model.model.llm_engine.scheduler[
             0].ENABLE_ARTIFICIAL_PREEMPT = False
             0].ENABLE_ARTIFICIAL_PREEMPT = False
-        aphrodite_outputs = aphrodite_model.generate_greedy(example_prompts, max_tokens)
+        aphrodite_outputs = aphrodite_model.generate_greedy(
+            example_prompts, max_tokens)
 
 
     check_outputs_equal(
     check_outputs_equal(
         outputs_0_lst=preempt_aphrodite_outputs,
         outputs_0_lst=preempt_aphrodite_outputs,
@@ -138,7 +142,11 @@ def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks(
     # statelessness mechanism where it can cleanup new incoming requests in
     # statelessness mechanism where it can cleanup new incoming requests in
     # a single step.
     # a single step.
     try:
     try:
-        with aphrodite_runner(model, dtype=dtype, max_num_seqs=10) as aphrodite_model:
+        with aphrodite_runner(
+            model,
+            dtype=dtype,
+            max_num_seqs=10,
+        ) as aphrodite_model:
             aphrodite_model.generate_greedy([example_prompts[0]] * 100, 10)
             aphrodite_model.generate_greedy([example_prompts[0]] * 100, 10)
     except ValueError:
     except ValueError:
         pytest.fail("Jamba inner state wasn't cleaned up properly between"
         pytest.fail("Jamba inner state wasn't cleaned up properly between"

+ 2 - 1
tests/models/decoder_only/language/test_models.py

@@ -41,7 +41,8 @@ def test_models(
         hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
         hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
 
 
     with aphrodite_runner(model, dtype=dtype) as aphrodite_model:
     with aphrodite_runner(model, dtype=dtype) as aphrodite_model:
-        aphrodite_outputs = aphrodite_model.generate_greedy(example_prompts, max_tokens)
+        aphrodite_outputs = aphrodite_model.generate_greedy(
+            example_prompts, max_tokens)
 
 
     check_outputs_equal(
     check_outputs_equal(
         outputs_0_lst=hf_outputs,
         outputs_0_lst=hf_outputs,

+ 6 - 2
tests/models/decoder_only/vision_language/test_blip2.py

@@ -3,8 +3,8 @@ from typing import List, Optional, Tuple
 import pytest
 import pytest
 from transformers import AutoModelForVision2Seq, AutoTokenizer
 from transformers import AutoModelForVision2Seq, AutoTokenizer
 
 
-from aphrodite.multimodal.utils import rescale_image_size
 from aphrodite.common.sequence import SampleLogprobs
 from aphrodite.common.sequence import SampleLogprobs
+from aphrodite.multimodal.utils import rescale_image_size
 
 
 from ....conftest import IMAGE_ASSETS
 from ....conftest import IMAGE_ASSETS
 from ...utils import check_logprobs_close
 from ...utils import check_logprobs_close
@@ -69,7 +69,11 @@ def test_models(hf_runner, aphrodite_runner, image_assets, model, size_factors,
     ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
     ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
 
 
     # max_model_len should be greater than image_feature_size
     # max_model_len should be greater than image_feature_size
-    with aphrodite_runner(model, dtype=dtype, enforce_eager=True) as aphrodite_model:
+    with aphrodite_runner(
+        model,
+        dtype=dtype,
+        enforce_eager=True,
+    ) as aphrodite_model:
         aphrodite_outputs_per_image = [
         aphrodite_outputs_per_image = [
             aphrodite_model.generate_greedy_logprobs(prompts,
             aphrodite_model.generate_greedy_logprobs(prompts,
                                                 max_tokens,
                                                 max_tokens,

+ 4 - 4
tests/models/decoder_only/vision_language/test_chameleon.py

@@ -3,10 +3,10 @@ from typing import List, Optional, Type
 import pytest
 import pytest
 from transformers import AutoModelForVision2Seq, BatchEncoding
 from transformers import AutoModelForVision2Seq, BatchEncoding
 
 
-from aphrodite.multimodal.utils import rescale_image_size
 from aphrodite.common.utils import STR_DTYPE_TO_TORCH_DTYPE
 from aphrodite.common.utils import STR_DTYPE_TO_TORCH_DTYPE
+from aphrodite.multimodal.utils import rescale_image_size
 
 
-from ....conftest import IMAGE_ASSETS, HfRunner, AphroditeRunner, _ImageAssets
+from ....conftest import IMAGE_ASSETS, AphroditeRunner, HfRunner, _ImageAssets
 from ...utils import check_outputs_equal
 from ...utils import check_outputs_equal
 
 
 HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
 HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
@@ -83,8 +83,8 @@ def run_test(
 
 
     for hf_outputs, aphrodite_outputs in zip(hf_outputs_per_image,
     for hf_outputs, aphrodite_outputs in zip(hf_outputs_per_image,
                                         aphrodite_outputs_per_image):
                                         aphrodite_outputs_per_image):
-        # HF Logprobs include image tokens, unlike Aphrodite, so we don't directly
-        # compare them
+        # HF Logprobs include image tokens, unlike Aphrodite, so we don't
+        # directly compare them
         check_outputs_equal(
         check_outputs_equal(
             outputs_0_lst=[outputs[:2] for outputs in hf_outputs],
             outputs_0_lst=[outputs[:2] for outputs in hf_outputs],
             outputs_1_lst=[outputs[:2] for outputs in aphrodite_outputs],
             outputs_1_lst=[outputs[:2] for outputs in aphrodite_outputs],

+ 4 - 3
tests/models/decoder_only/vision_language/test_fuyu.py

@@ -2,11 +2,11 @@ from typing import List, Optional, Tuple, Type
 
 
 import pytest
 import pytest
 
 
-from aphrodite.multimodal.utils import rescale_image_size
 from aphrodite.common.sequence import SampleLogprobs
 from aphrodite.common.sequence import SampleLogprobs
 from aphrodite.common.utils import is_cpu
 from aphrodite.common.utils import is_cpu
+from aphrodite.multimodal.utils import rescale_image_size
 
 
-from ....conftest import IMAGE_ASSETS, HfRunner, AphroditeRunner, _ImageAssets
+from ....conftest import IMAGE_ASSETS, AphroditeRunner, HfRunner, _ImageAssets
 from ...utils import check_logprobs_close
 from ...utils import check_logprobs_close
 
 
 HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
 HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
@@ -97,7 +97,8 @@ def run_test(
         check_logprobs_close(
         check_logprobs_close(
             outputs_0_lst=hf_outputs,
             outputs_0_lst=hf_outputs,
             outputs_1_lst=[
             outputs_1_lst=[
-                aphrodite_to_hf_output(aphrodite_output) for aphrodite_output in aphrodite_outputs
+                aphrodite_to_hf_output(
+                    aphrodite_output) for aphrodite_output in aphrodite_outputs
             ],
             ],
             name_0="hf",
             name_0="hf",
             name_1="aphrodite",
             name_1="aphrodite",

+ 3 - 3
tests/models/decoder_only/vision_language/test_internvl.py

@@ -6,11 +6,11 @@ import torch
 from PIL.Image import Image
 from PIL.Image import Image
 from transformers import AutoConfig
 from transformers import AutoConfig
 
 
+from aphrodite.common.utils import is_cpu
 from aphrodite.multimodal.utils import rescale_image_size
 from aphrodite.multimodal.utils import rescale_image_size
-from aphrodite.utils import is_cpu
 
 
-from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, AphroditeRunner,
-                          _ImageAssets)
+from ....conftest import (IMAGE_ASSETS, AphroditeRunner, HfRunner,
+                          PromptImageInput, _ImageAssets)
 from ...utils import check_logprobs_close
 from ...utils import check_logprobs_close
 
 
 HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
 HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({

+ 4 - 4
tests/models/decoder_only/vision_language/test_llava.py

@@ -4,12 +4,12 @@ import pytest
 from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
 from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
                           BatchEncoding)
                           BatchEncoding)
 
 
+from aphrodite.common.sequence import SampleLogprobs
+from aphrodite.common.utils import STR_DTYPE_TO_TORCH_DTYPE
 from aphrodite.multimodal.utils import rescale_image_size
 from aphrodite.multimodal.utils import rescale_image_size
-from aphrodite.sequence import SampleLogprobs
-from aphrodite.utils import STR_DTYPE_TO_TORCH_DTYPE
 
 
-from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, AphroditeRunner,
-                          _ImageAssets)
+from ....conftest import (IMAGE_ASSETS, AphroditeRunner, HfRunner,
+                          PromptImageInput, _ImageAssets)
 from ...utils import check_logprobs_close
 from ...utils import check_logprobs_close
 
 
 _LIMIT_IMAGE_PER_PROMPT = 4
 _LIMIT_IMAGE_PER_PROMPT = 4

+ 2 - 2
tests/models/decoder_only/vision_language/test_llava_image_embeds.py

@@ -3,9 +3,9 @@ from typing import List, Optional, Tuple, Type
 import pytest
 import pytest
 from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
 from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
 
 
-from aphrodite.sequence import SampleLogprobs
+from aphrodite.common.sequence import SampleLogprobs
 
 
-from ....conftest import IMAGE_ASSETS, HfRunner, AphroditeRunner, _ImageAssets
+from ....conftest import IMAGE_ASSETS, AphroditeRunner, HfRunner, _ImageAssets
 from ...utils import check_logprobs_close
 from ...utils import check_logprobs_close
 
 
 HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
 HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({

+ 2 - 2
tests/models/decoder_only/vision_language/test_llava_next.py

@@ -223,8 +223,8 @@ def test_models(hf_runner, aphrodite_runner, image_assets, model, size_factors,
 @pytest.mark.parametrize("dtype", ["half"])
 @pytest.mark.parametrize("dtype", ["half"])
 @pytest.mark.parametrize("max_tokens", [128])
 @pytest.mark.parametrize("max_tokens", [128])
 @pytest.mark.parametrize("num_logprobs", [5])
 @pytest.mark.parametrize("num_logprobs", [5])
-def test_models_fixed_sizes(hf_runner, aphrodite_runner, image_assets, model, sizes,
-                            dtype, max_tokens, num_logprobs) -> None:
+def test_models_fixed_sizes(hf_runner, aphrodite_runner, image_assets, model,
+                            sizes, dtype, max_tokens, num_logprobs) -> None:
     run_test(
     run_test(
         hf_runner,
         hf_runner,
         aphrodite_runner,
         aphrodite_runner,

+ 5 - 5
tests/models/decoder_only/vision_language/test_llava_next_video.py

@@ -4,11 +4,11 @@ import pytest
 import transformers
 import transformers
 from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
 from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
 
 
+from aphrodite.common.sequence import SampleLogprobs
 from aphrodite.multimodal.utils import (rescale_video_size, resize_video,
 from aphrodite.multimodal.utils import (rescale_video_size, resize_video,
-                                   sample_frames_from_video)
-from aphrodite.sequence import SampleLogprobs
+                                        sample_frames_from_video)
 
 
-from ....conftest import VIDEO_ASSETS, HfRunner, AphroditeRunner, _VideoAssets
+from ....conftest import VIDEO_ASSETS, AphroditeRunner, HfRunner, _VideoAssets
 from ...utils import check_logprobs_close
 from ...utils import check_logprobs_close
 
 
 _PREFACE = (
 _PREFACE = (
@@ -217,8 +217,8 @@ def test_models(hf_runner, aphrodite_runner, video_assets, model, size_factors,
 @pytest.mark.parametrize("max_tokens", [128])
 @pytest.mark.parametrize("max_tokens", [128])
 @pytest.mark.parametrize("num_logprobs", [5])
 @pytest.mark.parametrize("num_logprobs", [5])
 @pytest.mark.parametrize("num_frames", [16])
 @pytest.mark.parametrize("num_frames", [16])
-def test_models_fixed_sizes(hf_runner, aphrodite_runner, video_assets, model, sizes,
-                            dtype, max_tokens, num_logprobs,
+def test_models_fixed_sizes(hf_runner, aphrodite_runner, video_assets, model,
+                            sizes, dtype, max_tokens, num_logprobs,
                             num_frames) -> None:
                             num_frames) -> None:
     run_test(
     run_test(
         hf_runner,
         hf_runner,

+ 2 - 2
tests/models/decoder_only/vision_language/test_minicpmv.py

@@ -6,10 +6,10 @@ import torch.types
 from PIL import Image
 from PIL import Image
 from transformers import BatchEncoding
 from transformers import BatchEncoding
 
 
+from aphrodite.common.sequence import SampleLogprobs
 from aphrodite.multimodal.utils import rescale_image_size
 from aphrodite.multimodal.utils import rescale_image_size
-from aphrodite.sequence import SampleLogprobs
 
 
-from ....conftest import IMAGE_ASSETS, HfRunner, AphroditeRunner
+from ....conftest import IMAGE_ASSETS, AphroditeRunner, HfRunner
 from ...utils import check_logprobs_close
 from ...utils import check_logprobs_close
 
 
 # The image token is placed before "user" on purpose so that the test can pass
 # The image token is placed before "user" on purpose so that the test can pass

+ 4 - 4
tests/models/decoder_only/vision_language/test_paligemma.py

@@ -4,11 +4,11 @@ from typing import List, Optional, Tuple, Type
 import pytest
 import pytest
 from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
 from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
 
 
+from aphrodite.common.sequence import SampleLogprobs
+from aphrodite.common.utils import is_hip
 from aphrodite.multimodal.utils import rescale_image_size
 from aphrodite.multimodal.utils import rescale_image_size
-from aphrodite.sequence import SampleLogprobs
-from aphrodite.utils import is_hip
 
 
-from ....conftest import IMAGE_ASSETS, HfRunner, AphroditeRunner, _ImageAssets
+from ....conftest import IMAGE_ASSETS, AphroditeRunner, HfRunner, _ImageAssets
 from ...utils import check_logprobs_close
 from ...utils import check_logprobs_close
 
 
 HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
 HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
@@ -24,7 +24,7 @@ models = ["google/paligemma-3b-mix-224"]
 # excessive use of shared memory. Use other backends in the meantime.
 # excessive use of shared memory. Use other backends in the meantime.
 # FIXME (mattwong, gshtrasb, hongxiayan)
 # FIXME (mattwong, gshtrasb, hongxiayan)
 if is_hip():
 if is_hip():
-    os.environ["Aphrodite_USE_TRITON_FLASH_ATTN"] = "0"
+    os.environ["APHRODITE_USE_TRITON_FLASH_ATTN"] = "0"
 
 
 
 
 def aphrodite_to_hf_output(aphrodite_output: Tuple[List[int], str,
 def aphrodite_to_hf_output(aphrodite_output: Tuple[List[int], str,

+ 4 - 3
tests/models/decoder_only/vision_language/test_phi3v.py

@@ -5,11 +5,12 @@ from typing import List, Optional, Tuple, Type
 import pytest
 import pytest
 from transformers import AutoTokenizer
 from transformers import AutoTokenizer
 
 
+from aphrodite.common.sequence import SampleLogprobs
+from aphrodite.common.utils import is_cpu, is_hip
 from aphrodite.multimodal.utils import rescale_image_size
 from aphrodite.multimodal.utils import rescale_image_size
-from aphrodite.sequence import SampleLogprobs
-from aphrodite.utils import is_cpu, is_hip
 
 
-from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, AphroditeRunner
+from ....conftest import (IMAGE_ASSETS, AphroditeRunner, HfRunner,
+                          PromptImageInput)
 from ...utils import check_logprobs_close
 from ...utils import check_logprobs_close
 
 
 HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
 HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({

+ 3 - 2
tests/models/decoder_only/vision_language/test_pixtral.py

@@ -1,6 +1,7 @@
-"""Compare the outputs of HF and Aphrodite for Mistral models using greedy sampling.
+"""Compare the outputs of HF and Aphrodite for Mistral models using greedy
+sampling.
 
 
-Run `pytest tests/models/test_mistral.py`.
+Run `pytest tests/models/decoder_only/vision_language/test_pixtral.py`.
 """
 """
 import json
 import json
 import uuid
 import uuid

+ 3 - 2
tests/models/embedding/language/test_embedding.py

@@ -1,6 +1,7 @@
-"""Compare the outputs of HF and Aphrodite for Mistral models using greedy sampling.
+"""Compare the outputs of HF and Aphrodite for Mistral models using greedy
+sampling.
 
 
-Run `pytest tests/models/test_llama_embedding.py`.
+Run `pytest tests/models/embedding/language/test_embedding.py`.
 """
 """
 import pytest
 import pytest
 import torch
 import torch

+ 28 - 23
tests/models/encoder_decoder/language/test_bart.py

@@ -1,4 +1,5 @@
-"""Compare the outputs of HF and Aphrodite for BART models using greedy sampling.
+"""Compare the outputs of HF and Aphrodite for BART models using greedy
+sampling.
 
 
 Run `pytest tests/models/encoder_decoder/language/test_bart.py`.
 Run `pytest tests/models/encoder_decoder/language/test_bart.py`.
 """
 """
@@ -50,8 +51,8 @@ if not is_cpu():
         distributed_executor_backend: Optional[str] = None,
         distributed_executor_backend: Optional[str] = None,
     ) -> None:
     ) -> None:
         '''
         '''
-        Test the Aphrodite BART model for a variety of encoder/decoder input prompts,
-        by validating it against HuggingFace (HF) BART.
+        Test the Aphrodite BART model for a variety of encoder/decoder input
+        prompts, by validating it against HuggingFace (HF) BART.
 
 
         Arguments:
         Arguments:
 
 
@@ -87,20 +88,21 @@ if not is_cpu():
           then (4) after computing logits during prefill, override the model
           then (4) after computing logits during prefill, override the model
           logits & force <BOS> to be the first generated token.
           logits & force <BOS> to be the first generated token.
         
         
-        * Aphrodite will (1) tokenize the None prompt as [<BOS>], (2) append decoder-
-          start-token to the beginning, yielding [<decoder-start-token><BOS>],
-          (3) pass these tokens to the model & proceed with generation.
+        * Aphrodite will (1) tokenize the None prompt as [<BOS>], (2) append
+          <decoder-start-token> to the beginning, yielding
+          [<decoder-start-token><BOS>], (3) pass these tokens to the model &
+          proceed with generation.
         
         
-        The net effect is that compared to Aphrodite, the list of HF *decoded* tokens
-        will contain one more initial <BOS> than the Aphrodite generated tokens,
-        because Aphrodite's <BOS> token is injected into the prompt rather than into
-        the generated output. This is in spite of the fact that overall, the
-        complete sequences (prompt + decoded tokens) produced by Aphrodite will match
-        HF.
+        The net effect is that compared to Aphrodite, the list of HF *decoded*
+        tokens will contain one more initial <BOS> than the Aphrodite generated
+        tokens, because Aphrodite's <BOS> token is injected into the prompt
+        rather than into the generated output. This is in spite of the fact
+        that overall, the complete sequences (prompt + decoded tokens) produced
+        by Aphrodite will match HF.
         
         
-        So when we use HF decoded token output to validate Aphrodite's decoded token
-        output, the testing process must account for the difference in decoded
-        token sequences between Aphrodite and HF specifically in the
+        So when we use HF decoded token output to validate Aphrodite's decoded
+        token output, the testing process must account for the difference in
+        decoded token sequences between Aphrodite and HF specifically in the
         decoder-prompt-is-None case. 
         decoder-prompt-is-None case. 
         
         
         One option is to disable the logit processor feature that forces the
         One option is to disable the logit processor feature that forces the
@@ -126,19 +128,21 @@ if not is_cpu():
         # for encoder/decoder models Aphrodite will
         # for encoder/decoder models Aphrodite will
         # default to enforce_eager=True if enforce_eager
         # default to enforce_eager=True if enforce_eager
         # is left unspecified. However, the
         # is left unspecified. However, the
-        # AphroditeRunner test fixture (which wraps around the LLM class) defaults to
-        # enforce_eager=False (a behavior which a number of already-exisitng
-        # decoder-only unit tests expect), so when testing an encoder/decoder
-        # model we must explicitly specify enforce_eager=True in the AphroditeRunner
-        # constructor.
+        # AphroditeRunner test fixture (which wraps around the LLM class)
+        # defaults to enforce_eager=False (a behavior which a number of
+        # already-exisitng decoder-only unit tests expect), so when testing
+        # an encoder/decoder model we must explicitly specify enforce_eager=True
+        # in the AphroditeRunner constructor.
         with aphrodite_runner(
         with aphrodite_runner(
                 model,
                 model,
                 dtype=dtype,
                 dtype=dtype,
                 tensor_parallel_size=tensor_parallel_size,
                 tensor_parallel_size=tensor_parallel_size,
                 distributed_executor_backend=distributed_executor_backend,
                 distributed_executor_backend=distributed_executor_backend,
                 enforce_eager=True) as aphrodite_model:
                 enforce_eager=True) as aphrodite_model:
-            aphrodite_outputs = aphrodite_model.generate_encoder_decoder_greedy_logprobs(
-                prompts, max_tokens, num_logprobs)
+            aphrodite_outputs = (
+                aphrodite_model.generate_encoder_decoder_greedy_logprobs(
+                    prompts, max_tokens, num_logprobs)
+            )
 
 
         # Configuration settings for HF baseline
         # Configuration settings for HF baseline
         hf_kwargs = {
         hf_kwargs = {
@@ -181,7 +185,8 @@ if not is_cpu():
     @pytest.mark.parametrize("max_tokens", [64])
     @pytest.mark.parametrize("max_tokens", [64])
     @pytest.mark.parametrize("num_logprobs", [5])
     @pytest.mark.parametrize("num_logprobs", [5])
     @pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
     @pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
-    def test_models(hf_runner, aphrodite_runner, example_encoder_decoder_prompts,
+    def test_models(hf_runner, aphrodite_runner,
+                    example_encoder_decoder_prompts,
                     model, dtype, max_tokens, num_logprobs,
                     model, dtype, max_tokens, num_logprobs,
                     decoder_prompt_type) -> None:
                     decoder_prompt_type) -> None:
 
 

+ 0 - 0
tests/mq_aphrodite_engine/__init__.py


+ 68 - 0
tests/mq_aphrodite_engine/test_abort.py

@@ -0,0 +1,68 @@
+"""Test that aborting is handled properly."""
+
+import asyncio
+import tempfile
+import uuid
+
+import pytest
+
+from aphrodite.engine.args_tools import AsyncEngineArgs
+from tests.mq_aphrodite_engine.utils import RemoteMQAphroditeEngine, generate
+
+MODEL = "google/gemma-1.1-2b-it"
+ENGINE_ARGS = AsyncEngineArgs(model=MODEL)
+RAISED_ERROR = KeyError
+RAISED_VALUE = "foo"
+EXPECTED_TOKENS = 250
+
+
+@pytest.fixture(scope="function")
+def tmp_socket():
+    with tempfile.TemporaryDirectory() as td:
+        yield f"ipc://{td}/{uuid.uuid4()}"
+
+
+@pytest.mark.asyncio
+async def test_abort(tmp_socket):
+    with RemoteMQAphroditeEngine(
+        engine_args=ENGINE_ARGS,
+        ipc_path=tmp_socket) as engine:
+
+        client = await engine.make_client()
+
+        request_id_to_be_aborted = "request-aborted"
+        request_ids_a = [f"request-a-{idx}" for idx in range(10)]
+        request_ids_b = [f"request-b-{idx}" for idx in range(10)]
+
+        # Requests started before one to be aborted.
+        tasks = []
+        for request_id in request_ids_a:
+            tasks.append(
+                asyncio.create_task(
+                    generate(client, request_id, EXPECTED_TOKENS)))
+
+        # Aborted.
+        task_aborted = asyncio.create_task(
+            generate(client, request_id_to_be_aborted, EXPECTED_TOKENS))
+
+        # Requests started after one to be aborted.
+        for request_id in request_ids_b:
+            tasks.append(
+                asyncio.create_task(
+                    generate(client, request_id, EXPECTED_TOKENS)))
+
+        # Actually abort.
+        await asyncio.sleep(0.5)
+        await client.abort(request_id_to_be_aborted)
+
+        # Confirm that we got all the EXPECTED tokens from the requests.
+        for task in tasks:
+            count, request_id = await task
+            assert count == EXPECTED_TOKENS, (
+                f"{request_id} generated only {count} tokens")
+
+        # Cancel task (this will hang indefinitely if not).
+        task_aborted.cancel()
+
+        # Shutdown.
+        client.close()

+ 243 - 0
tests/mq_aphrodite_engine/test_error_handling.py

@@ -0,0 +1,243 @@
+"""Test that various errors are handled properly."""
+
+import asyncio
+import tempfile
+import time
+import uuid
+from unittest.mock import Mock
+
+import pytest
+
+from aphrodite import SamplingParams
+from aphrodite.common.utils import FlexibleArgumentParser
+from aphrodite.endpoints.openai.api_server import build_engine_client
+from aphrodite.endpoints.openai.args import make_arg_parser
+from aphrodite.engine.aphrodite_engine import AphroditeEngine
+from aphrodite.engine.args_tools import AsyncEngineArgs
+from aphrodite.engine.multiprocessing import MQEngineDeadError
+from aphrodite.engine.multiprocessing.engine import MQAphroditeEngine
+from aphrodite.lora.request import LoRARequest
+from tests.mq_aphrodite_engine.utils import RemoteMQAphroditeEngine
+
+MODEL = "google/gemma-1.1-2b-it"
+ENGINE_ARGS = AsyncEngineArgs(model=MODEL)
+RAISED_ERROR = KeyError
+RAISED_VALUE = "foo"
+
+
+@pytest.fixture(scope="function")
+def tmp_socket():
+    with tempfile.TemporaryDirectory() as td:
+        yield f"ipc://{td}/{uuid.uuid4()}"
+
+
+def run_with_evil_forward(engine_args: AsyncEngineArgs, ipc_path: str):
+    # Make engine.
+    engine = MQAphroditeEngine.from_engine_args(
+        engine_args=engine_args,
+        ipc_path=ipc_path)
+
+    # Raise error during first forward pass.
+    engine.engine.model_executor.execute_model = Mock(
+        side_effect=RAISED_ERROR(RAISED_VALUE))
+
+    # Run engine.
+    engine.start()
+
+
+@pytest.mark.asyncio
+async def test_evil_forward(tmp_socket):
+    with RemoteMQAphroditeEngine(engine_args=ENGINE_ARGS,
+                           ipc_path=tmp_socket,
+                           run_fn=run_with_evil_forward) as engine:
+
+        client = await engine.make_client()
+
+        # Server should be healthy after initial probe.
+        await asyncio.sleep(2.0)
+        await client.check_health()
+
+        # Throws an error in first forward pass.
+        with pytest.raises(RAISED_ERROR):
+            async for _ in client.generate(prompt="Hello my name is",
+                                           sampling_params=SamplingParams(),
+                                           request_id=uuid.uuid4()):
+                pass
+        assert client.errored
+
+        # Engine is errored, should get ENGINE_DEAD_ERROR.
+        with pytest.raises(MQEngineDeadError):
+            async for _ in client.generate(prompt="Hello my name is",
+                                           sampling_params=SamplingParams(),
+                                           request_id=uuid.uuid4()):
+                pass
+        assert client.errored
+
+        await asyncio.sleep(1.0)
+        with pytest.raises(RAISED_ERROR):
+            await client.check_health()
+        assert client.errored
+
+        # Shutdown.
+        client.close()
+
+
+def run_with_evil_model_executor_health(engine_args: AsyncEngineArgs,
+                                        ipc_path: str):
+    # Make engine.
+    engine = MQAphroditeEngine.from_engine_args(
+        engine_args=engine_args,
+        ipc_path=ipc_path)
+
+    # Raise error during first forward pass.
+    engine.engine.model_executor.check_health = Mock(side_effect=RAISED_ERROR)
+
+    # Run engine.
+    engine.start()
+
+
+@pytest.mark.asyncio
+async def test_failed_health_check(tmp_socket):
+    with RemoteMQAphroditeEngine(
+            engine_args=ENGINE_ARGS,
+            ipc_path=tmp_socket,
+            run_fn=run_with_evil_model_executor_health) as engine:
+
+        client = await engine.make_client()
+        assert client.is_running
+
+        # Health probe should throw RAISED_ERROR.
+        await asyncio.sleep(15.)
+
+        with pytest.raises(RAISED_ERROR):
+            await client.check_health()
+        assert client.errored
+
+        # Generate call should throw ENGINE_DEAD_ERROR
+        with pytest.raises(MQEngineDeadError):
+            async for _ in client.generate(prompt="Hello my name is",
+                                           sampling_params=SamplingParams(),
+                                           request_id=uuid.uuid4()):
+                pass
+
+        client.close()
+
+
+def run_with_evil_abort(engine_args: AsyncEngineArgs, ipc_path: str):
+    # Make engine.
+    engine = MQAphroditeEngine.from_engine_args(
+        engine_args=engine_args,
+        ipc_path=ipc_path)
+
+    # Raise error during abort call.
+    engine.engine.abort_request = Mock(side_effect=RAISED_ERROR)
+
+    # Run engine.
+    engine.start()
+
+
+@pytest.mark.asyncio
+async def test_failed_abort(tmp_socket):
+    with RemoteMQAphroditeEngine(engine_args=ENGINE_ARGS,
+                           ipc_path=tmp_socket,
+                           run_fn=run_with_evil_abort) as engine:
+
+        client = await engine.make_client()
+        assert client.is_running
+
+        # Firsh check health should work.
+        await client.check_health()
+
+        # Trigger an abort on the client side.
+        async def bad_abort_after_2s():
+            await asyncio.sleep(2.0)
+            await client.abort(request_id="foo")
+
+        # Trigger an abort in 2s from now.
+        abort_task = asyncio.create_task(bad_abort_after_2s())
+
+        # Exception in abort() will happen during this generation.
+        # This will kill the engine and should return ENGINE_DEAD_ERROR
+        # with reference to the original KeyError("foo")
+        with pytest.raises(MQEngineDeadError) as execinfo:
+            async for _ in client.generate(
+                    prompt="Hello my name is",
+                    sampling_params=SamplingParams(max_tokens=2000),
+                    request_id=uuid.uuid4()):
+                pass
+        assert "KeyError" in repr(execinfo.value)
+        assert client.errored
+
+        await abort_task
+
+        # This should raise the original error.
+        with pytest.raises(RAISED_ERROR):
+            await client.check_health()
+
+        client.close()
+
+
+@pytest.mark.asyncio
+async def test_bad_request(tmp_socket):
+    with RemoteMQAphroditeEngine(engine_args=ENGINE_ARGS,
+                           ipc_path=tmp_socket) as engine:
+
+        client = await engine.make_client()
+
+        # Invalid request should fail, but not crash the server.
+        with pytest.raises(ValueError):
+            async for _ in client.generate(prompt="Hello my name is",
+                                           sampling_params=SamplingParams(),
+                                           request_id="abcd-1",
+                                           lora_request=LoRARequest(
+                                               "invalid-lora", 1,
+                                               "invalid-path")):
+                pass
+
+        # This request should be okay.
+        async for _ in client.generate(prompt="Hello my name is",
+                                       sampling_params=SamplingParams(),
+                                       request_id="abcd-2"):
+            pass
+
+        # Shutdown.
+        client.close()
+
+
+@pytest.mark.asyncio
+async def test_mp_crash_detection(monkeypatch):
+
+    parser = FlexibleArgumentParser(
+        description="Aphrodite's remote OpenAI server.")
+    parser = make_arg_parser(parser)
+    args = parser.parse_args([])
+
+    # When AphroditeEngine is loaded, it will crash.
+    def mock_init():
+        raise ValueError
+
+    monkeypatch.setattr(AphroditeEngine, "__init__", mock_init)
+
+    start = time.perf_counter()
+    async with build_engine_client(args):
+        pass
+    end = time.perf_counter()
+
+    assert end - start < 60, (
+        "Expected Aphrodite to gracefully shutdown in <60s "
+        "if there is an error in the startup.")
+
+
+@pytest.mark.asyncio
+async def test_mp_cuda_init():
+    # it should not crash, when cuda is initialized
+    # in the API server process
+    import torch
+    torch.cuda.init()
+    parser = FlexibleArgumentParser(
+        description="Aphrodite's remote OpenAI server.")
+    parser = make_arg_parser(parser)
+    args = parser.parse_args([])
+
+    async with build_engine_client(args):
+        pass

+ 58 - 0
tests/mq_aphrodite_engine/test_load.py

@@ -0,0 +1,58 @@
+"""Test that the MQLLMEngine is able to handle 10k concurrent requests."""
+
+import asyncio
+import tempfile
+import uuid
+
+import pytest
+
+from aphrodite.engine.args_tools import AsyncEngineArgs
+from tests.mq_aphrodite_engine.utils import RemoteMQAphroditeEngine, generate
+
+MODEL = "google/gemma-1.1-2b-it"
+NUM_EXPECTED_TOKENS = 10
+NUM_REQUESTS = 10000
+
+# Scenarios to test for num generated token.
+ENGINE_ARGS = AsyncEngineArgs(model=MODEL, disable_log_requests=True)
+
+
+@pytest.fixture(scope="function")
+def tmp_socket():
+    with tempfile.TemporaryDirectory() as td:
+        yield f"ipc://{td}/{uuid.uuid4()}"
+
+
+@pytest.mark.asyncio
+async def test_load(tmp_socket):
+    with RemoteMQAphroditeEngine(
+        engine_args=ENGINE_ARGS,
+        ipc_path=tmp_socket) as engine:
+
+        client = await engine.make_client()
+
+        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
+
+        # Create concurrent requests.
+        tasks = []
+        for request_id in request_ids:
+            tasks.append(
+                asyncio.create_task(
+                    generate(client, request_id, NUM_EXPECTED_TOKENS)))
+
+        # Confirm that we got all the EXPECTED tokens from the requests.
+        failed_request_id = None
+        tokens = None
+        for task in tasks:
+            num_generated_tokens, request_id = await task
+            if (num_generated_tokens != NUM_EXPECTED_TOKENS
+                    and failed_request_id is None):
+                failed_request_id = request_id
+                tokens = num_generated_tokens
+
+        assert failed_request_id is None, (
+            f"{failed_request_id} generated {tokens} but "
+            f"expected {NUM_EXPECTED_TOKENS}")
+
+        # Shutdown.
+        client.close()

+ 76 - 0
tests/mq_aphrodite_engine/utils.py

@@ -0,0 +1,76 @@
+import asyncio
+import multiprocessing
+from typing import Callable, Tuple, Union
+
+from aphrodite import SamplingParams
+from aphrodite.common.outputs import RequestOutput
+from aphrodite.engine.args_tools import AsyncEngineArgs
+from aphrodite.engine.multiprocessing.client import MQAphroditeEngineClient
+from aphrodite.engine.multiprocessing.engine import MQAphroditeEngine
+
+
+async def generate(
+        client: MQAphroditeEngineClient,
+        request_id: str,
+        num_tokens: int,
+        return_output: bool = False) -> Union[RequestOutput, Tuple[int, str]]:
+
+    final_output = None
+    count = 0
+    async for out in client.generate(
+            request_id=request_id,
+            prompt="Hello my name is Robert and",
+            sampling_params=SamplingParams(max_tokens=num_tokens,
+                                           temperature=0)):
+
+        count += 1
+        final_output = out
+        await asyncio.sleep(0.)
+
+    if return_output:
+        return final_output
+
+    # Confirm we generated all the tokens we expected.
+    return count, request_id
+
+
+def run_normal(engine_args: AsyncEngineArgs, ipc_path: str):
+    # Make engine.
+    engine = MQAphroditeEngine.from_engine_args(
+        engine_args=engine_args,
+        ipc_path=ipc_path)
+
+    # Run engine.
+    engine.start()
+
+
+class RemoteMQAphroditeEngine:
+
+    def __init__(self,
+                 engine_args: AsyncEngineArgs,
+                 ipc_path: str,
+                 run_fn: Callable = run_normal) -> None:
+
+        self.engine_args = engine_args
+        self.ipc_path = ipc_path
+        context = multiprocessing.get_context("spawn")
+        self.proc = context.Process(target=run_fn,
+                                    args=(engine_args, ipc_path))
+        self.proc.start()
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        self.proc.kill()
+
+    async def make_client(self) -> MQAphroditeEngineClient:
+        engine_config = self.engine_args.create_engine_config()
+        client = MQAphroditeEngineClient(self.ipc_path, engine_config)
+        while True:
+            try:
+                await client.setup()
+                break
+            except TimeoutError:
+                assert self.proc.is_alive()
+        return client