parse.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. from typing import List, Literal, Sequence, TypedDict, Union, overload
  2. from typing_extensions import TypeIs
  3. from aphrodite.common.utils import is_list_of
  4. from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
  5. LLMInputs, PromptInputs)
  6. class ParsedText(TypedDict):
  7. content: str
  8. is_tokens: Literal[False]
  9. class ParsedTokens(TypedDict):
  10. content: List[int]
  11. is_tokens: Literal[True]
  12. @overload
  13. def parse_and_batch_prompt(
  14. prompt: Union[str, List[str]]) -> Sequence[ParsedText]:
  15. ...
  16. @overload
  17. def parse_and_batch_prompt(
  18. prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]:
  19. ...
  20. def parse_and_batch_prompt(
  21. prompt: Union[str, List[str], List[int], List[List[int]]],
  22. ) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
  23. if isinstance(prompt, str):
  24. # case 1: a string
  25. return [ParsedText(content=prompt, is_tokens=False)]
  26. if isinstance(prompt, list):
  27. if len(prompt) == 0:
  28. raise ValueError("please provide at least one prompt")
  29. if is_list_of(prompt, str):
  30. # case 2: array of strings
  31. return [
  32. ParsedText(content=elem, is_tokens=False) for elem in prompt
  33. ]
  34. if is_list_of(prompt, int):
  35. # case 3: array of tokens
  36. return [ParsedTokens(content=prompt, is_tokens=True)]
  37. if is_list_of(prompt, list):
  38. if len(prompt[0]) == 0:
  39. raise ValueError("please provide at least one prompt")
  40. if is_list_of(prompt[0], int):
  41. # case 4: array of token arrays
  42. return [
  43. ParsedTokens(content=elem, is_tokens=True)
  44. for elem in prompt
  45. ]
  46. raise ValueError("prompt must be a string, array of strings, "
  47. "array of tokens, or array of token arrays")
  48. def is_explicit_encoder_decoder_prompt(
  49. inputs: PromptInputs) -> TypeIs[ExplicitEncoderDecoderPrompt]:
  50. return isinstance(inputs, dict) and "encoder_prompt" in inputs
  51. def is_valid_encoder_decoder_llm_inputs(
  52. inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
  53. ) -> TypeIs[EncoderDecoderLLMInputs]:
  54. return "encoder_prompt_token_ids" in inputs