Ver Fonte

feat: enable prompt logprobs in OpenAI API (#720)

AlpinDale há 6 meses atrás
pai
commit
61c7182491

+ 9 - 2
aphrodite/endpoints/openai/protocol.py

@@ -12,6 +12,7 @@ from typing_extensions import Annotated
 from aphrodite.common.pooling_params import PoolingParams
 from aphrodite.common.sampling_params import (LogitsProcessorFunc,
                                               SamplingParams)
+from aphrodite.common.sequence import Logprob
 from aphrodite.common.utils import random_uuid
 from aphrodite.endpoints.chat_utils import ChatCompletionMessageParam
 from aphrodite.endpoints.openai.logits_processors import get_logits_processors
@@ -144,6 +145,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
     spaces_between_special_tokens: Optional[bool] = True
     truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
     temperature_last: Optional[bool] = False
+    prompt_logprobs: Optional[int] = None
     # doc: end-chat-completion-sampling-params
 
     # doc: begin-chat-completion-extra-params
@@ -261,7 +263,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
             max_tokens=max_tokens,
             min_tokens=self.min_tokens,
             logprobs=self.top_logprobs if self.logprobs else None,
-            prompt_logprobs=self.top_logprobs if self.echo else None,
+            prompt_logprobs=self.prompt_logprobs if self.prompt_logprobs else
+            (self.top_logprobs if self.echo else None),
             best_of=self.best_of,
             top_k=self.top_k,
             top_a=self.top_a,
@@ -384,6 +387,7 @@ class CompletionRequest(OpenAIBaseModel):
     include_stop_str_in_output: Optional[bool] = False
     add_special_tokens: Optional[bool] = False
     temperature_last: Optional[bool] = False
+    prompt_logprobs: Optional[int] = None
     # doc: end-completion-sampling-params
 
     # doc: begin-completion-extra-params
@@ -469,9 +473,10 @@ class CompletionRequest(OpenAIBaseModel):
             max_tokens=max_tokens if not echo_without_generation else 1,
             min_tokens=self.min_tokens,
             logprobs=self.logprobs,
+            prompt_logprobs=self.prompt_logprobs
+            if self.prompt_logprobs else self.logprobs if self.echo else None,
             use_beam_search=self.use_beam_search,
             early_stopping=self.early_stopping,
-            prompt_logprobs=self.logprobs if self.echo else None,
             skip_special_tokens=self.skip_special_tokens,
             spaces_between_special_tokens=(self.spaces_between_special_tokens),
             include_stop_str_in_output=self.include_stop_str_in_output,
@@ -550,6 +555,7 @@ class CompletionResponseChoice(OpenAIBaseModel):
             "to stop, None if the completion finished for some other reason "
             "including encountering the EOS token"),
     )
+    prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
 
 
 class CompletionResponse(OpenAIBaseModel):
@@ -645,6 +651,7 @@ class ChatCompletionResponse(OpenAIBaseModel):
     model: str
     choices: List[ChatCompletionResponseChoice]
     usage: UsageInfo
+    prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
 
 
 class DeltaMessage(OpenAIBaseModel):

+ 11 - 0
aphrodite/endpoints/openai/serving_chat.py

@@ -79,6 +79,16 @@ class OpenAIServingChat(OpenAIServing):
         if error_check_ret is not None:
             return error_check_ret
 
+        if request.prompt_logprobs is not None:
+            if request.stream and request.prompt_logprobs > 0:
+                return self.create_error_response(
+                    "Prompt_logprobs are not available when stream is enabled")
+
+            if request.prompt_logprobs < 0:
+                return self.create_error_response(
+                    f"Prompt_logprobs set to invalid "
+                    f"negative value: {request.prompt_logprobs}")
+
         try:
             (
                 lora_request,
@@ -496,6 +506,7 @@ class OpenAIServingChat(OpenAIServing):
             model=model_name,
             choices=choices,
             usage=usage,
+            prompt_logprobs=final_res.prompt_logprobs,
         )
 
         return response

+ 10 - 0
aphrodite/endpoints/openai/serving_completions.py

@@ -76,6 +76,15 @@ class OpenAIServingCompletion(OpenAIServing):
         request_id = f"cmpl-{random_uuid()}"
         created_time = int(time.time())
 
+        if request.prompt_logprobs is not None:
+            if request.stream and request.prompt_logprobs > 0:
+                return self.create_error_response(
+                    "Prompt_logprobs are not available when stream is enabled")
+            elif request.prompt_logprobs < 0:
+                return self.create_error_response(
+                    f"Prompt_logprobs set to invalid negative "
+                    f"value: {request.prompt_logprobs}")
+
         # Schedule the request and get the result generator.
         generators: List[AsyncGenerator[RequestOutput, None]] = []
         try:
@@ -357,6 +366,7 @@ class OpenAIServingCompletion(OpenAIServing):
                     logprobs=logprobs,
                     finish_reason=output.finish_reason,
                     stop_reason=output.stop_reason,
+                    prompt_logprobs=final_res.prompt_logprobs,
                 )
                 choices.append(choice_data)