selector.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import enum
  2. from functools import lru_cache
  3. from typing import Type
  4. import torch
  5. from loguru import logger
  6. from aphrodite.attention.backends.abstract import AttentionBackend
  7. from aphrodite.common.utils import is_cpu, is_hip
  8. class _Backend(enum.Enum):
  9. FLASH_ATTN = enum.auto()
  10. XFORMERS = enum.auto()
  11. ROCM_FLASH = enum.auto()
  12. TORCH_SDPA = enum.auto()
  13. @lru_cache(maxsize=None)
  14. def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
  15. backend = _which_attn_to_use(dtype)
  16. if backend == _Backend.FLASH_ATTN:
  17. logger.info("Using FlashAttention backend.")
  18. from aphrodite.attention.backends.flash_attn import FlashAttentionBackend # noqa: E501
  19. return FlashAttentionBackend
  20. elif backend == _Backend.XFORMERS:
  21. logger.info("Using XFormers backend.")
  22. from aphrodite.attention.backends.xformers import XFormersBackend # noqa: F501
  23. return XFormersBackend
  24. elif backend == _Backend.ROCM_FLASH:
  25. logger.info("Using ROCm FlashAttention backend.")
  26. from aphrodite.attention.backends.rocm_flash_attn import ( # noqa: F401
  27. ROCmFlashAttentionBackend)
  28. return ROCmFlashAttentionBackend
  29. elif backend == _Backend.TORCH_SDPA:
  30. logger.info("Using Torch SDPA backend.")
  31. from aphrodite.attention.backends.sdpa import TorchSDPABackend
  32. return TorchSDPABackend
  33. else:
  34. raise ValueError("Invalid attention backend.")
  35. def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
  36. """Returns which flash attention backend to use."""
  37. if is_cpu():
  38. return _Backend.TORCH_SDPA
  39. if is_hip():
  40. # AMD GPUs.
  41. if torch.cuda.get_device_capability()[0] != 9:
  42. # not Instinct series GPUs.
  43. logger.info("flash_atten is not supported on NAVI GPUs.")
  44. return _Backend.ROCM_FLASH
  45. # NVIDIA GPUs.
  46. if torch.cuda.get_device_capability()[0] < 8:
  47. # Volta and Turing NVIDIA GPUs.
  48. logger.info("Cannot use FlashAttention backend for Volta and Turing "
  49. "GPUs.")
  50. return _Backend.XFORMERS
  51. if dtype not in (torch.float16, torch.bfloat16):
  52. logger.info("Cannot use FlashAttention backend for dtype other than "
  53. "torch.float16 or torch.bfloat16.")
  54. return _Backend.XFORMERS
  55. try:
  56. import flash_attn # noqa: F401
  57. except ImportError:
  58. logger.info(
  59. "Cannot use FlashAttention backend because the flash_attn package "
  60. "is not found. Please install it for better performance.")
  61. return _Backend.XFORMERS
  62. return _Backend.FLASH_ATTN