detokenizer.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. from typing import List, Dict, Optional, Tuple, Union
  2. from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
  3. from aphrodite.common.sequence import (
  4. Sequence,
  5. Logprob,
  6. SequenceGroup,
  7. SamplingParams,
  8. )
  9. from aphrodite.transformers_utils.tokenizer_group.base_tokenizer_group import (
  10. BaseTokenizerGroup, )
  11. # Used eg. for marking rejected tokens in spec decoding.
  12. INVALID_TOKEN_ID = -1
  13. class Detokenizer:
  14. """Provides methods to decode the output of a model into text."""
  15. def __init__(self, tokenizer_group: BaseTokenizerGroup):
  16. self.tokenizer_group = tokenizer_group
  17. def get_tokenizer_for_seq(self,
  18. sequence: Sequence) -> "PreTrainedTokenizer":
  19. """Returns the HF tokenizer to use for a given sequence."""
  20. return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request)
  21. def decode_prompt_logprobs_inplace(
  22. self,
  23. seq_group: SequenceGroup,
  24. prompt_logprobs: List[Optional[Dict[int, Logprob]]],
  25. ) -> None:
  26. """Decodes the logprobs for the prompt of a sequence group.
  27. Args:
  28. seq_group: The sequence group to decode.
  29. prompt_logprobs: The logprobs to decode.
  30. Returns:
  31. The prompt logprobs with the decoded tokens.
  32. """
  33. prms = seq_group.sampling_params
  34. # We can pick any sequence for the prompt.
  35. seq = next(iter(seq_group.seqs_dict.values()))
  36. # Only prompt, without the generated token.
  37. all_token_ids = seq.get_token_ids()
  38. prompt_token_ids = all_token_ids[:-1]
  39. tokenizer = self.get_tokenizer_for_seq(seq)
  40. prefix_offset = 0
  41. read_offset = 0
  42. next_iter_prefix_offset = 0
  43. next_iter_read_offset = 0
  44. next_iter_tokens = []
  45. prev_tokens = None
  46. for token_position, prompt_logprobs_for_token in enumerate(
  47. prompt_logprobs):
  48. if not prompt_logprobs_for_token:
  49. continue
  50. for token_id, sample_logprob in prompt_logprobs_for_token.items():
  51. if (sample_logprob.decoded_token is None
  52. and token_id != INVALID_TOKEN_ID):
  53. prompt_token_ids_with_token = prompt_token_ids[:
  54. token_position] + [ # noqa: E501
  55. token_id
  56. ]
  57. (
  58. new_tokens,
  59. new_text,
  60. new_prefix_offset,
  61. new_read_offset,
  62. ) = detokenize_incrementally(
  63. tokenizer=tokenizer,
  64. all_input_ids=prompt_token_ids_with_token,
  65. prev_tokens=prev_tokens,
  66. prefix_offset=prefix_offset,
  67. read_offset=read_offset,
  68. skip_special_tokens=prms.skip_special_tokens,
  69. spaces_between_special_tokens=prms.
  70. spaces_between_special_tokens,
  71. )
  72. sample_logprob.decoded_token = new_text
  73. # Use the offsets & prev tokens corresponding to
  74. # real tokens to ensure detokenization is consistent
  75. # actual with prompt.
  76. if token_id == all_token_ids[token_position]:
  77. next_iter_prefix_offset = new_prefix_offset
  78. next_iter_read_offset = new_read_offset
  79. next_iter_tokens = new_tokens
  80. # Advance to the next token position.
  81. prefix_offset = next_iter_prefix_offset
  82. read_offset = next_iter_read_offset
  83. if prev_tokens is None:
  84. prev_tokens = next_iter_tokens
  85. else:
  86. prev_tokens.extend(next_iter_tokens)
  87. def decode_sequence_inplace(self, seq: Sequence,
  88. prms: SamplingParams) -> int:
  89. """Decodes the new token for a sequence. In-place operation.
  90. Args:
  91. seq: The sequence to decode.
  92. prms: The sampling parameters used to generate the sequence.
  93. Returns:
  94. The number of characters added to the output text.
  95. """
  96. all_input_ids = seq.get_token_ids()
  97. token_id_generated_this_iteration = all_input_ids[-1]
  98. tokenizer = self.get_tokenizer_for_seq(seq)
  99. # Convert prompt token IDs to tokens if necessary.
  100. # Do it here so that we don't have to repeat this
  101. # computation for each logprob.
  102. if seq.tokens is None:
  103. (
  104. seq.tokens,
  105. seq.prefix_offset,
  106. seq.read_offset,
  107. ) = convert_prompt_ids_to_tokens(
  108. tokenizer=tokenizer,
  109. prompt_ids=all_input_ids[:-1],
  110. skip_special_tokens=prms.skip_special_tokens,
  111. )
  112. (
  113. new_tokens,
  114. new_decoded_token_text,
  115. prefix_offset,
  116. read_offset,
  117. ) = detokenize_incrementally(
  118. tokenizer=tokenizer,
  119. all_input_ids=all_input_ids,
  120. prev_tokens=seq.tokens,
  121. prefix_offset=seq.prefix_offset,
  122. read_offset=seq.read_offset,
  123. skip_special_tokens=prms.skip_special_tokens,
  124. spaces_between_special_tokens=prms.spaces_between_special_tokens,
  125. )
  126. # Decode logprobs
  127. logprobs = seq.output_logprobs[-1]
  128. if logprobs:
  129. previous_tokens = all_input_ids[:-1]
  130. for token_id, sample_logprob in logprobs.items():
  131. # If the token was generated this iteration,
  132. # use the provided text.
  133. if token_id == token_id_generated_this_iteration:
  134. sample_logprob.decoded_token = new_decoded_token_text
  135. continue
  136. if (sample_logprob.decoded_token is None
  137. and token_id != INVALID_TOKEN_ID):
  138. all_input_ids_with_logprob = previous_tokens + [token_id]
  139. (_, new_text, _, _) = detokenize_incrementally(
  140. tokenizer=tokenizer,
  141. all_input_ids=all_input_ids_with_logprob,
  142. prev_tokens=seq.tokens,
  143. prefix_offset=seq.prefix_offset,
  144. read_offset=seq.read_offset,
  145. skip_special_tokens=prms.skip_special_tokens,
  146. spaces_between_special_tokens=prms.
  147. spaces_between_special_tokens,
  148. )
  149. sample_logprob.decoded_token = new_text
  150. seq.tokens.extend(new_tokens)
  151. seq.prefix_offset = prefix_offset
  152. seq.read_offset = read_offset
  153. seq.output_text += new_decoded_token_text
  154. return len(new_decoded_token_text)
  155. def _convert_tokens_to_string_with_added_encoders(
  156. tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
  157. output_tokens: List[str],
  158. skip_special_tokens: bool,
  159. spaces_between_special_tokens: bool,
  160. ) -> str:
  161. # Adapted from
  162. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
  163. # NOTE: The following code is slow because it runs a for loop over
  164. # the output_tokens. In Python, running a for loop over a list can be slow
  165. # even when the loop body is very simple.
  166. sub_texts = []
  167. current_sub_text = []
  168. all_special_tokens = set(tokenizer.all_special_tokens)
  169. for token in output_tokens:
  170. if skip_special_tokens and token in all_special_tokens:
  171. continue
  172. if token in tokenizer.get_added_vocab():
  173. if current_sub_text:
  174. sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
  175. sub_texts.append(sub_text)
  176. current_sub_text = []
  177. sub_texts.append(token)
  178. else:
  179. current_sub_text.append(token)
  180. if current_sub_text:
  181. sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
  182. sub_texts.append(sub_text)
  183. if spaces_between_special_tokens:
  184. return " ".join(sub_texts)
  185. else:
  186. return "".join(sub_texts)
  187. # 5 is an arbitrary value that should work for all
  188. # tokenizers (bigger = more conservative).
  189. INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5
  190. def convert_prompt_ids_to_tokens(
  191. tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
  192. prompt_ids: List[int],
  193. skip_special_tokens: bool = False,
  194. ) -> Tuple[List[str], int, int]:
  195. """Converts the prompt ids to tokens and returns the tokens and offsets
  196. for incremental detokenization.
  197. Note that not all tokens are converted to strings. Only the tokens that
  198. are necessary for incremental detokenization are converted to strings.
  199. """
  200. # We do not need to convert the whole prompt to tokens.
  201. # Offset a little more in case we have special tokens.
  202. new_tokens = tokenizer.convert_ids_to_tokens(
  203. prompt_ids[-INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2:],
  204. skip_special_tokens=skip_special_tokens)
  205. read_offset = len(new_tokens)
  206. prefix_offset = max(
  207. read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
  208. return new_tokens, prefix_offset, read_offset
  209. # Based on
  210. # https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
  211. # under Apache 2.0 license
  212. def detokenize_incrementally(
  213. tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
  214. all_input_ids: List[int],
  215. prev_tokens: Optional[List[str]],
  216. prefix_offset: int,
  217. read_offset: int,
  218. skip_special_tokens: bool = False,
  219. spaces_between_special_tokens: bool = True,
  220. ) -> Tuple[List[str], str, int, int]:
  221. """Detokenizes the input ids incrementally and returns the new tokens
  222. and the new text.
  223. If `prev_tokens` is None, this function will convert the input ids to
  224. tokens and return the tokens and the new text. Otherwise, it will return the
  225. new tokens and the new text.
  226. This function will also return the new prefix offset and the new read
  227. offset to be used in the next iteration.
  228. The offsets are necessary to defeat cleanup algorithms in the decode which
  229. decide to add a space or not depending on the surrounding ids.
  230. Args:
  231. tokenizer: The tokenizer to use.
  232. all_input_ids: The input ids. The last id is the new token id.
  233. prev_tokens: The previous tokens. If None, this function will convert
  234. the input ids to tokens and return the tokens and the new text.
  235. prefix_offset: The prefix offset.
  236. read_offset: The read offset.
  237. skip_special_tokens: Whether to skip special tokens.
  238. spaces_between_special_tokens: Whether to add spaces between special
  239. tokens.
  240. """
  241. new_token_id = all_input_ids[-1]
  242. # This is the first iteration for this sequence
  243. is_first_iter = prev_tokens is None
  244. if is_first_iter:
  245. (prev_tokens, prefix_offset,
  246. read_offset) = convert_prompt_ids_to_tokens(
  247. tokenizer,
  248. all_input_ids[:-1],
  249. skip_special_tokens=skip_special_tokens)
  250. # If the new token id is out of bounds, return an empty string.
  251. if new_token_id >= len(tokenizer):
  252. new_tokens = [""]
  253. else:
  254. # Put new_token_id in a list so skip_special_tokens is respected
  255. new_tokens = tokenizer.convert_ids_to_tokens(
  256. [new_token_id], skip_special_tokens=skip_special_tokens)
  257. output_tokens = prev_tokens + new_tokens
  258. # If this is the first iteration, return all tokens.
  259. if is_first_iter:
  260. new_tokens = output_tokens
  261. # The prefix text is necessary only to defeat cleanup algorithms in
  262. # the decode which decide to add a space or not depending on the
  263. # surrounding ids.
  264. if tokenizer.is_fast or not tokenizer.get_added_vocab():
  265. prefix_text = tokenizer.convert_tokens_to_string(
  266. output_tokens[prefix_offset:read_offset])
  267. new_text = tokenizer.convert_tokens_to_string(
  268. output_tokens[prefix_offset:])
  269. else:
  270. prefix_text = _convert_tokens_to_string_with_added_encoders(
  271. tokenizer,
  272. output_tokens[prefix_offset:read_offset],
  273. skip_special_tokens=skip_special_tokens,
  274. spaces_between_special_tokens=spaces_between_special_tokens,
  275. )
  276. new_text = _convert_tokens_to_string_with_added_encoders(
  277. tokenizer,
  278. output_tokens[prefix_offset:],
  279. skip_special_tokens=skip_special_tokens,
  280. spaces_between_special_tokens=spaces_between_special_tokens,
  281. )
  282. if len(new_text) <= len(prefix_text) or new_text.endswith("�"):
  283. # utf-8 char at the end means it's a potential unfinished byte sequence
  284. # from byte fallback tokenization.
  285. # If it's in the middle, it's probably a real invalid id generated
  286. # by the model
  287. return new_tokens, "", prefix_offset, read_offset
  288. new_text = new_text[len(prefix_text):]
  289. return new_tokens, new_text, read_offset, len(output_tokens)