parse.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. from typing import List, Literal, Sequence, TypedDict, Union, overload
  2. from typing_extensions import TypeIs
  3. from aphrodite.common.utils import is_list_of
  4. from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
  5. LLMInputs, PromptType, SingletonPrompt, TextPrompt,
  6. TokensPrompt)
  7. class ParsedText(TypedDict):
  8. content: str
  9. is_tokens: Literal[False]
  10. class ParsedTokens(TypedDict):
  11. content: List[int]
  12. is_tokens: Literal[True]
  13. @overload
  14. def parse_and_batch_prompt(
  15. prompt: Union[str, List[str]]) -> Sequence[ParsedText]:
  16. ...
  17. @overload
  18. def parse_and_batch_prompt(
  19. prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]:
  20. ...
  21. def parse_and_batch_prompt(
  22. prompt: Union[str, List[str], List[int], List[List[int]]],
  23. ) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
  24. if isinstance(prompt, str):
  25. # case 1: a string
  26. return [ParsedText(content=prompt, is_tokens=False)]
  27. if isinstance(prompt, list):
  28. if len(prompt) == 0:
  29. raise ValueError("please provide at least one prompt")
  30. if is_list_of(prompt, str):
  31. # case 2: array of strings
  32. return [
  33. ParsedText(content=elem, is_tokens=False) for elem in prompt
  34. ]
  35. if is_list_of(prompt, int):
  36. # case 3: array of tokens
  37. return [ParsedTokens(content=prompt, is_tokens=True)]
  38. if is_list_of(prompt, list):
  39. if len(prompt[0]) == 0:
  40. raise ValueError("please provide at least one prompt")
  41. if is_list_of(prompt[0], int):
  42. # case 4: array of token arrays
  43. return [
  44. ParsedTokens(content=elem, is_tokens=True)
  45. for elem in prompt
  46. ]
  47. raise TypeError("prompt must be a string, array of strings, "
  48. "array of tokens, or array of token arrays")
  49. class ParsedStrPrompt(TypedDict):
  50. type: Literal["str"]
  51. content: str
  52. class ParsedTextPrompt(TypedDict):
  53. type: Literal["text"]
  54. content: TextPrompt
  55. class ParsedTokensPrompt(TypedDict):
  56. type: Literal["tokens"]
  57. content: TokensPrompt
  58. def parse_singleton_prompt(
  59. prompt: SingletonPrompt,
  60. ) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]:
  61. if isinstance(prompt, str):
  62. return ParsedStrPrompt(type="str", content=prompt)
  63. elif isinstance(prompt, dict):
  64. if "prompt_token_ids" in prompt:
  65. return ParsedTokensPrompt(type="tokens",
  66. content=prompt) # type: ignore
  67. elif "prompt" in prompt:
  68. return ParsedTextPrompt(type="text", content=prompt)
  69. raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt")
  70. def is_explicit_encoder_decoder_prompt(
  71. prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
  72. return isinstance(prompt, dict) and "encoder_prompt" in prompt
  73. def is_valid_encoder_decoder_llm_inputs(
  74. inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
  75. ) -> TypeIs[EncoderDecoderLLMInputs]:
  76. return "encoder_prompt_token_ids" in inputs