Browse Source

feat: add stream_options for chat completions

AlpinDale 7 months ago
parent
commit
1d7f5c45b0
2 changed files with 48 additions and 11 deletions
  1. 15 1
      aphrodite/endpoints/openai/protocol.py
  2. 33 10
      aphrodite/endpoints/openai/serving_chat.py

+ 15 - 1
aphrodite/endpoints/openai/protocol.py

@@ -60,7 +60,11 @@ class UsageInfo(BaseModel):
 
 class ResponseFormat(BaseModel):
     # type must be "json_object" or "text"
-    type: str = Literal["text", "json_object"]
+    type: Literal["text", "json_object"]
+
+
+class StreamOptions(BaseModel):
+    include_usage: Optional[bool]
 
 
 class FunctionDefinition(BaseModel):
@@ -102,6 +106,7 @@ class ChatCompletionRequest(BaseModel):
     stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
     include_stop_str_in_output: Optional[bool] = False
     stream: Optional[bool] = False
+    stream_options: Optional[StreamOptions] = None
     logprobs: Optional[bool] = False
     top_logprobs: Optional[int] = None
     presence_penalty: Optional[float] = 0.0
@@ -205,6 +210,15 @@ class ChatCompletionRequest(BaseModel):
             logits_processors=logits_processors,
         )
 
+    @model_validator(mode='before')
+    @classmethod
+    def validate_stream_options(cls, values):
+        if (values.get('stream_options') is not None
+                and not values.get('stream')):
+            raise ValueError(
+                "stream_options can only be set if stream is true")
+        return values
+
     @model_validator(mode="before")
     @classmethod
     def check_guided_decoding_count(cls, data):

+ 33 - 10
aphrodite/endpoints/openai/serving_chat.py

@@ -150,6 +150,9 @@ class OpenAIServingChat(OpenAIServing):
                             created=created_time,
                             choices=[choice_data],
                             model=model_name)
+                        if (request.stream_options
+                                and request.stream_options.include_usage):
+                            chunk.usage = None
                         data = chunk.model_dump_json(exclude_unset=True)
                         yield f"data: {data}\n\n"
 
@@ -179,6 +182,9 @@ class OpenAIServingChat(OpenAIServing):
                                     choices=[choice_data],
                                     logprobs=None,
                                     model=model_name)
+                                if (request.stream_options and
+                                        request.stream_options.include_usage):
+                                    chunk.usage = None
                                 data = chunk.model_dump_json(
                                     exclude_unset=True)
                                 yield f"data: {data}\n\n"
@@ -230,17 +236,14 @@ class OpenAIServingChat(OpenAIServing):
                             created=created_time,
                             choices=[choice_data],
                             model=model_name)
+                        if (request.stream_options
+                                and request.stream_options.include_usage):
+                            chunk.usage = None
                         data = chunk.model_dump_json(exclude_unset=True)
                         yield f"data: {data}\n\n"
                     else:
                         # Send the finish response for each request.n only once
                         prompt_tokens = len(res.prompt_token_ids)
-                        final_usage = UsageInfo(
-                            prompt_tokens=prompt_tokens,
-                            completion_tokens=previous_num_tokens[i],
-                            total_tokens=prompt_tokens +
-                            previous_num_tokens[i],
-                        )
                         choice_data = ChatCompletionResponseStreamChoice(
                             index=i,
                             delta=delta_message,
@@ -253,12 +256,32 @@ class OpenAIServingChat(OpenAIServing):
                             created=created_time,
                             choices=[choice_data],
                             model=model_name)
-                        if final_usage is not None:
-                            chunk.usage = final_usage
-                        data = chunk.model_dump_json(exclude_unset=True,
-                                                     exclude_none=True)
+                        if (request.stream_options
+                                and request.stream_options.include_usage):
+                            chunk.usage = None
+                        data = chunk.model_dump_json(exclude_unset=True)
                         yield f"data: {data}\n\n"
                         finish_reason_sent[i] = True
+
+                    if (request.stream_options
+                            and request.stream_options.include_usage):
+                        final_usage = UsageInfo(
+                            prompt_tokens=prompt_tokens,
+                            completion_tokens=previous_num_tokens[i],
+                            total_tokens=prompt_tokens +
+                            previous_num_tokens[i],
+                        )
+
+                        final_usage_chunk = ChatCompletionStreamResponse(
+                            id=request_id,
+                            object=chunk_object_type,
+                            created=created_time,
+                            choices=[],
+                            model=model_name,
+                            usage=final_usage)
+                        final_usage_data = (final_usage_chunk.model_dump_json(
+                            exclude_unset=True, exclude_none=True))
+                        yield f"data: {final_usage_data}\n\n"
         except ValueError as e:
             # TODO: Use an aphrodite-specific Validation Error
             data = self.create_streaming_error_response(str(e))