util.py 600 B

12345678910111213141516
  1. from typing import List
  2. from aphrodite.common.sequence import SamplerOutput
  3. def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput],
  4. num_seq_groups: int):
  5. """Helper method which transforms a 2d list organized by
  6. [step][sequence group] into [sequence group][step].
  7. """
  8. output_by_sequence_group = [[] for _ in range(num_seq_groups)]
  9. for step in sampler_outputs:
  10. for i, sequence_group_output in enumerate(step):
  11. output_by_sequence_group[i].append(sequence_group_output)
  12. return output_by_sequence_group