|
@@ -400,6 +400,7 @@ def _apply_temperature(
|
|
normalized_entropies.pow_(dynatemp_exps))
|
|
normalized_entropies.pow_(dynatemp_exps))
|
|
|
|
|
|
temperatures[dynatemp_mask] = dyn_temp
|
|
temperatures[dynatemp_mask] = dyn_temp
|
|
|
|
+ temperatures[temperatures == 0.0] = 1.0
|
|
logits.div_(temperatures.unsqueeze_(dim=1))
|
|
logits.div_(temperatures.unsqueeze_(dim=1))
|
|
return logits
|
|
return logits
|
|
|
|
|
|
@@ -556,12 +557,10 @@ def _sample(
|
|
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
|
|
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
|
|
sample_metadata = {}
|
|
sample_metadata = {}
|
|
|
|
|
|
- # Counterintiutively, having two loops here is actually faster.
|
|
|
|
|
|
+ # Counterintuitively, having two loops here is actually faster.
|
|
# The first loop can run without waiting on GPU<->CPU sync.
|
|
# The first loop can run without waiting on GPU<->CPU sync.
|
|
- for sampling_type in SamplingType:
|
|
|
|
- sample_indices = categorized_sample_indices[sampling_type]
|
|
|
|
- num_tokens = len(sample_indices)
|
|
|
|
- if num_tokens == 0:
|
|
|
|
|
|
+ for sampling_type, sample_indices in categorized_sample_indices.items():
|
|
|
|
+ if len(sample_indices) == 0:
|
|
continue
|
|
continue
|
|
seq_group_ids = categorized_seq_group_ids[sampling_type]
|
|
seq_group_ids = categorized_seq_group_ids[sampling_type]
|
|
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
|
|
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
|
|
@@ -585,11 +584,8 @@ def _sample(
|
|
|
|
|
|
# GPU<->CPU sync happens in the loop below.
|
|
# GPU<->CPU sync happens in the loop below.
|
|
|
|
|
|
- for sampling_type in SamplingType:
|
|
|
|
- if sampling_type not in sample_metadata:
|
|
|
|
- continue
|
|
|
|
- seq_group_ids, seq_groups, is_prompts, sample_indices = sample_metadata[
|
|
|
|
- sampling_type]
|
|
|
|
|
|
+ for sampling_type, metadata in sample_metadata.items():
|
|
|
|
+ seq_group_ids, seq_groups, is_prompts, sample_indices = metadata
|
|
if sampling_type == SamplingType.GREEDY:
|
|
if sampling_type == SamplingType.GREEDY:
|
|
sample_results = _greedy_sample(seq_groups, greedy_samples)
|
|
sample_results = _greedy_sample(seq_groups, greedy_samples)
|
|
elif sampling_type == SamplingType.RANDOM:
|
|
elif sampling_type == SamplingType.RANDOM:
|