fp8.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. from typing import Any, Dict, List, Optional, Tuple
  2. import torch
  3. from torch.nn import Module
  4. from torch.nn.parameter import Parameter
  5. from aphrodite.modeling.layers.linear import LinearMethodBase, set_weight_attrs
  6. from aphrodite.quantization.base_config import \
  7. QuantizationConfig
  8. class FP8Config(QuantizationConfig):
  9. """Config class for FP8."""
  10. @classmethod
  11. def get_name(cls) -> str:
  12. return "fp8"
  13. @classmethod
  14. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  15. return [torch.bfloat16, torch.half]
  16. @classmethod
  17. def get_min_capability(cls) -> int:
  18. return 89
  19. @classmethod
  20. def get_config_filenames(cls) -> List[str]:
  21. return []
  22. @classmethod
  23. def from_config(cls, config: Dict[str, Any]) -> "FP8Config":
  24. return cls()
  25. def get_linear_method(self) -> "Fp8LinearMethod":
  26. return Fp8LinearMethod(self)
  27. def get_scaled_act_names(self) -> List[str]:
  28. return []
  29. class Fp8LinearMethod(LinearMethodBase):
  30. """Linear method for FP8.
  31. We now support common FP16/BF16 model checkpoints ONLY. The weight
  32. scaling factor will be initialized after the model weights are loaded.
  33. Limitations:
  34. 1. Only support per-tensor quantization due to torch._scaled_mm support.
  35. 2. Only support float8_e4m3fn data type due to the limitation of
  36. torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
  37. Args:
  38. quant_config: The quantization config.
  39. """
  40. def __init__(self, quant_config: FP8Config):
  41. self.quant_config = quant_config
  42. def create_weights(
  43. self,
  44. layer: torch.nn.Module,
  45. input_size_per_partition: int,
  46. output_partition_sizes: List[int],
  47. input_size: int,
  48. output_size: int,
  49. params_dtype: torch.dtype,
  50. **extra_weight_attrs,
  51. ):
  52. output_size_per_partition = sum(output_partition_sizes)
  53. weight = Parameter(torch.empty(output_size_per_partition,
  54. input_size_per_partition,
  55. dtype=params_dtype),
  56. requires_grad=False)
  57. layer.register_parameter("weight", weight)
  58. set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
  59. set_weight_attrs(weight, extra_weight_attrs)
  60. w_scale = Parameter(
  61. torch.empty(1, dtype=torch.float32),
  62. requires_grad=False,
  63. )
  64. layer.register_parameter("weight_scaling_factor", w_scale)
  65. def process_weights_after_loading(self, layer: Module) -> None:
  66. # Although the linear_method is propagated to all layers,
  67. # only linear layers invoke "create_weights". So we check
  68. # whether "weight_scaling_facor" is registered to determine
  69. # whether the layer is a linear layer that requires quantization.
  70. if not hasattr(layer, "weight_scaling_factor"):
  71. return
  72. qweight, weight_scale = per_tensor_quantize(layer.weight)
  73. # torch._scaled_mm requires column-major in the second
  74. # input (weight), so we transpose the quantized weight.
  75. layer.weight = Parameter(qweight.t(), requires_grad=False)
  76. layer.weight_scaling_factor.data.copy_(weight_scale)
  77. def apply_weights(self,
  78. layer: torch.nn.Module,
  79. x: torch.Tensor,
  80. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  81. qinput, x_scale = per_tensor_quantize(x)
  82. output, _ = torch._scaled_mm(
  83. qinput,
  84. layer.weight,
  85. out_dtype=x.dtype,
  86. scale_a=x_scale,
  87. scale_b=layer.weight_scaling_factor,
  88. bias=bias,
  89. )
  90. return output
  91. def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
  92. """Quantize a tensor using per-tensor static scaling factor.
  93. Args:
  94. tensor: The input tensor.
  95. """
  96. finfo = torch.finfo(torch.float8_e4m3fn)
  97. # Calculate the scale as dtype max divided by absmax.
  98. # Since .abs() creates a new tensor, we use aminmax to get
  99. # the min and max first and then calculate the absmax.
  100. min_val, max_val = tensor.aminmax()
  101. amax = min_val.abs().max(max_val.abs())
  102. scale = finfo.max / amax.clamp(min=1e-12)
  103. # scale and clamp the tensor to bring it to
  104. # the representative range of float8 data type
  105. # (as default cast is unsaturated)
  106. qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
  107. # Return both float8 data and the inverse scale (as float),
  108. # as both required as inputs to torch._scaled_mm
  109. qweight = qweight.to(torch.float8_e4m3fn)
  110. scale = scale.float().reciprocal()
  111. return qweight, scale