base_tokenizer_group.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. from abc import ABC, abstractmethod
  2. from typing import List, Optional
  3. from transformers import PreTrainedTokenizer
  4. from aphrodite.lora.request import LoRARequest
  5. class BaseTokenizerGroup(ABC):
  6. """A group of tokenizers that can be used for LoRA adapters."""
  7. @abstractmethod
  8. def ping(self) -> bool:
  9. """Check if the tokenizer group is alive."""
  10. pass
  11. @abstractmethod
  12. def get_max_input_len(self,
  13. lora_request: Optional[LoRARequest] = None
  14. ) -> Optional[int]:
  15. """Get the maximum input length for the LoRA request."""
  16. pass
  17. @abstractmethod
  18. def encode(self,
  19. prompt: str,
  20. request_id: Optional[str] = None,
  21. lora_request: Optional[LoRARequest] = None) -> List[int]:
  22. """Encode a prompt using the tokenizer group."""
  23. pass
  24. @abstractmethod
  25. async def encode_async(
  26. self,
  27. prompt: str,
  28. request_id: Optional[str] = None,
  29. lora_request: Optional[LoRARequest] = None) -> List[int]:
  30. """Encode a prompt using the tokenizer group."""
  31. pass
  32. @abstractmethod
  33. def get_lora_tokenizer(
  34. self,
  35. lora_request: Optional[LoRARequest] = None
  36. ) -> "PreTrainedTokenizer":
  37. """Get a tokenizer for a LoRA request."""
  38. pass
  39. @abstractmethod
  40. async def get_lora_tokenizer_async(
  41. self,
  42. lora_request: Optional[LoRARequest] = None
  43. ) -> "PreTrainedTokenizer":
  44. """Get a tokenizer for a LoRA request."""
  45. pass