fbgemm_fp8.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from torch.nn import Module
  4. from torch.nn.parameter import Parameter
  5. from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
  6. UnquantizedLinearMethod)
  7. from aphrodite.modeling.utils import set_weight_attrs
  8. from aphrodite.platforms import current_platform
  9. from aphrodite.quantization.base_config import (QuantizationConfig,
  10. QuantizeMethodBase)
  11. from aphrodite.quantization.fp8 import cutlass_fp8_supported
  12. from aphrodite.quantization.utils.marlin_utils_fp8 import (
  13. apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
  14. from aphrodite.quantization.utils.quant_utils import is_layer_skipped
  15. from aphrodite.quantization.utils.w8a8_utils import (
  16. apply_fp8_linear, create_per_channel_scale_param)
  17. class FBGEMMFp8Config(QuantizationConfig):
  18. """Config class for FBGEMM Fp8."""
  19. def __init__(self, ignore_list: List[str], input_scale_ub: float):
  20. self.ignore_list = ignore_list if ignore_list else []
  21. self.input_scale_ub = input_scale_ub
  22. # For GPUs that lack FP8 hardware support, we can leverage the Marlin
  23. # kernel for fast weight-only FP8 quantization
  24. capability = current_platform.get_device_capability()
  25. capability = capability[0] * 10 + capability[1]
  26. self.use_marlin = capability < 89
  27. @classmethod
  28. def get_name(cls) -> str:
  29. return "fbgemm_fp8"
  30. @classmethod
  31. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  32. return [torch.bfloat16, torch.float16]
  33. @classmethod
  34. def get_min_capability(cls) -> int:
  35. return 80
  36. @classmethod
  37. def get_config_filenames(cls) -> List[str]:
  38. return []
  39. @classmethod
  40. def from_config(cls, config: Dict[str, Any]) -> "FBGEMMFp8Config":
  41. ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"])
  42. input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"])
  43. return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub)
  44. def get_quant_method(self, layer: torch.nn.Module,
  45. prefix: str) -> Optional["QuantizeMethodBase"]:
  46. if isinstance(layer, LinearBase):
  47. if is_layer_skipped(prefix, self.ignore_list):
  48. return UnquantizedLinearMethod()
  49. return FBGEMMFp8LinearMethod(self)
  50. return None
  51. def get_scaled_act_names(self) -> List[str]:
  52. return []
  53. class FBGEMMFp8LinearMethod(LinearMethodBase):
  54. def __init__(self, quant_config: FBGEMMFp8Config):
  55. self.quant_config = quant_config
  56. self.cutlass_fp8_supported = cutlass_fp8_supported()
  57. def create_weights(
  58. self,
  59. layer: torch.nn.Module,
  60. input_size_per_partition: int,
  61. output_partition_sizes: List[int],
  62. input_size: int,
  63. output_size: int,
  64. params_dtype: torch.dtype,
  65. **extra_weight_attrs,
  66. ):
  67. del input_size, output_size
  68. output_size_per_partition = sum(output_partition_sizes)
  69. layer.logical_widths = output_partition_sizes
  70. layer.input_size_per_partition = input_size_per_partition
  71. layer.output_size_per_partition = output_size_per_partition
  72. layer.orig_dtype = params_dtype
  73. # WEIGHT
  74. weight = Parameter(torch.empty(output_size_per_partition,
  75. input_size_per_partition,
  76. dtype=torch.float8_e4m3fn),
  77. requires_grad=False)
  78. layer.register_parameter("weight", weight)
  79. set_weight_attrs(weight, {
  80. "input_dim": 1,
  81. "output_dim": 0,
  82. **extra_weight_attrs,
  83. })
  84. # WEIGHT SCALE
  85. weight_scale = create_per_channel_scale_param(output_partition_sizes,
  86. **extra_weight_attrs)
  87. layer.register_parameter("weight_scale", weight_scale)
  88. # INPUT SCALE UPPER BOUND
  89. input_scale_ub = torch.nn.Parameter(torch.tensor(
  90. (self.quant_config.input_scale_ub), dtype=torch.float32),
  91. requires_grad=False)
  92. layer.input_scale_ub = input_scale_ub
  93. def process_weights_after_loading(self, layer: Module) -> None:
  94. weight = layer.weight
  95. layer.weight = Parameter(weight.t(), requires_grad=False)
  96. if self.quant_config.use_marlin:
  97. prepare_fp8_layer_for_marlin(layer)
  98. # Activations not quantized for marlin.
  99. del layer.input_scale_ub
  100. def apply(self,
  101. layer: torch.nn.Module,
  102. x: torch.Tensor,
  103. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  104. if self.quant_config.use_marlin:
  105. return apply_fp8_marlin_linear(
  106. input=x,
  107. weight=layer.weight,
  108. weight_scale=layer.weight_scale,
  109. workspace=layer.workspace,
  110. size_n=layer.output_size_per_partition,
  111. size_k=layer.input_size_per_partition,
  112. bias=bias)
  113. return apply_fp8_linear(
  114. input=x,
  115. weight=layer.weight,
  116. weight_scale=layer.weight_scale,
  117. input_scale=None,
  118. input_scale_ub=layer.input_scale_ub,
  119. bias=bias,
  120. cutlass_fp8_supported=self.cutlass_fp8_supported,
  121. use_per_token_if_dynamic=True)