custom_cache_manager.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import os
  2. from loguru import logger
  3. from triton.runtime.cache import (FileCacheManager, default_cache_dir,
  4. default_dump_dir, default_override_dir)
  5. def maybe_set_triton_cache_manager() -> None:
  6. """Set environment variable to tell Triton to use a
  7. custom cache manager"""
  8. cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
  9. if cache_manger is None:
  10. manager = "aphrodite.triton_utils.custom_cache_manager:CustomCacheManager" # noqa: E501
  11. logger.debug(f"Setting Triton cache manager to: {manager}")
  12. os.environ["TRITON_CACHE_MANAGER"] = manager
  13. class CustomCacheManager(FileCacheManager):
  14. """Re-implements Triton's cache manager, ensuring that a
  15. unique cache directory is created for each process. This is
  16. needed to avoid collisions when running with tp>1 and
  17. using multi-processing as the distributed backend.
  18. Note this issue was fixed by triton-lang/triton/pull/4295,
  19. but the fix is not yet included in triton==v3.0.0. However,
  20. it should be included in the subsequent version.
  21. """
  22. def __init__(self, key, override=False, dump=False):
  23. self.key = key
  24. self.lock_path = None
  25. if dump:
  26. self.cache_dir = default_dump_dir()
  27. self.cache_dir = os.path.join(self.cache_dir, self.key)
  28. self.lock_path = os.path.join(self.cache_dir, "lock")
  29. os.makedirs(self.cache_dir, exist_ok=True)
  30. elif override:
  31. self.cache_dir = default_override_dir()
  32. self.cache_dir = os.path.join(self.cache_dir, self.key)
  33. else:
  34. # create cache directory if it doesn't exist
  35. self.cache_dir = os.getenv("TRITON_CACHE_DIR",
  36. "").strip() or default_cache_dir()
  37. if self.cache_dir:
  38. self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
  39. self.cache_dir = os.path.join(self.cache_dir, self.key)
  40. self.lock_path = os.path.join(self.cache_dir, "lock")
  41. os.makedirs(self.cache_dir, exist_ok=True)
  42. else:
  43. raise RuntimeError("Could not create or locate cache dir")