base.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. """Assets for testing. vLLM conveniently has a bucket of public assets
  2. we can use."""
  3. import os
  4. from functools import lru_cache
  5. from pathlib import Path
  6. from typing import Optional
  7. from aphrodite.connections import global_http_connection
  8. def get_default_cache_root():
  9. return os.getenv(
  10. "XDG_CACHE_HOME",
  11. os.path.join(os.path.expanduser("~"), ".cache"),
  12. )
  13. vLLM_S3_BUCKET_URL = "https://vllm-public-assets.s3.us-west-2.amazonaws.com"
  14. APHRODITE_ASSETS_CACHE = os.path.expanduser(
  15. os.getenv(
  16. "APHRODITE_ASSETS_CACHE",
  17. os.path.join(get_default_cache_root(), "aphrodite", "assets"),
  18. ))
  19. APHRODITE_IMAGE_FETCH_TIMEOUT = int(os.getenv("APHRODITE_IMAGE_FETCH_TIMEOUT",
  20. 5))
  21. def get_cache_dir() -> Path:
  22. """Get the path to the cache for storing downloaded assets."""
  23. path = Path(APHRODITE_ASSETS_CACHE)
  24. path.mkdir(parents=True, exist_ok=True)
  25. return path
  26. @lru_cache
  27. def get_vllm_public_assets(filename: str,
  28. s3_prefix: Optional[str] = None) -> Path:
  29. """
  30. Download an asset file from ``s3://vllm-public-assets``
  31. and return the path to the downloaded file.
  32. """
  33. asset_directory = get_cache_dir() / "vllm_public_assets"
  34. asset_directory.mkdir(parents=True, exist_ok=True)
  35. asset_path = asset_directory / filename
  36. if not asset_path.exists():
  37. if s3_prefix is not None:
  38. filename = s3_prefix + "/" + filename
  39. global_http_connection.download_file(
  40. f"{vLLM_S3_BUCKET_URL}/{filename}",
  41. asset_path,
  42. timeout=APHRODITE_IMAGE_FETCH_TIMEOUT)
  43. return asset_path