|
@@ -48,6 +48,11 @@ class StopChecker:
|
|
|
# Check if the sequence has generated the EOS token.
|
|
|
if ((not sampling_params.ignore_eos)
|
|
|
and seq.get_last_token_id() == seq.eos_token_id):
|
|
|
+ # Remove the last EOS token unless explicitly specified
|
|
|
+ # This prevents unintended exposure of the EOS token
|
|
|
+ if new_char_count and (
|
|
|
+ not sampling_params.include_stop_str_in_output):
|
|
|
+ seq.output_text = seq.output_text[:-new_char_count]
|
|
|
seq.status = SequenceStatus.FINISHED_STOPPED
|
|
|
return
|
|
|
|