base_tokenizer_group.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. from abc import ABC, abstractmethod
  2. from typing import List, Optional
  3. from transformers import PreTrainedTokenizer
  4. from aphrodite.common.config import TokenizerPoolConfig
  5. from aphrodite.lora.request import LoRARequest
  6. class BaseTokenizerGroup(ABC):
  7. """A group of tokenizers that can be used for LoRA adapters."""
  8. @classmethod
  9. @abstractmethod
  10. def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
  11. **init_kwargs) -> "BaseTokenizerGroup":
  12. pass
  13. @abstractmethod
  14. def ping(self) -> bool:
  15. """Check if the tokenizer group is alive."""
  16. pass
  17. @abstractmethod
  18. def get_max_input_len(self,
  19. lora_request: Optional[LoRARequest] = None
  20. ) -> Optional[int]:
  21. """Get the maximum input length for the LoRA request."""
  22. pass
  23. @abstractmethod
  24. def encode(self,
  25. prompt: str,
  26. request_id: Optional[str] = None,
  27. lora_request: Optional[LoRARequest] = None) -> List[int]:
  28. """Encode a prompt using the tokenizer group."""
  29. pass
  30. @abstractmethod
  31. async def encode_async(
  32. self,
  33. prompt: str,
  34. request_id: Optional[str] = None,
  35. lora_request: Optional[LoRARequest] = None) -> List[int]:
  36. """Encode a prompt using the tokenizer group."""
  37. pass
  38. @abstractmethod
  39. def get_lora_tokenizer(
  40. self,
  41. lora_request: Optional[LoRARequest] = None
  42. ) -> "PreTrainedTokenizer":
  43. """Get a tokenizer for a LoRA request."""
  44. pass
  45. @abstractmethod
  46. async def get_lora_tokenizer_async(
  47. self,
  48. lora_request: Optional[LoRARequest] = None
  49. ) -> "PreTrainedTokenizer":
  50. """Get a tokenizer for a LoRA request."""
  51. pass
  52. def check_health(self):
  53. """Raise exception if the tokenizer group is unhealthy."""
  54. return