inputs.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence,
  2. TypedDict, Union, cast, overload)
  3. from typing_extensions import NotRequired
  4. if TYPE_CHECKING:
  5. from aphrodite.common.sequence import MultiModalData
  6. class ParsedText(TypedDict):
  7. content: str
  8. is_tokens: Literal[False]
  9. class ParsedTokens(TypedDict):
  10. content: List[int]
  11. is_tokens: Literal[True]
  12. @overload
  13. def parse_and_batch_prompt(
  14. prompt: Union[str, List[str]]) -> Sequence[ParsedText]:
  15. ...
  16. @overload
  17. def parse_and_batch_prompt(
  18. prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]:
  19. ...
  20. def parse_and_batch_prompt(
  21. prompt: Union[str, List[str], List[int], List[List[int]]],
  22. ) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
  23. if isinstance(prompt, str):
  24. # case 1: a string
  25. return [ParsedText(content=prompt, is_tokens=False)]
  26. if isinstance(prompt, list):
  27. if len(prompt) == 0:
  28. raise ValueError("please provide at least one prompt")
  29. if isinstance(prompt[0], str):
  30. # case 2: array of strings
  31. return [
  32. ParsedText(content=elem, is_tokens=False)
  33. for elem in cast(List[str], prompt)
  34. ]
  35. if isinstance(prompt[0], int):
  36. # case 3: array of tokens
  37. elem = cast(List[int], prompt)
  38. return [ParsedTokens(content=elem, is_tokens=True)]
  39. if isinstance(prompt[0], list):
  40. if len(prompt[0]) == 0:
  41. raise ValueError("please provide at least one prompt")
  42. if isinstance(prompt[0][0], int):
  43. # case 4: array of token arrays
  44. return [
  45. ParsedTokens(content=elem, is_tokens=True)
  46. for elem in cast(List[List[int]], prompt)
  47. ]
  48. raise ValueError("prompt must be a string, array of strings, "
  49. "array of tokens, or array of token arrays")
  50. class TextPrompt(TypedDict):
  51. """Schema for a text prompt."""
  52. prompt: str
  53. """The input text to be tokenized before passing to the model."""
  54. multi_modal_data: NotRequired["MultiModalData"]
  55. """
  56. Optional multi-modal data to pass to the model,
  57. if the model supports it.
  58. """
  59. class TokensPrompt(TypedDict):
  60. """Schema for a tokenized prompt."""
  61. prompt_token_ids: List[int]
  62. """A list of token IDs to pass to the model."""
  63. multi_modal_data: NotRequired["MultiModalData"]
  64. """
  65. Optional multi-modal data to pass to the model,
  66. if the model supports it.
  67. """
  68. class TextTokensPrompt(TypedDict):
  69. """It is assumed that :attr:`prompt` is consistent with
  70. :attr:`prompt_token_ids`. This is currently used in
  71. :class:`AsyncLLMEngine` for logging both the text and token IDs."""
  72. prompt: str
  73. """The prompt text."""
  74. prompt_token_ids: List[int]
  75. """The token IDs of the prompt. If None, we use the
  76. tokenizer to convert the prompts to token IDs."""
  77. multi_modal_data: NotRequired["MultiModalData"]
  78. """
  79. Optional multi-modal data to pass to the model,
  80. if the model supports it.
  81. """
  82. PromptStrictInputs = Union[str, TextPrompt, TokensPrompt]
  83. """
  84. The inputs to the LLM, which can take one of the following forms:
  85. - A text prompt (:class:`str` or :class:`TextPrompt`)
  86. - A tokenized prompt (:class:`TokensPrompt`)
  87. """
  88. PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt]
  89. """Same as :const:`PromptStrictInputs` but additionally accepts
  90. :class:`TextTokensPrompt`."""
  91. class LLMInputs(TypedDict):
  92. prompt_token_ids: List[int]
  93. prompt: NotRequired[Optional[str]]
  94. multi_modal_data: NotRequired[Optional["MultiModalData"]]