|
@@ -6,7 +6,7 @@ from typing import Dict
|
|
class LogitsProcessor(ABC):
|
|
class LogitsProcessor(ABC):
|
|
|
|
|
|
@abstractmethod
|
|
@abstractmethod
|
|
- def __call__(self, logits: torch.tensor) -> torch.tensor:
|
|
|
|
|
|
+ def __call__(self, logits: torch.Tensor, output_tokens: list[list[int]]) -> torch.Tensor:
|
|
pass
|
|
pass
|
|
|
|
|
|
|
|
|
|
@@ -30,7 +30,7 @@ class BiasLogitsProcessor(LogitsProcessor):
|
|
self.values = torch.tensor(list(self.biases.values()),
|
|
self.values = torch.tensor(list(self.biases.values()),
|
|
dtype=torch.long)
|
|
dtype=torch.long)
|
|
|
|
|
|
- def __call__(self, logits):
|
|
|
|
|
|
+ def __call__(self, logits: torch.Tensor, output_tokens: list[list[int]]) -> torch.Tensor:
|
|
if not self.biases:
|
|
if not self.biases:
|
|
return logits
|
|
return logits
|
|
|
|
|