policy.py 985 B

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. from collections import deque
  2. from typing import Deque
  3. from aphrodite.common.sequence import SequenceGroup
  4. class Policy:
  5. def get_priority(
  6. self,
  7. now: float,
  8. seq_group: SequenceGroup,
  9. ) -> float:
  10. raise NotImplementedError
  11. def sort_by_priority(
  12. self,
  13. now: float,
  14. seq_groups: Deque[SequenceGroup],
  15. ) -> Deque[SequenceGroup]:
  16. return deque(
  17. sorted(
  18. seq_groups,
  19. key=lambda seq_group: self.get_priority(now, seq_group),
  20. reverse=True,
  21. ))
  22. class FCFS(Policy):
  23. def get_priority(
  24. self,
  25. now: float,
  26. seq_group: SequenceGroup,
  27. ) -> float:
  28. return now - seq_group.metrics.arrival_time
  29. class PolicyFactory:
  30. _POLICY_REGISTRY = {
  31. 'fcfs': FCFS,
  32. }
  33. @classmethod
  34. def get_policy(cls, policy_name: str, **kwargs) -> Policy:
  35. return cls._POLICY_REGISTRY[policy_name](**kwargs)