Pārlūkot izejas kodu

Merge pull request #27 from StefanGliga/samplers-next

feat: Add Eta, Epsilon and Locally Typical sampling
Stefan Gligorijevic 1 gadu atpakaļ
vecāks
revīzija
985503e899

+ 27 - 1
aphrodite/common/sampling_params.py

@@ -50,7 +50,18 @@ class SamplingParams:
         top_a: Float that controls the cutoff for Top-A sampling.
         top_a: Float that controls the cutoff for Top-A sampling.
             Exact cutoff is top_a*max_prob**2. Must be in [0,inf], 0 to disable.
             Exact cutoff is top_a*max_prob**2. Must be in [0,inf], 0 to disable.
         tfs: Float that controls the cummulative approximate curvature of the
         tfs: Float that controls the cummulative approximate curvature of the
-            distribution to retain for Tail Free Sampling
+            distribution to retain for Tail Free Sampling.
+            Must be in (0, 1]. Set to 1 to disable
+        eta_cutoff: Float that controls the cutoff treshold for Eta sampling
+            (a form of entropy adaptive truncation sampling)
+            treshold is computed as min(eta, sqrt(eta)*entropy(probs)).
+            Specified in units of 1e-4. Set to 0 to disable
+        epsilon_cutoff: Float that controls the cutoff treshold for Epsilon sampling
+            (simple probability treshold truncation).
+            Specified in units of 1e-4. Set to 0 to disable.
+        typical_p: Float that controls the cumulative probability of tokens
+            closest in surprise to the expected surprise to consider.
+            Must be in (0, 1]. Set to 1 to disable.
         use_beam_search: Whether to use beam search instead of sampling.
         use_beam_search: Whether to use beam search instead of sampling.
         length_penalty: Float that penalizes sequences based on their length.
         length_penalty: Float that penalizes sequences based on their length.
             Used in beam search.
             Used in beam search.
@@ -88,6 +99,9 @@ class SamplingParams:
         top_k: int = -1,
         top_k: int = -1,
         top_a: float = 0.0,
         top_a: float = 0.0,
         tfs: float = 1.0,
         tfs: float = 1.0,
+        eta_cutoff: float = 0.0,
+        epsilon_cutoff: float = 0.0,
+        typical_p: float = 1.0,
         use_beam_search: bool = False,
         use_beam_search: bool = False,
         length_penalty: float = 1.0,
         length_penalty: float = 1.0,
         early_stopping: Union[bool, str] = False,
         early_stopping: Union[bool, str] = False,
@@ -109,6 +123,9 @@ class SamplingParams:
         self.top_k = top_k
         self.top_k = top_k
         self.top_a = top_a
         self.top_a = top_a
         self.tfs = tfs
         self.tfs = tfs
+        self.eta_cutoff = eta_cutoff
+        self.epsilon_cutoff = epsilon_cutoff
+        self.typical_p = typical_p
         self.use_beam_search = use_beam_search
         self.use_beam_search = use_beam_search
         self.length_penalty = length_penalty
         self.length_penalty = length_penalty
         self.early_stopping = early_stopping
         self.early_stopping = early_stopping
@@ -164,6 +181,12 @@ class SamplingParams:
             raise ValueError(f"top_a must be in [0, 1], got {self.top_a}.")
             raise ValueError(f"top_a must be in [0, 1], got {self.top_a}.")
         if not 0.0 < self.tfs <= 1.0:
         if not 0.0 < self.tfs <= 1.0:
             raise ValueError(f"tfs must be in (0, 1], got {self.tfs}.")
             raise ValueError(f"tfs must be in (0, 1], got {self.tfs}.")
+        if not 0.0 <= self.epsilon_cutoff <= 1000.0:
+            raise ValueError(f"epsilon_cutoff must be in [0, 1000], got {self.epsilon_cutoff}.")
+        if not self.eta_cutoff >= 0:
+            raise ValueError(f"eta_cutoff must be non negative, got {self.eta_cutoff}.")
+        if not 0.0 <= self.typical_p <= 1.0:
+            raise ValueError(f"typical_p must be in (0, 1], got {self.typical_p}.")
         if self.max_tokens < 1:
         if self.max_tokens < 1:
             raise ValueError(
             raise ValueError(
                 f"max_tokens must be at least 1, got {self.max_tokens}.")
                 f"max_tokens must be at least 1, got {self.max_tokens}.")
@@ -224,6 +247,9 @@ class SamplingParams:
                 f"top_k={self.top_k}, "
                 f"top_k={self.top_k}, "
                 f"top_a={self.top_a}, "
                 f"top_a={self.top_a}, "
                 f"tfs={self.tfs}, "
                 f"tfs={self.tfs}, "
+                f"eta_cutoff={self.eta_cutoff}, "
+                f"epsilon_cutoff={self.epsilon_cutoff}, "
+                f"typical_p={self.typical_p}, "
                 f"use_beam_search={self.use_beam_search}, "
                 f"use_beam_search={self.use_beam_search}, "
                 f"length_penalty={self.length_penalty}, "
                 f"length_penalty={self.length_penalty}, "
                 f"early_stopping={self.early_stopping}, "
                 f"early_stopping={self.early_stopping}, "

+ 3 - 0
aphrodite/endpoints/openai/api_server.py

@@ -222,6 +222,9 @@ async def create_chat_completion(request: ChatCompletionRequest,
             temperature=request.temperature,
             temperature=request.temperature,
             top_p=request.top_p,
             top_p=request.top_p,
             tfs=request.tfs,
             tfs=request.tfs,
+            eta_cutoff=request.eta_cutoff,
+            epsilon_cutoff=request.epsilon_cutoff,
+            typical_p=request.typical_p,
             stop=request.stop,
             stop=request.stop,
             stop_token_ids=request.stop_token_ids,
             stop_token_ids=request.stop_token_ids,
             max_tokens=request.max_tokens,
             max_tokens=request.max_tokens,

+ 6 - 0
aphrodite/endpoints/openai/protocol.py

@@ -58,6 +58,9 @@ class ChatCompletionRequest(BaseModel):
     temperature: Optional[float] = 0.7
     temperature: Optional[float] = 0.7
     top_p: Optional[float] = 1.0
     top_p: Optional[float] = 1.0
     tfs: Optional[float] = 1.0
     tfs: Optional[float] = 1.0
+    eta_cutoff: Optional[float] = 0.0
+    epsilon_cutoff: Optional[float] = 0.0
+    typical_p: Optional[float] = 1.0
     n: Optional[int] = 1
     n: Optional[int] = 1
     max_tokens: Optional[int] = None
     max_tokens: Optional[int] = None
     stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
     stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
@@ -83,6 +86,9 @@ class CompletionRequest(BaseModel):
     temperature: Optional[float] = 1.0
     temperature: Optional[float] = 1.0
     top_p: Optional[float] = 1.0
     top_p: Optional[float] = 1.0
     tfs: Optional[float] = 1.0
     tfs: Optional[float] = 1.0
+    eta_cutoff: Optional[float] = 0.0
+    epsilon_cutoff: Optional[float] = 0.0
+    typical_p: Optional[float] = 1.0
     n: Optional[int] = 1
     n: Optional[int] = 1
     stream: Optional[bool] = False
     stream: Optional[bool] = False
     logprobs: Optional[int] = None
     logprobs: Optional[int] = None

+ 3 - 0
aphrodite/endpoints/protocol.py

@@ -10,6 +10,9 @@ class SamplingParams(BaseModel):
     top_p: float = Field(1.0, alias="top_p")
     top_p: float = Field(1.0, alias="top_p")
     top_k: float = Field(-1, alias="top_k")
     top_k: float = Field(-1, alias="top_k")
     tfs: float = Field(1.0, alias="tfs")
     tfs: float = Field(1.0, alias="tfs")
+    eta_cutoff: float = Field(0.0, alias="eta_cutoff")
+    epsilon_cutoff: float = Field(0.0, alias="epsilon_cutoff")
+    typical_p: float = Field(1.0, alias="typical_p")
     use_beam_search: bool = Field(False, alias="use_beam_search")
     use_beam_search: bool = Field(False, alias="use_beam_search")
     length_penalty: float = Field(1.0, alias="length_penalty")
     length_penalty: float = Field(1.0, alias="length_penalty")
     early_stopping: Union[bool, str] = Field(False, alias="early_stopping")
     early_stopping: Union[bool, str] = Field(False, alias="early_stopping")

+ 109 - 0
aphrodite/modeling/layers/sampler.py

@@ -59,12 +59,29 @@ class Sampler(nn.Module):
         
         
         logits = _apply_logits_processors(input_metadata, logits, output_tokens)
         logits = _apply_logits_processors(input_metadata, logits, output_tokens)
 
 
+        # Apply Eta sampling, as described in https://arxiv.org/abs/2210.15191
+        eta_cutoffs = _get_eta_cutoffs(input_metadata)
+        assert len(eta_cutoffs) == logits.shape[0]
+        if any(eta > _SAMPLING_EPS for eta in eta_cutoffs):
+            logits = _apply_eta_cutoff(logits, eta_cutoffs)
+
+        # Apply Locally typical sampling, as described in https://arxiv.org/abs/2202.00666
+        typical_ps = _get_typical_ps(input_metadata)
+        assert len(typical_ps) == logits.shape[0]
+        if any(typ_p < 1.0 - _SAMPLING_EPS for typ_p in typical_ps):
+            logits = _apply_typical_sampling(logits, typical_ps)
+
         # Apply Tail Free Sampling, as described in https://www.trentonbricken.com/Tail-Free-Sampling/
         # Apply Tail Free Sampling, as described in https://www.trentonbricken.com/Tail-Free-Sampling/
         tfss = _get_tfs(input_metadata)
         tfss = _get_tfs(input_metadata)
         assert len(tfss) == logits.shape[0]
         assert len(tfss) == logits.shape[0]
         if any(z < 1.0 - _SAMPLING_EPS for z in tfss):
         if any(z < 1.0 - _SAMPLING_EPS for z in tfss):
             logits = _apply_tfs(logits, tfss)
             logits = _apply_tfs(logits, tfss)
 
 
+        epsilon_cutoffs = _get_epsilon_cutoffs(input_metadata)
+        assert len(epsilon_cutoffs) == logits.shape[0]
+        if any(epsilon > _SAMPLING_EPS for epsilon in epsilon_cutoffs):
+            logits = _apply_epsilon_cutoff(logits, epsilon_cutoffs)
+
         # Apply temperature scaling.
         # Apply temperature scaling.
         temperatures = _get_temperatures(input_metadata)
         temperatures = _get_temperatures(input_metadata)
         assert len(temperatures) == logits.shape[0]
         assert len(temperatures) == logits.shape[0]
@@ -285,6 +302,33 @@ def _get_tfs(input_metadata: InputMetadata) -> List[float]:
     return tfss
     return tfss
 
 
 
 
+def _get_eta_cutoffs(input_metadata: InputMetadata) -> List[float]:
+    eta_cutoffs: List[float] = []
+    for seq_group in input_metadata.seq_groups:
+        seq_ids, sampling_params = seq_group
+        eta_cutoff = sampling_params.eta_cutoff
+        eta_cutoffs += [eta_cutoff] * len(seq_ids)
+    return eta_cutoffs
+
+
+def _get_epsilon_cutoffs(input_metadata: InputMetadata) -> List[float]:
+    epsilon_cutoffs: List[float] = []
+    for seq_group in input_metadata.seq_groups:
+        seq_ids, sampling_params = seq_group
+        epsilon_cutoff = sampling_params.epsilon_cutoff
+        epsilon_cutoffs += [epsilon_cutoff] * len(seq_ids)
+    return epsilon_cutoffs
+
+
+def _get_typical_ps(input_metadata: InputMetadata) -> List[float]:
+    typical_ps: List[float] = []
+    for seq_group in input_metadata.seq_groups:
+        seq_ids, sampling_params = seq_group
+        typical_p = sampling_params.typical_p
+        typical_ps += [typical_p] * len(seq_ids)
+    return typical_ps
+
+
 def _apply_top_a_top_p_top_k(
 def _apply_top_a_top_p_top_k(
     logits: torch.Tensor,
     logits: torch.Tensor,
     top_ps: List[float],
     top_ps: List[float],
@@ -347,6 +391,71 @@ def _apply_tfs(
     return logits
     return logits
 
 
 
 
+
+def _apply_eta_cutoff(
+    logits: torch.Tensor,
+    eta_cutoffs: List[float],
+) -> torch.Tensor:
+    eta = torch.tensor(eta_cutoffs, dtype=logits.dtype, device=logits.device) * 1e-4
+    shifted_logits = torch.log_softmax(logits, dim=-1)
+    probs = shifted_logits.exp()
+
+    neg_entropy = (probs * shifted_logits).nansum(dim=-1)
+    eps = torch.min(eta, torch.sqrt(eta)*torch.exp(neg_entropy)).unsqueeze(dim=1)
+
+    eta_mask = probs < eps
+
+    if(torch.all(eta_mask)): # guard against nulling out all the logits
+        topk_prob, _ = torch.max(probs, dim=-1)
+        eta_mask = probs < topk_prob
+
+    logits[eta_mask] = -float("inf")
+    return logits
+
+
+def _apply_epsilon_cutoff(
+    logits: torch.Tensor,
+    epsilon_cutoffs: List[float],
+) -> torch.Tensor:
+    eps = torch.tensor(epsilon_cutoffs, dtype=logits.dtype, device=logits.device).unsqueeze(dim=1)
+    probs = logits.softmax(dim=-1)
+
+    eps_mask = probs < (eps * 1e-4)
+
+    if(torch.all(eps_mask)): # guard against nulling out all the logits
+        topk_prob, _ = torch.max(probs, dim=-1)
+        eps_mask = probs < topk_prob
+
+    logits[eps_mask] = -float("inf")
+    return logits
+
+
+def _apply_typical_sampling(
+    logits: torch.Tensor,
+    typical_ps: List[float],
+) -> torch.Tensor:
+    typ_p = torch.tensor(typical_ps, dtype=logits.dtype, device=logits.device)
+    shifted_logits = torch.log_softmax(logits, dim=-1)
+    probs = shifted_logits.exp()
+
+    neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
+
+    surprisal_deviations = (neg_entropy - shifted_logits).abs()
+    _, indices = torch.sort(surprisal_deviations)
+    reordered_probs = probs.gather(-1, indices)
+    typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typ_p.unsqueeze(dim=1)
+    
+    min_tokens_to_keep = 1
+    # Keep at least min_tokens_to_keep
+    typ_mask_sorted[..., :min_tokens_to_keep] = 0
+
+    typ_mask = typ_mask_sorted.scatter(
+        1, indices, typ_mask_sorted
+    )
+    logits[typ_mask] = -float("inf")
+    return logits
+
+
 def _get_topk_logprobs(
 def _get_topk_logprobs(
     logprobs: torch.Tensor,
     logprobs: torch.Tensor,
     num_logprobs: Optional[int],
     num_logprobs: Optional[int],