瀏覽代碼

Merge pull request #397 from 50h100a/pr_samplerasserts

Missed .items() and assert
50h100a 11 月之前
父節點
當前提交
f663d3fccc
共有 1 個文件被更改,包括 2 次插入1 次删除
  1. 2 1
      aphrodite/modeling/layers/sampler.py

+ 2 - 1
aphrodite/modeling/layers/sampler.py

@@ -662,6 +662,7 @@ def _sample_with_torch(
     corresponding to sampling_metadata.seq_groups."""
     assert sampling_metadata.seq_groups is not None
     assert sampling_metadata.categorized_sample_indices is not None
+    assert sampling_metadata.seq_data is not None
     categorized_seq_group_ids = {t: [] for t in SamplingType}
     categorized_sample_indices = sampling_metadata.categorized_sample_indices
     for i, seq_group in enumerate(sampling_metadata.seq_groups):
@@ -708,7 +709,7 @@ def _sample_with_torch(
 
     # GPU<->CPU sync happens in the loop below.
 
-    for sampling_type, metadata in sample_metadata:
+    for sampling_type, metadata in sample_metadata.items():
         seq_group_ids, seq_groups, is_prompts, sample_indices = metadata
         if sampling_type == SamplingType.GREEDY:
             sample_results = _greedy_sample(seq_groups, greedy_samples)