test_detokenize.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. from typing import Any, Dict, List, Optional
  2. import pytest
  3. from transformers import AutoTokenizer
  4. from aphrodite.common.sequence import (Logprob, SamplingParams, Sequence,
  5. SequenceGroup)
  6. from aphrodite.transformers_utils.detokenizer import (Detokenizer,
  7. detokenize_incrementally)
  8. from aphrodite.transformers_utils.tokenizer_group import get_tokenizer_group
  9. TRUTH = [
  10. "Hello here, this is a simple test",
  11. "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", # noqa
  12. "我很感谢你的热情"
  13. ]
  14. TOKENIZERS = [
  15. "facebook/opt-125m",
  16. "gpt2",
  17. "bigcode/tiny_starcoder_py",
  18. "EleutherAI/gpt-j-6b",
  19. "EleutherAI/pythia-70m",
  20. "bigscience/bloom-560m",
  21. "mosaicml/mpt-7b",
  22. "tiiuae/falcon-7b",
  23. "meta-llama/Llama-2-7b-hf",
  24. "codellama/CodeLlama-7b-hf",
  25. ]
  26. def _run_incremental_decode(tokenizer, all_input_ids,
  27. skip_special_tokens: bool, starting_index: int):
  28. decoded_text = ""
  29. offset = 0
  30. token_offset = 0
  31. prev_tokens = None
  32. for i in range(starting_index, len(all_input_ids)):
  33. new_tokens, text, offset, token_offset = detokenize_incrementally(
  34. tokenizer,
  35. all_input_ids[:i + 1],
  36. prev_tokens,
  37. offset,
  38. token_offset,
  39. skip_special_tokens=skip_special_tokens)
  40. decoded_text += text
  41. if prev_tokens is None:
  42. prev_tokens = new_tokens
  43. else:
  44. prev_tokens += new_tokens
  45. return decoded_text
  46. @pytest.mark.parametrize("truth", TRUTH)
  47. @pytest.mark.parametrize("with_prompt", [True, False])
  48. @pytest.mark.parametrize("tokenizer_id", TOKENIZERS)
  49. @pytest.mark.parametrize("skip_special_tokens", (True, False))
  50. def test_decode_streaming(tokenizer_id, truth, with_prompt,
  51. skip_special_tokens):
  52. tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
  53. if with_prompt:
  54. truth_tokens = tokenizer(truth, add_special_tokens=False)["input_ids"]
  55. prompt_input_ids = truth_tokens[:len(truth) // 2]
  56. generated_input_ids = truth_tokens[len(truth) // 2:]
  57. all_input_ids = prompt_input_ids + generated_input_ids
  58. starting_index = len(prompt_input_ids)
  59. prompt = tokenizer.decode(prompt_input_ids,
  60. skip_special_tokens=skip_special_tokens)
  61. generated = truth[len(prompt):]
  62. else:
  63. generated = truth
  64. starting_index = 0
  65. all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"]
  66. if skip_special_tokens:
  67. if tokenizer.bos_token_id is not None:
  68. all_input_ids = [tokenizer.bos_token_id] + all_input_ids
  69. starting_index += 1
  70. all_input_ids = all_input_ids + [tokenizer.eos_token_id]
  71. decoded_text = _run_incremental_decode(
  72. tokenizer,
  73. all_input_ids,
  74. skip_special_tokens=skip_special_tokens,
  75. starting_index=starting_index)
  76. assert decoded_text == generated
  77. decoded_text = _run_incremental_decode(
  78. tokenizer, [len(tokenizer)],
  79. skip_special_tokens=skip_special_tokens,
  80. starting_index=starting_index)
  81. assert decoded_text == ''
  82. @pytest.fixture
  83. def detokenizer(tokenizer_name: str) -> Detokenizer:
  84. init_kwargs = dict(
  85. tokenizer_id=tokenizer_name,
  86. enable_lora=False,
  87. max_num_seqs=100,
  88. max_input_length=None,
  89. tokenizer_mode="auto",
  90. trust_remote_code=False,
  91. revision=None,
  92. )
  93. tokenizer_group = get_tokenizer_group(
  94. None,
  95. **init_kwargs,
  96. )
  97. return Detokenizer(tokenizer_group)
  98. @pytest.fixture(name="complete_sequence_token_ids")
  99. def create_complete_sequence_token_ids(complete_sequence: str,
  100. tokenizer_name: str) -> List[int]:
  101. tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
  102. complete_sequence_token_ids = tokenizer(complete_sequence)["input_ids"]
  103. return complete_sequence_token_ids
  104. def create_sequence(prompt_token_ids=None):
  105. prompt_token_ids = prompt_token_ids or [1]
  106. return Sequence(
  107. seq_id=0,
  108. inputs={
  109. "prompt": "<s>",
  110. "prompt_token_ids": prompt_token_ids,
  111. },
  112. block_size=16,
  113. )
  114. def create_dummy_logprobs(
  115. complete_sequence_token_ids: List[int]) -> List[Dict[int, Logprob]]:
  116. return [{
  117. token_id: Logprob(logprob=0.0),
  118. token_id + 1: Logprob(logprob=0.1)
  119. } for token_id in complete_sequence_token_ids]
  120. def create_dummy_prompt_logprobs(
  121. complete_sequence_token_ids: List[int]
  122. ) -> List[Optional[Dict[int, Any]]]:
  123. # logprob for the first prompt token is None.
  124. logprobs: List[Optional[Dict[int, Any]]] = [None]
  125. logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:])
  126. return logprobs
  127. @pytest.mark.parametrize("complete_sequence", TRUTH)
  128. @pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
  129. @pytest.mark.parametrize("skip_special_tokens", [True, False])
  130. def test_decode_sequence_logprobs(complete_sequence: str,
  131. complete_sequence_token_ids: List[int],
  132. detokenizer: Detokenizer,
  133. skip_special_tokens: bool):
  134. """Verify Detokenizer decodes logprobs correctly."""
  135. sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
  136. logprobs=2)
  137. # Run sequentially.
  138. seq = create_sequence()
  139. dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
  140. sequential_logprobs_text_chosen_token: List[str] = []
  141. sequential_logprobs_text_other_token: List[str] = []
  142. for new_token, logprobs in zip(complete_sequence_token_ids,
  143. dummy_logprobs):
  144. seq.append_token_id(new_token, logprobs)
  145. detokenizer.decode_sequence_inplace(seq, sampling_params)
  146. sequential_logprobs_text_chosen_token.append(
  147. seq.output_logprobs[-1][new_token].decoded_token)
  148. sequential_logprobs_text_other_token.append(
  149. seq.output_logprobs[-1][new_token + 1].decoded_token)
  150. sequential_result = seq.output_text
  151. assert sequential_result == "".join(sequential_logprobs_text_chosen_token)
  152. assert sequential_result != "".join(sequential_logprobs_text_other_token)
  153. if skip_special_tokens:
  154. # Text for logprobs for the chosen token should be the same as the
  155. # generated text. Note that this will only be true if we skip
  156. # special tokens.
  157. assert sequential_result == complete_sequence
  158. @pytest.mark.parametrize("complete_sequence", TRUTH)
  159. @pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
  160. def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int],
  161. detokenizer: Detokenizer):
  162. """Verify Detokenizer decodes prompt logprobs correctly."""
  163. sampling_params = SamplingParams(skip_special_tokens=True,
  164. prompt_logprobs=1)
  165. # Run sequentially.
  166. seq = create_sequence(complete_sequence_token_ids)
  167. seq_group = SequenceGroup(request_id="1",
  168. seqs=[seq],
  169. sampling_params=sampling_params,
  170. arrival_time=0.0)
  171. dummy_logprobs = create_dummy_prompt_logprobs(complete_sequence_token_ids)
  172. detokenizer.decode_prompt_logprobs_inplace(seq_group,
  173. dummy_logprobs,
  174. position_offset=0)
  175. # First logprob is None.
  176. decoded_prompt_logprobs: List[Dict[int, Any]] = dummy_logprobs[
  177. 1:] # type: ignore
  178. # decoded_prompt_logprobs doesn't contain the first token.
  179. token_ids = complete_sequence_token_ids
  180. tokenzier = detokenizer.get_tokenizer_for_seq(seq)
  181. text_full = tokenzier.decode(token_ids, skip_special_tokens=True)
  182. text_first = tokenzier.decode(token_ids[0], skip_special_tokens=True)
  183. text = text_full[len(text_first):]
  184. # Text for logprobs for the chosen token should be the same as the
  185. # prompt text. Note that the first logprob is None.
  186. assert text == "".join([
  187. logprobs[token_id].decoded_token
  188. for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
  189. ])
  190. assert text != "".join([
  191. logprobs[token_id + 1].decoded_token
  192. for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
  193. ])
  194. @pytest.mark.parametrize("model", ["facebook/opt-125m"])
  195. @pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 7, 16, -1])
  196. def test_decode_prompt_logprobs_chunked_prefill(
  197. aphrodite_runner,
  198. model,
  199. chunked_prefill_token_size: int,
  200. example_prompts,
  201. ):
  202. max_num_seqs = 256
  203. enable_chunked_prefill = False
  204. max_num_batched_tokens = None
  205. if chunked_prefill_token_size != -1:
  206. enable_chunked_prefill = True
  207. max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
  208. max_num_batched_tokens = chunked_prefill_token_size
  209. with aphrodite_runner(model,
  210. dtype="half",
  211. max_logprobs=5,
  212. gpu_memory_utilization=0.5,
  213. enable_chunked_prefill=enable_chunked_prefill,
  214. max_num_batched_tokens=max_num_batched_tokens,
  215. max_num_seqs=max_num_seqs) as aphrodite_model:
  216. aphrodite_sampling_params = SamplingParams(max_tokens=10,
  217. logprobs=5,
  218. prompt_logprobs=5,
  219. temperature=0.0)
  220. aphrodite_results = aphrodite_model.model.generate(
  221. example_prompts, sampling_params=aphrodite_sampling_params)
  222. for idx, result in enumerate(aphrodite_results):
  223. assert result.prompt_logprobs is not None
  224. assert result.prompt_logprobs[0] is None
  225. # Compared detokenized prompts ids to original prompt.
  226. generated_string = ""
  227. for (prompt_token,
  228. prompt_logprobs) in zip(result.prompt_token_ids[1:],
  229. result.prompt_logprobs[1:]):
  230. # prompt_logprobs is a dict of the token_id: logprob
  231. # We select the token_id corresponding to the actual prompt
  232. # Decoded token in the detokenized string corresponding to this
  233. # prompt token.
  234. generated_string += prompt_logprobs[prompt_token].decoded_token
  235. assert generated_string == example_prompts[idx], (
  236. "Detokenized prompt logprobs do not match original prompt")