fp8.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. from typing import Any, Dict, List, Optional, Tuple, Union
  2. from contextlib import suppress
  3. import torch
  4. from torch.nn import Module
  5. from torch.nn.parameter import Parameter
  6. from loguru import logger
  7. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  8. from aphrodite.quantization.base_config import (QuantizationConfig)
  9. from aphrodite.modeling.utils import set_weight_attrs
  10. HAS_QUANTS = False
  11. with suppress(ImportError):
  12. from aphrodite._quant_C import quant_ops as ops
  13. HAS_QUANTS = True
  14. ACTIVATION_SCHEMES = ["static", "dynamic"]
  15. def scaled_fp8_quant(
  16. input: torch.Tensor,
  17. scale: Optional[torch.Tensor] = None,
  18. batch_dim_padding: Optional[int] = None,
  19. ) -> Tuple[torch.Tensor, torch.Tensor]:
  20. """
  21. Quantize input tensor to FP8 and return quantized tensor and scale.
  22. This function supports both static and dynamic quantization: If you
  23. provide the scale, it will use static scaling and if you omit it,
  24. the scale will be determined dynamically. The function also allows
  25. optional padding of the output tensor for downstream kernels that
  26. will benefit from padding.
  27. Args:
  28. input: The input tensor to be quantized to FP8
  29. scale: Optional scaling factor for the FP8 quantization
  30. batch_dim_padding: If specified, pad the first dimension
  31. of the output to at least this value.
  32. Returns:
  33. Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
  34. scaling factor.
  35. """
  36. if batch_dim_padding:
  37. shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:])
  38. output = torch.empty(shape,
  39. device=input.device,
  40. dtype=torch.float8_e4m3fn)
  41. else:
  42. output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
  43. if scale is None:
  44. scale = torch.zeros(1, device=input.device, dtype=torch.float32)
  45. ops.dynamic_scaled_fp8_quant(output, input, scale)
  46. else:
  47. ops.static_scaled_fp8_quant(output, input, scale)
  48. return output, scale
  49. class Fp8Config(QuantizationConfig):
  50. """Config class for FP8."""
  51. def __init__(
  52. self,
  53. is_checkpoint_fp8_serialized: bool = False,
  54. activation_scheme: str = "dynamic",
  55. ) -> None:
  56. self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
  57. if is_checkpoint_fp8_serialized:
  58. logger.warning("Detected fp8 checkpoint. Please note that the "
  59. "format is experimental and subject to change.")
  60. if activation_scheme not in ACTIVATION_SCHEMES:
  61. raise ValueError(
  62. f"Unsupported activation scheme {activation_scheme}")
  63. self.activation_scheme = activation_scheme
  64. @classmethod
  65. def get_name(cls) -> str:
  66. return "fp8"
  67. @classmethod
  68. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  69. return [torch.bfloat16, torch.half]
  70. @classmethod
  71. def get_min_capability(cls) -> int:
  72. return 89
  73. @classmethod
  74. def get_config_filenames(cls) -> List[str]:
  75. return []
  76. @classmethod
  77. def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
  78. quant_method = cls.get_from_keys(config, ["quant_method"])
  79. is_checkpoint_fp8_serialized = ("fp8" in quant_method)
  80. activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
  81. return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
  82. activation_scheme=activation_scheme)
  83. def get_quant_method(
  84. self, layer: torch.nn.Module) -> Optional["Fp8LinearMethod"]:
  85. if isinstance(layer, LinearBase):
  86. return Fp8LinearMethod(self)
  87. return None
  88. def get_scaled_act_names(self) -> List[str]:
  89. return []
  90. class Fp8LinearMethod(LinearMethodBase):
  91. """Linear method for FP8.
  92. Supports loading FP8 checkpoints with static weight scale and
  93. dynamic/static activation scale.
  94. Also supports loading quantized FP16/BF16 model checkpoints with dynamic
  95. activation scaling. The weight scaling factor will be initialized after
  96. the model weights are loaded.
  97. Limitations:
  98. 1. Only support per-tensor quantization due to torch._scaled_mm support.
  99. 2. Only support float8_e4m3fn data type due to the limitation of
  100. torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
  101. Args:
  102. quant_config: The quantization config.
  103. """
  104. def __init__(self, quant_config: Fp8Config):
  105. if not HAS_QUANTS:
  106. raise ImportError("Could not find the quantization kernels.")
  107. self.quant_config = quant_config
  108. def _create_scale_param(
  109. self,
  110. scale_name: str,
  111. layer: torch.nn.Module,
  112. output_partition_sizes: List[int],
  113. **extra_weight_attrs,
  114. ) -> None:
  115. scale = Parameter(torch.empty(len(output_partition_sizes),
  116. dtype=torch.float32),
  117. requires_grad=False)
  118. layer.register_parameter(scale_name, scale)
  119. set_weight_attrs(
  120. scale, {
  121. **extra_weight_attrs,
  122. "fp8_scales_shard_indexer":
  123. self.scales_shard_indexer,
  124. })
  125. def create_weights(
  126. self,
  127. layer: torch.nn.Module,
  128. input_size_per_partition: int,
  129. output_partition_sizes: List[int],
  130. input_size: int,
  131. output_size: int,
  132. params_dtype: torch.dtype,
  133. **extra_weight_attrs,
  134. ):
  135. del input_size, output_size
  136. output_size_per_partition = sum(output_partition_sizes)
  137. layer.process_after_load = True
  138. layer.logical_widths = output_partition_sizes
  139. # WEIGHT
  140. weight_dtype = (torch.float8_e4m3fn
  141. if self.quant_config.is_checkpoint_fp8_serialized else
  142. params_dtype)
  143. weight = Parameter(torch.empty(output_size_per_partition,
  144. input_size_per_partition,
  145. dtype=weight_dtype),
  146. requires_grad=False)
  147. layer.register_parameter("weight", weight)
  148. set_weight_attrs(weight, {
  149. **extra_weight_attrs,
  150. "input_dim": 1,
  151. "output_dim": 0,
  152. })
  153. # If checkpoint is serialized fp8, load them.
  154. # Otherwise, wait until process_weights_after_loading.
  155. if self.quant_config.is_checkpoint_fp8_serialized:
  156. # WEIGHT SCALE
  157. self._create_scale_param(
  158. scale_name="weight_scale",
  159. layer=layer,
  160. output_partition_sizes=output_partition_sizes,
  161. **extra_weight_attrs)
  162. # ACTIVATION SCALE
  163. if self.quant_config.activation_scheme == "static":
  164. self._create_scale_param(
  165. scale_name="act_scale",
  166. layer=layer,
  167. output_partition_sizes=output_partition_sizes,
  168. **extra_weight_attrs)
  169. def scales_shard_indexer(
  170. self, param: torch.Tensor, loaded_weight: torch.Tensor,
  171. shard_id: Union[str, int]) -> Tuple[torch.Tensor, torch.Tensor]:
  172. qkv_idxs = {"q": 0, "k": 1, "v": 2}
  173. if isinstance(shard_id, int):
  174. pass
  175. elif isinstance(shard_id, str):
  176. if shard_id not in qkv_idxs:
  177. raise ValueError(f"Unknown shard_id: {shard_id}")
  178. shard_id = qkv_idxs[shard_id]
  179. else:
  180. ValueError(f"Shard id must be int or str but got {type(shard_id)}")
  181. return param[shard_id], loaded_weight
  182. def process_weights_after_loading(self, layer: Module) -> None:
  183. if (not hasattr(layer, "process_after_load")
  184. or not layer.process_after_load):
  185. return
  186. # If checkpoint is fp/bf16 (not serialized fp8), quantize the weights.
  187. if not self.quant_config.is_checkpoint_fp8_serialized:
  188. qweight, weight_scale = scaled_fp8_quant(layer.weight, scale=None)
  189. layer.weight = Parameter(qweight.t(), requires_grad=False)
  190. layer.weight_scale = Parameter(weight_scale, requires_grad=False)
  191. layer.logical_widths = None
  192. layer.act_scale = None
  193. return
  194. # If checkpoint is fp8, requantize the separately quantized logical
  195. # weights into a single fp8 weight with a single weight scale.
  196. else:
  197. # WEIGHT_SCALE / WEIGHT
  198. # Loop over logical weights, requantizing with single scale.
  199. max_w_scale = layer.weight_scale.max()
  200. start = 0
  201. for idx, logical_width in enumerate(layer.logical_widths):
  202. end = start + logical_width
  203. weight_dq = per_tensor_dequantize(layer.weight[start:end, :],
  204. layer.weight_scale[idx])
  205. layer.weight[start:end, :] = per_tensor_quantize(
  206. weight_dq, layer.weight_scale.max())
  207. start = end
  208. layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
  209. # WEIGHT
  210. # Transpose weight for passing to torch._scaled_mm
  211. weight = layer.weight
  212. layer.weight = Parameter(weight.t(), requires_grad=False)
  213. # ACT_SCALE
  214. # Dynamic: set to None (required input to ops.scaled_fp8_quant).
  215. # Static: set to max of the act_scales (since they are equal).
  216. if self.quant_config.activation_scheme == "dynamic":
  217. layer.act_scale = None
  218. elif self.quant_config.activation_scheme == "static":
  219. if not all_close_1d(layer.act_scale):
  220. raise ValueError(
  221. "All the act_scales for the logical weights of a layer "
  222. f"must be equal. But got {layer.act_scale}")
  223. layer.act_scale = Parameter(layer.act_scale.max(),
  224. requires_grad=False)
  225. else:
  226. raise ValueError(
  227. f"Unknown scheme {self.quant_config.activation_scheme}")
  228. def apply(self,
  229. layer: torch.nn.Module,
  230. x: torch.Tensor,
  231. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  232. # ops.scaled_fp8_quant supports both dynamic and static quant.
  233. # If dynamic, layer.act_scale is None and x_scale computed from x.
  234. # If static, layer.act_scale is scalar and x_scale set to act_scale.
  235. qinput, x_scale = scaled_fp8_quant(x,
  236. layer.act_scale,
  237. batch_dim_padding=17)
  238. # Fused GEMM_DQ -- note we padded the input above because
  239. # torch._scaled_mm is more performant for matrices with
  240. # batch dimension > 16. Note that this could change
  241. # in the future.
  242. output, _ = torch._scaled_mm(
  243. qinput,
  244. layer.weight,
  245. out_dtype=x.dtype,
  246. scale_a=x_scale,
  247. scale_b=layer.weight_scale,
  248. bias=bias,
  249. )
  250. return torch.narrow(output, 0, 0, x.shape[0])
  251. def all_close_1d(x: torch.Tensor) -> bool:
  252. assert len(x.shape) == 1
  253. return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
  254. def per_tensor_quantize(tensor: torch.Tensor,
  255. inv_scale: float) -> torch.Tensor:
  256. finfo = torch.finfo(torch.float8_e4m3fn)
  257. qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
  258. return qweight.to(torch.float8_e4m3fn)
  259. def per_tensor_dequantize(tensor: torch.Tensor,
  260. inv_scale: float) -> torch.Tensor:
  261. fake_qweight = tensor.to(torch.float16)
  262. dq_weight = fake_qweight * inv_scale
  263. return dq_weight