1
0

selector.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. import enum
  2. import os
  3. from functools import lru_cache
  4. from typing import Optional, Type
  5. import torch
  6. from loguru import logger
  7. from aphrodite.attention.backends.abstract import AttentionBackend
  8. from aphrodite.common.utils import is_cpu, is_hip, is_openvino, is_tpu, is_xpu
  9. from aphrodite.platforms import current_platform
  10. APHRODITE_ATTENTION_BACKEND = "APHRODITE_ATTENTION_BACKEND"
  11. class _Backend(enum.Enum):
  12. FLASH_ATTN = enum.auto()
  13. XFORMERS = enum.auto()
  14. ROCM_FLASH = enum.auto()
  15. TORCH_SDPA = enum.auto()
  16. OPENVINO = enum.auto()
  17. FLASHINFER = enum.auto()
  18. PALLAS = enum.auto()
  19. IPEX = enum.auto()
  20. @lru_cache(maxsize=None)
  21. def get_attn_backend(
  22. num_heads: int,
  23. head_size: int,
  24. num_kv_heads: int,
  25. sliding_window: Optional[int],
  26. dtype: torch.dtype,
  27. kv_cache_dtype: Optional[str],
  28. block_size: int,
  29. is_blocksparse: bool = False,
  30. ) -> Type[AttentionBackend]:
  31. if is_blocksparse:
  32. logger.info("Using BlocksparseFlashAttention backend.")
  33. from aphrodite.attention.backends.blocksparse_attn import (
  34. BlocksparseFlashAttentionBackend)
  35. return BlocksparseFlashAttentionBackend
  36. """Determine which attention backend to use and only import
  37. the selected backend module.
  38. """
  39. backend = which_attn_to_use(num_heads, head_size, num_kv_heads,
  40. sliding_window, dtype, kv_cache_dtype,
  41. block_size)
  42. if backend == _Backend.FLASH_ATTN:
  43. from aphrodite.attention.backends.flash_attn import ( # noqa: F401
  44. FlashAttentionBackend)
  45. return FlashAttentionBackend
  46. if backend == _Backend.XFORMERS:
  47. logger.info("Using XFormers backend.")
  48. from aphrodite.attention.backends.xformers import ( # noqa: F401
  49. XFormersBackend)
  50. return XFormersBackend
  51. elif backend == _Backend.ROCM_FLASH:
  52. logger.info("Using ROCmFlashAttention backend.")
  53. from aphrodite.attention.backends.rocm_flash_attn import ( # noqa: F401
  54. ROCmFlashAttentionBackend)
  55. return ROCmFlashAttentionBackend
  56. elif backend == _Backend.TORCH_SDPA:
  57. assert is_cpu(), RuntimeError(
  58. "Torch SDPA backend is only used for CPU devices.")
  59. logger.info("Using Torch SDPA backend.")
  60. from aphrodite.attention.backends.torch_sdpa import TorchSDPABackend
  61. return TorchSDPABackend
  62. elif backend == _Backend.OPENVINO:
  63. logger.info("Using OpenVINO attention backend.")
  64. from aphrodite.attention.backends.openvino import (
  65. OpenVINOAttentionBackend)
  66. return OpenVINOAttentionBackend
  67. elif backend == _Backend.IPEX:
  68. assert is_xpu(), RuntimeError(
  69. "IPEX attention backend is only used for the XPU device.")
  70. logger.info("Using IPEX attention backend.")
  71. from aphrodite.attention.backends.ipex_attn import IpexAttnBackend
  72. return IpexAttnBackend
  73. elif backend == _Backend.FLASHINFER:
  74. logger.info("Using Flashinfer backend.")
  75. from aphrodite.attention.backends.flashinfer import FlashInferBackend
  76. return FlashInferBackend
  77. elif backend == _Backend.PALLAS:
  78. logger.info("Using Pallas backend.")
  79. from aphrodite.attention.backends.pallas import PallasAttentionBackend
  80. return PallasAttentionBackend
  81. else:
  82. raise ValueError("Invalid attention backend.")
  83. def which_attn_to_use(
  84. num_heads: int,
  85. head_size: int,
  86. num_kv_heads: int,
  87. sliding_window: Optional[int],
  88. dtype: torch.dtype,
  89. kv_cache_dtype: Optional[str],
  90. block_size: int,
  91. ) -> _Backend:
  92. """Returns which flash attention backend to use."""
  93. # Default case.
  94. selected_backend = _Backend.FLASH_ATTN
  95. # Check the environment variable and override if specified
  96. backend_by_env_var: Optional[str] = os.getenv(APHRODITE_ATTENTION_BACKEND)
  97. if backend_by_env_var is not None:
  98. backend_members = _Backend.__members__
  99. if backend_by_env_var.upper() not in backend_members:
  100. raise ValueError(
  101. f"Invalid attention backend '{backend_by_env_var}'. "
  102. f"Available backends: {', '.join(backend_members)} ")
  103. selected_backend = _Backend[backend_by_env_var.upper()]
  104. if is_cpu():
  105. if selected_backend != _Backend.TORCH_SDPA:
  106. logger.info(f"Cannot use {selected_backend} backend on CPU.")
  107. return _Backend.TORCH_SDPA
  108. if is_openvino():
  109. if selected_backend != _Backend.OPENVINO:
  110. logger.info(f"Cannot use {selected_backend} backend on OpenVINO.")
  111. return _Backend.OPENVINO
  112. if is_xpu():
  113. if selected_backend != _Backend.IPEX:
  114. logger.info(f"Cannot use {selected_backend} backend on XPU.")
  115. return _Backend.IPEX
  116. if is_tpu():
  117. if selected_backend != _Backend.PALLAS:
  118. logger.info(f"Cannot use {selected_backend} backend on TPU.")
  119. return _Backend.PALLAS
  120. if is_hip():
  121. # AMD GPUs.
  122. selected_backend = (_Backend.ROCM_FLASH if selected_backend
  123. == _Backend.FLASH_ATTN else selected_backend)
  124. if selected_backend == _Backend.ROCM_FLASH:
  125. if current_platform.get_device_capability()[0] != 9:
  126. # not Instinct series GPUs.
  127. logger.info("flash_attn is not supported on NAVI GPUs.")
  128. else:
  129. logger.info(f"{selected_backend} is not supported in AMD GPUs.")
  130. return _Backend.ROCM_FLASH
  131. # FlashAttn in NVIDIA GPUs.
  132. if selected_backend == _Backend.FLASH_ATTN:
  133. if current_platform.get_device_capability()[0] < 8:
  134. # Volta and Turing NVIDIA GPUs.
  135. logger.info(
  136. "Cannot use FlashAttention-2 backend for Volta and Turing "
  137. "GPUs.")
  138. selected_backend = _Backend.XFORMERS
  139. elif dtype not in (torch.float16, torch.bfloat16):
  140. logger.info(
  141. "Cannot use FlashAttention-2 backend for dtype other than "
  142. "torch.float16 or torch.bfloat16.")
  143. selected_backend = _Backend.XFORMERS
  144. elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
  145. logger.info(
  146. "Cannot use FlashAttention-2 backend for FP8 KV cache.")
  147. selected_backend = _Backend.XFORMERS
  148. elif block_size % 16 != 0:
  149. logger.info(
  150. "Cannot use FlashAttention-2 backend for block size not "
  151. "divisible by 16.")
  152. selected_backend = _Backend.XFORMERS
  153. elif sliding_window is not None:
  154. logger.info(
  155. "Cannot use FlashAttention-2 backend due to sliding window.")
  156. selected_backend = _Backend.XFORMERS
  157. # FlashAttn is valid for the model, checking if the package is installed.
  158. if selected_backend == _Backend.FLASH_ATTN:
  159. try:
  160. import aphrodite_flash_attn # noqa: F401
  161. from aphrodite.attention.backends.flash_attn import ( # noqa: F401
  162. FlashAttentionBackend)
  163. supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
  164. if head_size not in supported_sizes:
  165. logger.info(
  166. "Cannot use FlashAttention-2 backend for head size "
  167. f"{head_size}")
  168. selected_backend = _Backend.XFORMERS
  169. except ImportError:
  170. logger.info(
  171. "Cannot use FlashAttention-2 backend because the "
  172. "aphrodite_flash_attn package is not found. "
  173. "`pip install aphrodite-flash-attn` for better performance.")
  174. selected_backend = _Backend.XFORMERS
  175. return selected_backend