selector.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import enum
  2. import os
  3. from functools import lru_cache
  4. from typing import Optional, Type
  5. import torch
  6. from loguru import logger
  7. from aphrodite.attention.backends.abstract import AttentionBackend
  8. from aphrodite.common.utils import is_cpu, is_hip
  9. APHRODITE_ATTENTION_BACKEND = "APHRODITE_ATTENTION_BACKEND"
  10. class _Backend(enum.Enum):
  11. FLASH_ATTN = enum.auto()
  12. XFORMERS = enum.auto()
  13. ROCM_FLASH = enum.auto()
  14. TORCH_SDPA = enum.auto()
  15. FLASHINFER = enum.auto()
  16. @lru_cache(maxsize=None)
  17. def get_attn_backend(
  18. num_heads: int,
  19. head_size: int,
  20. num_kv_heads: int,
  21. sliding_window: Optional[int],
  22. dtype: torch.dtype,
  23. kv_cache_dtype: Optional[str],
  24. block_size: int,
  25. ) -> Type[AttentionBackend]:
  26. """Determine which attention backend to use and only import
  27. the selected backend module.
  28. """
  29. backend = which_attn_to_use(num_heads, head_size, num_kv_heads,
  30. sliding_window, dtype, kv_cache_dtype,
  31. block_size)
  32. if backend == _Backend.FLASH_ATTN:
  33. from aphrodite.attention.backends.flash_attn import \
  34. FlashAttentionBackend # noqa: F401
  35. logger.info("Using FlashAttention backend.")
  36. return FlashAttentionBackend
  37. if backend == _Backend.XFORMERS:
  38. logger.info("Using XFormers backend.")
  39. from aphrodite.attention.backends.xformers import \
  40. XFormersBackend # noqa: F401
  41. return XFormersBackend
  42. elif backend == _Backend.ROCM_FLASH:
  43. logger.info("Using ROCmFlashAttention backend.")
  44. from aphrodite.attention.backends.rocm_flash_attn import \
  45. ROCmFlashAttentionBackend # noqa: F401
  46. return ROCmFlashAttentionBackend
  47. elif backend == _Backend.TORCH_SDPA:
  48. logger.info("Using Torch SDPA backend.")
  49. from aphrodite.attention.backends.torch_sdpa import TorchSDPABackend
  50. return TorchSDPABackend
  51. elif backend == _Backend.FLASHINFER:
  52. logger.info("Using Flashinfer backend.")
  53. logger.warning("Eager mode is required for the Flashinfer backend. "
  54. "Please make sure --enforce-eager is set.")
  55. from aphrodite.attention.backends.flashinfer import FlashInferBackend
  56. return FlashInferBackend
  57. else:
  58. raise ValueError("Invalid attention backend.")
  59. def which_attn_to_use(
  60. num_heads: int,
  61. head_size: int,
  62. num_kv_heads: int,
  63. sliding_window: Optional[int],
  64. dtype: torch.dtype,
  65. kv_cache_dtype: Optional[str],
  66. block_size: int,
  67. ) -> _Backend:
  68. """Returns which flash attention backend to use."""
  69. # Default case.
  70. selected_backend = _Backend.FLASH_ATTN
  71. # Check the environment variable and override if specified
  72. backend_by_env_var: Optional[str] = os.getenv(APHRODITE_ATTENTION_BACKEND)
  73. if backend_by_env_var is not None:
  74. backend_members = _Backend.__members__
  75. if backend_by_env_var.upper() not in backend_members:
  76. raise ValueError(
  77. f"Invalid attention backend '{backend_by_env_var}'. "
  78. f"Available backends: {', '.join(backend_members)} ")
  79. selected_backend = _Backend[backend_by_env_var.upper()]
  80. if is_cpu():
  81. if selected_backend != _Backend.TORCH_SDPA:
  82. logger.info(f"Cannot use {selected_backend} backend on CPU.")
  83. return _Backend.TORCH_SDPA
  84. if is_hip():
  85. # AMD GPUs.
  86. selected_backend = (_Backend.ROCM_FLASH if selected_backend
  87. == _Backend.FLASH_ATTN else selected_backend)
  88. if selected_backend == _Backend.ROCM_FLASH:
  89. if torch.cuda.get_device_capability()[0] != 9:
  90. # not Instinct series GPUs.
  91. logger.info("flash_attn is not supported on NAVI GPUs.")
  92. else:
  93. logger.info("f{selected_backend} is not supported in AMD GPUs.")
  94. return _Backend.ROCM_FLASH
  95. # FlashAttn in NVIDIA GPUs.
  96. if selected_backend == _Backend.FLASH_ATTN:
  97. if torch.cuda.get_device_capability()[0] < 8:
  98. # Volta and Turing NVIDIA GPUs.
  99. logger.info(
  100. "Cannot use FlashAttention-2 backend for Volta and Turing "
  101. "GPUs.")
  102. selected_backend = _Backend.XFORMERS
  103. elif dtype not in (torch.float16, torch.bfloat16):
  104. logger.info(
  105. "Cannot use FlashAttention-2 backend for dtype other than "
  106. "torch.float16 or torch.bfloat16.")
  107. selected_backend = _Backend.XFORMERS
  108. elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
  109. logger.info(
  110. "Cannot use FlashAttention-2 backend for FP8 KV cache.")
  111. selected_backend = _Backend.XFORMERS
  112. elif block_size % 16 != 0:
  113. logger.info(
  114. "Cannot use FlashAttention-2 backend for block size not "
  115. "divisible by 16.")
  116. selected_backend = _Backend.XFORMERS
  117. elif sliding_window is not None:
  118. logger.info(
  119. "Cannot use FlashAttention-2 backend due to sliding window.")
  120. selected_backend = _Backend.XFORMERS
  121. # FlashAttn is valid for the model, checking if the package is installed.
  122. if selected_backend == _Backend.FLASH_ATTN:
  123. try:
  124. import vllm_flash_attn # noqa: F401
  125. from aphrodite.attention.backends.flash_attn import ( # noqa: F401
  126. FlashAttentionBackend)
  127. supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
  128. if head_size not in supported_sizes:
  129. logger.info(
  130. "Cannot use FlashAttention-2 backend for head size "
  131. f"{head_size}")
  132. selected_backend = _Backend.XFORMERS
  133. except ImportError:
  134. logger.info(
  135. "Cannot use FlashAttention-2 backend because the "
  136. "vllm_flash_attn package is not found. "
  137. "`pip install vllm-flash-attn` for better performance.")
  138. selected_backend = _Backend.XFORMERS
  139. return selected_backend