selector.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import enum
  2. import os
  3. from functools import lru_cache
  4. from typing import Type
  5. import torch
  6. from loguru import logger
  7. from aphrodite.attention.backends.abstract import AttentionBackend
  8. from aphrodite.common.utils import is_cpu, is_hip
  9. APHRODITE_ATTENTION_BACKEND = "APHRODITE_ATTENTION_BACKEND"
  10. class _Backend(enum.Enum):
  11. FLASH_ATTN = enum.auto()
  12. XFORMERS = enum.auto()
  13. ROCM_FLASH = enum.auto()
  14. TORCH_SDPA = enum.auto()
  15. @lru_cache(maxsize=None)
  16. def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
  17. backend = _which_attn_to_use(dtype)
  18. if backend == _Backend.FLASH_ATTN:
  19. logger.info("Using FlashAttention backend.")
  20. from aphrodite.attention.backends.flash_attn import \
  21. FlashAttentionBackend # noqa: F401
  22. return FlashAttentionBackend
  23. elif backend == _Backend.XFORMERS:
  24. logger.info("Using XFormers backend.")
  25. from aphrodite.attention.backends.xformers import \
  26. XFormersBackend # noqa: F401
  27. return XFormersBackend
  28. elif backend == _Backend.ROCM_FLASH:
  29. logger.info("Using ROCmFlashAttention backend.")
  30. from aphrodite.attention.backends.rocm_flash_attn import \
  31. ROCmFlashAttentionBackend # noqa: F401
  32. return ROCmFlashAttentionBackend
  33. elif backend == _Backend.TORCH_SDPA:
  34. logger.info("Using Torch SDPA backend.")
  35. from aphrodite.attention.backends.torch_sdpa import TorchSDPABackend
  36. return TorchSDPABackend
  37. else:
  38. raise ValueError("Invalid attention backend.")
  39. def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
  40. """Returns which flash attention backend to use."""
  41. if is_cpu():
  42. return _Backend.TORCH_SDPA
  43. if is_hip():
  44. # AMD GPUs.
  45. if torch.cuda.get_device_capability()[0] != 9:
  46. # not Instinct series GPUs.
  47. logger.info("flash_atten is not supported on NAVI GPUs.")
  48. return _Backend.ROCM_FLASH
  49. # NVIDIA GPUs.
  50. if torch.cuda.get_device_capability()[0] < 8:
  51. # Volta and Turing NVIDIA GPUs.
  52. logger.info("Cannot use FlashAttention backend for Volta and Turing "
  53. "GPUs.")
  54. return _Backend.XFORMERS
  55. if dtype not in (torch.float16, torch.bfloat16):
  56. logger.info("Cannot use FlashAttention backend for dtype other than "
  57. "torch.float16 or torch.bfloat16.")
  58. return _Backend.XFORMERS
  59. try:
  60. import flash_attn # noqa: F401
  61. except ImportError:
  62. logger.info(
  63. "Cannot use FlashAttention backend because the flash_attn package "
  64. "is not found. Please install it for better performance.")
  65. return _Backend.XFORMERS
  66. backend_by_env_var = os.getenv(APHRODITE_ATTENTION_BACKEND)
  67. if backend_by_env_var is not None:
  68. return _Backend[backend_by_env_var]
  69. # Default case.
  70. return _Backend.FLASH_ATTN