Kaynağa Gözat

Merge pull request #24 from 50h100a/new_samplers

feat:Added top_a and repetition_penalty samplers.
Stefan Gligorijevic 1 yıl önce
ebeveyn
işleme
e107f0f009

+ 11 - 0
aphrodite/common/sampling_params.py

@@ -75,9 +75,11 @@ class SamplingParams:
         best_of: Optional[int] = None,
         best_of: Optional[int] = None,
         presence_penalty: float = 0.0,
         presence_penalty: float = 0.0,
         frequency_penalty: float = 0.0,
         frequency_penalty: float = 0.0,
+        repetition_penalty: float = 1.0,
         temperature: float = 1.0,
         temperature: float = 1.0,
         top_p: float = 1.0,
         top_p: float = 1.0,
         top_k: int = -1,
         top_k: int = -1,
+        top_a: float = 0.0,
         tfs: float = 1.0,
         tfs: float = 1.0,
         use_beam_search: bool = False,
         use_beam_search: bool = False,
         length_penalty: float = 1.0,
         length_penalty: float = 1.0,
@@ -94,9 +96,11 @@ class SamplingParams:
         self.best_of = best_of if best_of is not None else n
         self.best_of = best_of if best_of is not None else n
         self.presence_penalty = presence_penalty
         self.presence_penalty = presence_penalty
         self.frequency_penalty = frequency_penalty
         self.frequency_penalty = frequency_penalty
+        self.repetition_penalty = repetition_penalty
         self.temperature = temperature
         self.temperature = temperature
         self.top_p = top_p
         self.top_p = top_p
         self.top_k = top_k
         self.top_k = top_k
+        self.top_a = top_a
         self.tfs = tfs
         self.tfs = tfs
         self.use_beam_search = use_beam_search
         self.use_beam_search = use_beam_search
         self.length_penalty = length_penalty
         self.length_penalty = length_penalty
@@ -138,6 +142,9 @@ class SamplingParams:
         if not -2.0 <= self.frequency_penalty <= 2.0:
         if not -2.0 <= self.frequency_penalty <= 2.0:
             raise ValueError("frequency_penalty must be in [-2, 2], got "
             raise ValueError("frequency_penalty must be in [-2, 2], got "
                              f"{self.frequency_penalty}.")
                              f"{self.frequency_penalty}.")
+        if not 1.0 <= self.repetition_penalty:
+            raise ValueError("repetition_penalty must be in [1, inf), got "
+                             f"{self.repetition_penalty}.")
         if self.temperature < 0.0:
         if self.temperature < 0.0:
             raise ValueError(
             raise ValueError(
                 f"temperature must be non-negative, got {self.temperature}.")
                 f"temperature must be non-negative, got {self.temperature}.")
@@ -146,6 +153,8 @@ class SamplingParams:
         if self.top_k < -1 or self.top_k == 0:
         if self.top_k < -1 or self.top_k == 0:
             raise ValueError(f"top_k must be -1 (disable), or at least 1, "
             raise ValueError(f"top_k must be -1 (disable), or at least 1, "
                              f"got {self.top_k}.")
                              f"got {self.top_k}.")
+        if not 0.0 <= self.top_a <= 1.0:
+            raise ValueError(f"top_a must be in [0, 1], got {self.top_a}.")
         if not 0.0 < self.tfs <= 1.0:
         if not 0.0 < self.tfs <= 1.0:
             raise ValueError(f"tfs must be in (0, 1], got {self.tfs}.")
             raise ValueError(f"tfs must be in (0, 1], got {self.tfs}.")
         if self.max_tokens < 1:
         if self.max_tokens < 1:
@@ -202,9 +211,11 @@ class SamplingParams:
                 f"best_of={self.best_of}, "
                 f"best_of={self.best_of}, "
                 f"presence_penalty={self.presence_penalty}, "
                 f"presence_penalty={self.presence_penalty}, "
                 f"frequency_penalty={self.frequency_penalty}, "
                 f"frequency_penalty={self.frequency_penalty}, "
+                f"repetition_penalty={self.repetition_penalty}, "
                 f"temperature={self.temperature}, "
                 f"temperature={self.temperature}, "
                 f"top_p={self.top_p}, "
                 f"top_p={self.top_p}, "
                 f"top_k={self.top_k}, "
                 f"top_k={self.top_k}, "
+                f"top_a={self.top_a}, "
                 f"tfs={self.tfs}, "
                 f"tfs={self.tfs}, "
                 f"use_beam_search={self.use_beam_search}, "
                 f"use_beam_search={self.use_beam_search}, "
                 f"length_penalty={self.length_penalty}, "
                 f"length_penalty={self.length_penalty}, "

+ 52 - 30
aphrodite/modeling/layers/sampler.py

@@ -50,12 +50,12 @@ class Sampler(nn.Module):
         # Apply presence and frequency penalties.
         # Apply presence and frequency penalties.
         output_tokens = _get_output_tokens(input_metadata)
         output_tokens = _get_output_tokens(input_metadata)
         assert len(output_tokens) == logits.shape[0]
         assert len(output_tokens) == logits.shape[0]
-        presence_penalties, frequency_penalties = _get_penalties(
-            input_metadata)
+        presence_penalties, frequency_penalties, repetition_penalties = _get_penalties(input_metadata)
         assert len(presence_penalties) == logits.shape[0]
         assert len(presence_penalties) == logits.shape[0]
         assert len(frequency_penalties) == logits.shape[0]
         assert len(frequency_penalties) == logits.shape[0]
-        logits = _apply_penalties(logits, output_tokens, presence_penalties,
-                                  frequency_penalties)
+        logits = _apply_penalties(logits, output_tokens,
+                                  presence_penalties, frequency_penalties, repetition_penalties,
+                                  self.vocab_size)
         
         
         logits = _apply_logits_processors(input_metadata, logits, output_tokens)
         logits = _apply_logits_processors(input_metadata, logits, output_tokens)
 
 
@@ -75,13 +75,14 @@ class Sampler(nn.Module):
             # Use in-place division to avoid creating a new tensor.
             # Use in-place division to avoid creating a new tensor.
             logits.div_(t.unsqueeze(dim=1))
             logits.div_(t.unsqueeze(dim=1))
 
 
-        # Apply top-p and top-k truncation.
-        top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
+        # Apply top-p, top-k, and top-a truncation.
+        top_ps, top_ks, top_as = _get_top_a_top_p_top_k(input_metadata, self.vocab_size)
         assert len(top_ps) == len(top_ks) == logits.shape[0]
         assert len(top_ps) == len(top_ks) == logits.shape[0]
         do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
         do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
         do_top_k = any(k != self.vocab_size for k in top_ks)
         do_top_k = any(k != self.vocab_size for k in top_ks)
-        if do_top_p or do_top_k:
-            logits = _apply_top_p_top_k(logits, top_ps, top_ks)
+        do_top_a = any(a > _SAMPLING_EPS for a in top_as)
+        if do_top_p or do_top_k or do_top_a:
+            logits = _apply_top_a_top_p_top_k(logits, top_ps, top_ks, top_as)
 
 
         # We use float32 for probabilities and log probabilities.
         # We use float32 for probabilities and log probabilities.
         # Compute the probabilities.
         # Compute the probabilities.
@@ -142,13 +143,13 @@ def _get_penalties(
     # Collect the presence and frequency penalties.
     # Collect the presence and frequency penalties.
     presence_penalties: List[float] = []
     presence_penalties: List[float] = []
     frequency_penalties: List[float] = []
     frequency_penalties: List[float] = []
+    repetition_penalties: List[float] = []
     for seq_group in input_metadata.seq_groups:
     for seq_group in input_metadata.seq_groups:
         seq_ids, sampling_params = seq_group
         seq_ids, sampling_params = seq_group
-        p = sampling_params.presence_penalty
-        f = sampling_params.frequency_penalty
-        presence_penalties += [p] * len(seq_ids)
-        frequency_penalties += [f] * len(seq_ids)
-    return presence_penalties, frequency_penalties
+        presence_penalties += [sampling_params.presence_penalty] * len(seq_ids)
+        frequency_penalties += [sampling_params.frequency_penalty] * len(seq_ids)
+        repetition_penalties += [sampling_params.repetition_penalty] * len(seq_ids)
+    return presence_penalties, frequency_penalties, repetition_penalties
 
 
 
 
 def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
 def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
@@ -180,14 +181,16 @@ def _apply_penalties(
     output_tokens: List[List[int]],
     output_tokens: List[List[int]],
     presence_penalties: List[float],
     presence_penalties: List[float],
     frequency_penalties: List[float],
     frequency_penalties: List[float],
+    repetition_penalties: List[float],
+    vocab_size: int,
 ) -> torch.Tensor:
 ) -> torch.Tensor:
     num_seqs, vocab_size = logits.shape
     num_seqs, vocab_size = logits.shape
     for i in range(num_seqs):
     for i in range(num_seqs):
         if not output_tokens[i]:
         if not output_tokens[i]:
             continue
             continue
-        p = presence_penalties[i]
-        f = frequency_penalties[i]
-        if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS:
+        if (abs(presence_penalties[i]) < _SAMPLING_EPS and
+            abs(frequency_penalties[i]) < _SAMPLING_EPS and
+            repetition_penalties[i] < 1.0 + _SAMPLING_EPS):
             continue
             continue
         break
         break
     else:
     else:
@@ -218,11 +221,21 @@ def _apply_penalties(
     presence_penalties = torch.tensor(presence_penalties,
     presence_penalties = torch.tensor(presence_penalties,
                                       dtype=logits.dtype,
                                       dtype=logits.dtype,
                                       device=logits.device)
                                       device=logits.device)
+    repetition_penalties = torch.tensor(repetition_penalties,
+                                      dtype=logits.dtype,
+                                      device=logits.device)
 
 
     # We follow the definition in OpenAI API.
     # We follow the definition in OpenAI API.
     # Refer to https://platform.openai.com/docs/api-reference/parameter-details
     # Refer to https://platform.openai.com/docs/api-reference/parameter-details
     logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts
     logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts
-    logits -= presence_penalties.unsqueeze(dim=1) * (bin_counts > 0)
+    presence_mask = (bin_counts > 0)
+    logits -= presence_penalties.unsqueeze(dim=1) * presence_mask
+
+    # Effectively: If token is present and logit is positive, divide logit by rep_pen.
+    #              If token is present and logit is negative, multiply logit by rep_pen.
+    logits += logits * (1 / repetition_penalties.unsqueeze(dim=1) - 1) * presence_mask * (logits > 0)
+    logits += logits * (repetition_penalties.unsqueeze(dim=1) - 1) * presence_mask * (logits < 0)
+
     return logits
     return logits
 
 
 
 
@@ -241,22 +254,26 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
     return temperatures
     return temperatures
 
 
 
 
-def _get_top_p_top_k(
+def _get_top_a_top_p_top_k(
     input_metadata: InputMetadata,
     input_metadata: InputMetadata,
     vocab_size: int,
     vocab_size: int,
-) -> Tuple[List[float], List[int]]:
+) -> Tuple[List[float], List[int], List[float]]:
     top_ps: List[float] = []
     top_ps: List[float] = []
     top_ks: List[int] = []
     top_ks: List[int] = []
+    top_as: List[float] = []
     for seq_group in input_metadata.seq_groups:
     for seq_group in input_metadata.seq_groups:
         seq_ids, sampling_params = seq_group
         seq_ids, sampling_params = seq_group
-        top_p = sampling_params.top_p
         # k should not be greater than the vocab size.
         # k should not be greater than the vocab size.
         top_k = min(sampling_params.top_k, vocab_size)
         top_k = min(sampling_params.top_k, vocab_size)
         # k=-1 means no truncation.
         # k=-1 means no truncation.
         top_k = vocab_size if top_k == -1 else top_k
         top_k = vocab_size if top_k == -1 else top_k
-        top_ps += [top_p] * len(seq_ids)
+
+        top_ps += [sampling_params.top_p] * len(seq_ids)
         top_ks += [top_k] * len(seq_ids)
         top_ks += [top_k] * len(seq_ids)
-    return top_ps, top_ks
+        top_as += [sampling_params.top_a] * len(seq_ids)
+
+    return top_ps, top_ks, top_as
+
 
 
 
 
 def _get_tfs(input_metadata: InputMetadata) -> List[float]:
 def _get_tfs(input_metadata: InputMetadata) -> List[float]:
@@ -268,26 +285,31 @@ def _get_tfs(input_metadata: InputMetadata) -> List[float]:
     return tfss
     return tfss
 
 
 
 
-def _apply_top_p_top_k(
+def _apply_top_a_top_p_top_k(
     logits: torch.Tensor,
     logits: torch.Tensor,
     top_ps: List[float],
     top_ps: List[float],
     top_ks: List[int],
     top_ks: List[int],
+    top_as: List[float],
 ) -> torch.Tensor:
 ) -> torch.Tensor:
-    p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device)
-    k = torch.tensor(top_ks, dtype=torch.int, device=logits.device)
+    ts_p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device)
+    ts_k = torch.tensor(top_ks, dtype=torch.int, device=logits.device)
+    ts_a = torch.tensor(top_as, dtype=logits.dtype, device=logits.device)
     logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
     logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
 
 
-    # Apply top-p.
+    # Apply top-p and top-a.
     probs_sort = logits_sort.softmax(dim=-1)
     probs_sort = logits_sort.softmax(dim=-1)
     probs_sum = probs_sort.cumsum(dim=-1)
     probs_sum = probs_sort.cumsum(dim=-1)
-    top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
-    logits_sort[top_p_mask] = -float("inf")
-
+    top_a_thresholds = torch.pow(probs_sort[:, 0], 2) * ts_a
+    top_ap_mask = (probs_sort < top_a_thresholds.unsqueeze(1)) # Cull logits below the top-a threshold
+    top_ap_mask.logical_or_(probs_sum > ts_p.unsqueeze(dim=1)) # Cull logits above the top-p summation threshold
+    top_ap_mask[:, 0] = False # Guarantee at least one token is pickable
+    logits_sort[top_ap_mask] = -float("inf")
+    
     # Apply top-k.
     # Apply top-k.
     # Create a mask for the top-k elements.
     # Create a mask for the top-k elements.
     top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
     top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
     top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
     top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
-    top_k_mask = top_k_mask >= k.unsqueeze(dim=1)
+    top_k_mask = top_k_mask >= ts_k.unsqueeze(dim=1)
     logits_sort[top_k_mask] = -float("inf")
     logits_sort[top_k_mask] = -float("inf")
 
 
     # Re-sort the probabilities.
     # Re-sort the probabilities.