base_tokenizer_group.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. from abc import ABC, abstractmethod
  2. from typing import List, Optional, Union
  3. from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
  4. from aphrodite.common.config import TokenizerPoolConfig
  5. from aphrodite.lora.request import LoRARequest
  6. AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
  7. class BaseTokenizerGroup(ABC):
  8. """A group of tokenizers that can be used for LoRA adapters."""
  9. @classmethod
  10. @abstractmethod
  11. def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
  12. **init_kwargs) -> "BaseTokenizerGroup":
  13. pass
  14. @abstractmethod
  15. def ping(self) -> bool:
  16. """Check if the tokenizer group is alive."""
  17. pass
  18. @abstractmethod
  19. def get_max_input_len(self,
  20. lora_request: Optional[LoRARequest] = None
  21. ) -> Optional[int]:
  22. """Get the maximum input length for the LoRA request."""
  23. pass
  24. @abstractmethod
  25. def encode(self,
  26. prompt: str,
  27. request_id: Optional[str] = None,
  28. lora_request: Optional[LoRARequest] = None) -> List[int]:
  29. """Encode a prompt using the tokenizer group."""
  30. pass
  31. @abstractmethod
  32. async def encode_async(
  33. self,
  34. prompt: str,
  35. request_id: Optional[str] = None,
  36. lora_request: Optional[LoRARequest] = None) -> List[int]:
  37. """Encode a prompt using the tokenizer group."""
  38. pass
  39. @abstractmethod
  40. def get_lora_tokenizer(
  41. self,
  42. lora_request: Optional[LoRARequest] = None,
  43. ) -> AnyTokenizer:
  44. """Get a tokenizer for a LoRA request."""
  45. pass
  46. @abstractmethod
  47. async def get_lora_tokenizer_async(
  48. self,
  49. lora_request: Optional[LoRARequest] = None,
  50. ) -> AnyTokenizer:
  51. """Get a tokenizer for a LoRA request."""
  52. pass
  53. def check_health(self):
  54. """Raise exception if the tokenizer group is unhealthy."""
  55. return