Ver código fonte

`entropy_deviation` -> `surprisal_deviation`

AlpinDale 1 ano atrás
pai
commit
022380e896
1 arquivos alterados com 2 adições e 2 exclusões
  1. 2 2
      aphrodite/modeling/layers/sampler.py

+ 2 - 2
aphrodite/modeling/layers/sampler.py

@@ -418,7 +418,7 @@ def _apply_typical_sampling(
     neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
     neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
 
 
     surprisal_deviations = (neg_entropy - shifted_logits).abs()
     surprisal_deviations = (neg_entropy - shifted_logits).abs()
-    _, indices = torch.sort(entropy_deviation)
+    _, indices = torch.sort(surprisal_deviation)
     reordered_probs = probs.gather(-1, sorted_indices)
     reordered_probs = probs.gather(-1, sorted_indices)
     typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typical_ps
     typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typical_ps
     
     
@@ -667,4 +667,4 @@ def _sample(
         assert sample_idx == num_tokens
         assert sample_idx == num_tokens
         category_start_idx += num_tokens
         category_start_idx += num_tokens
 
 
-    return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))]
+    return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))]