1
0

tokenizer.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. from typing import List, Tuple, Union, Optional
  2. from transformers import (AutoTokenizer, PreTrainedTokenizer,
  3. PreTrainedTokenizerFast)
  4. from aphrodite.common.logger import init_logger
  5. from aphrodite.lora.request import LoRARequest
  6. from aphrodite.common.utils import make_async, LRUCache
  7. logger = init_logger(__name__)
  8. # A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file.
  9. _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
  10. def get_tokenizer(
  11. tokenizer_name: str,
  12. *args,
  13. tokenizer_mode: str = "auto",
  14. trust_remote_code: bool = False,
  15. **kwargs,
  16. ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
  17. """Gets a tokenizer for the given model name via Huggingface."""
  18. if tokenizer_mode == "slow":
  19. if kwargs.get("use_fast", False):
  20. raise ValueError(
  21. "Cannot use the fast tokenizer in slow tokenizer mode.")
  22. kwargs["use_fast"] = False
  23. if ("llama" in tokenizer_name.lower() and kwargs.get("use_fast", True)
  24. and tokenizer_name != _FAST_LLAMA_TOKENIZER):
  25. logger.info(
  26. "For some LLaMA V1 models, initializing the fast tokenizer may "
  27. "take a long time. To reduce the initialization time, consider "
  28. f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
  29. "tokenizer.")
  30. try:
  31. tokenizer = AutoTokenizer.from_pretrained(
  32. tokenizer_name,
  33. *args,
  34. trust_remote_code=trust_remote_code,
  35. **kwargs)
  36. except TypeError as e:
  37. # The LLaMA tokenizer causes a protobuf error in some environments.
  38. err_msg = (
  39. "Failed to load the tokenizer. If you are using a LLaMA V1 model "
  40. f"consider using '{_FAST_LLAMA_TOKENIZER}' instead of the "
  41. "original tokenizer.")
  42. raise RuntimeError(err_msg) from e
  43. except ValueError as e:
  44. # If the error pertains to the tokenizer class not existing or not
  45. # currently being imported, suggest using the --trust-remote-code flag.
  46. if (not trust_remote_code and
  47. ("does not exist or is not currently imported." in str(e)
  48. or "requires you to execute the tokenizer file" in str(e))):
  49. err_msg = (
  50. "Failed to load the tokenizer. If the tokenizer is a custom "
  51. "tokenizer not yet available in the HuggingFace transformers "
  52. "library, consider setting `trust_remote_code=True` in LLM "
  53. "or using the `--trust-remote-code` flag in the CLI.")
  54. raise RuntimeError(err_msg) from e
  55. else:
  56. raise e
  57. if not isinstance(tokenizer, PreTrainedTokenizerFast):
  58. logger.warning(
  59. "Using a slow tokenizer. This might cause a significant "
  60. "slowdown. Consider using a fast tokenizer instead.")
  61. return tokenizer
  62. def get_lora_tokenizer(lora_request: LoRARequest, *args,
  63. **kwargs) -> Optional[PreTrainedTokenizer]:
  64. if lora_request is None:
  65. return None
  66. try:
  67. tokenizer = get_tokenizer(lora_request.lora_local_path, *args,
  68. **kwargs)
  69. except OSError as e:
  70. # No tokenizer was found in the LoRA folder,
  71. # use base model tokenizer
  72. logger.warning(
  73. f"No tokenizer found in {lora_request.lora_local_path}, "
  74. "using base model tokenizer instead. "
  75. f"(Exception: {str(e)})")
  76. tokenizer = None
  77. return tokenizer
  78. get_lora_tokenizer_async = make_async(get_lora_tokenizer)
  79. class TokenizerGroup:
  80. """A group of tokenizers that can be used for LoRA adapters."""
  81. def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
  82. max_input_length: Optional[int], **tokenizer_config):
  83. self.tokenizer_id = tokenizer_id
  84. self.tokenizer_config = tokenizer_config
  85. self.enable_lora = enable_lora
  86. self.max_input_length = max_input_length
  87. self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
  88. if enable_lora:
  89. self.lora_tokenizers = LRUCache(capacity=max_num_seqs)
  90. else:
  91. self.lora_tokenizers = None
  92. def encode(
  93. self,
  94. prompt: str,
  95. request_id: Optional[str] = None, # pylint: disable=unused-argument
  96. lora_request: Optional[LoRARequest] = None
  97. ) -> List[int]:
  98. tokenizer = self.get_lora_tokenizer(lora_request)
  99. return tokenizer.encode(prompt)
  100. async def encode_async(
  101. self,
  102. prompt: str,
  103. request_id: Optional[str] = None, # pylint: disable=unused-argument
  104. lora_request: Optional[LoRARequest] = None
  105. ) -> List[int]:
  106. tokenizer = await self.get_lora_tokenizer_async(lora_request)
  107. return tokenizer.encode(prompt)
  108. def get_lora_tokenizer(
  109. self,
  110. lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
  111. if not lora_request or not self.enable_lora:
  112. return self.tokenizer
  113. if lora_request.lora_int_id not in self.lora_tokenizers:
  114. tokenizer = (get_lora_tokenizer(
  115. lora_request, **self.tokenizer_config) or self.tokenizer)
  116. self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
  117. return tokenizer
  118. else:
  119. return self.lora_tokenizers.get(lora_request.lora_int_id)
  120. async def get_lora_tokenizer_async(
  121. self,
  122. lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
  123. if not lora_request or not self.enable_lora:
  124. return self.tokenizer
  125. if lora_request.lora_int_id not in self.lora_tokenizers:
  126. tokenizer = (await get_lora_tokenizer_async(
  127. lora_request, **self.tokenizer_config) or self.tokenizer)
  128. self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
  129. return tokenizer
  130. else:
  131. return self.lora_tokenizers.get(lora_request.lora_int_id)
  132. def _convert_tokens_to_string_with_added_encoders(
  133. tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
  134. output_tokens: List[str],
  135. skip_special_tokens: bool,
  136. spaces_between_special_tokens: bool,
  137. ) -> str:
  138. # Adapted from
  139. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
  140. # NOTE: The following code is slow because it runs a for loop over
  141. # the output_tokens. In Python, running a for loop over a list can be slow
  142. # even when the loop body is very simple.
  143. sub_texts = []
  144. current_sub_text = []
  145. all_special_tokens = set(tokenizer.all_special_tokens)
  146. for token in output_tokens:
  147. if skip_special_tokens and token in all_special_tokens:
  148. continue
  149. if token in tokenizer.get_added_vocab():
  150. if current_sub_text:
  151. sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
  152. sub_texts.append(sub_text)
  153. current_sub_text = []
  154. sub_texts.append(token)
  155. else:
  156. current_sub_text.append(token)
  157. if current_sub_text:
  158. sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
  159. sub_texts.append(sub_text)
  160. if spaces_between_special_tokens:
  161. return " ".join(sub_texts)
  162. else:
  163. return "".join(sub_texts)
  164. # Based on
  165. # https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
  166. # under Apache 2.0 license
  167. def detokenize_incrementally(
  168. tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
  169. all_input_ids: List[int],
  170. prev_tokens: Optional[List[str]],
  171. prefix_offset: int = 0,
  172. read_offset: int = 0,
  173. skip_special_tokens: bool = False,
  174. spaces_between_special_tokens: bool = True,
  175. ) -> Tuple[List[str], str, int, int]:
  176. new_token_id = all_input_ids[-1]
  177. # This is the first iteration for this sequence
  178. if prev_tokens is None:
  179. new_tokens = tokenizer.convert_ids_to_tokens(
  180. all_input_ids, skip_special_tokens=skip_special_tokens)
  181. output_tokens = new_tokens
  182. # 5 is an arbitrary value that should work for all
  183. # tokenizers (bigger = more conservative).
  184. # Subtract 1 extra to account for the generated token.
  185. prefix_offset = max(len(output_tokens) - 6, 0)
  186. # If the first new token is a special token we can't skip 1 extra token
  187. if skip_special_tokens and new_token_id in tokenizer.all_special_ids:
  188. read_offset = max(len(output_tokens), 0)
  189. else:
  190. read_offset = max(len(output_tokens) - 1, 0)
  191. else:
  192. # Put new_token_id in a list so skip_special_tokens is respected
  193. new_tokens = tokenizer.convert_ids_to_tokens(
  194. [new_token_id], skip_special_tokens=skip_special_tokens)
  195. output_tokens = prev_tokens + new_tokens
  196. # The prefix text is necessary only to defeat cleanup algorithms in
  197. # the decode which decide to add a space or not depending on the
  198. # surrounding ids.
  199. if tokenizer.is_fast or not tokenizer.get_added_vocab():
  200. prefix_text = tokenizer.convert_tokens_to_string(
  201. output_tokens[prefix_offset:read_offset])
  202. new_text = tokenizer.convert_tokens_to_string(
  203. output_tokens[prefix_offset:])
  204. else:
  205. prefix_text = _convert_tokens_to_string_with_added_encoders(
  206. tokenizer,
  207. output_tokens[prefix_offset:read_offset],
  208. skip_special_tokens=skip_special_tokens,
  209. spaces_between_special_tokens=spaces_between_special_tokens)
  210. new_text = _convert_tokens_to_string_with_added_encoders(
  211. tokenizer,
  212. output_tokens[prefix_offset:],
  213. skip_special_tokens=skip_special_tokens,
  214. spaces_between_special_tokens=spaces_between_special_tokens)
  215. if len(new_text) > len(prefix_text) and not new_text.endswith("�"):
  216. # utf-8 char at the end means it's a potential unfinished byte sequence
  217. # from byte fallback tokenization.
  218. # If it's in the middle, it's probably a real invalid id generated
  219. # by the model
  220. new_text = new_text[len(prefix_text):]
  221. return new_tokens, new_text, read_offset, len(output_tokens)
  222. else:
  223. return new_tokens, "", prefix_offset, read_offset