base.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. """Assets for testing. vLLM conveniently has a bucket of public assets
  2. we can use."""
  3. import os
  4. from functools import lru_cache
  5. from pathlib import Path
  6. from typing import Optional
  7. from aphrodite.connections import global_http_connection
  8. def get_default_cache_root():
  9. return os.getenv(
  10. "XDG_CACHE_HOME",
  11. os.path.join(os.path.expanduser("~"), ".cache"),
  12. )
  13. vLLM_S3_BUCKET_URL = "https://vllm-public-assets.s3.us-west-2.amazonaws.com"
  14. APHRODITE_ASSETS_CACHE = os.path.expanduser(
  15. os.getenv(
  16. "APHRODITE_ASSETS_CACHE",
  17. os.path.join(get_default_cache_root(), "aphrodite", "assets"),
  18. ))
  19. APHRODITE_IMAGE_FETCH_TIMEOUT = int(os.getenv("APHRODITE_IMAGE_FETCH_TIMEOUT", 5))
  20. def get_cache_dir() -> Path:
  21. """Get the path to the cache for storing downloaded assets."""
  22. path = Path(APHRODITE_ASSETS_CACHE)
  23. path.mkdir(parents=True, exist_ok=True)
  24. return path
  25. @lru_cache
  26. def get_vllm_public_assets(filename: str,
  27. s3_prefix: Optional[str] = None) -> Path:
  28. """
  29. Download an asset file from ``s3://vllm-public-assets``
  30. and return the path to the downloaded file.
  31. """
  32. asset_directory = get_cache_dir() / "vllm_public_assets"
  33. asset_directory.mkdir(parents=True, exist_ok=True)
  34. asset_path = asset_directory / filename
  35. if not asset_path.exists():
  36. if s3_prefix is not None:
  37. filename = s3_prefix + "/" + filename
  38. global_http_connection.download_file(
  39. f"{vLLM_S3_BUCKET_URL}/{filename}",
  40. asset_path,
  41. timeout=APHRODITE_IMAGE_FETCH_TIMEOUT)
  42. return asset_path