Bläddra i källkod

fix: validate `n` in the sampling params (#1075)

AlpinDale 2 månader sedan
förälder
incheckning
814c850d89
1 ändrade filer med 6 tillägg och 1 borttagningar
  1. 6 1
      aphrodite/common/sampling_params.py

+ 6 - 1
aphrodite/common/sampling_params.py

@@ -397,9 +397,14 @@ class SamplingParams(
         self._all_stop_token_ids = set(self.stop_token_ids)
 
     def _verify_args(self) -> None:
+        if not isinstance(self.n, int):
+            raise ValueError(f"n must be an int, but is of "
+                             f"type {type(self.n)}")
         if self.n < 1:
             raise ValueError(f"n must be at least 1, got {self.n}.")
-        assert isinstance(self.best_of, int)
+        if not isinstance(self.best_of, int):
+            raise ValueError(f'best_of must be an int, but is of '
+                             f'type {type(self.best_of)}')
         if self.best_of < self.n:
             raise ValueError(f"best_of must be greater than or equal to n, "
                              f"got n={self.n} and best_of={self.best_of}.")