1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798 |
- from typing import List, Optional
- from aphrodite.common.config import TokenizerPoolConfig
- from aphrodite.common.utils import LRUCache
- from aphrodite.lora.request import LoRARequest
- from aphrodite.transformers_utils.tokenizer import (get_lora_tokenizer,
- get_lora_tokenizer_async,
- get_tokenizer)
- from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
- class TokenizerGroup(BaseTokenizerGroup):
- """A group of tokenizers that can be used for LoRA adapters."""
- def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
- max_input_length: Optional[int], **tokenizer_config):
- self.tokenizer_id = tokenizer_id
- self.tokenizer_config = tokenizer_config
- self.enable_lora = enable_lora
- self.max_input_length = max_input_length
- self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
- self.lora_tokenizers = LRUCache[AnyTokenizer](
- capacity=max_num_seqs if enable_lora else 0)
- @classmethod
- def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
- **init_kwargs) -> "TokenizerGroup":
- return cls(**init_kwargs)
- def ping(self) -> bool:
- """Check if the tokenizer group is alive."""
- return True
- def get_max_input_len(self,
- lora_request: Optional[LoRARequest] = None
- ) -> Optional[int]:
- """Get the maximum input length for the LoRA request."""
- return self.max_input_length
- def _raise_if_input_too_long(self,
- encoded_tokens: List[int],
- lora_request: Optional[LoRARequest] = None):
- input_length = len(encoded_tokens)
- if lora_request:
- max_input_length = (lora_request.long_lora_max_len
- or self.max_input_length)
- else:
- max_input_length = self.max_input_length
- if max_input_length is not None and input_length > max_input_length:
- raise ValueError("Input too long.", input_length, max_input_length)
- def encode(self,
- prompt: str,
- request_id: Optional[str] = None,
- lora_request: Optional[LoRARequest] = None) -> List[int]:
- tokenizer = self.get_lora_tokenizer(lora_request)
- ret = tokenizer.encode(prompt)
- self._raise_if_input_too_long(ret, lora_request)
- return ret
- async def encode_async(
- self,
- prompt: str,
- request_id: Optional[str] = None,
- lora_request: Optional[LoRARequest] = None) -> List[int]:
- tokenizer = await self.get_lora_tokenizer_async(lora_request)
- ret = tokenizer.encode(prompt)
- self._raise_if_input_too_long(ret, lora_request)
- return ret
- def get_lora_tokenizer(
- self,
- lora_request: Optional[LoRARequest] = None,
- ) -> AnyTokenizer:
- if not lora_request or not self.enable_lora:
- return self.tokenizer
- if lora_request.lora_int_id not in self.lora_tokenizers:
- tokenizer = (get_lora_tokenizer(
- lora_request, **self.tokenizer_config) or self.tokenizer)
- self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
- return tokenizer
- else:
- return self.lora_tokenizers[lora_request.lora_int_id]
- async def get_lora_tokenizer_async(
- self,
- lora_request: Optional[LoRARequest] = None,
- ) -> AnyTokenizer:
- if not lora_request or not self.enable_lora:
- return self.tokenizer
- if lora_request.lora_int_id not in self.lora_tokenizers:
- tokenizer = (await get_lora_tokenizer_async(
- lora_request, **self.tokenizer_config) or self.tokenizer)
- self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
- return tokenizer
- else:
- return self.lora_tokenizers[lora_request.lora_int_id]
|