1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253 |
- from typing import Optional, Type
- from aphrodite.common.config import (ModelConfig, ParallelConfig,
- SchedulerConfig, TokenizerPoolConfig)
- from aphrodite.executor.ray_utils import ray
- from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
- from .tokenizer_group import TokenizerGroup
- if ray:
- from aphrodite.transformers_utils.tokenizer_group.ray_tokenizer_group import ( # noqa E501
- RayTokenizerGroupPool)
- else:
- RayTokenizerGroupPool = None # type: ignore
- def init_tokenizer_from_configs(model_config: ModelConfig,
- scheduler_config: SchedulerConfig,
- parallel_config: ParallelConfig,
- enable_lora: bool):
- init_kwargs = dict(tokenizer_id=model_config.tokenizer,
- enable_lora=enable_lora,
- max_num_seqs=scheduler_config.max_num_seqs,
- max_input_length=None,
- tokenizer_mode=model_config.tokenizer_mode,
- trust_remote_code=model_config.trust_remote_code,
- revision=model_config.tokenizer_revision)
- return get_tokenizer_group(parallel_config.tokenizer_pool_config,
- **init_kwargs)
- def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
- **init_kwargs) -> BaseTokenizerGroup:
- tokenizer_cls: Type[BaseTokenizerGroup]
- if tokenizer_pool_config is None:
- tokenizer_cls = TokenizerGroup
- elif isinstance(tokenizer_pool_config.pool_type, type) and issubclass(
- tokenizer_pool_config.pool_type, BaseTokenizerGroup):
- tokenizer_cls = tokenizer_pool_config.pool_type
- elif tokenizer_pool_config.pool_type == "ray":
- if RayTokenizerGroupPool is None:
- raise ImportError(
- "RayTokenizerGroupPool is not available. Please install "
- "the ray package to use the Ray tokenizer group pool.")
- tokenizer_cls = RayTokenizerGroupPool
- else:
- raise ValueError(
- f"Unknown pool type: {tokenizer_pool_config.pool_type}")
- return tokenizer_cls.from_config(tokenizer_pool_config, **init_kwargs)
- __all__ = ["AnyTokenizer", "get_tokenizer_group", "BaseTokenizerGroup"]
|