1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768 |
- from abc import ABC, abstractmethod
- from typing import List, Optional, Union
- from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
- from aphrodite.common.config import TokenizerPoolConfig
- from aphrodite.lora.request import LoRARequest
- AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
- class BaseTokenizerGroup(ABC):
- """A group of tokenizers that can be used for LoRA adapters."""
- @classmethod
- @abstractmethod
- def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
- **init_kwargs) -> "BaseTokenizerGroup":
- pass
- @abstractmethod
- def ping(self) -> bool:
- """Check if the tokenizer group is alive."""
- pass
- @abstractmethod
- def get_max_input_len(self,
- lora_request: Optional[LoRARequest] = None
- ) -> Optional[int]:
- """Get the maximum input length for the LoRA request."""
- pass
- @abstractmethod
- def encode(self,
- prompt: str,
- request_id: Optional[str] = None,
- lora_request: Optional[LoRARequest] = None) -> List[int]:
- """Encode a prompt using the tokenizer group."""
- pass
- @abstractmethod
- async def encode_async(
- self,
- prompt: str,
- request_id: Optional[str] = None,
- lora_request: Optional[LoRARequest] = None) -> List[int]:
- """Encode a prompt using the tokenizer group."""
- pass
- @abstractmethod
- def get_lora_tokenizer(
- self,
- lora_request: Optional[LoRARequest] = None,
- ) -> AnyTokenizer:
- """Get a tokenizer for a LoRA request."""
- pass
- @abstractmethod
- async def get_lora_tokenizer_async(
- self,
- lora_request: Optional[LoRARequest] = None,
- ) -> AnyTokenizer:
- """Get a tokenizer for a LoRA request."""
- pass
- def check_health(self):
- """Raise exception if the tokenizer group is unhealthy."""
- return
|