Procházet zdrojové kódy

Fix logitproc for logit_bias in OAI endpoints.

50h100a před 11 měsíci
rodič
revize
35b4aa7da5

+ 17 - 25
aphrodite/endpoints/openai/protocol.py

@@ -5,10 +5,10 @@ from typing import Dict, List, Literal, Optional, Union
 
 from pydantic import (AliasChoices, BaseModel, Field, model_validator,
                       root_validator)
-import torch
 
 from aphrodite.common.utils import random_uuid
 from aphrodite.common.sampling_params import SamplingParams
+from aphrodite.common.logits_processor import BiasLogitsProcessor
 
 
 class ErrorResponse(BaseModel):
@@ -103,23 +103,18 @@ class ChatCompletionRequest(BaseModel):
     guided_regex: Optional[str] = None
     guided_choice: Optional[List[str]] = None
 
-    def to_sampling_params(self) -> SamplingParams:
+    def to_sampling_params(self, vocab_size: int) -> SamplingParams:
         if self.logprobs and not self.top_logprobs:
             raise ValueError("Top logprobs must be set when logprobs is.")
 
-        logits_processors = None
+        logits_processors = []
         if self.logit_bias:
-
-            def logit_bias_logits_processor(
-                    token_ids: List[int],
-                    logits: torch.Tensor) -> torch.Tensor:
-                for token_id, bias in self.logit_bias.items():
-                    # Clamp the bias between -100 and 100 per OpenAI API spec
-                    bias = min(100, max(-100, bias))
-                    logits[int(token_id)] += bias
-                return logits
-
-            logits_processors = [logit_bias_logits_processor]
+            biases = {
+                int(tok): min(100, max(float(bias), -100))
+                for tok, bias in self.logit_bias.items()
+                if 0 < int(tok) < vocab_size
+            }
+            logits_processors.append(BiasLogitsProcessor(biases))
 
         return SamplingParams(
             n=self.n,
@@ -230,21 +225,18 @@ class CompletionRequest(BaseModel):
     guided_regex: Optional[str] = None
     guided_choice: Optional[List[str]] = None
 
-    def to_sampling_params(self) -> SamplingParams:
+    def to_sampling_params(self, vocab_size: int) -> SamplingParams:
         echo_without_generation = self.echo and self.max_tokens == 0
 
-        logits_processors = None
+        logits_processors = []
         if self.logit_bias:
+            biases = {
+                int(tok): min(100, max(float(bias), -100))
+                for tok, bias in self.logit_bias.items()
+                if 0 < int(tok) < vocab_size
+            }
+            logits_processors.append(BiasLogitsProcessor(biases))
 
-            def logit_bias_logits_processor(
-                    token_ids: List[int],
-                    logits: torch.Tensor) -> torch.Tensor:
-                for token_id, bias in self.logit_bias.items():
-                    bias = min(100, max(-100, bias))
-                    logits[int(token_id)] += bias
-                return logits
-
-            logits_processors = [logit_bias_logits_processor]
         return SamplingParams(
             n=self.n,
             max_tokens=self.max_tokens if not echo_without_generation else 1,

+ 2 - 1
aphrodite/endpoints/openai/serving_chat.py

@@ -62,7 +62,8 @@ class OpenAIServingChat(OpenAIServing):
         try:
             token_ids = self._validate_prompt_and_tokenize(request,
                                                            prompt=prompt)
-            sampling_params = request.to_sampling_params()
+            sampling_params = request.to_sampling_params(
+                self.tokenizer.vocab_size)
             lora_request = self._maybe_get_lora(request)
             guided_decode_logits_processor = (
                 await get_guided_decoding_logits_processor(

+ 2 - 1
aphrodite/endpoints/openai/serving_completions.py

@@ -122,7 +122,8 @@ class OpenAIServingCompletion(OpenAIServing):
         # Schedule the request and get the result generator.
         generators = []
         try:
-            sampling_params = request.to_sampling_params()
+            sampling_params = request.to_sampling_params(
+                self.tokenizer.vocab_size)
             lora_request = self._maybe_get_lora(request)
             guided_decode_logit_processor = (
                 await get_guided_decoding_logits_processor(