modelopt.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from loguru import logger
  4. from torch.nn import Module
  5. from torch.nn.parameter import Parameter
  6. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  7. from aphrodite.modeling.parameter import (ModelWeightParameter,
  8. PerTensorScaleParameter)
  9. from aphrodite.quantization.base_config import (QuantizationConfig,
  10. QuantizeMethodBase)
  11. from aphrodite.quantization.kv_cache import BaseKVCacheMethod
  12. from aphrodite.quantization.utils.w8a8_utils import (apply_fp8_linear,
  13. cutlass_fp8_supported,
  14. requantize_with_max_scale)
  15. ACTIVATION_SCHEMES = ["static"]
  16. class ModelOptFp8Config(QuantizationConfig):
  17. """Config class for ModelOpt FP8."""
  18. def __init__(
  19. self,
  20. is_checkpoint_fp8_serialized: bool = False,
  21. ) -> None:
  22. self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
  23. if is_checkpoint_fp8_serialized:
  24. logger.warning(
  25. "Detected ModelOpt fp8 checkpoint. Please note that"
  26. " the format is experimental and could change."
  27. )
  28. @classmethod
  29. def get_name(cls) -> str:
  30. return "modelopt"
  31. @classmethod
  32. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  33. return [torch.bfloat16, torch.half]
  34. @classmethod
  35. def get_min_capability(cls) -> int:
  36. return 89
  37. @classmethod
  38. def get_config_filenames(cls) -> List[str]:
  39. return ["hf_quant_config.json"]
  40. @classmethod
  41. def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
  42. quant_config = cls.get_from_keys(config, ["quantization"])
  43. quant_method = quant_config["quant_algo"]
  44. is_checkpoint_fp8_serialized = "FP8" in quant_method
  45. if not is_checkpoint_fp8_serialized:
  46. raise ValueError(
  47. "ModelOpt currently only supports static FP8"
  48. "quantization in Aphrodite. Please check the "
  49. "`hf_quant_config.json` file for your model's "
  50. "quant configuration."
  51. )
  52. return cls(is_checkpoint_fp8_serialized)
  53. def get_quant_method(
  54. self, layer: torch.nn.Module, prefix: str
  55. ) -> Optional["QuantizeMethodBase"]:
  56. from aphrodite.attention.layer import (
  57. Attention) # Avoid circular import
  58. if isinstance(layer, LinearBase):
  59. return ModelOptFp8LinearMethod(self)
  60. elif isinstance(layer, Attention):
  61. return ModelOptFp8KVCacheMethod(self)
  62. return None
  63. def get_scaled_act_names(self) -> List[str]:
  64. return []
  65. class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
  66. """
  67. Supports loading kv-cache scaling factors from FP8 checkpoints.
  68. """
  69. def __init__(self, quant_config: ModelOptFp8Config):
  70. super().__init__(quant_config)
  71. class ModelOptFp8LinearMethod(LinearMethodBase):
  72. """Linear method for Model Optimizer static quantization.
  73. Supports loading FP8 checkpoints with static weight scale and
  74. activation scale. Future support might be added for dynamic
  75. scales.
  76. Limitations:
  77. 1. Only support per-tensor quantization due to torch._scaled_mm support.
  78. 2. Only support float8_e4m3fn datatype
  79. Args: quant_config: The ModelOpt quantization config.
  80. """
  81. def __init__(self, quant_config: ModelOptFp8Config):
  82. self.quant_config = quant_config
  83. self.cutlass_fp8_supported = cutlass_fp8_supported()
  84. def create_weights(
  85. self,
  86. layer: torch.nn.Module,
  87. input_size_per_partition: int,
  88. output_partition_sizes: List[int],
  89. input_size: int,
  90. output_size: int,
  91. params_dtype: torch.dtype,
  92. **extra_weight_attrs,
  93. ):
  94. del input_size, output_size
  95. output_size_per_partition = sum(output_partition_sizes)
  96. weight_loader = extra_weight_attrs.get("weight_loader")
  97. layer.logical_widths = output_partition_sizes
  98. layer.input_size_per_partition = input_size_per_partition
  99. layer.output_size_per_partition = output_size_per_partition
  100. weight_dtype = (
  101. torch.float8_e4m3fn
  102. if self.quant_config.is_checkpoint_fp8_serialized
  103. else params_dtype
  104. )
  105. weight = ModelWeightParameter(
  106. data=torch.empty(
  107. output_size_per_partition,
  108. input_size_per_partition,
  109. dtype=weight_dtype,
  110. ),
  111. input_dim=1,
  112. output_dim=0,
  113. weight_loader=weight_loader,
  114. )
  115. layer.register_parameter("weight", weight)
  116. if self.quant_config.is_checkpoint_fp8_serialized:
  117. # WEIGHT SCALE
  118. weight_scale = PerTensorScaleParameter(
  119. data=torch.empty(
  120. len(output_partition_sizes), dtype=torch.float32
  121. ),
  122. weight_loader=weight_loader,
  123. )
  124. weight_scale[:] = torch.finfo(torch.float32).min
  125. layer.register_parameter("weight_scale", weight_scale)
  126. # INPUT SCALE
  127. scale = PerTensorScaleParameter(
  128. data=torch.empty(
  129. len(output_partition_sizes), dtype=torch.float32
  130. ),
  131. weight_loader=weight_loader,
  132. )
  133. scale[:] = torch.finfo(torch.float32).min
  134. layer.register_parameter("input_scale", scale)
  135. def process_weights_after_loading(self, layer: Module) -> None:
  136. max_w_scale, weight = requantize_with_max_scale(
  137. layer.weight, layer.weight_scale, layer.logical_widths
  138. )
  139. layer.weight = Parameter(weight.t(), requires_grad=False)
  140. layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
  141. layer.input_scale = Parameter(
  142. layer.input_scale.max(), requires_grad=False
  143. )
  144. def apply(
  145. self,
  146. layer: torch.nn.Module,
  147. x: torch.Tensor,
  148. bias: Optional[torch.Tensor] = None,
  149. ) -> torch.Tensor:
  150. return apply_fp8_linear(
  151. input=x,
  152. weight=layer.weight,
  153. weight_scale=layer.weight_scale,
  154. input_scale=layer.input_scale,
  155. bias=bias,
  156. cutlass_fp8_supported=self.cutlass_fp8_supported,
  157. )