123456789101112131415161718 |
- from typing import List
- from aphrodite.common.sequence import SamplerOutput
- def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput],
- num_seq_groups: int):
- """Helper method which transforms a 2d list organized by
- [step][sequence group] into [sequence group][step].
- """
- output_by_sequence_group: List[List[SamplerOutput]] = [
- [] for _ in range(num_seq_groups)
- ]
- for step in sampler_outputs:
- for i, sequence_group_output in enumerate(step):
- output_by_sequence_group[i].append(sequence_group_output)
- return output_by_sequence_group
|