|
@@ -51,6 +51,7 @@ class Sampler(nn.Module):
|
|
|
# containing the sampled token ids and probabilities. This is used by
|
|
|
# speculative decoding.
|
|
|
self.include_gpu_probs_tensor = False
|
|
|
+ self.should_modify_greedy_probs_inplace = False
|
|
|
|
|
|
def _init_sampling_tensors(
|
|
|
self,
|
|
@@ -222,8 +223,7 @@ class Sampler(nn.Module):
|
|
|
This is used by speculative decoding, which requires that the sampling
|
|
|
method be encoded into the probability distribution.
|
|
|
"""
|
|
|
- # Modify greedy probs if include_gpu_probs_tensor is set.
|
|
|
- return self.include_gpu_probs_tensor
|
|
|
+ return self.should_modify_greedy_probs_inplace
|
|
|
|
|
|
|
|
|
def _get_bin_counts_and_mask(
|