|
@@ -662,6 +662,7 @@ def _sample_with_torch(
|
|
|
corresponding to sampling_metadata.seq_groups."""
|
|
|
assert sampling_metadata.seq_groups is not None
|
|
|
assert sampling_metadata.categorized_sample_indices is not None
|
|
|
+ assert sampling_metadata.seq_data is not None
|
|
|
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
|
|
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
|
|
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
|
@@ -708,7 +709,7 @@ def _sample_with_torch(
|
|
|
|
|
|
# GPU<->CPU sync happens in the loop below.
|
|
|
|
|
|
- for sampling_type, metadata in sample_metadata:
|
|
|
+ for sampling_type, metadata in sample_metadata.items():
|
|
|
seq_group_ids, seq_groups, is_prompts, sample_indices = metadata
|
|
|
if sampling_type == SamplingType.GREEDY:
|
|
|
sample_results = _greedy_sample(seq_groups, greedy_samples)
|