|
@@ -58,9 +58,11 @@ class SamplingMetadata:
|
|
|
hidden_states = execute_model(...)
|
|
|
logits = hidden_states[sampling_metadata.selected_token_indices]
|
|
|
sample(logits)
|
|
|
+
|
|
|
def sample(logits):
|
|
|
# Use categorized_sample_indices for sampling....
|
|
|
```
|
|
|
+
|
|
|
Args:
|
|
|
seq_groups: List of batched sequence groups.
|
|
|
selected_token_indices: (num_query_tokens_to_logprob). Indices to find
|
|
@@ -141,6 +143,7 @@ def _prepare_seq_groups(
|
|
|
) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
|
|
|
SamplingType, List[Tuple[int, int]]], int]:
|
|
|
"""Prepare sequence groups and indices for sampling.
|
|
|
+
|
|
|
Args:
|
|
|
seq_group_metadata_list: A list of sequence group to batch.
|
|
|
prompt_lens: A list of prompt lens per sequence group.
|
|
@@ -149,6 +152,7 @@ def _prepare_seq_groups(
|
|
|
of entire prompt tokens, and it could be shorter.
|
|
|
device: A device to use for random number generator,
|
|
|
`SequenceGroupToSample.generator`.
|
|
|
+
|
|
|
Returns:
|
|
|
seq_groups: A list of sequence group to sample.
|
|
|
selected_token_indices: See the definition from `SamplingMetadata`.
|
|
@@ -215,6 +219,7 @@ def _prepare_seq_groups(
|
|
|
"""
|
|
|
This blocks computes selected_token_indices which is used in the
|
|
|
following way.
|
|
|
+
|
|
|
hidden_states = model(...)
|
|
|
logits = hidden_states[selected_token_indices]
|
|
|
"""
|
|
@@ -232,6 +237,7 @@ def _prepare_seq_groups(
|
|
|
"""
|
|
|
This block computes categorized_sample_indices which is used in the
|
|
|
following way.
|
|
|
+
|
|
|
hidden_states = model(...)
|
|
|
logits = hidden_states[selected_token_indices]
|
|
|
def sample(logits):
|
|
@@ -274,6 +280,7 @@ def _prepare_seq_groups(
|
|
|
@dataclass
|
|
|
class SamplingTensors:
|
|
|
"""Tensors for sampling."""
|
|
|
+
|
|
|
temperatures: torch.Tensor
|
|
|
top_ps: torch.Tensor
|
|
|
top_ks: torch.Tensor
|
|
@@ -286,9 +293,6 @@ class SamplingTensors:
|
|
|
eta_cutoffs: torch.Tensor
|
|
|
epsilon_cutoffs: torch.Tensor
|
|
|
typical_ps: torch.Tensor
|
|
|
- dynatemp_mins: torch.Tensor
|
|
|
- dynatemp_maxs: torch.Tensor
|
|
|
- dynatemp_exps: torch.Tensor
|
|
|
smoothing_factors: torch.Tensor
|
|
|
smoothing_curves: torch.Tensor
|
|
|
sampling_seeds: torch.Tensor
|
|
@@ -308,7 +312,12 @@ class SamplingTensors:
|
|
|
extra_seeds_to_generate: int = 0,
|
|
|
extra_entropy: Optional[Tuple[int, ...]] = None
|
|
|
) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
|
|
|
- bool, bool, bool, bool]:
|
|
|
+ bool, bool]:
|
|
|
+ """
|
|
|
+ extra_seeds_to_generate: extra seeds to generate using the
|
|
|
+ user-defined seed for each sequence.
|
|
|
+ extra_entropy: extra entropy to use when generating seeds.
|
|
|
+ """
|
|
|
prompt_tokens: List[List[int]] = []
|
|
|
output_tokens: List[List[int]] = []
|
|
|
top_ks: List[int] = []
|
|
@@ -323,20 +332,15 @@ class SamplingTensors:
|
|
|
eta_cutoffs: List[float] = []
|
|
|
epsilon_cutoffs: List[float] = []
|
|
|
typical_ps: List[float] = []
|
|
|
- dynatemp_mins: List[float] = []
|
|
|
- dynatemp_maxs: List[float] = []
|
|
|
- dynatemp_exps: List[float] = []
|
|
|
smoothing_factors: List[float] = []
|
|
|
smoothing_curves: List[float] = []
|
|
|
sampling_seeds: List[int] = []
|
|
|
sample_indices: List[int] = []
|
|
|
prompt_best_of: List[int] = []
|
|
|
- do_temperatures = False
|
|
|
do_penalties = False
|
|
|
- do_topks = False
|
|
|
- do_topps = False
|
|
|
- do_topas = False
|
|
|
- do_minps = False
|
|
|
+ do_top_p_top_k = False
|
|
|
+ do_top_as = False
|
|
|
+ do_min_p = False
|
|
|
do_tfss = False
|
|
|
do_eta_cutoffs = False
|
|
|
do_epsilon_cutoffs = False
|
|
@@ -356,38 +360,37 @@ class SamplingTensors:
|
|
|
f = sampling_params.frequency_penalty
|
|
|
r = sampling_params.repetition_penalty
|
|
|
top_p = sampling_params.top_p
|
|
|
- # k should not be greater than the vocab size
|
|
|
- top_k = min(sampling_params.top_k, vocab_size)
|
|
|
- top_k = vocab_size if top_k == -1 else top_k
|
|
|
top_a = sampling_params.top_a
|
|
|
min_p = sampling_params.min_p
|
|
|
tfs = sampling_params.tfs
|
|
|
eta_cutoff = sampling_params.eta_cutoff
|
|
|
epsilon_cutoff = sampling_params.epsilon_cutoff
|
|
|
typical_p = sampling_params.typical_p
|
|
|
- dynatemp_min = sampling_params.dynatemp_min
|
|
|
- dynatemp_max = sampling_params.dynatemp_max
|
|
|
- dynatemp_exp = sampling_params.dynatemp_exponent
|
|
|
smoothing_factor = sampling_params.smoothing_factor
|
|
|
smoothing_curve = sampling_params.smoothing_curve
|
|
|
seed = sampling_params.seed
|
|
|
|
|
|
is_greedy = sampling_params.sampling_type == SamplingType.GREEDY
|
|
|
|
|
|
- if do_temperatures is False and temperature > _SAMPLING_EPS:
|
|
|
- do_temperatures = True
|
|
|
+ # k should not be greater than the vocab size.
|
|
|
+ top_k = min(sampling_params.top_k, vocab_size)
|
|
|
+ top_k = vocab_size if top_k == -1 else top_k
|
|
|
+ if temperature < _SAMPLING_EPS:
|
|
|
+ # NOTE: Zero temperature means deterministic sampling
|
|
|
+ # (i.e., greedy sampling or beam search).
|
|
|
+ # Set the temperature to 1 to avoid division by zero.
|
|
|
+ temperature = 1.0
|
|
|
+ if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS
|
|
|
+ or top_k != vocab_size):
|
|
|
+ do_top_p_top_k = True
|
|
|
+ if do_top_as is False and top_a > 0.0:
|
|
|
+ do_top_as = True
|
|
|
+ if not do_min_p and min_p > _SAMPLING_EPS:
|
|
|
+ do_min_p = True
|
|
|
if not do_penalties and (abs(p) >= _SAMPLING_EPS
|
|
|
or abs(f) >= _SAMPLING_EPS
|
|
|
or abs(r - 1.0) >= _SAMPLING_EPS):
|
|
|
do_penalties = True
|
|
|
- if do_topks is False and top_k != vocab_size:
|
|
|
- do_topks = True
|
|
|
- if do_topps is False and top_p < 1.0 - _SAMPLING_EPS:
|
|
|
- do_topps = True
|
|
|
- if do_topas is False and top_a > 0.0:
|
|
|
- do_topas = True
|
|
|
- if do_minps is False and min_p > _SAMPLING_EPS:
|
|
|
- do_minps = True
|
|
|
if do_tfss is False and tfs < 1.0 - _SAMPLING_EPS:
|
|
|
do_tfss = True
|
|
|
if do_eta_cutoffs is False and eta_cutoff > _SAMPLING_EPS:
|
|
@@ -403,8 +406,8 @@ class SamplingTensors:
|
|
|
is_prompt = seq_group.is_prompt
|
|
|
if (seq_group.is_prompt
|
|
|
and sampling_params.prompt_logprobs is not None):
|
|
|
- # For tokens in the prompt that we only need to get their
|
|
|
- # logprobs
|
|
|
+ # For tokens in the prompt that we only need to get
|
|
|
+ # their logprobs
|
|
|
subquery_len = seq_group.subquery_len
|
|
|
assert subquery_len is not None
|
|
|
prefill_len = len(seq_group.prompt_logprob_indices)
|
|
@@ -420,9 +423,6 @@ class SamplingTensors:
|
|
|
eta_cutoffs += [0] * prefill_len
|
|
|
epsilon_cutoffs += [0] * prefill_len
|
|
|
typical_ps += [1] * prefill_len
|
|
|
- dynatemp_mins += [dynatemp_min] * prefill_len
|
|
|
- dynatemp_maxs += [dynatemp_max] * prefill_len
|
|
|
- dynatemp_exps += [dynatemp_exp] * prefill_len
|
|
|
smoothing_factors += [smoothing_factor] * prefill_len
|
|
|
smoothing_curves += [smoothing_curve] * prefill_len
|
|
|
prompt_tokens.extend([] for _ in range(prefill_len))
|
|
@@ -435,23 +435,20 @@ class SamplingTensors:
|
|
|
seq_data = seq_group.seq_data[seq_id]
|
|
|
prompt_tokens.append(seq_data.prompt_token_ids)
|
|
|
output_tokens.append(seq_data.output_token_ids)
|
|
|
- temperatures += [temperature] * len(seq_ids)
|
|
|
- top_ps += [top_p] * len(seq_ids)
|
|
|
- top_ks += [top_k] * len(seq_ids)
|
|
|
- top_as += [top_a] * len(seq_ids)
|
|
|
- min_ps += [min_p] * len(seq_ids)
|
|
|
- presence_penalties += [p] * len(seq_ids)
|
|
|
- frequency_penalties += [f] * len(seq_ids)
|
|
|
- repetition_penalties += [r] * len(seq_ids)
|
|
|
- tfss += [tfs] * len(seq_ids)
|
|
|
- eta_cutoffs += [eta_cutoff] * len(seq_ids)
|
|
|
- epsilon_cutoffs += [epsilon_cutoff] * len(seq_ids)
|
|
|
- typical_ps += [typical_p] * len(seq_ids)
|
|
|
- dynatemp_mins += [dynatemp_min] * len(seq_ids)
|
|
|
- dynatemp_maxs += [dynatemp_max] * len(seq_ids)
|
|
|
- dynatemp_exps += [dynatemp_exp] * len(seq_ids)
|
|
|
- smoothing_factors += [smoothing_factor] * len(seq_ids)
|
|
|
- smoothing_curves += [smoothing_curve] * len(seq_ids)
|
|
|
+ temperatures += [temperature] * len(seq_ids)
|
|
|
+ top_ps += [top_p] * len(seq_ids)
|
|
|
+ top_ks += [top_k] * len(seq_ids)
|
|
|
+ top_as += [top_a] * len(seq_ids)
|
|
|
+ min_ps += [min_p] * len(seq_ids)
|
|
|
+ presence_penalties += [p] * len(seq_ids)
|
|
|
+ frequency_penalties += [f] * len(seq_ids)
|
|
|
+ repetition_penalties += [r] * len(seq_ids)
|
|
|
+ tfss += [tfs] * len(seq_ids)
|
|
|
+ eta_cutoffs += [eta_cutoff] * len(seq_ids)
|
|
|
+ epsilon_cutoffs += [epsilon_cutoff] * len(seq_ids)
|
|
|
+ typical_ps += [typical_p] * len(seq_ids)
|
|
|
+ smoothing_factors += [smoothing_factor] * len(seq_ids)
|
|
|
+ smoothing_curves += [smoothing_curve] * len(seq_ids)
|
|
|
|
|
|
if is_prompt:
|
|
|
prompt_best_of.append(sampling_params.best_of)
|
|
@@ -474,13 +471,12 @@ class SamplingTensors:
|
|
|
sampling_tensors = SamplingTensors.from_lists(
|
|
|
temperatures, top_ps, top_ks, top_as, min_ps, presence_penalties,
|
|
|
frequency_penalties, repetition_penalties, tfss, eta_cutoffs,
|
|
|
- epsilon_cutoffs, typical_ps, dynatemp_mins, dynatemp_maxs,
|
|
|
- dynatemp_exps, smoothing_factors, smoothing_curves, sampling_seeds,
|
|
|
- sample_indices, prompt_tokens, output_tokens, vocab_size,
|
|
|
- extra_seeds_to_generate, device, dtype)
|
|
|
- return (sampling_tensors, do_temperatures, do_penalties, do_topks,
|
|
|
- do_topps, do_topas, do_minps, do_tfss, do_eta_cutoffs,
|
|
|
- do_epsilon_cutoffs, do_typical_ps, do_quadratic)
|
|
|
+ epsilon_cutoffs, typical_ps, smoothing_factors, smoothing_curves,
|
|
|
+ sampling_seeds, sample_indices, prompt_tokens, output_tokens,
|
|
|
+ vocab_size, extra_seeds_to_generate, device, dtype)
|
|
|
+ return (sampling_tensors, do_penalties, do_top_p_top_k, do_top_as,
|
|
|
+ do_min_p, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
|
|
|
+ do_typical_ps, do_quadratic)
|
|
|
|
|
|
@classmethod
|
|
|
def from_lists(cls, temperatures: List[float], top_ps: List[float],
|
|
@@ -489,9 +485,7 @@ class SamplingTensors:
|
|
|
frequency_penalties: List[float],
|
|
|
repetition_penalties: List[float], tfss: List[float],
|
|
|
eta_cutoffs: List[float], epsilon_cutoffs: List[float],
|
|
|
- typical_ps: List[float], dynatemp_mins: List[float],
|
|
|
- dynatemp_maxs: List[float], dynatemp_exps: List[float],
|
|
|
- smoothing_factors: List[float],
|
|
|
+ typical_ps: List[float], smoothing_factors: List[float],
|
|
|
smoothing_curves: List[float], sampling_seeds: List[int],
|
|
|
sample_indices: List[int], prompt_tokens: List[List[int]],
|
|
|
output_tokens: List[List[int]], vocab_size: int,
|
|
@@ -513,38 +507,52 @@ class SamplingTensors:
|
|
|
for tokens in output_tokens
|
|
|
]
|
|
|
|
|
|
- temperatures_t = torch.tensor(temperatures,
|
|
|
- device="cpu",
|
|
|
- dtype=dtype,
|
|
|
- pin_memory=pin_memory)
|
|
|
- top_ps_t = torch.tensor(top_ps,
|
|
|
- device="cpu",
|
|
|
- dtype=dtype,
|
|
|
- pin_memory=pin_memory)
|
|
|
- top_ks_t = torch.tensor(top_ks,
|
|
|
- device="cpu",
|
|
|
- dtype=torch.int,
|
|
|
- pin_memory=pin_memory)
|
|
|
+ temperatures_t = torch.tensor(
|
|
|
+ temperatures,
|
|
|
+ device="cpu",
|
|
|
+ dtype=dtype,
|
|
|
+ pin_memory=pin_memory,
|
|
|
+ )
|
|
|
+ top_ps_t = torch.tensor(
|
|
|
+ top_ps,
|
|
|
+ device="cpu",
|
|
|
+ dtype=dtype,
|
|
|
+ pin_memory=pin_memory,
|
|
|
+ )
|
|
|
top_as_t = torch.tensor(top_as,
|
|
|
device="cpu",
|
|
|
dtype=dtype,
|
|
|
pin_memory=pin_memory)
|
|
|
- min_ps_t = torch.tensor(min_ps,
|
|
|
- device="cpu",
|
|
|
- dtype=dtype,
|
|
|
- pin_memory=pin_memory)
|
|
|
- presence_penalties_t = torch.tensor(presence_penalties,
|
|
|
- device="cpu",
|
|
|
- dtype=dtype,
|
|
|
- pin_memory=pin_memory)
|
|
|
- frequency_penalties_t = torch.tensor(frequency_penalties,
|
|
|
- device="cpu",
|
|
|
- dtype=dtype,
|
|
|
- pin_memory=pin_memory)
|
|
|
- repetition_penalties_t = torch.tensor(repetition_penalties,
|
|
|
- device="cpu",
|
|
|
- dtype=dtype,
|
|
|
- pin_memory=pin_memory)
|
|
|
+ min_ps_t = torch.tensor(
|
|
|
+ min_ps,
|
|
|
+ device="cpu",
|
|
|
+ dtype=dtype,
|
|
|
+ pin_memory=pin_memory,
|
|
|
+ )
|
|
|
+ presence_penalties_t = torch.tensor(
|
|
|
+ presence_penalties,
|
|
|
+ device="cpu",
|
|
|
+ dtype=dtype,
|
|
|
+ pin_memory=pin_memory,
|
|
|
+ )
|
|
|
+ frequency_penalties_t = torch.tensor(
|
|
|
+ frequency_penalties,
|
|
|
+ device="cpu",
|
|
|
+ dtype=dtype,
|
|
|
+ pin_memory=pin_memory,
|
|
|
+ )
|
|
|
+ repetition_penalties_t = torch.tensor(
|
|
|
+ repetition_penalties,
|
|
|
+ device="cpu",
|
|
|
+ dtype=dtype,
|
|
|
+ pin_memory=pin_memory,
|
|
|
+ )
|
|
|
+ top_ks_t = torch.tensor(
|
|
|
+ top_ks,
|
|
|
+ device="cpu",
|
|
|
+ dtype=torch.int,
|
|
|
+ pin_memory=pin_memory,
|
|
|
+ )
|
|
|
tfss_t = torch.tensor(tfss,
|
|
|
device="cpu",
|
|
|
dtype=dtype,
|
|
@@ -561,18 +569,6 @@ class SamplingTensors:
|
|
|
device="cpu",
|
|
|
dtype=dtype,
|
|
|
pin_memory=pin_memory)
|
|
|
- dynatemp_mins_t = torch.tensor(dynatemp_mins,
|
|
|
- device="cpu",
|
|
|
- dtype=dtype,
|
|
|
- pin_memory=pin_memory)
|
|
|
- dynatemp_maxs_t = torch.tensor(dynatemp_maxs,
|
|
|
- device="cpu",
|
|
|
- dtype=dtype,
|
|
|
- pin_memory=pin_memory)
|
|
|
- dynatemp_exps_t = torch.tensor(dynatemp_exps,
|
|
|
- device="cpu",
|
|
|
- dtype=dtype,
|
|
|
- pin_memory=pin_memory)
|
|
|
smoothing_factors_t = torch.tensor(smoothing_factors,
|
|
|
device="cpu",
|
|
|
dtype=dtype,
|
|
@@ -581,18 +577,24 @@ class SamplingTensors:
|
|
|
device="cpu",
|
|
|
dtype=dtype,
|
|
|
pin_memory=pin_memory)
|
|
|
- sample_indices_t = torch.tensor(sample_indices,
|
|
|
- device="cpu",
|
|
|
- dtype=torch.int,
|
|
|
- pin_memory=pin_memory)
|
|
|
- prompt_tensor = torch.tensor(prompt_padded_tokens,
|
|
|
- device=device,
|
|
|
- dtype=torch.long,
|
|
|
- pin_memory=pin_memory)
|
|
|
- output_tensor = torch.tensor(output_padded_tokens,
|
|
|
- device=device,
|
|
|
- dtype=torch.long,
|
|
|
- pin_memory=pin_memory)
|
|
|
+ sample_indices_t = torch.tensor(
|
|
|
+ sample_indices,
|
|
|
+ device="cpu",
|
|
|
+ dtype=torch.long,
|
|
|
+ pin_memory=pin_memory,
|
|
|
+ )
|
|
|
+ prompt_tensor = torch.tensor(
|
|
|
+ prompt_padded_tokens,
|
|
|
+ device="cpu",
|
|
|
+ dtype=torch.long,
|
|
|
+ pin_memory=pin_memory,
|
|
|
+ )
|
|
|
+ output_tensor = torch.tensor(
|
|
|
+ output_padded_tokens,
|
|
|
+ device="cpu",
|
|
|
+ dtype=torch.long,
|
|
|
+ pin_memory=pin_memory,
|
|
|
+ )
|
|
|
# need to transpose and make contiguous to
|
|
|
# copy the tensor correctly.
|
|
|
# [batch_size, n_seeds] -> [n_seeds, batch_size]
|
|
@@ -602,6 +604,7 @@ class SamplingTensors:
|
|
|
dtype=torch.long,
|
|
|
pin_memory=pin_memory,
|
|
|
).T.contiguous()
|
|
|
+
|
|
|
# Because the memory is pinned, we can do non-blocking
|
|
|
# transfer to device.
|
|
|
|
|
@@ -613,6 +616,7 @@ class SamplingTensors:
|
|
|
if not extra_seeds_gpu.numel():
|
|
|
extra_seeds_gpu = None
|
|
|
sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
|
|
|
+
|
|
|
return cls(
|
|
|
temperatures=temperatures_t.to(device=device, non_blocking=True),
|
|
|
top_ps=top_ps_t.to(device=device, non_blocking=True),
|
|
@@ -629,9 +633,6 @@ class SamplingTensors:
|
|
|
eta_cutoffs=eta_cutoffs_t.to(device=device, non_blocking=True),
|
|
|
epsilon_cutoffs=epsilon_cutoffs_t.to(device=device,
|
|
|
non_blocking=True),
|
|
|
- dynatemp_mins=dynatemp_mins_t.to(device=device, non_blocking=True),
|
|
|
- dynatemp_maxs=dynatemp_maxs_t.to(device=device, non_blocking=True),
|
|
|
- dynatemp_exps=dynatemp_exps_t.to(device=device, non_blocking=True),
|
|
|
smoothing_factors=smoothing_factors_t.to(device=device,
|
|
|
non_blocking=True),
|
|
|
smoothing_curves=smoothing_curves_t.to(device=device,
|