|
@@ -70,13 +70,13 @@ class Sampler(nn.Module):
|
|
logits.div_(t.unsqueeze(dim=1))
|
|
logits.div_(t.unsqueeze(dim=1))
|
|
|
|
|
|
# Apply top-p, top-k, and top-a truncation.
|
|
# Apply top-p, top-k, and top-a truncation.
|
|
- top_ps, top_ks, top_as = _get_top_ap_top_k(input_metadata, self.vocab_size)
|
|
|
|
|
|
+ top_ps, top_ks, top_as = _get_top_a_top_p_top_k(input_metadata, self.vocab_size)
|
|
assert len(top_ps) == len(top_ks) == logits.shape[0]
|
|
assert len(top_ps) == len(top_ks) == logits.shape[0]
|
|
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
|
|
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
|
|
do_top_k = any(k != self.vocab_size for k in top_ks)
|
|
do_top_k = any(k != self.vocab_size for k in top_ks)
|
|
do_top_a = any(a > _SAMPLING_EPS for a in top_as)
|
|
do_top_a = any(a > _SAMPLING_EPS for a in top_as)
|
|
if do_top_p or do_top_k or do_top_a:
|
|
if do_top_p or do_top_k or do_top_a:
|
|
- logits = _apply_top_ap_top_k(logits, top_ps, top_ks, top_as)
|
|
|
|
|
|
+ logits = _apply_top_a_top_p_top_k(logits, top_ps, top_ks, top_as)
|
|
|
|
|
|
# We use float32 for probabilities and log probabilities.
|
|
# We use float32 for probabilities and log probabilities.
|
|
# Compute the probabilities.
|
|
# Compute the probabilities.
|
|
@@ -248,7 +248,7 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
|
|
return temperatures
|
|
return temperatures
|
|
|
|
|
|
|
|
|
|
-def _get_top_ap_top_k(
|
|
|
|
|
|
+def _get_top_a_top_p_top_k(
|
|
input_metadata: InputMetadata,
|
|
input_metadata: InputMetadata,
|
|
vocab_size: int,
|
|
vocab_size: int,
|
|
) -> Tuple[List[float], List[int], List[float]]:
|
|
) -> Tuple[List[float], List[int], List[float]]:
|
|
@@ -269,7 +269,7 @@ def _get_top_ap_top_k(
|
|
return top_ps, top_ks, top_as
|
|
return top_ps, top_ks, top_as
|
|
|
|
|
|
|
|
|
|
-def _apply_top_ap_top_k(
|
|
|
|
|
|
+def _apply_top_a_top_p_top_k(
|
|
logits: torch.Tensor,
|
|
logits: torch.Tensor,
|
|
top_ps: List[float],
|
|
top_ps: List[float],
|
|
top_ks: List[int],
|
|
top_ks: List[int],
|