__init__.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. from typing import Optional, Type
  2. from aphrodite.common.config import TokenizerPoolConfig
  3. from aphrodite.transformers_utils.tokenizer_group.base_tokenizer_group import (
  4. BaseTokenizerGroup)
  5. from aphrodite.transformers_utils.tokenizer_group.tokenizer_group import (
  6. TokenizerGroup)
  7. from aphrodite.executor.ray_utils import ray
  8. if ray:
  9. from aphrodite.transformers_utils.tokenizer_group.ray_tokenizer_group import ( # noqa: E501
  10. RayTokenizerGroupPool)
  11. else:
  12. RayTokenizerGroupPool = None
  13. def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
  14. **init_kwargs) -> BaseTokenizerGroup:
  15. tokenizer_cls: Type[BaseTokenizerGroup]
  16. if tokenizer_pool_config is None:
  17. tokenizer_cls = TokenizerGroup
  18. elif isinstance(tokenizer_pool_config.pool_type, type) and issubclass(
  19. tokenizer_pool_config.pool_type, BaseTokenizerGroup):
  20. tokenizer_cls = tokenizer_pool_config.pool_type
  21. elif tokenizer_pool_config.pool_type == "ray":
  22. if RayTokenizerGroupPool is None:
  23. raise ImportError(
  24. "RayTokenizerGroupPool is not available. Please install "
  25. "the ray package to use the Ray tokenizer group pool.")
  26. tokenizer_cls = RayTokenizerGroupPool
  27. else:
  28. raise ValueError(
  29. f"Unknown pool type: {tokenizer_pool_config.pool_type}")
  30. return tokenizer_cls.from_config(tokenizer_pool_config, **init_kwargs)
  31. __all__ = ["get_tokenizer_group", "BaseTokenizerGroup"]