Prechádzať zdrojové kódy

Fix LogitProcessor infrastructure (#26)

50h100a 1 rok pred
rodič
commit
a86b934469
1 zmenil súbory, kde vykonal 2 pridanie a 2 odobranie
  1. 2 2
      aphrodite/common/logits.py

+ 2 - 2
aphrodite/common/logits.py

@@ -6,7 +6,7 @@ from typing import Dict
 class LogitsProcessor(ABC):
 
     @abstractmethod
-    def __call__(self, logits: torch.tensor) -> torch.tensor:
+    def __call__(self, logits: torch.Tensor, output_tokens: list[list[int]]) -> torch.Tensor:
         pass
 
 
@@ -30,7 +30,7 @@ class BiasLogitsProcessor(LogitsProcessor):
         self.values = torch.tensor(list(self.biases.values()),
                                    dtype=torch.long)
 
-    def __call__(self, logits):
+    def __call__(self, logits: torch.Tensor, output_tokens: list[list[int]]) -> torch.Tensor:
         if not self.biases:
             return logits