123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566 |
- from abc import ABC, abstractmethod
- from typing import List, Optional
- from transformers import PreTrainedTokenizer
- from aphrodite.common.config import TokenizerPoolConfig
- from aphrodite.lora.request import LoRARequest
- class BaseTokenizerGroup(ABC):
- """A group of tokenizers that can be used for LoRA adapters."""
- @classmethod
- @abstractmethod
- def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
- **init_kwargs) -> "BaseTokenizerGroup":
- pass
- @abstractmethod
- def ping(self) -> bool:
- """Check if the tokenizer group is alive."""
- pass
- @abstractmethod
- def get_max_input_len(self,
- lora_request: Optional[LoRARequest] = None
- ) -> Optional[int]:
- """Get the maximum input length for the LoRA request."""
- pass
- @abstractmethod
- def encode(self,
- prompt: str,
- request_id: Optional[str] = None,
- lora_request: Optional[LoRARequest] = None) -> List[int]:
- """Encode a prompt using the tokenizer group."""
- pass
- @abstractmethod
- async def encode_async(
- self,
- prompt: str,
- request_id: Optional[str] = None,
- lora_request: Optional[LoRARequest] = None) -> List[int]:
- """Encode a prompt using the tokenizer group."""
- pass
- @abstractmethod
- def get_lora_tokenizer(
- self,
- lora_request: Optional[LoRARequest] = None
- ) -> "PreTrainedTokenizer":
- """Get a tokenizer for a LoRA request."""
- pass
- @abstractmethod
- async def get_lora_tokenizer_async(
- self,
- lora_request: Optional[LoRARequest] = None
- ) -> "PreTrainedTokenizer":
- """Get a tokenizer for a LoRA request."""
- pass
- def check_health(self):
- """Raise exception if the tokenizer group is unhealthy."""
- return
|