util.py 641 B

123456789101112131415161718
  1. from typing import List
  2. from aphrodite.common.sequence import SamplerOutput
  3. def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput],
  4. num_seq_groups: int):
  5. """Helper method which transforms a 2d list organized by
  6. [step][sequence group] into [sequence group][step].
  7. """
  8. output_by_sequence_group: List[List[SamplerOutput]] = [
  9. [] for _ in range(num_seq_groups)
  10. ]
  11. for step in sampler_outputs:
  12. for i, sequence_group_output in enumerate(step):
  13. output_by_sequence_group[i].append(sequence_group_output)
  14. return output_by_sequence_group