fp8.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. from contextlib import suppress
  2. from typing import Any, Dict, List, Optional
  3. import torch
  4. from torch.nn import Module
  5. from torch.nn.parameter import Parameter
  6. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  7. from aphrodite.modeling.utils import set_weight_attrs
  8. from aphrodite.quantization.base_config import (QuantizationConfig,
  9. QuantizeMethodBase)
  10. HAS_QUANTS = False
  11. with suppress(ImportError):
  12. from aphrodite._quant_C import quant_ops as ops
  13. HAS_QUANTS = True
  14. class Fp8Config(QuantizationConfig):
  15. """Config class for FP8."""
  16. @classmethod
  17. def get_name(cls) -> str:
  18. return "fp8"
  19. @classmethod
  20. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  21. return [torch.bfloat16, torch.half]
  22. @classmethod
  23. def get_min_capability(cls) -> int:
  24. return 89
  25. @classmethod
  26. def get_config_filenames(cls) -> List[str]:
  27. return []
  28. @classmethod
  29. def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
  30. return cls()
  31. def get_quant_method(
  32. self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
  33. if isinstance(layer, LinearBase):
  34. return Fp8LinearMethod(self)
  35. return None
  36. def get_scaled_act_names(self) -> List[str]:
  37. return []
  38. class Fp8LinearMethod(LinearMethodBase):
  39. """Linear method for FP8.
  40. We now support common FP16/BF16 model checkpoints ONLY. The weight
  41. scaling factor will be initialized after the model weights are loaded.
  42. Limitations:
  43. 1. Only support per-tensor quantization due to torch._scaled_mm support.
  44. 2. Only support float8_e4m3fn data type due to the limitation of
  45. torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
  46. Args:
  47. quant_config: The quantization config.
  48. """
  49. def __init__(self, quant_config: Fp8Config):
  50. if not HAS_QUANTS:
  51. raise ImportError("Could not find the quantization kernels.")
  52. self.quant_config = quant_config
  53. def create_weights(
  54. self,
  55. layer: torch.nn.Module,
  56. input_size_per_partition: int,
  57. output_partition_sizes: List[int],
  58. input_size: int,
  59. output_size: int,
  60. params_dtype: torch.dtype,
  61. **extra_weight_attrs,
  62. ):
  63. output_size_per_partition = sum(output_partition_sizes)
  64. weight = Parameter(torch.empty(output_size_per_partition,
  65. input_size_per_partition,
  66. dtype=params_dtype),
  67. requires_grad=False)
  68. layer.register_parameter("weight", weight)
  69. set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
  70. set_weight_attrs(weight, extra_weight_attrs)
  71. w_scale = Parameter(
  72. torch.empty(1, dtype=torch.float32),
  73. requires_grad=False,
  74. )
  75. layer.register_parameter("weight_scaling_factor", w_scale)
  76. def process_weights_after_loading(self, layer: Module) -> None:
  77. # Although the linear_method is propagated to all layers,
  78. # only linear layers invoke "create_weights". So we check
  79. # whether "weight_scaling_facor" is registered to determine
  80. # whether the layer is a linear layer that requires quantization.
  81. if not hasattr(layer, "weight_scaling_factor"):
  82. return
  83. qweight, weight_scale = ops.scaled_fp8_quant(layer.weight)
  84. # torch._scaled_mm requires column-major in the second
  85. # input (weight), so we transpose the quantized weight.
  86. layer.weight = Parameter(qweight.t(), requires_grad=False)
  87. layer.weight_scaling_factor.data.copy_(weight_scale)
  88. def apply(self,
  89. layer: torch.nn.Module,
  90. x: torch.Tensor,
  91. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  92. qinput, x_scale = ops.scaled_fp8_quant(x)
  93. output, _ = torch._scaled_mm(
  94. qinput,
  95. layer.weight,
  96. out_dtype=x.dtype,
  97. scale_a=x_scale,
  98. scale_b=layer.weight_scaling_factor,
  99. bias=bias,
  100. )
  101. return output