policy.py 970 B

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. from collections import deque
  2. from typing import Deque
  3. from aphrodite.common.sequence import SequenceGroup
  4. class Policy:
  5. def get_priority(
  6. self,
  7. now: float,
  8. seq_group: SequenceGroup,
  9. ) -> float:
  10. raise NotImplementedError
  11. def sort_by_priority(
  12. self,
  13. now: float,
  14. seq_groups: Deque[SequenceGroup],
  15. ) -> Deque[SequenceGroup]:
  16. return deque(
  17. sorted(
  18. seq_groups,
  19. key=lambda seq_group: self.get_priority(now, seq_group),
  20. reverse=True,
  21. ))
  22. class FCFS(Policy):
  23. def get_priority(
  24. self,
  25. now: float,
  26. seq_group: SequenceGroup,
  27. ) -> float:
  28. return now - seq_group.metrics.arrival_time
  29. class PolicyFactory:
  30. _POLICY_REGISTRY = {'fcfs': FCFS}
  31. @classmethod
  32. def get_policy(cls, policy_name: str, **kwargs) -> Policy:
  33. return cls._POLICY_REGISTRY[policy_name](**kwargs)