__init__.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import os
  2. from typing import List, Optional, Type
  3. from aphrodite.platforms import current_platform
  4. from aphrodite.quantization.kernels.machete import MacheteLinearKernel
  5. from aphrodite.quantization.kernels.marlin import MarlinLinearKernel
  6. from aphrodite.quantization.kernels.MPLinearKernel import (MPLinearKernel,
  7. MPLinearLayerConfig)
  8. # in priority/performance order (when available)
  9. _POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
  10. MacheteLinearKernel,
  11. MarlinLinearKernel,
  12. ]
  13. def choose_mp_linear_kernel(
  14. config: MPLinearLayerConfig,
  15. compute_capability: Optional[int] = None) -> Type[MPLinearKernel]:
  16. """
  17. Choose an MPLinearKernel that can implement the given config for the given
  18. compute capability. Attempts to choose the best kernel in terms of
  19. performance.
  20. Args:
  21. config (MPLinearLayerConfig): Description of the linear layer to be
  22. implemented.
  23. compute_capability (Optional[int], optional): The compute capability of
  24. the target device, if None uses `current_platform` to get the compute
  25. capability. Defaults to None.
  26. Raises:
  27. ValueError: If no kernel can implement the given config.
  28. Returns:
  29. Type[MPLinearKernel]: Chosen kernel.
  30. """
  31. if compute_capability is None:
  32. if current_platform is None:
  33. raise ValueError("Cannot determine compute capability")
  34. _cc = current_platform.get_device_capability()
  35. compute_capability = _cc[0] * 10 + _cc[1]
  36. failure_reasons = []
  37. for kernel in _POSSIBLE_KERNELS:
  38. if kernel.__name__ in os.environ.get("APHRODITE_DISABLED_KERNELS", "")\
  39. .split(","):
  40. failure_reasons.append(
  41. f' {kernel.__name__} disabled by environment variable')
  42. continue
  43. if kernel.get_min_capability() > compute_capability:
  44. failure_reasons.append(
  45. f"{kernel.__name__} requires capability "
  46. f"{kernel.get_min_capability()}, current compute capability "
  47. f"is {compute_capability}")
  48. continue
  49. can_implement, failure_reason = kernel.can_implement(config)
  50. if can_implement:
  51. return kernel
  52. else:
  53. failure_reasons.append(
  54. f' {kernel.__name__} cannot implement due to: {failure_reason}'
  55. )
  56. raise ValueError(
  57. "Failed to find a kernel that can implement the "\
  58. "WNA16 linear layer. Reasons: \n"
  59. + '\n'.join(failure_reasons))