selector.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. import enum
  2. import os
  3. from contextlib import contextmanager
  4. from functools import lru_cache
  5. from typing import Generator, Optional, Type
  6. import torch
  7. from loguru import logger
  8. from aphrodite.attention.backends.abstract import AttentionBackend
  9. from aphrodite.common.utils import (STR_BACKEND_ENV_VAR, is_cpu, is_hip,
  10. is_openvino, is_tpu, is_xpu)
  11. from aphrodite.platforms import current_platform
  12. APHRODITE_ATTENTION_BACKEND = os.getenv("APHRODITE_ATTENTION_BACKEND", None)
  13. class _Backend(enum.Enum):
  14. FLASH_ATTN = enum.auto()
  15. XFORMERS = enum.auto()
  16. ROCM_FLASH = enum.auto()
  17. TORCH_SDPA = enum.auto()
  18. OPENVINO = enum.auto()
  19. FLASHINFER = enum.auto()
  20. PALLAS = enum.auto()
  21. IPEX = enum.auto()
  22. def backend_name_to_enum(backend_name: str) -> _Backend:
  23. assert backend_name is not None
  24. backend_members = _Backend.__members__
  25. if backend_name not in backend_members:
  26. raise ValueError(f"Invalid attention backend '{backend_name}'. "
  27. f"Available backends: {', '.join(backend_members)} "
  28. "(case-sensitive).")
  29. return _Backend[backend_name]
  30. def get_env_variable_attn_backend() -> Optional[_Backend]:
  31. '''
  32. Get the backend override specified by the Aphrodite attention
  33. backend environment variable, if one is specified.
  34. Returns:
  35. * _Backend enum value if an override is specified
  36. * None otherwise
  37. '''
  38. backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
  39. return (None
  40. if backend_name is None else backend_name_to_enum(backend_name))
  41. # Global state allows a particular choice of backend
  42. # to be forced, overriding the logic which auto-selects
  43. # a backend based on system & workload configuration
  44. # (default behavior if this variable is None)
  45. #
  46. # THIS SELECTION TAKES PRECEDENCE OVER THE
  47. # APHRODITE ATTENTION BACKEND ENVIRONMENT VARIABLE
  48. forced_attn_backend: Optional[_Backend] = None
  49. def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
  50. '''
  51. Force all attention operations to use a specified backend.
  52. Passing `None` for the argument re-enables automatic
  53. backend selection.,
  54. Arguments:
  55. * attn_backend: backend selection (None to revert to auto)
  56. '''
  57. global forced_attn_backend
  58. forced_attn_backend = attn_backend
  59. def get_global_forced_attn_backend() -> Optional[_Backend]:
  60. '''
  61. Get the currently-forced choice of attention backend,
  62. or None if auto-selection is currently enabled.
  63. '''
  64. return forced_attn_backend
  65. @lru_cache(maxsize=None)
  66. def get_attn_backend(
  67. num_heads: int,
  68. head_size: int,
  69. num_kv_heads: int,
  70. sliding_window: Optional[int],
  71. dtype: torch.dtype,
  72. kv_cache_dtype: Optional[str],
  73. block_size: int,
  74. is_blocksparse: bool = False,
  75. ) -> Type[AttentionBackend]:
  76. """Selects which attention backend to use and lazily imports it."""
  77. if is_blocksparse:
  78. logger.info("Using BlocksparseFlashAttention backend.")
  79. from aphrodite.attention.backends.blocksparse_attn import (
  80. BlocksparseFlashAttentionBackend)
  81. return BlocksparseFlashAttentionBackend
  82. backend = which_attn_to_use(num_heads, head_size, num_kv_heads,
  83. sliding_window, dtype, kv_cache_dtype,
  84. block_size)
  85. if backend == _Backend.FLASH_ATTN:
  86. from aphrodite.attention.backends.flash_attn import ( # noqa: F401
  87. FlashAttentionBackend)
  88. return FlashAttentionBackend
  89. if backend == _Backend.XFORMERS:
  90. logger.info("Using XFormers backend.")
  91. from aphrodite.attention.backends.xformers import ( # noqa: F401
  92. XFormersBackend)
  93. return XFormersBackend
  94. elif backend == _Backend.ROCM_FLASH:
  95. logger.info("Using ROCmFlashAttention backend.")
  96. from aphrodite.attention.backends.rocm_flash_attn import ( # noqa: F401
  97. ROCmFlashAttentionBackend)
  98. return ROCmFlashAttentionBackend
  99. elif backend == _Backend.TORCH_SDPA:
  100. assert is_cpu(), RuntimeError(
  101. "Torch SDPA backend is only used for the CPU device.")
  102. logger.info("Using Torch SDPA backend.")
  103. from aphrodite.attention.backends.torch_sdpa import TorchSDPABackend
  104. return TorchSDPABackend
  105. elif backend == _Backend.OPENVINO:
  106. logger.info("Using OpenVINO Attention backend.")
  107. from aphrodite.attention.backends.openvino import (
  108. OpenVINOAttentionBackend)
  109. return OpenVINOAttentionBackend
  110. elif backend == _Backend.IPEX:
  111. assert is_xpu(), RuntimeError(
  112. "IPEX attention backend is only used for the XPU device.")
  113. logger.info("Using IPEX attention backend.")
  114. from aphrodite.attention.backends.ipex_attn import IpexAttnBackend
  115. return IpexAttnBackend
  116. elif backend == _Backend.FLASHINFER:
  117. logger.info("Using Flashinfer backend.")
  118. from aphrodite.attention.backends.flashinfer import FlashInferBackend
  119. return FlashInferBackend
  120. elif backend == _Backend.PALLAS:
  121. logger.info("Using Pallas backend.")
  122. from aphrodite.attention.backends.pallas import PallasAttentionBackend
  123. return PallasAttentionBackend
  124. else:
  125. raise ValueError("Invalid attention backend.")
  126. def which_attn_to_use(
  127. num_heads: int,
  128. head_size: int,
  129. num_kv_heads: int,
  130. sliding_window: Optional[int],
  131. dtype: torch.dtype,
  132. kv_cache_dtype: Optional[str],
  133. block_size: int,
  134. ) -> _Backend:
  135. """Returns which flash attention backend to use."""
  136. # Default case.
  137. selected_backend = _Backend.FLASH_ATTN
  138. # Check whether a particular choice of backend was
  139. # previously forced.
  140. #
  141. # THIS SELECTION OVERRIDES THE APHRODITE_ATTENTION_BACKEND
  142. # ENVIRONMENT VARIABLE.
  143. backend_by_global_setting: Optional[_Backend] = (
  144. get_global_forced_attn_backend())
  145. if backend_by_global_setting is not None:
  146. selected_backend = backend_by_global_setting
  147. else:
  148. # Check the environment variable and override if specified
  149. backend_by_env_var: Optional[str] = APHRODITE_ATTENTION_BACKEND
  150. if backend_by_env_var is not None:
  151. selected_backend = backend_name_to_enum(backend_by_env_var)
  152. if is_cpu():
  153. if selected_backend != _Backend.TORCH_SDPA:
  154. logger.info(f"Cannot use {selected_backend} backend on CPU.")
  155. return _Backend.TORCH_SDPA
  156. if is_openvino():
  157. if selected_backend != _Backend.OPENVINO:
  158. logger.info(f"Cannot use {selected_backend} backend on OpenVINO.")
  159. return _Backend.OPENVINO
  160. if is_xpu():
  161. if selected_backend != _Backend.IPEX:
  162. logger.info(f"Cannot use {selected_backend} backend on XPU.")
  163. return _Backend.IPEX
  164. if is_tpu():
  165. if selected_backend != _Backend.PALLAS:
  166. logger.info(f"Cannot use {selected_backend} backend on TPU.")
  167. return _Backend.PALLAS
  168. if is_hip():
  169. # AMD GPUs.
  170. selected_backend = (_Backend.ROCM_FLASH if selected_backend
  171. == _Backend.FLASH_ATTN else selected_backend)
  172. if selected_backend == _Backend.ROCM_FLASH:
  173. if current_platform.get_device_capability()[0] != 9:
  174. # not Instinct series GPUs.
  175. logger.info("flash_attn is not supported on NAVI GPUs.")
  176. else:
  177. logger.info(f"{selected_backend} is not supported in AMD GPUs.")
  178. return _Backend.ROCM_FLASH
  179. # FlashAttn in NVIDIA GPUs.
  180. if selected_backend == _Backend.FLASH_ATTN:
  181. if current_platform.get_device_capability()[0] < 8:
  182. # Volta and Turing NVIDIA GPUs.
  183. logger.info(
  184. "Cannot use FlashAttention-2 backend for Volta and Turing "
  185. "GPUs.")
  186. selected_backend = _Backend.XFORMERS
  187. elif dtype not in (torch.float16, torch.bfloat16):
  188. logger.info(
  189. "Cannot use FlashAttention-2 backend for dtype other than "
  190. "torch.float16 or torch.bfloat16.")
  191. selected_backend = _Backend.XFORMERS
  192. elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
  193. logger.info(
  194. "Cannot use FlashAttention-2 backend for FP8 KV cache.")
  195. selected_backend = _Backend.XFORMERS
  196. elif block_size % 16 != 0:
  197. logger.info(
  198. "Cannot use FlashAttention-2 backend for block size not "
  199. "divisible by 16.")
  200. selected_backend = _Backend.XFORMERS
  201. elif sliding_window is not None:
  202. logger.info(
  203. "Cannot use FlashAttention-2 backend due to sliding window.")
  204. selected_backend = _Backend.XFORMERS
  205. # FlashAttn is valid for the model, checking if the package is installed.
  206. if selected_backend == _Backend.FLASH_ATTN:
  207. try:
  208. import aphrodite_flash_attn # noqa: F401
  209. from aphrodite.attention.backends.flash_attn import ( # noqa: F401
  210. FlashAttentionBackend)
  211. supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
  212. if head_size not in supported_sizes:
  213. logger.info(
  214. "Cannot use FlashAttention-2 backend for head size "
  215. f"{head_size}")
  216. selected_backend = _Backend.XFORMERS
  217. except ImportError:
  218. logger.info(
  219. "Cannot use FlashAttention-2 backend because the "
  220. "aphrodite_flash_attn package is not found. "
  221. "`pip install aphrodite-flash-attn` for better performance.")
  222. selected_backend = _Backend.XFORMERS
  223. return selected_backend
  224. @contextmanager
  225. def global_force_attn_backend_context_manager(
  226. attn_backend: _Backend) -> Generator[None, None, None]:
  227. '''
  228. Globally force a Aphrodite attention backend override within a
  229. context manager, reverting the global attention backend
  230. override to its prior state upon exiting the context
  231. manager.
  232. Arguments:
  233. * attn_backend: attention backend to force
  234. Returns:
  235. * Generator
  236. '''
  237. # Save the current state of the global backend override (if any)
  238. original_value = get_global_forced_attn_backend()
  239. # Globally force the new backend override
  240. global_force_attn_backend(attn_backend)
  241. # Yield control back to the enclosed code block
  242. try:
  243. yield
  244. finally:
  245. # Revert the original global backend override, if any
  246. global_force_attn_backend(original_value)