12345678910111213141516 |
- from typing import List
- from aphrodite.common.sequence import SamplerOutput
- def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput],
- num_seq_groups: int):
- """Helper method which transforms a 2d list organized by
- [step][sequence group] into [sequence group][step].
- """
- output_by_sequence_group = [[] for _ in range(num_seq_groups)]
- for step in sampler_outputs:
- for i, sequence_group_output in enumerate(step):
- output_by_sequence_group[i].append(sequence_group_output)
- return output_by_sequence_group
|