stop_checker.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. from typing import Callable, Optional
  2. from transformers import PreTrainedTokenizer
  3. from aphrodite.common.sampling_params import SamplingParams
  4. from aphrodite.common.sequence import Sequence, SequenceStatus
  5. from aphrodite.lora.request import LoRARequest
  6. class StopChecker:
  7. """AphroditeEngine helper class which separates out the logic involving
  8. stop checking. This checks things such as: whether the eos token was
  9. emitted, whether the max_tokens has been consumed, whether a stop string
  10. has been emitted, or if we have exceeded the max model len.
  11. """
  12. def __init__(self, max_model_len: int,
  13. get_tokenizer_for_seq: Callable[[Sequence],
  14. PreTrainedTokenizer]):
  15. # Do not use it directly, but use `self._get_max_model_len`.
  16. self._max_model_len = max_model_len
  17. self.get_tokenizer_for_seq = get_tokenizer_for_seq
  18. def _get_max_model_len(self, lora_req: Optional[LoRARequest]):
  19. if lora_req and lora_req.long_lora_max_len:
  20. return lora_req.long_lora_max_len
  21. else:
  22. return self._max_model_len
  23. def maybe_stop_sequence(
  24. self,
  25. seq: Sequence,
  26. new_char_count: int,
  27. sampling_params: SamplingParams,
  28. lora_req: Optional[LoRARequest] = None,
  29. ) -> None:
  30. """Stop the finished sequences.
  31. new_char_count is the number of chars added to the
  32. sequence's output text for the newly generated token
  33. """
  34. # Check if the minimum number of tokens has been generated yet;
  35. # skip the stop string/token checks if not
  36. if seq.get_output_len() < sampling_params.min_tokens:
  37. return
  38. # Check if the sequence has generated the EOS token.
  39. if ((not sampling_params.ignore_eos)
  40. and seq.get_last_token_id() == seq.eos_token_id):
  41. # Remove the last EOS token unless explicitly specified
  42. # This prevents unintended exposure of the EOS token
  43. if new_char_count and (
  44. not sampling_params.include_stop_str_in_output):
  45. seq.output_text = seq.output_text[:-new_char_count]
  46. seq.status = SequenceStatus.FINISHED_STOPPED
  47. return
  48. # Check if a stop token was encountered.
  49. # This assumes a single token produced per step.
  50. last_token_id = seq.get_last_token_id()
  51. if last_token_id in sampling_params.stop_token_ids:
  52. if new_char_count and (
  53. not sampling_params.include_stop_str_in_output):
  54. # Remove last token
  55. seq.output_text = seq.output_text[:-new_char_count]
  56. seq.status = SequenceStatus.FINISHED_STOPPED
  57. seq.stop_reason = last_token_id
  58. return
  59. # Check if any stop strings are matched.
  60. stop_str = self._check_stop_strings(seq, new_char_count,
  61. sampling_params)
  62. if stop_str is not None:
  63. seq.status = SequenceStatus.FINISHED_STOPPED
  64. seq.stop_reason = stop_str
  65. return
  66. # Check if the sequence has reached max_model_len.
  67. if seq.get_len() > self._get_max_model_len(lora_req):
  68. seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
  69. return
  70. # Check if the sequence has reached max_tokens.
  71. if seq.get_output_len() == sampling_params.max_tokens:
  72. seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
  73. return
  74. @staticmethod
  75. def _check_stop_strings(seq: Sequence, new_char_count: int,
  76. sampling_params: SamplingParams) -> Optional[str]:
  77. """Check if any stop strings are matched and truncate sequence
  78. output text accordingly.
  79. Returns the stop string if matched or else None.
  80. """
  81. if not new_char_count:
  82. return None
  83. for stop_str in sampling_params.stop:
  84. stop_string_len = len(stop_str)
  85. # Avoid searching already-searched text.
  86. stop_index = seq.output_text.find(
  87. stop_str, -new_char_count - stop_string_len)
  88. if stop_index == -1:
  89. continue
  90. if sampling_params.include_stop_str_in_output:
  91. # Truncate to end of stop string.
  92. stop_index += stop_string_len
  93. if stop_index >= len(seq.output_text):
  94. # No truncation required.
  95. return stop_str
  96. # Truncate the output text to either the beginning
  97. # or end of the stop string.
  98. seq.output_text = seq.output_text[:stop_index]
  99. return stop_str
  100. return None