detokenizer.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. from typing import Dict, List, Optional, Tuple, Union
  2. from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
  3. from aphrodite.common.sequence import (Logprob, SamplingParams, Sequence,
  4. SequenceGroup)
  5. from aphrodite.transformers_utils.tokenizer_group.base_tokenizer_group import (
  6. BaseTokenizerGroup)
  7. # Used eg. for marking rejected tokens in spec decoding.
  8. INVALID_TOKEN_ID = -1
  9. class Detokenizer:
  10. """Provides methods to decode the output of a model into text."""
  11. def __init__(self, tokenizer_group: BaseTokenizerGroup):
  12. self.tokenizer_group = tokenizer_group
  13. def get_tokenizer_for_seq(self,
  14. sequence: Sequence) -> "PreTrainedTokenizer":
  15. """Returns the HF tokenizer to use for a given sequence."""
  16. return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request)
  17. def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup,
  18. prompt_logprobs: List[Optional[Dict[
  19. int, Logprob]]],
  20. position_offset: int) -> None:
  21. """Decodes the logprobs for the prompt of a sequence group.
  22. Args:
  23. seq_group: The sequence group to decode.
  24. prompt_logprobs: The logprobs to decode.
  25. position_offset: Offset of the first index of the logprobs
  26. relative to the start of the sequence (for chunked prefill).
  27. Returns:
  28. The prompt logprobs with the decoded tokens.
  29. """
  30. prms = seq_group.sampling_params
  31. # We can pick any sequence for the prompt.
  32. seq = seq_group.get_seqs()[0]
  33. # Only prompt, without the generated token.
  34. all_token_ids = seq.get_token_ids()
  35. prompt_token_ids = all_token_ids[:-1]
  36. tokenizer = self.get_tokenizer_for_seq(seq)
  37. prefix_offset = 0
  38. read_offset = 0
  39. next_iter_prefix_offset = 0
  40. next_iter_read_offset = 0
  41. next_iter_tokens = []
  42. prev_tokens = None
  43. for token_position_in_logprob, prompt_logprobs_for_token in enumerate(
  44. prompt_logprobs):
  45. # Absolute token position equals the index in the logprobs
  46. # list plus the offset of the entire logprobs list relative
  47. # to the start of the sequence.
  48. token_position = token_position_in_logprob + position_offset
  49. if not prompt_logprobs_for_token:
  50. continue
  51. for token_id, sample_logprob in prompt_logprobs_for_token.items():
  52. if (sample_logprob.decoded_token is None
  53. and token_id != INVALID_TOKEN_ID):
  54. prompt_token_ids_with_token = (
  55. prompt_token_ids[:token_position] + [token_id])
  56. (new_tokens, new_text, new_prefix_offset,
  57. new_read_offset) = detokenize_incrementally(
  58. tokenizer=tokenizer,
  59. all_input_ids=prompt_token_ids_with_token,
  60. prev_tokens=prev_tokens,
  61. prefix_offset=prefix_offset,
  62. read_offset=read_offset,
  63. skip_special_tokens=prms.skip_special_tokens,
  64. spaces_between_special_tokens=prms.
  65. spaces_between_special_tokens,
  66. )
  67. sample_logprob.decoded_token = new_text
  68. # Use the offsets & prev tokens corresponding to
  69. # real tokens to ensure detokenization is consistent
  70. # actual with prompt.
  71. if token_id == all_token_ids[token_position]:
  72. next_iter_prefix_offset = new_prefix_offset
  73. next_iter_read_offset = new_read_offset
  74. next_iter_tokens = new_tokens
  75. # Advance to the next token position.
  76. prefix_offset = next_iter_prefix_offset
  77. read_offset = next_iter_read_offset
  78. if prev_tokens is None:
  79. prev_tokens = next_iter_tokens
  80. else:
  81. prev_tokens.extend(next_iter_tokens)
  82. def decode_sequence_inplace(self, seq: Sequence,
  83. prms: SamplingParams) -> int:
  84. """Decodes the new token for a sequence. In-place operation.
  85. Args:
  86. seq: The sequence to decode.
  87. prms: The sampling parameters used to generate the sequence.
  88. Returns:
  89. The number of characters added to the output text.
  90. """
  91. all_input_ids = seq.get_token_ids()
  92. token_id_generated_this_iteration = all_input_ids[-1]
  93. tokenizer = self.get_tokenizer_for_seq(seq)
  94. # Convert prompt token IDs to tokens if necessary.
  95. # Do it here so that we don't have to repeat this
  96. # computation for each logprob.
  97. if seq.tokens is None:
  98. (seq.tokens, seq.prefix_offset,
  99. seq.read_offset) = convert_prompt_ids_to_tokens(
  100. tokenizer=tokenizer,
  101. prompt_ids=all_input_ids[:-1],
  102. skip_special_tokens=prms.skip_special_tokens,
  103. )
  104. (new_tokens, new_decoded_token_text, prefix_offset,
  105. read_offset) = detokenize_incrementally(
  106. tokenizer=tokenizer,
  107. all_input_ids=all_input_ids,
  108. prev_tokens=seq.tokens,
  109. prefix_offset=seq.prefix_offset,
  110. read_offset=seq.read_offset,
  111. skip_special_tokens=prms.skip_special_tokens,
  112. spaces_between_special_tokens=prms.spaces_between_special_tokens,
  113. )
  114. # Decode logprobs
  115. logprobs = seq.output_logprobs[-1]
  116. if logprobs:
  117. previous_tokens = all_input_ids[:-1]
  118. for token_id, sample_logprob in logprobs.items():
  119. # If the token was generated this iteration,
  120. # use the provided text.
  121. if token_id == token_id_generated_this_iteration:
  122. sample_logprob.decoded_token = new_decoded_token_text
  123. continue
  124. if (sample_logprob.decoded_token is None
  125. and token_id != INVALID_TOKEN_ID):
  126. all_input_ids_with_logprob = previous_tokens + [token_id]
  127. (_, new_text, _, _) = detokenize_incrementally(
  128. tokenizer=tokenizer,
  129. all_input_ids=all_input_ids_with_logprob,
  130. prev_tokens=seq.tokens,
  131. prefix_offset=seq.prefix_offset,
  132. read_offset=seq.read_offset,
  133. skip_special_tokens=prms.skip_special_tokens,
  134. spaces_between_special_tokens=prms.
  135. spaces_between_special_tokens,
  136. )
  137. sample_logprob.decoded_token = new_text
  138. seq.tokens.extend(new_tokens)
  139. seq.prefix_offset = prefix_offset
  140. seq.read_offset = read_offset
  141. seq.output_text += new_decoded_token_text
  142. return len(new_decoded_token_text)
  143. def _replace_none_with_empty(tokens: List[Optional[str]]):
  144. for i, token in enumerate(tokens):
  145. if token is None:
  146. tokens[i] = ""
  147. def _convert_tokens_to_string_with_added_encoders(
  148. tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
  149. output_tokens: List[str],
  150. skip_special_tokens: bool,
  151. spaces_between_special_tokens: bool,
  152. ) -> str:
  153. # Adapted from
  154. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
  155. # NOTE(woosuk): The following code is slow because it runs a for loop over
  156. # the output_tokens. In Python, running a for loop over a list can be slow
  157. # even when the loop body is very simple.
  158. sub_texts: List[str] = []
  159. current_sub_text: List[str] = []
  160. all_special_tokens = set(tokenizer.all_special_tokens)
  161. for token in output_tokens:
  162. if skip_special_tokens and token in all_special_tokens:
  163. continue
  164. if token in tokenizer.get_added_vocab():
  165. if current_sub_text:
  166. sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
  167. sub_texts.append(sub_text)
  168. current_sub_text = []
  169. sub_texts.append(token)
  170. else:
  171. current_sub_text.append(token)
  172. if current_sub_text:
  173. sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
  174. sub_texts.append(sub_text)
  175. if spaces_between_special_tokens:
  176. return " ".join(sub_texts)
  177. else:
  178. return "".join(sub_texts)
  179. # 5 is an arbitrary value that should work for all
  180. # tokenizers (bigger = more conservative).
  181. INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5
  182. def convert_prompt_ids_to_tokens(
  183. tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
  184. prompt_ids: List[int],
  185. skip_special_tokens: bool = False,
  186. ) -> Tuple[List[str], int, int]:
  187. """Converts the prompt ids to tokens and returns the tokens and offsets
  188. for incremental detokenization.
  189. Note that not all tokens are converted to strings. Only the tokens that
  190. are necessary for incremental detokenization are converted to strings.
  191. """
  192. # We do not need to convert the whole prompt to tokens.
  193. # Offset a little more in case we have special tokens.
  194. new_tokens = tokenizer.convert_ids_to_tokens(
  195. prompt_ids[-INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2:],
  196. skip_special_tokens=skip_special_tokens)
  197. read_offset = len(new_tokens)
  198. prefix_offset = max(
  199. read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
  200. # This is required to guard against out-of-vocab prompt token ids
  201. _replace_none_with_empty(new_tokens) # type: ignore[arg-type]
  202. return new_tokens, prefix_offset, read_offset
  203. # Based on
  204. # https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
  205. # under Apache 2.0 license
  206. def detokenize_incrementally(
  207. tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
  208. all_input_ids: List[int],
  209. prev_tokens: Optional[List[str]],
  210. prefix_offset: int,
  211. read_offset: int,
  212. skip_special_tokens: bool = False,
  213. spaces_between_special_tokens: bool = True,
  214. ) -> Tuple[List[str], str, int, int]:
  215. """Detokenizes the input ids incrementally and returns the new tokens
  216. and the new text.
  217. If `prev_tokens` is None, this function will convert the input ids to
  218. tokens and return the tokens and the new text. Otherwise, it will return the
  219. new tokens and the new text.
  220. This function will also return the new prefix offset and the new read
  221. offset to be used in the next iteration.
  222. The offsets are necessary to defeat cleanup algorithms in the decode which
  223. decide to add a space or not depending on the surrounding ids.
  224. Args:
  225. tokenizer: The tokenizer to use.
  226. all_input_ids: The input ids. The last id is the new token id.
  227. prev_tokens: The previous tokens. If None, this function will convert
  228. the input ids to tokens and return the tokens and the new text.
  229. prefix_offset: The prefix offset.
  230. read_offset: The read offset.
  231. skip_special_tokens: Whether to skip special tokens.
  232. spaces_between_special_tokens: Whether to add spaces between special
  233. tokens.
  234. """
  235. new_token_id = all_input_ids[-1]
  236. # This is the first iteration for this sequence
  237. is_first_iter = prev_tokens is None
  238. if is_first_iter:
  239. (prev_tokens, prefix_offset,
  240. read_offset) = convert_prompt_ids_to_tokens(
  241. tokenizer,
  242. all_input_ids[:-1],
  243. skip_special_tokens=skip_special_tokens)
  244. assert prev_tokens is not None
  245. # If the new token id is out of bounds, return an empty string.
  246. if new_token_id >= len(tokenizer):
  247. new_tokens = [""]
  248. else:
  249. # Put new_token_id in a list so skip_special_tokens is respected
  250. new_tokens = tokenizer.convert_ids_to_tokens(
  251. [new_token_id], skip_special_tokens=skip_special_tokens)
  252. if isinstance(new_tokens, str):
  253. new_tokens = [new_tokens]
  254. output_tokens = prev_tokens + new_tokens
  255. # If this is the first iteration, return all tokens.
  256. if is_first_iter:
  257. new_tokens = output_tokens
  258. # The prefix text is necessary only to defeat cleanup algorithms in
  259. # the decode which decide to add a space or not depending on the
  260. # surrounding ids.
  261. if tokenizer.is_fast or not tokenizer.get_added_vocab():
  262. prefix_text = tokenizer.convert_tokens_to_string(
  263. output_tokens[prefix_offset:read_offset])
  264. new_text = tokenizer.convert_tokens_to_string(
  265. output_tokens[prefix_offset:])
  266. else:
  267. prefix_text = _convert_tokens_to_string_with_added_encoders(
  268. tokenizer,
  269. output_tokens[prefix_offset:read_offset],
  270. skip_special_tokens=skip_special_tokens,
  271. spaces_between_special_tokens=spaces_between_special_tokens,
  272. )
  273. new_text = _convert_tokens_to_string_with_added_encoders(
  274. tokenizer,
  275. output_tokens[prefix_offset:],
  276. skip_special_tokens=skip_special_tokens,
  277. spaces_between_special_tokens=spaces_between_special_tokens,
  278. )
  279. if len(new_text) <= len(prefix_text) or new_text.endswith("�"):
  280. # utf-8 char at the end means it's a potential unfinished byte sequence
  281. # from byte fallback tokenization.
  282. # If it's in the middle, it's probably a real invalid id generated
  283. # by the model
  284. return new_tokens, "", prefix_offset, read_offset
  285. new_text = new_text[len(prefix_text):]
  286. return new_tokens, new_text, read_offset, len(output_tokens)