test_multi_step.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. import random
  2. from unittest.mock import MagicMock
  3. import pytest
  4. from transformers import PreTrainedTokenizer
  5. from aphrodite.common.sampling_params import SamplingParams
  6. from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
  7. SequenceOutput, SequenceStatus)
  8. from aphrodite.common.utils import Counter
  9. from aphrodite.engine.output_processor.multi_step import (
  10. MultiStepOutputProcessor)
  11. from aphrodite.engine.output_processor.stop_checker import StopChecker
  12. from aphrodite.processing.scheduler import Scheduler
  13. from aphrodite.transformers_utils.detokenizer import Detokenizer
  14. from ...core.utils import create_seq_group
  15. @pytest.mark.parametrize("seq_output_len", [128])
  16. @pytest.mark.parametrize("num_new_tokens", [1, 12])
  17. @pytest.mark.skip_global_cleanup
  18. def test_appends_token_ids(num_new_tokens: int, seq_output_len: int):
  19. """Verify multi-step decoding appends token ids correctly.
  20. We append token ids and verify all the token ids were appended correctly.
  21. Note that ignore_eos=True.
  22. """
  23. detokenizer = MagicMock(spec=Detokenizer)
  24. scheduler = MagicMock(spec=Scheduler)
  25. stop_checker = MagicMock(spec=StopChecker)
  26. seq_counter = Counter()
  27. output_processor = MultiStepOutputProcessor(
  28. detokenizer=detokenizer,
  29. scheduler=[scheduler],
  30. seq_counter=seq_counter,
  31. get_tokenizer_for_seq=lambda _: mock_tokenizer(),
  32. stop_checker=stop_checker,
  33. )
  34. seq_group = create_seq_group(
  35. seq_prompt_len=1024,
  36. seq_output_lens=[seq_output_len],
  37. sampling_params=SamplingParams(max_tokens=seq_output_len +
  38. num_new_tokens,
  39. ignore_eos=True),
  40. )
  41. seq = seq_group.get_seqs()[0]
  42. seq.status = SequenceStatus.RUNNING
  43. new_token_ids = list(range(num_new_tokens))
  44. outputs = [
  45. CompletionSequenceGroupOutput(
  46. samples=[
  47. SequenceOutput(
  48. parent_seq_id=seq.seq_id,
  49. output_token=output_token,
  50. logprobs={output_token: Logprob(0.0)},
  51. )
  52. ],
  53. prompt_logprobs=None,
  54. ) for output_token in new_token_ids
  55. ]
  56. assert seq.get_token_ids()[-len(new_token_ids):] != new_token_ids
  57. output_processor.process_outputs(seq_group, outputs)
  58. assert seq.get_token_ids()[-len(new_token_ids):] == new_token_ids
  59. @pytest.mark.parametrize("seq_prompt_len", [1024])
  60. @pytest.mark.parametrize("seq_output_len", [128])
  61. @pytest.mark.parametrize("num_new_tokens", [5, 6, 7, 8])
  62. @pytest.mark.parametrize("max_tokens", [128 + 3])
  63. @pytest.mark.skip_global_cleanup
  64. def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int,
  65. seq_output_len: int, max_tokens: int):
  66. """Verify tokens after max_tokens are dropped and not appended to the
  67. sequence.
  68. """
  69. detokenizer = MagicMock(spec=Detokenizer)
  70. scheduler = MagicMock(spec=Scheduler)
  71. stop_checker = MagicMock(spec=StopChecker)
  72. seq_counter = Counter()
  73. output_processor = MultiStepOutputProcessor(
  74. detokenizer=detokenizer,
  75. scheduler=[scheduler],
  76. seq_counter=seq_counter,
  77. get_tokenizer_for_seq=lambda _: mock_tokenizer(),
  78. stop_checker=stop_checker,
  79. )
  80. seq_group = create_seq_group(
  81. seq_prompt_len=seq_prompt_len,
  82. seq_output_lens=[seq_output_len],
  83. sampling_params=SamplingParams(max_tokens=max_tokens, ),
  84. )
  85. seq = seq_group.get_seqs()[0]
  86. seq.status = SequenceStatus.RUNNING
  87. new_token_ids = list(range(num_new_tokens))
  88. outputs = [
  89. CompletionSequenceGroupOutput(
  90. samples=[
  91. SequenceOutput(
  92. parent_seq_id=seq.seq_id,
  93. output_token=output_token,
  94. logprobs={output_token: Logprob(0.0)},
  95. )
  96. ],
  97. prompt_logprobs=None,
  98. ) for output_token in new_token_ids
  99. ]
  100. assert seq.get_len() == seq_prompt_len + seq_output_len
  101. output_processor.process_outputs(seq_group, outputs)
  102. # Expect the processed sequence to not go over max tokens in len.
  103. assert seq.get_len() == seq_prompt_len + max_tokens
  104. # Expect the correct tokens were appended.
  105. expected_appended_tokens = new_token_ids[:max_tokens - seq_output_len]
  106. assert seq.get_token_ids(
  107. )[-len(expected_appended_tokens):] == expected_appended_tokens
  108. @pytest.mark.parametrize("seq_prompt_len", [1024])
  109. @pytest.mark.parametrize("seq_output_len", [128])
  110. @pytest.mark.parametrize("num_new_tokens", [12])
  111. @pytest.mark.parametrize("seed", list(range(6)))
  112. @pytest.mark.skip_global_cleanup
  113. def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
  114. seq_output_len: int, seed: int):
  115. """Verify the eos token id is included in the sequence, but subsequent
  116. tokens are dropped (not appended to sequence).
  117. """
  118. random.seed(seed)
  119. detokenizer = MagicMock(spec=Detokenizer)
  120. scheduler = MagicMock(spec=Scheduler)
  121. stop_checker = MagicMock(spec=StopChecker)
  122. seq_counter = Counter()
  123. eos_token_id = 100
  124. output_processor = MultiStepOutputProcessor(
  125. detokenizer=detokenizer,
  126. scheduler=[scheduler],
  127. seq_counter=seq_counter,
  128. get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
  129. stop_checker=stop_checker,
  130. )
  131. seq_group = create_seq_group(
  132. seq_prompt_len=seq_prompt_len,
  133. seq_output_lens=[seq_output_len],
  134. sampling_params=SamplingParams(
  135. # Ensure enough space.
  136. max_tokens=seq_output_len + num_new_tokens, ),
  137. )
  138. seq = seq_group.get_seqs()[0]
  139. seq.status = SequenceStatus.RUNNING
  140. new_token_ids = list(range(num_new_tokens))
  141. assert eos_token_id not in new_token_ids
  142. eos_index = random.randint(0, len(new_token_ids) - 1)
  143. new_token_ids[eos_index] = eos_token_id
  144. outputs = [
  145. CompletionSequenceGroupOutput(
  146. samples=[
  147. SequenceOutput(
  148. parent_seq_id=seq.seq_id,
  149. output_token=output_token,
  150. logprobs={output_token: Logprob(0.0)},
  151. )
  152. ],
  153. prompt_logprobs=None,
  154. ) for output_token in new_token_ids
  155. ]
  156. assert seq.get_len() == seq_prompt_len + seq_output_len
  157. output_processor.process_outputs(seq_group, outputs)
  158. # Expect the processed sequence to not go beyond provided eos.
  159. assert seq.get_len() == seq_prompt_len + seq_output_len + (eos_index + 1)
  160. # Expect the correct tokens were appended.
  161. expected_appended_tokens = new_token_ids[:eos_index + 1]
  162. assert seq.get_token_ids(
  163. )[-len(expected_appended_tokens):] == expected_appended_tokens
  164. @pytest.mark.parametrize("seq_prompt_len", [1024])
  165. @pytest.mark.parametrize("seq_output_len", [128])
  166. @pytest.mark.parametrize("num_new_tokens", [12])
  167. @pytest.mark.parametrize("seed", list(range(6)))
  168. @pytest.mark.skip_global_cleanup
  169. def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
  170. seq_output_len: int, seed: int):
  171. """When sampling parameters dictate that we should ignore the eos token id,
  172. ensure all token ids are appended even if the eos token id is emitted.
  173. """
  174. random.seed(seed)
  175. detokenizer = MagicMock(spec=Detokenizer)
  176. scheduler = MagicMock(spec=Scheduler)
  177. stop_checker = MagicMock(spec=StopChecker)
  178. seq_counter = Counter()
  179. eos_token_id = 100
  180. output_processor = MultiStepOutputProcessor(
  181. detokenizer=detokenizer,
  182. scheduler=[scheduler],
  183. seq_counter=seq_counter,
  184. get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
  185. stop_checker=stop_checker,
  186. )
  187. seq_group = create_seq_group(
  188. seq_prompt_len=seq_prompt_len,
  189. seq_output_lens=[seq_output_len],
  190. sampling_params=SamplingParams(
  191. # Ensure enough space.
  192. max_tokens=seq_output_len + num_new_tokens,
  193. ignore_eos=True,
  194. ),
  195. )
  196. seq = seq_group.get_seqs()[0]
  197. seq.status = SequenceStatus.RUNNING
  198. new_token_ids = list(range(num_new_tokens))
  199. assert eos_token_id not in new_token_ids
  200. eos_index = random.randint(0, len(new_token_ids) - 1)
  201. new_token_ids[eos_index] = eos_token_id
  202. outputs = [
  203. CompletionSequenceGroupOutput(
  204. samples=[
  205. SequenceOutput(
  206. parent_seq_id=seq.seq_id,
  207. output_token=output_token,
  208. logprobs={output_token: Logprob(0.0)},
  209. )
  210. ],
  211. prompt_logprobs=None,
  212. ) for output_token in new_token_ids
  213. ]
  214. assert seq.get_len() == seq_prompt_len + seq_output_len
  215. output_processor.process_outputs(seq_group, outputs)
  216. # Expect the processed sequence to go beyond eos.
  217. assert seq.get_len() == seq_prompt_len + seq_output_len + num_new_tokens
  218. # Expect the correct tokens were appended.
  219. expected_appended_tokens = new_token_ids[:seq_output_len + num_new_tokens -
  220. seq_output_len]
  221. assert seq.get_token_ids(
  222. )[-len(expected_appended_tokens):] == expected_appended_tokens
  223. def mock_tokenizer(eos_token_id=1000):
  224. tokenizer = MagicMock(spec=PreTrainedTokenizer)
  225. tokenizer.eos_token_id = eos_token_id
  226. return tokenizer