selector.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import enum
  2. import os
  3. from functools import lru_cache
  4. from typing import Optional, Type
  5. import torch
  6. from loguru import logger
  7. from aphrodite.attention.backends.abstract import AttentionBackend
  8. from aphrodite.common.utils import is_cpu, is_hip
  9. APHRODITE_ATTENTION_BACKEND = "APHRODITE_ATTENTION_BACKEND"
  10. class _Backend(enum.Enum):
  11. FLASH_ATTN = enum.auto()
  12. XFORMERS = enum.auto()
  13. ROCM_FLASH = enum.auto()
  14. TORCH_SDPA = enum.auto()
  15. FLASHINFER = enum.auto()
  16. @lru_cache(maxsize=None)
  17. def get_attn_backend(
  18. num_heads: int,
  19. head_size: int,
  20. num_kv_heads: int,
  21. sliding_window: Optional[int],
  22. dtype: torch.dtype,
  23. kv_cache_dtype: Optional[str],
  24. block_size: int,
  25. ) -> Type[AttentionBackend]:
  26. backend = _which_attn_to_use(num_heads, head_size, num_kv_heads,
  27. sliding_window, dtype, kv_cache_dtype,
  28. block_size)
  29. if backend == _Backend.FLASH_ATTN:
  30. logger.info("Using FlashAttention backend.")
  31. from aphrodite.attention.backends.flash_attn import \
  32. FlashAttentionBackend # noqa: F401
  33. return FlashAttentionBackend
  34. elif backend == _Backend.XFORMERS:
  35. logger.info("Using XFormers backend.")
  36. from aphrodite.attention.backends.xformers import \
  37. XFormersBackend # noqa: F401
  38. return XFormersBackend
  39. elif backend == _Backend.ROCM_FLASH:
  40. logger.info("Using ROCmFlashAttention backend.")
  41. from aphrodite.attention.backends.rocm_flash_attn import \
  42. ROCmFlashAttentionBackend # noqa: F401
  43. return ROCmFlashAttentionBackend
  44. elif backend == _Backend.TORCH_SDPA:
  45. logger.info("Using Torch SDPA backend.")
  46. from aphrodite.attention.backends.torch_sdpa import TorchSDPABackend
  47. return TorchSDPABackend
  48. elif backend == _Backend.FLASHINFER:
  49. logger.info("Using Flashinfer backend.")
  50. logger.warning("Eager mode is enforced for the Flashinfer backend. ")
  51. from aphrodite.attention.backends.flashinfer import FlashInferBackend
  52. return FlashInferBackend
  53. else:
  54. raise ValueError("Invalid attention backend.")
  55. def _which_attn_to_use(
  56. num_heads: int,
  57. head_size: int,
  58. num_kv_heads: int,
  59. sliding_window: Optional[int],
  60. dtype: torch.dtype,
  61. kv_cache_dtype: Optional[str],
  62. block_size: int,
  63. ) -> _Backend:
  64. """Returns which flash attention backend to use."""
  65. if is_cpu():
  66. return _Backend.TORCH_SDPA
  67. if is_hip():
  68. # AMD GPUs.
  69. if torch.cuda.get_device_capability()[0] != 9:
  70. # not Instinct series GPUs.
  71. logger.info("flash_atten is not supported on NAVI GPUs.")
  72. return _Backend.ROCM_FLASH
  73. # NVIDIA GPUs.
  74. if torch.cuda.get_device_capability()[0] < 8:
  75. # Volta and Turing NVIDIA GPUs.
  76. logger.info("Cannot use FlashAttention backend for Volta and Turing "
  77. "GPUs.")
  78. return _Backend.XFORMERS
  79. if dtype not in (torch.float16, torch.bfloat16):
  80. logger.info("Cannot use FlashAttention backend for dtype other than "
  81. "torch.float16 or torch.bfloat16.")
  82. return _Backend.XFORMERS
  83. if block_size % 16 != 0:
  84. logger.info("Cannot use FlashAttention-2 backend for block size not "
  85. "divisible by 16.")
  86. return _Backend.XFORMERS
  87. if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
  88. logger.info("Cannot use FlashAttention-2 backend for FP8 KV cache.")
  89. return _Backend.XFORMERS
  90. if block_size % 16 != 0:
  91. logger.info("Cannot use FlashAttention-2 backend for block size not "
  92. "divisible by 16.")
  93. return _Backend.XFORMERS
  94. if sliding_window is not None:
  95. logger.info(
  96. "Cannot use FlashAttention-2 backend due to sliding window.")
  97. return _Backend.XFORMERS
  98. try:
  99. import vllm_flash_attn # noqa: F401
  100. except ImportError:
  101. logger.info(
  102. "Cannot use FlashAttention-2 backend because the vllm_flash_attn "
  103. "package is not found. `pip install vllm-flash-attn` for better "
  104. "performance.")
  105. return _Backend.XFORMERS
  106. backend_by_env_var = os.getenv(APHRODITE_ATTENTION_BACKEND)
  107. if backend_by_env_var is not None:
  108. return _Backend[backend_by_env_var.upper()]
  109. # Default case.
  110. return _Backend.FLASH_ATTN