Kaynağa Gözat

Merge branch 'new_samplers' of https://github.com/50h100a/aphrodite-engine into new_samplers

50h100a 1 yıl önce
ebeveyn
işleme
633b99d266

+ 7 - 0
aphrodite/common/sampling_params.py

@@ -42,6 +42,8 @@ class SamplingParams:
             to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
         top_k: Integer that controls the number of top tokens to consider. Set
             to -1 to consider all tokens.
+        tfs: Float that controls the cummulative approximate curvature of the
+            distribution to retain for Tail Free Sampling
         use_beam_search: Whether to use beam search instead of sampling.
         length_penalty: Float that penalizes sequences based on their length.
             Used in beam search.
@@ -78,6 +80,7 @@ class SamplingParams:
         top_p: float = 1.0,
         top_k: int = -1,
         top_a: float = 0.0,
+        tfs: float = 1.0,
         use_beam_search: bool = False,
         length_penalty: float = 1.0,
         early_stopping: Union[bool, str] = False,
@@ -98,6 +101,7 @@ class SamplingParams:
         self.top_p = top_p
         self.top_k = top_k
         self.top_a = top_a
+        self.tfs = tfs
         self.use_beam_search = use_beam_search
         self.length_penalty = length_penalty
         self.early_stopping = early_stopping
@@ -151,6 +155,8 @@ class SamplingParams:
                              f"got {self.top_k}.")
         if not 0.0 <= self.top_a <= 1.0:
             raise ValueError(f"top_a must be in [0, 1], got {self.top_a}.")
+        if not 0.0 < self.tfs <= 1.0:
+            raise ValueError(f"tfs must be in (0, 1], got {self.tfs}.")
         if self.max_tokens < 1:
             raise ValueError(
                 f"max_tokens must be at least 1, got {self.max_tokens}.")
@@ -210,6 +216,7 @@ class SamplingParams:
                 f"top_p={self.top_p}, "
                 f"top_k={self.top_k}, "
                 f"top_a={self.top_a}, "
+                f"tfs={self.tfs}, "
                 f"use_beam_search={self.use_beam_search}, "
                 f"length_penalty={self.length_penalty}, "
                 f"early_stopping={self.early_stopping}, "

+ 2 - 0
aphrodite/endpoints/openai/api_server.py

@@ -221,6 +221,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
             frequency_penalty=request.frequency_penalty,
             temperature=request.temperature,
             top_p=request.top_p,
+            tfs=request.tfs,
             stop=request.stop,
             stop_token_ids=request.stop_token_ids,
             max_tokens=request.max_tokens,
@@ -426,6 +427,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
             frequency_penalty=request.frequency_penalty,
             temperature=request.temperature,
             top_p=request.top_p,
+            tfs=request.tfs,
             top_k=request.top_k,
             stop=request.stop,
             stop_token_ids=request.stop_token_ids,

+ 2 - 0
aphrodite/endpoints/openai/protocol.py

@@ -57,6 +57,7 @@ class ChatCompletionRequest(BaseModel):
     messages: Union[str, List[Dict[str, str]]]
     temperature: Optional[float] = 0.7
     top_p: Optional[float] = 1.0
+    tfs: Optional[float] = 1.0
     n: Optional[int] = 1
     max_tokens: Optional[int] = None
     stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
@@ -81,6 +82,7 @@ class CompletionRequest(BaseModel):
     max_tokens: Optional[int] = 16
     temperature: Optional[float] = 1.0
     top_p: Optional[float] = 1.0
+    tfs: Optional[float] = 1.0
     n: Optional[int] = 1
     stream: Optional[bool] = False
     logprobs: Optional[int] = None

+ 2 - 1
aphrodite/endpoints/protocol.py

@@ -9,11 +9,12 @@ class SamplingParams(BaseModel):
     temperature: float = Field(1.0, alias="temperature")
     top_p: float = Field(1.0, alias="top_p")
     top_k: float = Field(-1, alias="top_k")
+    tfs: float = Field(1.0, alias="tfs")
     use_beam_search: bool = Field(False, alias="use_beam_search")
     length_penalty: float = Field(1.0, alias="length_penalty")
     early_stopping: Union[bool, str] = Field(False, alias="early_stopping")
     stop: Union[None, str, List[str]] = Field(None, alias="stop_sequence")
-    ignore_eos: bool Field(False, alias="ignore_eos")
+    ignore_eos: bool = Field(False, alias="ignore_eos")
     max_tokens: int = Field(16, alias="max_length")
     logprobs: Optional[int] = Field(None, alias="logprobs")
 

+ 44 - 0
aphrodite/modeling/layers/sampler.py

@@ -59,6 +59,12 @@ class Sampler(nn.Module):
         
         logits = _apply_logits_processors(input_metadata, logits, output_tokens)
 
+        # Apply Tail Free Sampling, as described in https://www.trentonbricken.com/Tail-Free-Sampling/
+        tfss = _get_tfs(input_metadata)
+        assert len(tfss) == logits.shape[0]
+        if any(z < 1.0 - _SAMPLING_EPS for z in tfss):
+            logits = _apply_tfs(logits, tfss)
+
         # Apply temperature scaling.
         temperatures = _get_temperatures(input_metadata)
         assert len(temperatures) == logits.shape[0]
@@ -269,6 +275,16 @@ def _get_top_a_top_p_top_k(
     return top_ps, top_ks, top_as
 
 
+
+def _get_tfs(input_metadata: InputMetadata) -> List[float]:
+    tfss: List[float] = []
+    for seq_group in input_metadata.seq_groups:
+        seq_ids, sampling_params = seq_group
+        z = sampling_params.tfs
+        tfss += [z] * len(seq_ids)
+    return tfss
+
+
 def _apply_top_a_top_p_top_k(
     logits: torch.Tensor,
     top_ps: List[float],
@@ -302,6 +318,34 @@ def _apply_top_a_top_p_top_k(
                           index=torch.argsort(logits_idx, dim=-1))
     return logits
 
+def _apply_tfs(
+    logits: torch.Tensor,
+    tfss: List[float],
+) -> torch.Tensor:
+    z = torch.tensor(tfss, dtype=logits.dtype, device=logits.device)
+    logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
+    d2 = logits_sort.softmax(dim=-1).diff().diff().abs()
+    normalized_d2 = d2 / torch.sum(d2, dim=-1)
+    curvature_cdf = torch.cumsum(normalized_d2, dim=-1)
+
+    tfs_mask = curvature_cdf > z.unsqueeze(dim=-1)
+
+    tfs_mask = torch.cat(
+            (
+                torch.zeros(logits.shape[0], 1, dtype=torch.bool, device=logits.device),
+                tfs_mask,
+                torch.ones(logits.shape[0], 1, dtype=torch.bool, device=logits.device),
+            ),
+            dim=-1,
+        )
+    
+    logits_sort[tfs_mask] = -float("inf")
+    logits = torch.gather(logits_sort,
+                          dim=-1,
+                          index=torch.argsort(logits_idx, dim=-1))
+
+    return logits
+
 
 def _get_topk_logprobs(
     logprobs: torch.Tensor,