__init__.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. from typing import Optional, Type
  2. from aphrodite.common.config import (ModelConfig, ParallelConfig,
  3. SchedulerConfig, TokenizerPoolConfig)
  4. from aphrodite.executor.ray_utils import ray
  5. from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
  6. from .tokenizer_group import TokenizerGroup
  7. if ray:
  8. from aphrodite.transformers_utils.tokenizer_group.ray_tokenizer_group import ( # noqa E501
  9. RayTokenizerGroupPool)
  10. else:
  11. RayTokenizerGroupPool = None # type: ignore
  12. def init_tokenizer_from_configs(model_config: ModelConfig,
  13. scheduler_config: SchedulerConfig,
  14. parallel_config: ParallelConfig,
  15. enable_lora: bool):
  16. init_kwargs = dict(tokenizer_id=model_config.tokenizer,
  17. enable_lora=enable_lora,
  18. max_num_seqs=scheduler_config.max_num_seqs,
  19. max_input_length=None,
  20. tokenizer_mode=model_config.tokenizer_mode,
  21. trust_remote_code=model_config.trust_remote_code,
  22. revision=model_config.tokenizer_revision)
  23. return get_tokenizer_group(parallel_config.tokenizer_pool_config,
  24. **init_kwargs)
  25. def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
  26. **init_kwargs) -> BaseTokenizerGroup:
  27. tokenizer_cls: Type[BaseTokenizerGroup]
  28. if tokenizer_pool_config is None:
  29. tokenizer_cls = TokenizerGroup
  30. elif isinstance(tokenizer_pool_config.pool_type, type) and issubclass(
  31. tokenizer_pool_config.pool_type, BaseTokenizerGroup):
  32. tokenizer_cls = tokenizer_pool_config.pool_type
  33. elif tokenizer_pool_config.pool_type == "ray":
  34. if RayTokenizerGroupPool is None:
  35. raise ImportError(
  36. "RayTokenizerGroupPool is not available. Please install "
  37. "the ray package to use the Ray tokenizer group pool.")
  38. tokenizer_cls = RayTokenizerGroupPool
  39. else:
  40. raise ValueError(
  41. f"Unknown pool type: {tokenizer_pool_config.pool_type}")
  42. return tokenizer_cls.from_config(tokenizer_pool_config, **init_kwargs)
  43. __all__ = ["AnyTokenizer", "get_tokenizer_group", "BaseTokenizerGroup"]