|
@@ -254,6 +254,7 @@ def _prepare_seq_groups(
|
|
|
|
|
|
sample_obj.prompt_logprob_indices.clear()
|
|
|
sample_obj.sample_indices.clear()
|
|
|
+ dry_sequence_breakerss = []
|
|
|
sampling_params = seq_group_metadata.sampling_params
|
|
|
is_prompt = seq_group_metadata.is_prompt
|
|
|
generator: Optional[torch.Generator] = None
|
|
@@ -265,6 +266,7 @@ def _prepare_seq_groups(
|
|
|
sample_indices: List[int] = \
|
|
|
sample_obj.sample_indices if cache is not None else []
|
|
|
do_sample = seq_group_metadata.do_sample
|
|
|
+ dry_sequence_breakerss.extend([sampling_params.dry_sequence_breakers] * len(seq_ids))
|
|
|
|
|
|
if seq_group_metadata.is_prompt:
|
|
|
if sampling_params.seed is not None:
|
|
@@ -375,6 +377,10 @@ class SamplingTensors:
|
|
|
presence_penalties: torch.Tensor
|
|
|
frequency_penalties: torch.Tensor
|
|
|
repetition_penalties: torch.Tensor
|
|
|
+ dry_multipliers: torch.Tensor
|
|
|
+ dry_bases: torch.Tensor
|
|
|
+ dry_allowed_lengths: torch.Tensor
|
|
|
+ dry_sequence_breakerss: torch.Tensor
|
|
|
tfss: torch.Tensor
|
|
|
eta_cutoffs: torch.Tensor
|
|
|
epsilon_cutoffs: torch.Tensor
|
|
@@ -398,7 +404,7 @@ class SamplingTensors:
|
|
|
extra_seeds_to_generate: int = 0,
|
|
|
extra_entropy: Optional[Tuple[int, ...]] = None
|
|
|
) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
|
|
|
- bool, bool, bool]:
|
|
|
+ bool, bool, bool, bool]:
|
|
|
"""
|
|
|
extra_seeds_to_generate: extra seeds to generate using the
|
|
|
user-defined seed for each sequence.
|
|
@@ -415,6 +421,10 @@ class SamplingTensors:
|
|
|
presence_penalties: List[float] = []
|
|
|
frequency_penalties: List[float] = []
|
|
|
repetition_penalties: List[float] = []
|
|
|
+ dry_multipliers: List[float] = []
|
|
|
+ dry_bases: List[float] = []
|
|
|
+ dry_allowed_lengths: List[int] = []
|
|
|
+ dry_sequence_breakerss: List[int] = []
|
|
|
tfss: List[float] = []
|
|
|
eta_cutoffs: List[float] = []
|
|
|
epsilon_cutoffs: List[float] = []
|
|
@@ -424,6 +434,7 @@ class SamplingTensors:
|
|
|
sampling_seeds: List[int] = []
|
|
|
sample_indices: List[int] = []
|
|
|
do_penalties = False
|
|
|
+ do_dries = False
|
|
|
do_top_p_top_k = False
|
|
|
do_top_as = False
|
|
|
do_min_p = False
|
|
@@ -450,6 +461,10 @@ class SamplingTensors:
|
|
|
p = sampling_params.presence_penalty
|
|
|
f = sampling_params.frequency_penalty
|
|
|
r = sampling_params.repetition_penalty
|
|
|
+ dry_multiplier = sampling_params.dry_multiplier
|
|
|
+ dry_base = sampling_params.dry_base
|
|
|
+ dry_allowed_length = sampling_params.dry_allowed_length
|
|
|
+ dry_sequence_breakers = sampling_params.dry_sequence_breakers
|
|
|
top_p = sampling_params.top_p
|
|
|
top_a = sampling_params.top_a
|
|
|
min_p = sampling_params.min_p
|
|
@@ -479,6 +494,8 @@ class SamplingTensors:
|
|
|
or abs(f) >= _SAMPLING_EPS
|
|
|
or abs(r - 1.0) >= _SAMPLING_EPS):
|
|
|
do_penalties = True
|
|
|
+ if do_dries is False and dry_multiplier > _SAMPLING_EPS:
|
|
|
+ do_dries = True
|
|
|
if do_tfss is False and tfs < 1.0 - _SAMPLING_EPS:
|
|
|
do_tfss = True
|
|
|
if do_eta_cutoffs is False and eta_cutoff > _SAMPLING_EPS:
|
|
@@ -509,6 +526,10 @@ class SamplingTensors:
|
|
|
presence_penalties += [0] * prefill_len
|
|
|
frequency_penalties += [0] * prefill_len
|
|
|
repetition_penalties += [1] * prefill_len
|
|
|
+ dry_multipliers += [0] * prefill_len
|
|
|
+ dry_bases += [0] * prefill_len
|
|
|
+ dry_allowed_lengths += [0] * prefill_len
|
|
|
+ dry_sequence_breakerss += [0] * prefill_len
|
|
|
tfss += [1] * prefill_len
|
|
|
eta_cutoffs += [0] * prefill_len
|
|
|
epsilon_cutoffs += [0] * prefill_len
|
|
@@ -528,6 +549,10 @@ class SamplingTensors:
|
|
|
presence_penalties += [p] * len(seq_ids)
|
|
|
frequency_penalties += [f] * len(seq_ids)
|
|
|
repetition_penalties += [r] * len(seq_ids)
|
|
|
+ dry_multipliers += [dry_multiplier] * len(seq_ids)
|
|
|
+ dry_bases += [dry_base] * len(seq_ids)
|
|
|
+ dry_allowed_lengths += [dry_allowed_length] * len(seq_ids)
|
|
|
+ dry_sequence_breakerss += [dry_sequence_breakers] * len(seq_ids)
|
|
|
tfss += [tfs] * len(seq_ids)
|
|
|
eta_cutoffs += [eta_cutoff] * len(seq_ids)
|
|
|
epsilon_cutoffs += [epsilon_cutoff] * len(seq_ids)
|
|
@@ -576,12 +601,14 @@ class SamplingTensors:
|
|
|
sampling_tensors = SamplingTensors.from_lists(
|
|
|
temperatures, temperature_lasts, top_ps, top_ks, top_as, min_ps,
|
|
|
presence_penalties, frequency_penalties, repetition_penalties,
|
|
|
- tfss, eta_cutoffs, epsilon_cutoffs, typical_ps, smoothing_factors,
|
|
|
- smoothing_curves, sampling_seeds, sample_indices, prompt_tokens,
|
|
|
- output_tokens, vocab_size, extra_seeds_to_generate, device, dtype)
|
|
|
- return (sampling_tensors, do_penalties, do_top_p_top_k, do_top_as,
|
|
|
- do_min_p, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
|
|
|
- do_typical_ps, do_quadratic, do_temp_last)
|
|
|
+ dry_multipliers, dry_bases, dry_allowed_lengths,
|
|
|
+ dry_sequence_breakerss, tfss, eta_cutoffs, epsilon_cutoffs,
|
|
|
+ typical_ps, smoothing_factors, smoothing_curves, sampling_seeds,
|
|
|
+ sample_indices, prompt_tokens, output_tokens, vocab_size,
|
|
|
+ extra_seeds_to_generate, device, dtype)
|
|
|
+ return (sampling_tensors, do_penalties, do_dries, do_top_p_top_k,
|
|
|
+ do_top_as, do_min_p, do_tfss, do_eta_cutoffs,
|
|
|
+ do_epsilon_cutoffs, do_typical_ps, do_quadratic, do_temp_last)
|
|
|
|
|
|
@classmethod
|
|
|
def from_lists(cls, temperatures: List[float],
|
|
@@ -589,7 +616,10 @@ class SamplingTensors:
|
|
|
top_ks: List[int], top_as: List[float],
|
|
|
min_ps: List[float], presence_penalties: List[float],
|
|
|
frequency_penalties: List[float],
|
|
|
- repetition_penalties: List[float], tfss: List[float],
|
|
|
+ repetition_penalties: List[float],
|
|
|
+ dry_multipliers: List[float], dry_bases: List[float],
|
|
|
+ dry_allowed_lengths: List[int],
|
|
|
+ dry_sequence_breakerss: List[List[List[int]]], tfss: List[float],
|
|
|
eta_cutoffs: List[float], epsilon_cutoffs: List[float],
|
|
|
typical_ps: List[float], smoothing_factors: List[float],
|
|
|
smoothing_curves: List[float], sampling_seeds: List[int],
|
|
@@ -668,6 +698,30 @@ class SamplingTensors:
|
|
|
dtype=dtype,
|
|
|
pin_memory=pin_memory,
|
|
|
)
|
|
|
+ dry_multipliers_t = torch.tensor(
|
|
|
+ dry_multipliers,
|
|
|
+ device="cpu",
|
|
|
+ dtype=dtype,
|
|
|
+ pin_memory=pin_memory,
|
|
|
+ )
|
|
|
+ dry_bases_t = torch.tensor(
|
|
|
+ dry_bases,
|
|
|
+ device="cpu",
|
|
|
+ dtype=dtype,
|
|
|
+ pin_memory=pin_memory,
|
|
|
+ )
|
|
|
+ dry_allowed_lengths_t = torch.tensor(
|
|
|
+ dry_allowed_lengths,
|
|
|
+ device="cpu",
|
|
|
+ dtype=torch.int,
|
|
|
+ pin_memory=pin_memory,
|
|
|
+ )
|
|
|
+ # dry_sequence_breakerss_t = torch.tensor(
|
|
|
+ # dry_sequence_breakerss,
|
|
|
+ # device="cpu",
|
|
|
+ # dtype=torch.int,
|
|
|
+ # pin_memory=pin_memory,
|
|
|
+ # )
|
|
|
top_ks_t = torch.tensor(
|
|
|
top_ks,
|
|
|
device="cpu",
|
|
@@ -726,6 +780,16 @@ class SamplingTensors:
|
|
|
extra_seeds_gpu = None
|
|
|
sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
|
|
|
|
|
|
+ max_breakers = max(len(breakers) for breakers in dry_sequence_breakerss)
|
|
|
+ max_breaker_length = max(max(len(breaker) for breaker in breakers) for breakers in dry_sequence_breakerss)
|
|
|
+
|
|
|
+ dry_sequence_breakerss_t = torch.full((len(dry_sequence_breakerss), max_breakers, max_breaker_length),
|
|
|
+ -1, device="cpu", dtype=torch.long, pin_memory=pin_memory)
|
|
|
+
|
|
|
+ for i, breakers in enumerate(dry_sequence_breakerss):
|
|
|
+ for j, breaker in enumerate(breakers):
|
|
|
+ dry_sequence_breakerss_t[i, j, :len(breaker)] = torch.tensor(breaker, dtype=torch.long)
|
|
|
+
|
|
|
return cls(
|
|
|
temperatures=temperatures_t.to(device=device, non_blocking=True),
|
|
|
temperature_lasts=temp_lasts_t.to(device=device, non_blocking=True),
|
|
@@ -739,6 +803,13 @@ class SamplingTensors:
|
|
|
non_blocking=True),
|
|
|
repetition_penalties=repetition_penalties_t.to(device=device,
|
|
|
non_blocking=True),
|
|
|
+ dry_multipliers=dry_multipliers_t.to(device=device,
|
|
|
+ non_blocking=True),
|
|
|
+ dry_bases=dry_bases_t.to(device=device, non_blocking=True),
|
|
|
+ dry_allowed_lengths=dry_allowed_lengths_t.to(device=device,
|
|
|
+ non_blocking=True),
|
|
|
+ dry_sequence_breakerss=dry_sequence_breakerss_t.to(device=device,
|
|
|
+ non_blocking=True),
|
|
|
tfss=tfss_t.to(device=device, non_blocking=True),
|
|
|
eta_cutoffs=eta_cutoffs_t.to(device=device, non_blocking=True),
|
|
|
epsilon_cutoffs=epsilon_cutoffs_t.to(device=device,
|