__init__.py 1.3 KB

123456789101112131415161718192021222324252627282930313233
  1. from typing import Optional
  2. from aphrodite.common.config import TokenizerPoolConfig
  3. from aphrodite.transformers_utils.tokenizer_group.base_tokenizer_group import (
  4. BaseTokenizerGroup)
  5. from aphrodite.transformers_utils.tokenizer_group.tokenizer_group import (
  6. TokenizerGroup)
  7. from aphrodite.engine.ray_tools import ray
  8. if ray:
  9. from aphrodite.transformers_utils.tokenizer_group.ray_tokenizer_group import ( # noqa: E501
  10. RayTokenizerGroupPool)
  11. else:
  12. RayTokenizerGroupPool = None
  13. def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
  14. **init_kwargs) -> BaseTokenizerGroup:
  15. if tokenizer_pool_config is None:
  16. return TokenizerGroup(**init_kwargs)
  17. if tokenizer_pool_config.pool_type == "ray":
  18. if RayTokenizerGroupPool is None:
  19. raise ImportError(
  20. "RayTokenizerGroupPool is not available. Please install "
  21. "the ray package to use the Ray tokenizer group pool.")
  22. return RayTokenizerGroupPool.from_config(tokenizer_pool_config,
  23. **init_kwargs)
  24. else:
  25. raise ValueError(f"Unknown tokenizer pool type: "
  26. f"{tokenizer_pool_config.pool_type}")
  27. __all__ = ["get_tokenizer_group", "BaseTokenizerGroup"]