util.py 831 B

12345678910111213141516171819202122
  1. from typing import List
  2. from typing import Sequence as GenericSequence
  3. from typing import Union
  4. from aphrodite.common.sequence import PoolerOutput, SequenceGroupOutput
  5. from aphrodite.modeling.layers.sampler import SamplerOutput
  6. def create_output_by_sequence_group(
  7. outputs: GenericSequence[Union[SamplerOutput, PoolerOutput]],
  8. num_seq_groups: int) -> List[List[SequenceGroupOutput]]:
  9. """Helper method which transforms a 2d list organized by
  10. [step][sequence group] into [sequence group][step].
  11. """
  12. output_by_sequence_group: List[List[SequenceGroupOutput]] = [
  13. [] for _ in range(num_seq_groups)
  14. ]
  15. for step in outputs:
  16. for i, sequence_group_output in enumerate(step):
  17. output_by_sequence_group[i].append(sequence_group_output)
  18. return output_by_sequence_group