fbgemm_fp8.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from torch.nn import Module
  4. from torch.nn.parameter import Parameter
  5. from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
  6. UnquantizedLinearMethod)
  7. from aphrodite.modeling.utils import set_weight_attrs
  8. from aphrodite.platforms import current_platform
  9. from aphrodite.quantization.base_config import (QuantizationConfig,
  10. QuantizeMethodBase)
  11. from aphrodite.quantization.utils.marlin_utils_fp8 import (
  12. apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
  13. from aphrodite.quantization.utils.w8a8_utils import (
  14. apply_fp8_linear, create_per_channel_scale_param)
  15. # Note: this is a hack. We should update each model to register the
  16. # stacked params and get it from there instead in a future PR.
  17. # fused_name: List[shard_name]
  18. _FUSED_LAYER_NAME_MAPPING = {
  19. "qkv_proj": ["q_proj", "k_proj", "v_proj"],
  20. "gate_up_proj": ["gate_proj", "up_proj"]
  21. }
  22. class FBGEMMFp8Config(QuantizationConfig):
  23. """Config class for FBGEMM Fp8."""
  24. def __init__(self, ignore_list: List[str], input_scale_ub: float):
  25. self.ignore_list = ignore_list
  26. self.input_scale_ub = input_scale_ub
  27. # For GPUs that lack FP8 hardware support, we can leverage the Marlin
  28. # kernel for fast weight-only FP8 quantization
  29. capability = current_platform.get_device_capability()
  30. capability = capability[0] * 10 + capability[1]
  31. self.use_marlin = capability < 89
  32. @classmethod
  33. def get_name(cls) -> str:
  34. return "fbgemm_fp8"
  35. @classmethod
  36. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  37. return [torch.bfloat16, torch.float16]
  38. @classmethod
  39. def get_min_capability(cls) -> int:
  40. return 80
  41. @classmethod
  42. def get_config_filenames(cls) -> List[str]:
  43. return []
  44. @classmethod
  45. def from_config(cls, config: Dict[str, Any]) -> "FBGEMMFp8Config":
  46. ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"])
  47. input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"])
  48. return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub)
  49. def _is_layer_skipped(self, prefix: str) -> bool:
  50. # prefix: model.layers.0.self_attn.q_proj
  51. # proj_name: q_proj
  52. proj_name = prefix.split(".")[-1]
  53. if proj_name in _FUSED_LAYER_NAME_MAPPING:
  54. shard_prefixes = [
  55. prefix.replace(proj_name, shard_proj_name)
  56. for shard_proj_name in _FUSED_LAYER_NAME_MAPPING[proj_name]
  57. ]
  58. is_skipped = None
  59. for shard_prefix in shard_prefixes:
  60. is_shard_skipped = shard_prefix in self.ignore_list
  61. if is_skipped is None:
  62. is_skipped = is_shard_skipped
  63. elif is_shard_skipped != is_skipped:
  64. raise ValueError(
  65. f"Detected some but not all shards of {prefix} "
  66. "are quantized. All shards of fused layers "
  67. "to have the same precision.")
  68. else:
  69. is_skipped = prefix in self.ignore_list
  70. assert is_skipped is not None
  71. return is_skipped
  72. def get_quant_method(self, layer: torch.nn.Module,
  73. prefix: str) -> Optional["QuantizeMethodBase"]:
  74. if isinstance(layer, LinearBase):
  75. if self._is_layer_skipped(prefix):
  76. return UnquantizedLinearMethod()
  77. return FBGEMMFp8LinearMethod(self)
  78. return None
  79. def get_scaled_act_names(self) -> List[str]:
  80. return []
  81. class FBGEMMFp8LinearMethod(LinearMethodBase):
  82. def __init__(self, quant_config: FBGEMMFp8Config):
  83. self.quant_config = quant_config
  84. def create_weights(
  85. self,
  86. layer: torch.nn.Module,
  87. input_size_per_partition: int,
  88. output_partition_sizes: List[int],
  89. input_size: int,
  90. output_size: int,
  91. params_dtype: torch.dtype,
  92. **extra_weight_attrs,
  93. ):
  94. del input_size, output_size
  95. output_size_per_partition = sum(output_partition_sizes)
  96. layer.logical_widths = output_partition_sizes
  97. layer.input_size_per_partition = input_size_per_partition
  98. layer.output_size_per_partition = output_size_per_partition
  99. layer.orig_dtype = params_dtype
  100. # WEIGHT
  101. weight = Parameter(torch.empty(output_size_per_partition,
  102. input_size_per_partition,
  103. dtype=torch.float8_e4m3fn),
  104. requires_grad=False)
  105. layer.register_parameter("weight", weight)
  106. set_weight_attrs(weight, {
  107. "input_dim": 1,
  108. "output_dim": 0,
  109. **extra_weight_attrs,
  110. })
  111. # WEIGHT SCALE
  112. weight_scale = create_per_channel_scale_param(output_partition_sizes,
  113. **extra_weight_attrs)
  114. layer.register_parameter("weight_scale", weight_scale)
  115. # INPUT SCALE UPPER BOUND
  116. input_scale_ub = torch.nn.Parameter(torch.tensor(
  117. (self.quant_config.input_scale_ub), dtype=torch.float32),
  118. requires_grad=False)
  119. layer.input_scale_ub = input_scale_ub
  120. def process_weights_after_loading(self, layer: Module) -> None:
  121. weight = layer.weight
  122. layer.weight = Parameter(weight.t(), requires_grad=False)
  123. if self.quant_config.use_marlin:
  124. prepare_fp8_layer_for_marlin(layer)
  125. # Activations not quantized for marlin.
  126. del layer.input_scale_ub
  127. def apply(self,
  128. layer: torch.nn.Module,
  129. x: torch.Tensor,
  130. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  131. if self.quant_config.use_marlin:
  132. return apply_fp8_marlin_linear(
  133. input=x,
  134. weight=layer.weight,
  135. weight_scale=layer.weight_scale,
  136. workspace=layer.workspace,
  137. size_n=layer.output_size_per_partition,
  138. size_k=layer.input_size_per_partition,
  139. bias=bias)
  140. return apply_fp8_linear(input=x,
  141. weight=layer.weight,
  142. weight_scale=layer.weight_scale,
  143. input_scale=None,
  144. input_scale_ub=layer.input_scale_ub,
  145. bias=bias,
  146. cutlass_fp8_supported=True,
  147. use_per_token_if_dynamic=True)