fp8.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. from contextlib import suppress
  2. from typing import Any, Dict, List, Optional
  3. import torch
  4. from torch.nn import Module
  5. from torch.nn.parameter import Parameter
  6. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  7. from aphrodite.modeling.utils import set_weight_attrs
  8. from aphrodite.quantization.base_config import (QuantizationConfig,
  9. QuantizeMethodBase)
  10. HAS_QUANTS = False
  11. with suppress(ImportError):
  12. from aphrodite._quant_C import quant_ops as ops
  13. HAS_QUANTS = True
  14. class Fp8Config(QuantizationConfig):
  15. """Config class for FP8."""
  16. def __init__(
  17. self,
  18. activation_scheme: str = "dynamic",
  19. ) -> None:
  20. self.activation_scheme = activation_scheme
  21. @classmethod
  22. def get_name(cls) -> str:
  23. return "fp8"
  24. @classmethod
  25. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  26. return [torch.bfloat16, torch.half]
  27. @classmethod
  28. def get_min_capability(cls) -> int:
  29. return 89
  30. @classmethod
  31. def get_config_filenames(cls) -> List[str]:
  32. return []
  33. @classmethod
  34. def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
  35. activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
  36. return cls(activation_scheme)
  37. def get_quant_method(
  38. self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
  39. if isinstance(layer, LinearBase):
  40. return Fp8LinearMethod(self)
  41. return None
  42. def get_scaled_act_names(self) -> List[str]:
  43. return []
  44. class Fp8LinearMethod(LinearMethodBase):
  45. """Linear method for FP8.
  46. We now support common FP16/BF16 model checkpoints ONLY. The weight
  47. scaling factor will be initialized after the model weights are loaded.
  48. Limitations:
  49. 1. Only support per-tensor quantization due to torch._scaled_mm support.
  50. 2. Only support float8_e4m3fn data type due to the limitation of
  51. torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
  52. Args:
  53. quant_config: The quantization config.
  54. """
  55. def __init__(self, quant_config: Fp8Config):
  56. if not HAS_QUANTS:
  57. raise ImportError("Could not find the quantization kernels.")
  58. self.quant_config = quant_config
  59. def create_weights(
  60. self,
  61. layer: torch.nn.Module,
  62. input_size_per_partition: int,
  63. output_partition_sizes: List[int],
  64. input_size: int,
  65. output_size: int,
  66. params_dtype: torch.dtype,
  67. **extra_weight_attrs,
  68. ):
  69. output_size_per_partition = sum(output_partition_sizes)
  70. weight = Parameter(torch.empty(output_size_per_partition,
  71. input_size_per_partition,
  72. dtype=params_dtype),
  73. requires_grad=False)
  74. layer.register_parameter("weight", weight)
  75. set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
  76. set_weight_attrs(weight, extra_weight_attrs)
  77. w_scale = Parameter(
  78. torch.empty(1, dtype=torch.float32),
  79. requires_grad=False,
  80. )
  81. layer.register_parameter("weight_scaling_factor", w_scale)
  82. def process_weights_after_loading(self, layer: Module) -> None:
  83. # Although the linear_method is propagated to all layers,
  84. # only linear layers invoke "create_weights". So we check
  85. # whether "weight_scaling_facor" is registered to determine
  86. # whether the layer is a linear layer that requires quantization.
  87. if not hasattr(layer, "weight_scaling_factor"):
  88. return
  89. qweight, weight_scale = ops.scaled_fp8_quant(layer.weight)
  90. # torch._scaled_mm requires column-major in the second
  91. # input (weight), so we transpose the quantized weight.
  92. layer.weight = Parameter(qweight.t(), requires_grad=False)
  93. layer.weight_scaling_factor.data.copy_(weight_scale)
  94. def apply(self,
  95. layer: torch.nn.Module,
  96. x: torch.Tensor,
  97. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  98. qinput, x_scale = ops.scaled_fp8_quant(x)
  99. output, _ = torch._scaled_mm(
  100. qinput,
  101. layer.weight,
  102. out_dtype=x.dtype,
  103. scale_a=x_scale,
  104. scale_b=layer.weight_scaling_factor,
  105. bias=bias,
  106. )
  107. return output