from abc import ABC, abstractmethod from typing import List, Optional, Union from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from aphrodite.common.config import TokenizerPoolConfig from aphrodite.lora.request import LoRARequest AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] class BaseTokenizerGroup(ABC): """A group of tokenizers that can be used for LoRA adapters.""" @classmethod @abstractmethod def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig], **init_kwargs) -> "BaseTokenizerGroup": pass @abstractmethod def ping(self) -> bool: """Check if the tokenizer group is alive.""" pass @abstractmethod def get_max_input_len(self, lora_request: Optional[LoRARequest] = None ) -> Optional[int]: """Get the maximum input length for the LoRA request.""" pass @abstractmethod def encode(self, prompt: str, request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None) -> List[int]: """Encode a prompt using the tokenizer group.""" pass @abstractmethod async def encode_async( self, prompt: str, request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None) -> List[int]: """Encode a prompt using the tokenizer group.""" pass @abstractmethod def get_lora_tokenizer( self, lora_request: Optional[LoRARequest] = None, ) -> AnyTokenizer: """Get a tokenizer for a LoRA request.""" pass @abstractmethod async def get_lora_tokenizer_async( self, lora_request: Optional[LoRARequest] = None, ) -> AnyTokenizer: """Get a tokenizer for a LoRA request.""" pass def check_health(self): """Raise exception if the tokenizer group is unhealthy.""" return