evictor.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import enum
  2. from abc import ABC, abstractmethod, abstractproperty
  3. from typing import OrderedDict
  4. from aphrodite.common.block import PhysicalTokenBlock
  5. class EvictionPolicy(enum.Enum):
  6. """Enum for eviction policy used by make_evictor to instantiate the correct
  7. Evictor subclass.
  8. """
  9. LRU = enum.auto()
  10. class Evictor(ABC):
  11. """The Evictor subclasses should be used by the BlockAllocator class to
  12. handle eviction of freed PhysicalTokenBlocks.
  13. """
  14. @abstractmethod
  15. def __init__(self):
  16. pass
  17. @abstractmethod
  18. def __contains__(self, block_hash: int) -> bool:
  19. pass
  20. @abstractmethod
  21. def evict(self) -> PhysicalTokenBlock:
  22. """Runs the eviction algorithm and returns the evicted block"""
  23. pass
  24. @abstractmethod
  25. def add(self, block: PhysicalTokenBlock):
  26. """Adds block to the evictor, making it a candidate for eviction"""
  27. pass
  28. @abstractmethod
  29. def remove(self, block_hash: int) -> PhysicalTokenBlock:
  30. """Simply removes the block with the hash value block_hash from the
  31. evictor. Caller is responsible for making sure that block_hash is
  32. contained in the evictor before calling remove. Should be used to
  33. "bring back" blocks that have been freed but not evicted yet.
  34. """
  35. pass
  36. @abstractproperty
  37. def num_blocks(self) -> int:
  38. pass
  39. class LRUEvictor(Evictor):
  40. """Evicts in a least-recently-used order using the last_accessed timestamp
  41. that's recorded in the PhysicalTokenBlock. If there are multiple blocks with
  42. the same last_accessed time, then the one with the largest num_hashed_tokens
  43. will be evicted. If two blocks each have the lowest last_accessed time and
  44. highest num_hashed_tokens value, then one will be chose arbitrarily
  45. """
  46. def __init__(self):
  47. self.free_table: OrderedDict[int, PhysicalTokenBlock] = OrderedDict()
  48. def __contains__(self, block_hash: int) -> bool:
  49. return block_hash in self.free_table
  50. def evict(self) -> PhysicalTokenBlock:
  51. if len(self.free_table) == 0:
  52. raise ValueError("No usable cache memory left")
  53. evicted_block = next(iter(self.free_table.values()))
  54. # The blocks with the lowest timestamps should be placed consecutively
  55. # at the start of OrderedDict. Loop through all these blocks to
  56. # find the one with maximum number of hashed tokens.
  57. for _, block in self.free_table.items():
  58. if evicted_block.last_accessed < block.last_accessed:
  59. break
  60. if evicted_block.num_hashed_tokens < block.num_hashed_tokens:
  61. evicted_block = block
  62. self.free_table.pop(evicted_block.block_hash)
  63. evicted_block.computed = False
  64. return evicted_block
  65. def add(self, block: PhysicalTokenBlock):
  66. self.free_table[block.block_hash] = block
  67. def remove(self, block_hash: int) -> PhysicalTokenBlock:
  68. if block_hash not in self.free_table:
  69. raise ValueError(
  70. "Attempting to remove block that's not in the evictor")
  71. block: PhysicalTokenBlock = self.free_table[block_hash]
  72. self.free_table.pop(block_hash)
  73. return block
  74. @property
  75. def num_blocks(self) -> int:
  76. return len(self.free_table)
  77. def make_evictor(eviction_policy: EvictionPolicy) -> Evictor:
  78. if eviction_policy == EvictionPolicy.LRU:
  79. return LRUEvictor()
  80. else:
  81. raise ValueError(f"Unknown cache eviction policy: {eviction_policy}")