selector.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. import enum
  2. import os
  3. from contextlib import contextmanager
  4. from functools import lru_cache
  5. from typing import Generator, Optional, Type
  6. import torch
  7. from loguru import logger
  8. import aphrodite.common.envs as envs
  9. from aphrodite.attention.backends.abstract import AttentionBackend
  10. from aphrodite.common.utils import (STR_BACKEND_ENV_VAR, is_cpu, is_hip,
  11. is_openvino, is_xpu)
  12. from aphrodite.platforms import current_platform
  13. APHRODITE_ATTENTION_BACKEND = envs.APHRODITE_ATTENTION_BACKEND
  14. class _Backend(enum.Enum):
  15. FLASH_ATTN = enum.auto()
  16. XFORMERS = enum.auto()
  17. ROCM_FLASH = enum.auto()
  18. TORCH_SDPA = enum.auto()
  19. OPENVINO = enum.auto()
  20. FLASHINFER = enum.auto()
  21. PALLAS = enum.auto()
  22. IPEX = enum.auto()
  23. NO_ATTENTION = enum.auto()
  24. def backend_name_to_enum(backend_name: str) -> _Backend:
  25. assert backend_name is not None
  26. backend_members = _Backend.__members__
  27. if backend_name not in backend_members:
  28. raise ValueError(f"Invalid attention backend '{backend_name}'. "
  29. f"Available backends: {', '.join(backend_members)} "
  30. "(case-sensitive).")
  31. return _Backend[backend_name]
  32. def get_env_variable_attn_backend() -> Optional[_Backend]:
  33. '''
  34. Get the backend override specified by the Aphrodite attention
  35. backend environment variable, if one is specified.
  36. Returns:
  37. * _Backend enum value if an override is specified
  38. * None otherwise
  39. '''
  40. backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
  41. return (None
  42. if backend_name is None else backend_name_to_enum(backend_name))
  43. # Global state allows a particular choice of backend
  44. # to be forced, overriding the logic which auto-selects
  45. # a backend based on system & workload configuration
  46. # (default behavior if this variable is None)
  47. #
  48. # THIS SELECTION TAKES PRECEDENCE OVER THE
  49. # APHRODITE ATTENTION BACKEND ENVIRONMENT VARIABLE
  50. forced_attn_backend: Optional[_Backend] = None
  51. def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
  52. '''
  53. Force all attention operations to use a specified backend.
  54. Passing `None` for the argument re-enables automatic
  55. backend selection.,
  56. Arguments:
  57. * attn_backend: backend selection (None to revert to auto)
  58. '''
  59. global forced_attn_backend
  60. forced_attn_backend = attn_backend
  61. def get_global_forced_attn_backend() -> Optional[_Backend]:
  62. '''
  63. Get the currently-forced choice of attention backend,
  64. or None if auto-selection is currently enabled.
  65. '''
  66. return forced_attn_backend
  67. @lru_cache(maxsize=None)
  68. def get_attn_backend(
  69. head_size: int,
  70. sliding_window: Optional[int],
  71. dtype: torch.dtype,
  72. kv_cache_dtype: Optional[str],
  73. block_size: int,
  74. is_attention_free: bool,
  75. is_blocksparse: bool = False,
  76. ) -> Type[AttentionBackend]:
  77. """Selects which attention backend to use and lazily imports it."""
  78. if is_blocksparse:
  79. logger.info("Using BlocksparseFlashAttention backend.")
  80. from aphrodite.attention.backends.blocksparse_attn import (
  81. BlocksparseFlashAttentionBackend)
  82. return BlocksparseFlashAttentionBackend
  83. backend = which_attn_to_use(head_size, sliding_window, dtype,
  84. kv_cache_dtype, block_size, is_attention_free)
  85. if backend == _Backend.FLASH_ATTN:
  86. from aphrodite.attention.backends.flash_attn import ( # noqa: F401
  87. FlashAttentionBackend)
  88. return FlashAttentionBackend
  89. if backend == _Backend.XFORMERS:
  90. logger.info("Using XFormers backend.")
  91. from aphrodite.attention.backends.xformers import ( # noqa: F401
  92. XFormersBackend)
  93. return XFormersBackend
  94. elif backend == _Backend.ROCM_FLASH:
  95. logger.info("Using ROCmFlashAttention backend.")
  96. from aphrodite.attention.backends.rocm_flash_attn import ( # noqa: F401
  97. ROCmFlashAttentionBackend)
  98. return ROCmFlashAttentionBackend
  99. elif backend == _Backend.TORCH_SDPA:
  100. assert is_cpu(), RuntimeError(
  101. "Torch SDPA backend is only used for the CPU device.")
  102. logger.info("Using Torch SDPA backend.")
  103. from aphrodite.attention.backends.torch_sdpa import TorchSDPABackend
  104. return TorchSDPABackend
  105. elif backend == _Backend.OPENVINO:
  106. logger.info("Using OpenVINO Attention backend.")
  107. from aphrodite.attention.backends.openvino import (
  108. OpenVINOAttentionBackend)
  109. return OpenVINOAttentionBackend
  110. elif backend == _Backend.IPEX:
  111. assert is_xpu(), RuntimeError(
  112. "IPEX attention backend is only used for the XPU device.")
  113. logger.info("Using IPEX attention backend.")
  114. from aphrodite.attention.backends.ipex_attn import IpexAttnBackend
  115. return IpexAttnBackend
  116. elif backend == _Backend.FLASHINFER:
  117. logger.info("Using Flashinfer backend.")
  118. from aphrodite.attention.backends.flashinfer import FlashInferBackend
  119. return FlashInferBackend
  120. elif backend == _Backend.PALLAS:
  121. logger.info("Using Pallas backend.")
  122. from aphrodite.attention.backends.pallas import PallasAttentionBackend
  123. return PallasAttentionBackend
  124. elif backend == _Backend.NO_ATTENTION:
  125. from aphrodite.attention.backends.placeholder_attn import (
  126. PlaceholderAttentionBackend)
  127. return PlaceholderAttentionBackend
  128. else:
  129. raise ValueError("Invalid attention backend.")
  130. def which_attn_to_use(
  131. head_size: int,
  132. sliding_window: Optional[int],
  133. dtype: torch.dtype,
  134. kv_cache_dtype: Optional[str],
  135. block_size: int,
  136. is_attention_free: bool,
  137. ) -> _Backend:
  138. """Returns which flash attention backend to use."""
  139. # Default case.
  140. selected_backend = _Backend.FLASH_ATTN
  141. # If there are no attention layers (e.g. we are running Mamba),
  142. # use the placeholder NO_ATTENTION
  143. if is_attention_free:
  144. return _Backend.NO_ATTENTION
  145. # Check whether a particular choice of backend was
  146. # previously forced.
  147. #
  148. # THIS SELECTION OVERRIDES THE APHRODITE_ATTENTION_BACKEND
  149. # ENVIRONMENT VARIABLE.
  150. backend_by_global_setting: Optional[_Backend] = (
  151. get_global_forced_attn_backend())
  152. if backend_by_global_setting is not None:
  153. selected_backend = backend_by_global_setting
  154. else:
  155. # Check the environment variable and override if specified
  156. backend_by_env_var: Optional[str] = APHRODITE_ATTENTION_BACKEND
  157. if backend_by_env_var is not None:
  158. selected_backend = backend_name_to_enum(backend_by_env_var)
  159. if is_cpu():
  160. if selected_backend != _Backend.TORCH_SDPA:
  161. logger.info(f"Cannot use {selected_backend} backend on CPU.")
  162. return _Backend.TORCH_SDPA
  163. if is_openvino():
  164. if selected_backend != _Backend.OPENVINO:
  165. logger.info(f"Cannot use {selected_backend} backend on OpenVINO.")
  166. return _Backend.OPENVINO
  167. if is_xpu():
  168. if selected_backend != _Backend.IPEX:
  169. logger.info(f"Cannot use {selected_backend} backend on XPU.")
  170. return _Backend.IPEX
  171. if current_platform.is_tpu():
  172. if selected_backend != _Backend.PALLAS:
  173. logger.info(f"Cannot use {selected_backend} backend on TPU.")
  174. return _Backend.PALLAS
  175. if is_hip():
  176. # AMD GPUs.
  177. selected_backend = (_Backend.ROCM_FLASH if selected_backend
  178. == _Backend.FLASH_ATTN else selected_backend)
  179. if selected_backend == _Backend.ROCM_FLASH:
  180. if current_platform.get_device_capability()[0] != 9:
  181. # not Instinct series GPUs.
  182. logger.info("flash_attn is not supported on NAVI GPUs.")
  183. else:
  184. logger.info(f"{selected_backend} is not supported in AMD GPUs.")
  185. return _Backend.ROCM_FLASH
  186. # FlashAttn in NVIDIA GPUs.
  187. if selected_backend == _Backend.FLASH_ATTN:
  188. if current_platform.get_device_capability()[0] < 8:
  189. # Volta and Turing NVIDIA GPUs.
  190. logger.info(
  191. "Cannot use FlashAttention-2 backend for Volta and Turing "
  192. "GPUs.")
  193. selected_backend = _Backend.XFORMERS
  194. elif dtype not in (torch.float16, torch.bfloat16):
  195. logger.info(
  196. "Cannot use FlashAttention-2 backend for dtype other than "
  197. "torch.float16 or torch.bfloat16.")
  198. selected_backend = _Backend.XFORMERS
  199. elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
  200. logger.info(
  201. "Cannot use FlashAttention-2 backend for FP8 KV cache.")
  202. logger.warning(
  203. "Please use FlashInfer backend with FP8 KV Cache for "
  204. "better performance by setting the environment "
  205. "variable APHRODITE_ATTENTION_BACKEND=FLASHINFER")
  206. selected_backend = _Backend.XFORMERS
  207. elif block_size % 16 != 0:
  208. logger.info(
  209. "Cannot use FlashAttention-2 backend for block size not "
  210. "divisible by 16.")
  211. selected_backend = _Backend.XFORMERS
  212. elif sliding_window is not None:
  213. logger.info(
  214. "Cannot use FlashAttention-2 backend due to sliding window.")
  215. selected_backend = _Backend.XFORMERS
  216. # FlashAttn is valid for the model, checking if the package is installed.
  217. if selected_backend == _Backend.FLASH_ATTN:
  218. try:
  219. import aphrodite_flash_attn # noqa: F401
  220. from aphrodite.attention.backends.flash_attn import ( # noqa: F401
  221. FlashAttentionBackend)
  222. supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
  223. if head_size not in supported_sizes:
  224. logger.info(
  225. "Cannot use FlashAttention-2 backend for head size "
  226. f"{head_size}")
  227. selected_backend = _Backend.XFORMERS
  228. except ImportError:
  229. logger.info(
  230. "Cannot use FlashAttention-2 backend because the "
  231. "aphrodite_flash_attn package is not found. "
  232. "`pip install aphrodite-flash-attn` for better performance.")
  233. selected_backend = _Backend.XFORMERS
  234. return selected_backend
  235. @contextmanager
  236. def global_force_attn_backend_context_manager(
  237. attn_backend: _Backend) -> Generator[None, None, None]:
  238. '''
  239. Globally force a Aphrodite attention backend override within a
  240. context manager, reverting the global attention backend
  241. override to its prior state upon exiting the context
  242. manager.
  243. Arguments:
  244. * attn_backend: attention backend to force
  245. Returns:
  246. * Generator
  247. '''
  248. # Save the current state of the global backend override (if any)
  249. original_value = get_global_forced_attn_backend()
  250. # Globally force the new backend override
  251. global_force_attn_backend(attn_backend)
  252. # Yield control back to the enclosed code block
  253. try:
  254. yield
  255. finally:
  256. # Revert the original global backend override, if any
  257. global_force_attn_backend(original_value)