selector.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. import enum
  2. import os
  3. from functools import lru_cache
  4. from typing import Optional, Type
  5. import torch
  6. from loguru import logger
  7. from aphrodite.attention.backends.abstract import AttentionBackend
  8. from aphrodite.common.utils import is_cpu, is_hip
  9. APHRODITE_ATTENTION_BACKEND = "APHRODITE_ATTENTION_BACKEND"
  10. class _Backend(enum.Enum):
  11. FLASH_ATTN = enum.auto()
  12. XFORMERS = enum.auto()
  13. ROCM_FLASH = enum.auto()
  14. TORCH_SDPA = enum.auto()
  15. FLASHINFER = enum.auto()
  16. @lru_cache(maxsize=None)
  17. def get_attn_backend(
  18. num_heads: int,
  19. head_size: int,
  20. num_kv_heads: int,
  21. sliding_window: Optional[int],
  22. dtype: torch.dtype,
  23. kv_cache_dtype: Optional[str],
  24. block_size: int,
  25. is_blocksparse: bool = False,
  26. ) -> Type[AttentionBackend]:
  27. if is_blocksparse:
  28. logger.info("Using BlocksparseFlashAttention backend.")
  29. from aphrodite.attention.backends.blocksparse_attn import \
  30. BlocksparseFlashAttentionBackend
  31. return BlocksparseFlashAttentionBackend
  32. """Determine which attention backend to use and only import
  33. the selected backend module.
  34. """
  35. backend = which_attn_to_use(num_heads, head_size, num_kv_heads,
  36. sliding_window, dtype, kv_cache_dtype,
  37. block_size)
  38. if backend == _Backend.FLASH_ATTN:
  39. from aphrodite.attention.backends.flash_attn import \
  40. FlashAttentionBackend # noqa: F401
  41. return FlashAttentionBackend
  42. if backend == _Backend.XFORMERS:
  43. logger.info("Using XFormers backend.")
  44. from aphrodite.attention.backends.xformers import \
  45. XFormersBackend # noqa: F401
  46. return XFormersBackend
  47. elif backend == _Backend.ROCM_FLASH:
  48. logger.info("Using ROCmFlashAttention backend.")
  49. from aphrodite.attention.backends.rocm_flash_attn import \
  50. ROCmFlashAttentionBackend # noqa: F401
  51. return ROCmFlashAttentionBackend
  52. elif backend == _Backend.TORCH_SDPA:
  53. logger.info("Using Torch SDPA backend.")
  54. from aphrodite.attention.backends.torch_sdpa import TorchSDPABackend
  55. return TorchSDPABackend
  56. elif backend == _Backend.FLASHINFER:
  57. logger.info("Using Flashinfer backend.")
  58. logger.warning("Eager mode is required for the Flashinfer backend. "
  59. "Please make sure --enforce-eager is set.")
  60. from aphrodite.attention.backends.flashinfer import FlashInferBackend
  61. return FlashInferBackend
  62. else:
  63. raise ValueError("Invalid attention backend.")
  64. def which_attn_to_use(
  65. num_heads: int,
  66. head_size: int,
  67. num_kv_heads: int,
  68. sliding_window: Optional[int],
  69. dtype: torch.dtype,
  70. kv_cache_dtype: Optional[str],
  71. block_size: int,
  72. ) -> _Backend:
  73. """Returns which flash attention backend to use."""
  74. # Default case.
  75. selected_backend = _Backend.FLASH_ATTN
  76. # Check the environment variable and override if specified
  77. backend_by_env_var: Optional[str] = os.getenv(APHRODITE_ATTENTION_BACKEND)
  78. if backend_by_env_var is not None:
  79. backend_members = _Backend.__members__
  80. if backend_by_env_var.upper() not in backend_members:
  81. raise ValueError(
  82. f"Invalid attention backend '{backend_by_env_var}'. "
  83. f"Available backends: {', '.join(backend_members)} ")
  84. selected_backend = _Backend[backend_by_env_var.upper()]
  85. if is_cpu():
  86. if selected_backend != _Backend.TORCH_SDPA:
  87. logger.info(f"Cannot use {selected_backend} backend on CPU.")
  88. return _Backend.TORCH_SDPA
  89. if is_hip():
  90. # AMD GPUs.
  91. selected_backend = (_Backend.ROCM_FLASH if selected_backend
  92. == _Backend.FLASH_ATTN else selected_backend)
  93. if selected_backend == _Backend.ROCM_FLASH:
  94. if torch.cuda.get_device_capability()[0] != 9:
  95. # not Instinct series GPUs.
  96. logger.info("flash_attn is not supported on NAVI GPUs.")
  97. else:
  98. logger.info("f{selected_backend} is not supported in AMD GPUs.")
  99. return _Backend.ROCM_FLASH
  100. # FlashAttn in NVIDIA GPUs.
  101. if selected_backend == _Backend.FLASH_ATTN:
  102. if torch.cuda.get_device_capability()[0] < 8:
  103. # Volta and Turing NVIDIA GPUs.
  104. logger.info(
  105. "Cannot use FlashAttention-2 backend for Volta and Turing "
  106. "GPUs.")
  107. selected_backend = _Backend.XFORMERS
  108. elif dtype not in (torch.float16, torch.bfloat16):
  109. logger.info(
  110. "Cannot use FlashAttention-2 backend for dtype other than "
  111. "torch.float16 or torch.bfloat16.")
  112. selected_backend = _Backend.XFORMERS
  113. elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
  114. logger.info(
  115. "Cannot use FlashAttention-2 backend for FP8 KV cache.")
  116. selected_backend = _Backend.XFORMERS
  117. elif block_size % 16 != 0:
  118. logger.info(
  119. "Cannot use FlashAttention-2 backend for block size not "
  120. "divisible by 16.")
  121. selected_backend = _Backend.XFORMERS
  122. elif sliding_window is not None:
  123. logger.info(
  124. "Cannot use FlashAttention-2 backend due to sliding window.")
  125. selected_backend = _Backend.XFORMERS
  126. # FlashAttn is valid for the model, checking if the package is installed.
  127. if selected_backend == _Backend.FLASH_ATTN:
  128. try:
  129. import vllm_flash_attn # noqa: F401
  130. from aphrodite.attention.backends.flash_attn import \
  131. FlashAttentionBackend # noqa: F401
  132. supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
  133. if head_size not in supported_sizes:
  134. logger.info(
  135. "Cannot use FlashAttention-2 backend for head size "
  136. f"{head_size}")
  137. selected_backend = _Backend.XFORMERS
  138. except ImportError:
  139. logger.info(
  140. "Cannot use FlashAttention-2 backend because the "
  141. "vllm_flash_attn package is not found. "
  142. "`pip install vllm-flash-attn` for better performance.")
  143. selected_backend = _Backend.XFORMERS
  144. return selected_backend