policy.py 904 B

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. from typing import List
  2. from aphrodite.common.sequence import SequenceGroup
  3. class Policy:
  4. def get_priority(
  5. self,
  6. now: float,
  7. seq_group: SequenceGroup,
  8. ) -> float:
  9. raise NotImplementedError
  10. def sort_by_priority(
  11. self,
  12. now: float,
  13. seq_groups: List[SequenceGroup],
  14. ) -> List[SequenceGroup]:
  15. return sorted(
  16. seq_groups,
  17. key=lambda seq_group: self.get_priority(now, seq_group),
  18. reverse=True,
  19. )
  20. class FCFS(Policy):
  21. def get_priority(
  22. self,
  23. now: float,
  24. seq_group: SequenceGroup,
  25. ) -> float:
  26. return now - seq_group.arrival_time
  27. class PolicyFactory:
  28. _POLICY_REGISTRY = {
  29. 'fcfs': FCFS,
  30. }
  31. @classmethod
  32. def get_policy(cls, policy_name: str, **kwargs) -> Policy:
  33. return cls._POLICY_REGISTRY[policy_name](**kwargs)