1
0

tokenizer.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import os
  2. import warnings
  3. from pathlib import Path
  4. from typing import Optional, Union
  5. import huggingface_hub
  6. from loguru import logger
  7. from transformers import (AutoTokenizer, PreTrainedTokenizer,
  8. PreTrainedTokenizerFast)
  9. from aphrodite.common.envs import APHRODITE_USE_MODELSCOPE
  10. from aphrodite.common.utils import make_async
  11. from aphrodite.lora.request import LoRARequest
  12. from aphrodite.transformers_utils.tokenizers import (BaichuanTokenizer,
  13. MistralTokenizer)
  14. from aphrodite.transformers_utils.utils import check_gguf_file
  15. AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
  16. MistralTokenizer]
  17. def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
  18. """Get tokenizer with cached properties.
  19. This will patch the tokenizer object in place.
  20. By default, transformers will recompute multiple tokenizer properties
  21. each time they are called, leading to a significant slowdown. This
  22. function caches these properties for faster access."""
  23. tokenizer_all_special_ids = set(tokenizer.all_special_ids)
  24. tokenizer_all_special_tokens_extended = (
  25. tokenizer.all_special_tokens_extended)
  26. tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
  27. tokenizer_len = len(tokenizer)
  28. class CachedTokenizer(tokenizer.__class__): # type: ignore
  29. @property
  30. def all_special_ids(self):
  31. return tokenizer_all_special_ids
  32. @property
  33. def all_special_tokens(self):
  34. return tokenizer_all_special_tokens
  35. @property
  36. def all_special_tokens_extended(self):
  37. return tokenizer_all_special_tokens_extended
  38. def __len__(self):
  39. return tokenizer_len
  40. CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"
  41. tokenizer.__class__ = CachedTokenizer
  42. return tokenizer
  43. def get_tokenizer(
  44. tokenizer_name: Union[str, Path],
  45. *args,
  46. tokenizer_mode: str = "auto",
  47. trust_remote_code: bool = False,
  48. revision: Optional[str] = None,
  49. download_dir: Optional[str] = None,
  50. **kwargs,
  51. ) -> AnyTokenizer:
  52. """Gets a tokenizer for the given model name via HuggingFace or ModelScope.
  53. """
  54. if APHRODITE_USE_MODELSCOPE:
  55. # download model from ModelScope hub,
  56. # lazy import so that modelscope is not required for normal use.
  57. # pylint: disable=C.
  58. from modelscope.hub.snapshot_download import snapshot_download
  59. # Only set the tokenizer here, model will be downloaded on the workers.
  60. if not os.path.exists(tokenizer_name):
  61. tokenizer_path = snapshot_download(
  62. model_id=tokenizer_name,
  63. cache_dir=download_dir,
  64. revision=revision,
  65. local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
  66. # Ignore weights - we only need the tokenizer.
  67. ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
  68. tokenizer_name = tokenizer_path
  69. if tokenizer_mode == "slow":
  70. if kwargs.get("use_fast", False):
  71. raise ValueError(
  72. "Cannot use the fast tokenizer in slow tokenizer mode.")
  73. kwargs["use_fast"] = False
  74. if "truncation_side" not in kwargs:
  75. kwargs["truncation_side"] = "left"
  76. # Separate model folder from file path for GGUF models
  77. is_gguf = check_gguf_file(tokenizer_name)
  78. if is_gguf:
  79. kwargs["gguf_file"] = Path(tokenizer_name).name
  80. tokenizer_name = Path(tokenizer_name).parent
  81. # if tokenizer is from official mistral org
  82. is_from_mistral_org = str(tokenizer_name).split("/")[0] == "mistralai"
  83. if is_from_mistral_org and tokenizer_mode != "mistral":
  84. warnings.warn(
  85. 'It is strongly recommended to run mistral models with '
  86. '`--tokenizer_mode "mistral"` to ensure correct '
  87. 'encoding and decoding.',
  88. FutureWarning,
  89. stacklevel=2)
  90. if tokenizer_mode == "mistral":
  91. tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name),
  92. revision=revision)
  93. else:
  94. try:
  95. tokenizer = AutoTokenizer.from_pretrained(
  96. tokenizer_name,
  97. *args,
  98. trust_remote_code=trust_remote_code,
  99. revision=revision,
  100. **kwargs,
  101. )
  102. except ValueError as e:
  103. # If the error pertains to the tokenizer class not existing or not
  104. # currently being imported,
  105. # suggest using the --trust-remote-code flag.
  106. if not trust_remote_code and (
  107. "does not exist or is not currently imported." in str(e)
  108. or "requires you to execute the tokenizer file" in str(e)):
  109. err_msg = ("Failed to load the tokenizer. If the tokenizer "
  110. "is a custom tokenizer not yet available in the "
  111. "HuggingFace transformers library, consider "
  112. "setting `trust_remote_code=True` in LLM or using "
  113. "the `--trust-remote-code` flag in the CLI.")
  114. raise RuntimeError(err_msg) from e
  115. else:
  116. raise e
  117. except AttributeError as e:
  118. if "BaichuanTokenizer" in str(e):
  119. # This is for the error "'BaichuanTokenizer' object has no
  120. # attribute 'sp_model'".
  121. tokenizer = BaichuanTokenizer.from_pretrained(
  122. tokenizer_name,
  123. *args,
  124. trust_remote_code=trust_remote_code,
  125. revision=revision,
  126. **kwargs,
  127. )
  128. else:
  129. raise e
  130. if not isinstance(tokenizer, PreTrainedTokenizerFast):
  131. logger.warning(
  132. "Using a slow tokenizer. This might cause a significant "
  133. "slowdown. Consider using a fast tokenizer instead.")
  134. tokenizer = get_cached_tokenizer(tokenizer)
  135. return tokenizer
  136. def get_lora_tokenizer(lora_request: LoRARequest, *args,
  137. **kwargs) -> Optional[AnyTokenizer]:
  138. if lora_request is None:
  139. return None
  140. try:
  141. tokenizer = get_tokenizer(lora_request.lora_path, *args, **kwargs)
  142. except OSError as e:
  143. # No tokenizer was found in the LoRA folder,
  144. # use base model tokenizer
  145. logger.warning(f"No tokenizer found in {lora_request.lora_path}, "
  146. "using base model tokenizer instead. "
  147. f"(Exception: {str(e)})")
  148. tokenizer = None
  149. return tokenizer
  150. get_lora_tokenizer_async = make_async(get_lora_tokenizer)