stop_checker.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. from typing import Callable, Optional
  2. from transformers import PreTrainedTokenizer
  3. from aphrodite.common.sampling_params import SamplingParams
  4. from aphrodite.common.sequence import Sequence, SequenceStatus
  5. class StopChecker:
  6. """AphroditeEngine helper class which separates out the logic involving
  7. stop checking. This checks things such as: whether the eos token was
  8. emitted, whether the max_tokens has been consumed, whether a stop string
  9. has been emitted, or if we have exceeded the max model len.
  10. """
  11. def __init__(self, max_model_len: int,
  12. get_tokenizer_for_seq: Callable[[Sequence],
  13. PreTrainedTokenizer]):
  14. self.max_model_len = max_model_len
  15. self.get_tokenizer_for_seq = get_tokenizer_for_seq
  16. def maybe_stop_sequence(self, seq: Sequence, new_char_count: int,
  17. sampling_params: SamplingParams) -> None:
  18. """Stop the finished sequences.
  19. new_char_count is the number of chars added to the
  20. sequence's output text for the newly generated token
  21. """
  22. # Check if the minimum number of tokens has been generated yet;
  23. # skip the stop string/token checks if not
  24. if seq.get_output_len() < sampling_params.min_tokens:
  25. return
  26. # Check if the sequence has generated the EOS token.
  27. if ((not sampling_params.ignore_eos)
  28. and seq.get_last_token_id() == seq.eos_token_id):
  29. seq.status = SequenceStatus.FINISHED_STOPPED
  30. return
  31. # Check if a stop token was encountered.
  32. # This assumes a single token produced per step.
  33. last_token_id = seq.get_last_token_id()
  34. if last_token_id in sampling_params.stop_token_ids:
  35. if new_char_count and (
  36. not sampling_params.include_stop_str_in_output):
  37. # Remove last token
  38. seq.output_text = seq.output_text[:-new_char_count]
  39. seq.status = SequenceStatus.FINISHED_STOPPED
  40. seq.stop_reason = last_token_id
  41. return
  42. # Check if any stop strings are matched.
  43. stop_str = self._check_stop_strings(seq, new_char_count,
  44. sampling_params)
  45. if stop_str is not None:
  46. seq.status = SequenceStatus.FINISHED_STOPPED
  47. seq.stop_reason = stop_str
  48. return
  49. # Check if the sequence has reached max_model_len.
  50. if seq.get_len() > self.max_model_len:
  51. seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
  52. return
  53. # Check if the sequence has reached max_tokens.
  54. if seq.get_output_len() == sampling_params.max_tokens:
  55. seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
  56. return
  57. @staticmethod
  58. def _check_stop_strings(seq: Sequence, new_char_count: int,
  59. sampling_params: SamplingParams) -> Optional[str]:
  60. """Check if any stop strings are matched and truncate sequence
  61. output text accordingly.
  62. Returns the stop string if matched or else None.
  63. """
  64. if not new_char_count:
  65. return None
  66. for stop_str in sampling_params.stop:
  67. stop_string_len = len(stop_str)
  68. # Avoid searching already-searched text.
  69. stop_index = seq.output_text.find(
  70. stop_str, -new_char_count - stop_string_len)
  71. if stop_index == -1:
  72. continue
  73. if sampling_params.include_stop_str_in_output:
  74. # Truncate to end of stop string.
  75. stop_index += stop_string_len
  76. if stop_index >= len(seq.output_text):
  77. # No truncation required.
  78. return stop_str
  79. # Truncate the output text to either the beginning
  80. # or end of the stop string.
  81. seq.output_text = seq.output_text[:stop_index]
  82. return stop_str
  83. return None