util.py 827 B

12345678910111213141516171819202122
  1. from typing import List
  2. from typing import Sequence as GenericSequence
  3. from typing import Union
  4. from aphrodite.common.sequence import (PoolerOutput, SamplerOutput,
  5. SequenceGroupOutput)
  6. def create_output_by_sequence_group(
  7. outputs: GenericSequence[Union[SamplerOutput, PoolerOutput]],
  8. num_seq_groups: int) -> List[List[SequenceGroupOutput]]:
  9. """Helper method which transforms a 2d list organized by
  10. [step][sequence group] into [sequence group][step].
  11. """
  12. output_by_sequence_group: List[List[SequenceGroupOutput]] = [
  13. [] for _ in range(num_seq_groups)
  14. ]
  15. for step in outputs:
  16. for i, sequence_group_output in enumerate(step):
  17. output_by_sequence_group[i].append(sequence_group_output)
  18. return output_by_sequence_group