浏览代码

Fix LogitProcessor infrastructure (#26)

50h100a 1 年之前
父节点
当前提交
a86b934469
共有 1 个文件被更改,包括 2 次插入2 次删除
  1. 2 2
      aphrodite/common/logits.py

+ 2 - 2
aphrodite/common/logits.py

@@ -6,7 +6,7 @@ from typing import Dict
 class LogitsProcessor(ABC):
 class LogitsProcessor(ABC):
 
 
     @abstractmethod
     @abstractmethod
-    def __call__(self, logits: torch.tensor) -> torch.tensor:
+    def __call__(self, logits: torch.Tensor, output_tokens: list[list[int]]) -> torch.Tensor:
         pass
         pass
 
 
 
 
@@ -30,7 +30,7 @@ class BiasLogitsProcessor(LogitsProcessor):
         self.values = torch.tensor(list(self.biases.values()),
         self.values = torch.tensor(list(self.biases.values()),
                                    dtype=torch.long)
                                    dtype=torch.long)
 
 
-    def __call__(self, logits):
+    def __call__(self, logits: torch.Tensor, output_tokens: list[list[int]]) -> torch.Tensor:
         if not self.biases:
         if not self.biases:
             return logits
             return logits