selector.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. import enum
  2. import os
  3. from contextlib import contextmanager
  4. from functools import lru_cache
  5. from typing import Generator, Optional, Type
  6. import torch
  7. from loguru import logger
  8. from aphrodite.attention.backends.abstract import AttentionBackend
  9. from aphrodite.common.utils import (STR_BACKEND_ENV_VAR, is_cpu, is_hip,
  10. is_openvino, is_xpu)
  11. from aphrodite.platforms import current_platform
  12. APHRODITE_ATTENTION_BACKEND = os.getenv("APHRODITE_ATTENTION_BACKEND", None)
  13. class _Backend(enum.Enum):
  14. FLASH_ATTN = enum.auto()
  15. XFORMERS = enum.auto()
  16. ROCM_FLASH = enum.auto()
  17. TORCH_SDPA = enum.auto()
  18. OPENVINO = enum.auto()
  19. FLASHINFER = enum.auto()
  20. PALLAS = enum.auto()
  21. IPEX = enum.auto()
  22. NO_ATTENTION = enum.auto()
  23. def backend_name_to_enum(backend_name: str) -> _Backend:
  24. assert backend_name is not None
  25. backend_members = _Backend.__members__
  26. if backend_name not in backend_members:
  27. raise ValueError(f"Invalid attention backend '{backend_name}'. "
  28. f"Available backends: {', '.join(backend_members)} "
  29. "(case-sensitive).")
  30. return _Backend[backend_name]
  31. def get_env_variable_attn_backend() -> Optional[_Backend]:
  32. '''
  33. Get the backend override specified by the Aphrodite attention
  34. backend environment variable, if one is specified.
  35. Returns:
  36. * _Backend enum value if an override is specified
  37. * None otherwise
  38. '''
  39. backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
  40. return (None
  41. if backend_name is None else backend_name_to_enum(backend_name))
  42. # Global state allows a particular choice of backend
  43. # to be forced, overriding the logic which auto-selects
  44. # a backend based on system & workload configuration
  45. # (default behavior if this variable is None)
  46. #
  47. # THIS SELECTION TAKES PRECEDENCE OVER THE
  48. # APHRODITE ATTENTION BACKEND ENVIRONMENT VARIABLE
  49. forced_attn_backend: Optional[_Backend] = None
  50. def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
  51. '''
  52. Force all attention operations to use a specified backend.
  53. Passing `None` for the argument re-enables automatic
  54. backend selection.,
  55. Arguments:
  56. * attn_backend: backend selection (None to revert to auto)
  57. '''
  58. global forced_attn_backend
  59. forced_attn_backend = attn_backend
  60. def get_global_forced_attn_backend() -> Optional[_Backend]:
  61. '''
  62. Get the currently-forced choice of attention backend,
  63. or None if auto-selection is currently enabled.
  64. '''
  65. return forced_attn_backend
  66. @lru_cache(maxsize=None)
  67. def get_attn_backend(
  68. head_size: int,
  69. sliding_window: Optional[int],
  70. dtype: torch.dtype,
  71. kv_cache_dtype: Optional[str],
  72. block_size: int,
  73. is_attention_free: bool,
  74. is_blocksparse: bool = False,
  75. ) -> Type[AttentionBackend]:
  76. """Selects which attention backend to use and lazily imports it."""
  77. if is_blocksparse:
  78. logger.info("Using BlocksparseFlashAttention backend.")
  79. from aphrodite.attention.backends.blocksparse_attn import (
  80. BlocksparseFlashAttentionBackend)
  81. return BlocksparseFlashAttentionBackend
  82. backend = which_attn_to_use(head_size, sliding_window, dtype,
  83. kv_cache_dtype, block_size, is_attention_free)
  84. if backend == _Backend.FLASH_ATTN:
  85. from aphrodite.attention.backends.flash_attn import ( # noqa: F401
  86. FlashAttentionBackend)
  87. return FlashAttentionBackend
  88. if backend == _Backend.XFORMERS:
  89. logger.info("Using XFormers backend.")
  90. from aphrodite.attention.backends.xformers import ( # noqa: F401
  91. XFormersBackend)
  92. return XFormersBackend
  93. elif backend == _Backend.ROCM_FLASH:
  94. logger.info("Using ROCmFlashAttention backend.")
  95. from aphrodite.attention.backends.rocm_flash_attn import ( # noqa: F401
  96. ROCmFlashAttentionBackend)
  97. return ROCmFlashAttentionBackend
  98. elif backend == _Backend.TORCH_SDPA:
  99. assert is_cpu(), RuntimeError(
  100. "Torch SDPA backend is only used for the CPU device.")
  101. logger.info("Using Torch SDPA backend.")
  102. from aphrodite.attention.backends.torch_sdpa import TorchSDPABackend
  103. return TorchSDPABackend
  104. elif backend == _Backend.OPENVINO:
  105. logger.info("Using OpenVINO Attention backend.")
  106. from aphrodite.attention.backends.openvino import (
  107. OpenVINOAttentionBackend)
  108. return OpenVINOAttentionBackend
  109. elif backend == _Backend.IPEX:
  110. assert is_xpu(), RuntimeError(
  111. "IPEX attention backend is only used for the XPU device.")
  112. logger.info("Using IPEX attention backend.")
  113. from aphrodite.attention.backends.ipex_attn import IpexAttnBackend
  114. return IpexAttnBackend
  115. elif backend == _Backend.FLASHINFER:
  116. logger.info("Using Flashinfer backend.")
  117. from aphrodite.attention.backends.flashinfer import FlashInferBackend
  118. return FlashInferBackend
  119. elif backend == _Backend.PALLAS:
  120. logger.info("Using Pallas backend.")
  121. from aphrodite.attention.backends.pallas import PallasAttentionBackend
  122. return PallasAttentionBackend
  123. elif backend == _Backend.NO_ATTENTION:
  124. from aphrodite.attention.backends.placeholder_attn import (
  125. PlaceholderAttentionBackend)
  126. return PlaceholderAttentionBackend
  127. else:
  128. raise ValueError("Invalid attention backend.")
  129. def which_attn_to_use(
  130. head_size: int,
  131. sliding_window: Optional[int],
  132. dtype: torch.dtype,
  133. kv_cache_dtype: Optional[str],
  134. block_size: int,
  135. is_attention_free: bool,
  136. ) -> _Backend:
  137. """Returns which flash attention backend to use."""
  138. # Default case.
  139. selected_backend = _Backend.FLASH_ATTN
  140. # If there are no attention layers (e.g. we are running Mamba),
  141. # use the placeholder NO_ATTENTION
  142. if is_attention_free:
  143. return _Backend.NO_ATTENTION
  144. # Check whether a particular choice of backend was
  145. # previously forced.
  146. #
  147. # THIS SELECTION OVERRIDES THE APHRODITE_ATTENTION_BACKEND
  148. # ENVIRONMENT VARIABLE.
  149. backend_by_global_setting: Optional[_Backend] = (
  150. get_global_forced_attn_backend())
  151. if backend_by_global_setting is not None:
  152. selected_backend = backend_by_global_setting
  153. else:
  154. # Check the environment variable and override if specified
  155. backend_by_env_var: Optional[str] = APHRODITE_ATTENTION_BACKEND
  156. if backend_by_env_var is not None:
  157. selected_backend = backend_name_to_enum(backend_by_env_var)
  158. if is_cpu():
  159. if selected_backend != _Backend.TORCH_SDPA:
  160. logger.info(f"Cannot use {selected_backend} backend on CPU.")
  161. return _Backend.TORCH_SDPA
  162. if is_openvino():
  163. if selected_backend != _Backend.OPENVINO:
  164. logger.info(f"Cannot use {selected_backend} backend on OpenVINO.")
  165. return _Backend.OPENVINO
  166. if is_xpu():
  167. if selected_backend != _Backend.IPEX:
  168. logger.info(f"Cannot use {selected_backend} backend on XPU.")
  169. return _Backend.IPEX
  170. if current_platform.is_tpu():
  171. if selected_backend != _Backend.PALLAS:
  172. logger.info(f"Cannot use {selected_backend} backend on TPU.")
  173. return _Backend.PALLAS
  174. if is_hip():
  175. # AMD GPUs.
  176. selected_backend = (_Backend.ROCM_FLASH if selected_backend
  177. == _Backend.FLASH_ATTN else selected_backend)
  178. if selected_backend == _Backend.ROCM_FLASH:
  179. if current_platform.get_device_capability()[0] != 9:
  180. # not Instinct series GPUs.
  181. logger.info("flash_attn is not supported on NAVI GPUs.")
  182. else:
  183. logger.info(f"{selected_backend} is not supported in AMD GPUs.")
  184. return _Backend.ROCM_FLASH
  185. # FlashAttn in NVIDIA GPUs.
  186. if selected_backend == _Backend.FLASH_ATTN:
  187. if current_platform.get_device_capability()[0] < 8:
  188. # Volta and Turing NVIDIA GPUs.
  189. logger.info(
  190. "Cannot use FlashAttention-2 backend for Volta and Turing "
  191. "GPUs.")
  192. selected_backend = _Backend.XFORMERS
  193. elif dtype not in (torch.float16, torch.bfloat16):
  194. logger.info(
  195. "Cannot use FlashAttention-2 backend for dtype other than "
  196. "torch.float16 or torch.bfloat16.")
  197. selected_backend = _Backend.XFORMERS
  198. elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
  199. logger.info(
  200. "Cannot use FlashAttention-2 backend for FP8 KV cache.")
  201. selected_backend = _Backend.XFORMERS
  202. elif block_size % 16 != 0:
  203. logger.info(
  204. "Cannot use FlashAttention-2 backend for block size not "
  205. "divisible by 16.")
  206. selected_backend = _Backend.XFORMERS
  207. elif sliding_window is not None:
  208. logger.info(
  209. "Cannot use FlashAttention-2 backend due to sliding window.")
  210. selected_backend = _Backend.XFORMERS
  211. # FlashAttn is valid for the model, checking if the package is installed.
  212. if selected_backend == _Backend.FLASH_ATTN:
  213. try:
  214. import aphrodite_flash_attn # noqa: F401
  215. from aphrodite.attention.backends.flash_attn import ( # noqa: F401
  216. FlashAttentionBackend)
  217. supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
  218. if head_size not in supported_sizes:
  219. logger.info(
  220. "Cannot use FlashAttention-2 backend for head size "
  221. f"{head_size}")
  222. selected_backend = _Backend.XFORMERS
  223. except ImportError:
  224. logger.info(
  225. "Cannot use FlashAttention-2 backend because the "
  226. "aphrodite_flash_attn package is not found. "
  227. "`pip install aphrodite-flash-attn` for better performance.")
  228. selected_backend = _Backend.XFORMERS
  229. return selected_backend
  230. @contextmanager
  231. def global_force_attn_backend_context_manager(
  232. attn_backend: _Backend) -> Generator[None, None, None]:
  233. '''
  234. Globally force a Aphrodite attention backend override within a
  235. context manager, reverting the global attention backend
  236. override to its prior state upon exiting the context
  237. manager.
  238. Arguments:
  239. * attn_backend: attention backend to force
  240. Returns:
  241. * Generator
  242. '''
  243. # Save the current state of the global backend override (if any)
  244. original_value = get_global_forced_attn_backend()
  245. # Globally force the new backend override
  246. global_force_attn_backend(attn_backend)
  247. # Yield control back to the enclosed code block
  248. try:
  249. yield
  250. finally:
  251. # Revert the original global backend override, if any
  252. global_force_attn_backend(original_value)