tokenizer_group.py 4.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. from typing import List, Optional
  2. from aphrodite.common.config import TokenizerPoolConfig
  3. from aphrodite.common.utils import LRUCache
  4. from aphrodite.lora.request import LoRARequest
  5. from aphrodite.transformers_utils.tokenizer import (get_lora_tokenizer,
  6. get_lora_tokenizer_async,
  7. get_tokenizer)
  8. from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
  9. class TokenizerGroup(BaseTokenizerGroup):
  10. """A group of tokenizers that can be used for LoRA adapters."""
  11. def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
  12. max_input_length: Optional[int], **tokenizer_config):
  13. self.tokenizer_id = tokenizer_id
  14. self.tokenizer_config = tokenizer_config
  15. self.enable_lora = enable_lora
  16. self.max_input_length = max_input_length
  17. self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
  18. self.lora_tokenizers = LRUCache[AnyTokenizer](
  19. capacity=max_num_seqs if enable_lora else 0)
  20. @classmethod
  21. def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
  22. **init_kwargs) -> "TokenizerGroup":
  23. return cls(**init_kwargs)
  24. def ping(self) -> bool:
  25. """Check if the tokenizer group is alive."""
  26. return True
  27. def get_max_input_len(self,
  28. lora_request: Optional[LoRARequest] = None
  29. ) -> Optional[int]:
  30. """Get the maximum input length for the LoRA request."""
  31. return self.max_input_length
  32. def _raise_if_input_too_long(self,
  33. encoded_tokens: List[int],
  34. lora_request: Optional[LoRARequest] = None):
  35. input_length = len(encoded_tokens)
  36. if lora_request:
  37. max_input_length = (lora_request.long_lora_max_len
  38. or self.max_input_length)
  39. else:
  40. max_input_length = self.max_input_length
  41. if max_input_length is not None and input_length > max_input_length:
  42. raise ValueError("Input too long.", input_length, max_input_length)
  43. def encode(self,
  44. prompt: str,
  45. request_id: Optional[str] = None,
  46. lora_request: Optional[LoRARequest] = None) -> List[int]:
  47. tokenizer = self.get_lora_tokenizer(lora_request)
  48. ret = tokenizer.encode(prompt)
  49. self._raise_if_input_too_long(ret, lora_request)
  50. return ret
  51. async def encode_async(
  52. self,
  53. prompt: str,
  54. request_id: Optional[str] = None,
  55. lora_request: Optional[LoRARequest] = None) -> List[int]:
  56. tokenizer = await self.get_lora_tokenizer_async(lora_request)
  57. ret = tokenizer.encode(prompt)
  58. self._raise_if_input_too_long(ret, lora_request)
  59. return ret
  60. def get_lora_tokenizer(
  61. self,
  62. lora_request: Optional[LoRARequest] = None,
  63. ) -> AnyTokenizer:
  64. if not lora_request or not self.enable_lora:
  65. return self.tokenizer
  66. if lora_request.lora_int_id not in self.lora_tokenizers:
  67. tokenizer = (get_lora_tokenizer(
  68. lora_request, **self.tokenizer_config) or self.tokenizer)
  69. self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
  70. return tokenizer
  71. else:
  72. return self.lora_tokenizers[lora_request.lora_int_id]
  73. async def get_lora_tokenizer_async(
  74. self,
  75. lora_request: Optional[LoRARequest] = None,
  76. ) -> AnyTokenizer:
  77. if not lora_request or not self.enable_lora:
  78. return self.tokenizer
  79. if lora_request.lora_int_id not in self.lora_tokenizers:
  80. tokenizer = (await get_lora_tokenizer_async(
  81. lora_request, **self.tokenizer_config) or self.tokenizer)
  82. self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
  83. return tokenizer
  84. else:
  85. return self.lora_tokenizers[lora_request.lora_int_id]