tokenizer.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. import os
  2. import tempfile
  3. from typing import Optional, Union
  4. from transformers import (AutoTokenizer, PreTrainedTokenizer,
  5. PreTrainedTokenizerFast, LlamaTokenizer)
  6. from transformers.convert_slow_tokenizer import import_protobuf
  7. from loguru import logger
  8. from aphrodite.lora.request import LoRARequest
  9. from aphrodite.common.utils import make_async
  10. from aphrodite.common.gguf import GGUFReader
  11. from aphrodite.transformers_utils.tokenizers import BaichuanTokenizer
  12. def convert_gguf_to_tokenizer(checkpoint):
  13. if os.path.isfile(checkpoint):
  14. result = GGUFReader(checkpoint)
  15. elif os.path.isdir(checkpoint):
  16. try:
  17. return AutoTokenizer.from_pretrained(checkpoint)
  18. except Exception:
  19. pass
  20. all_gguf_files = sorted([
  21. file for file in os.listdir(checkpoint)
  22. if os.path.splitext(file)[-1].lower() == ".gguf"
  23. ])
  24. # assume the tokenizer is always in the first shard
  25. result = GGUFReader(os.path.join(checkpoint, all_gguf_files[0]))
  26. else:
  27. raise RuntimeError(f"Cannot find any tokenizer with `{checkpoint}`")
  28. logger.log_once("INFO", "Converting tokenizer from GGUF...")
  29. # write vocab
  30. sentencepiece_model_pb2 = import_protobuf()
  31. vocab = sentencepiece_model_pb2.ModelProto()
  32. vocab_size = len(result.fields['tokenizer.ggml.token_type'].data)
  33. vocab.trainer_spec.model_type = 2 # BPE
  34. vocab.trainer_spec.vocab_size = vocab_size
  35. vocab.trainer_spec.byte_fallback = True
  36. vocab.normalizer_spec.remove_extra_whitespaces = False
  37. tokens = result.fields['tokenizer.ggml.tokens']
  38. scores = result.fields['tokenizer.ggml.scores']
  39. types = result.fields['tokenizer.ggml.token_type']
  40. for i in range(vocab_size):
  41. new_token = vocab.SentencePiece()
  42. new_token.piece = str(bytes(tokens.parts[tokens.data[i]]),
  43. encoding='utf-8')
  44. new_token.score = scores.parts[scores.data[i]]
  45. # llama.cpp tokentype is the same with sentencepiece token type
  46. new_token.type = int(types.parts[types.data[i]])
  47. vocab.pieces.append(new_token)
  48. with tempfile.NamedTemporaryFile(mode='wb', delete=False) as temp_file:
  49. temp_file.write(vocab.SerializeToString())
  50. temp_file_filename = temp_file.name
  51. tokenizer_args = {"vocab_file": temp_file_filename}
  52. if 'tokenizer.ggml.bos_token_id' in result.fields:
  53. tokenizer_args["bos_token"] = vocab.pieces[int(
  54. result.fields['tokenizer.ggml.bos_token_id'].parts[-1])].piece
  55. if 'tokenizer.ggml.eos_token_id' in result.fields:
  56. tokenizer_args["eos_token"] = vocab.pieces[int(
  57. result.fields['tokenizer.ggml.eos_token_id'].parts[-1])].piece
  58. if 'tokenizer.ggml.padding_token_id' in result.fields:
  59. tokenizer_args["pad_token"] = vocab.pieces[int(
  60. result.fields['tokenizer.ggml.padding_token_id'].parts[-1])].piece
  61. if 'tokenizer.ggml.unknown_token_id' in result.fields:
  62. tokenizer_args["unk_token"] = vocab.pieces[int(
  63. result.fields['tokenizer.ggml.unknown_token_id'].parts[-1])].piece
  64. if 'tokenizer.ggml.add_bos_token' in result.fields:
  65. tokenizer_args["add_bos_token"] = bool(
  66. result.fields['tokenizer.ggml.add_bos_token'].parts[-1])
  67. if 'tokenizer.ggml.add_eos_token' in result.fields:
  68. tokenizer_args["add_eos_token"] = bool(
  69. result.fields['tokenizer.ggml.add_eos_token'].parts[-1])
  70. if 'tokenizer.chat_template' in result.fields:
  71. tokenizer_args["chat_template"] = str(
  72. bytes(result.fields['tokenizer.chat_template'].parts[-1]))
  73. tokenizer = LlamaTokenizer(**tokenizer_args)
  74. os.unlink(temp_file_filename)
  75. return tokenizer
  76. def get_cached_tokenizer(
  77. tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
  78. ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
  79. """Get tokenizer with cached properties.
  80. This will patch the tokenizer object in place.
  81. By default, transformers will recompute multiple tokenizer
  82. properties each time they are called, leading to a significant
  83. slowdown. This function caches these properties for faster
  84. access."""
  85. tokenizer_all_special_ids = set(tokenizer.all_special_ids)
  86. tokenizer_all_special_tokens_extended = (
  87. tokenizer.all_special_tokens_extended)
  88. tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
  89. tokenizer_len = len(tokenizer)
  90. class CachedTokenizer(tokenizer.__class__):
  91. @property
  92. def all_special_ids(self):
  93. return tokenizer_all_special_ids
  94. @property
  95. def all_special_tokens(self):
  96. return tokenizer_all_special_tokens
  97. @property
  98. def all_special_tokens_extended(self):
  99. return tokenizer_all_special_tokens_extended
  100. def __len__(self):
  101. return tokenizer_len
  102. CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"
  103. tokenizer.__class__ = CachedTokenizer
  104. return tokenizer
  105. def get_tokenizer(
  106. tokenizer_name: str,
  107. *args,
  108. tokenizer_mode: str = "auto",
  109. trust_remote_code: bool = False,
  110. tokenizer_revision: Optional[str] = None,
  111. **kwargs,
  112. ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
  113. """Gets a tokenizer for the given model name via Huggingface."""
  114. if tokenizer_name.endswith("gguf"):
  115. return convert_gguf_to_tokenizer(tokenizer_name)
  116. if tokenizer_mode == "slow":
  117. if kwargs.get("use_fast", False):
  118. raise ValueError(
  119. "Cannot use the fast tokenizer in slow tokenizer mode.")
  120. kwargs["use_fast"] = False
  121. try:
  122. tokenizer = AutoTokenizer.from_pretrained(
  123. tokenizer_name,
  124. *args,
  125. trust_remote_code=trust_remote_code,
  126. tokenizer_revision=tokenizer_revision,
  127. **kwargs)
  128. except ValueError as e:
  129. # If the error pertains to the tokenizer class not existing or not
  130. # currently being imported, suggest using the --trust-remote-code flag.
  131. if (not trust_remote_code and
  132. ("does not exist or is not currently imported." in str(e)
  133. or "requires you to execute the tokenizer file" in str(e))):
  134. err_msg = (
  135. "Failed to load the tokenizer. If the tokenizer is a custom "
  136. "tokenizer not yet available in the HuggingFace transformers "
  137. "library, consider setting `trust_remote_code=True` in LLM "
  138. "or using the `--trust-remote-code` flag in the CLI.")
  139. raise RuntimeError(err_msg) from e
  140. else:
  141. raise e
  142. except AttributeError as e:
  143. if "BaichuanTokenizer" in str(e):
  144. # This is for the error "'BaichuanTokenizer' object has no
  145. # attribute 'sp_model'".
  146. tokenizer = BaichuanTokenizer.from_pretrained(
  147. tokenizer_name,
  148. *args,
  149. trust_remote_code=trust_remote_code,
  150. tokenizer_revision=tokenizer_revision,
  151. **kwargs)
  152. else:
  153. raise e
  154. if not isinstance(tokenizer, PreTrainedTokenizerFast):
  155. logger.warning(
  156. "Using a slow tokenizer. This might cause a significant "
  157. "slowdown. Consider using a fast tokenizer instead.")
  158. return get_cached_tokenizer(tokenizer)
  159. def get_lora_tokenizer(lora_request: LoRARequest, *args,
  160. **kwargs) -> Optional[PreTrainedTokenizer]:
  161. if lora_request is None:
  162. return None
  163. try:
  164. tokenizer = get_tokenizer(lora_request.lora_local_path, *args,
  165. **kwargs)
  166. except OSError as e:
  167. # No tokenizer was found in the LoRA folder,
  168. # use base model tokenizer
  169. logger.warning(
  170. f"No tokenizer found in {lora_request.lora_local_path}, "
  171. "using base model tokenizer instead. "
  172. f"(Exception: {str(e)})")
  173. tokenizer = None
  174. return tokenizer
  175. get_lora_tokenizer_async = make_async(get_lora_tokenizer)