fp8.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. from typing import Any, Dict, List, Optional, Tuple, Union
  2. import torch
  3. from loguru import logger
  4. from torch.nn import Module
  5. from torch.nn.parameter import Parameter
  6. from aphrodite import _custom_ops as ops
  7. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  8. from aphrodite.quantization.base_config import (QuantizationConfig,
  9. QuantizeMethodBase)
  10. from aphrodite.modeling.utils import set_weight_attrs
  11. from aphrodite.common.utils import print_warning_once
  12. ACTIVATION_SCHEMES = ["static", "dynamic"]
  13. def cutlass_fp8_supported() -> bool:
  14. capability = torch.cuda.get_device_capability()
  15. capability = capability[0] * 10 + capability[1]
  16. return ops.cutlass_scaled_mm_supports_fp8(capability)
  17. class Fp8Config(QuantizationConfig):
  18. """Config class for FP8."""
  19. def __init__(
  20. self,
  21. is_checkpoint_fp8_serialized: bool = False,
  22. activation_scheme: str = "dynamic",
  23. ) -> None:
  24. self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
  25. if is_checkpoint_fp8_serialized:
  26. logger.warning("Detected fp8 checkpoint. Please note that the "
  27. "format is experimental and subject to change.")
  28. if activation_scheme not in ACTIVATION_SCHEMES:
  29. raise ValueError(
  30. f"Unsupported activation scheme {activation_scheme}")
  31. self.activation_scheme = activation_scheme
  32. @classmethod
  33. def get_name(cls) -> str:
  34. return "fp8"
  35. @classmethod
  36. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  37. return [torch.bfloat16, torch.half]
  38. @classmethod
  39. def get_min_capability(cls) -> int:
  40. return 89
  41. @classmethod
  42. def get_config_filenames(cls) -> List[str]:
  43. return []
  44. @classmethod
  45. def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
  46. quant_method = cls.get_from_keys(config, ["quant_method"])
  47. is_checkpoint_fp8_serialized = ("fp8" in quant_method)
  48. activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
  49. return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
  50. activation_scheme=activation_scheme)
  51. def get_quant_method(
  52. self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
  53. from aphrodite.attention.layer import Attention # Avoid circular import
  54. if isinstance(layer, LinearBase):
  55. return Fp8LinearMethod(self)
  56. if isinstance(layer, Attention):
  57. return Fp8KVCacheMethod(self)
  58. return None
  59. def get_scaled_act_names(self) -> List[str]:
  60. return []
  61. class Fp8LinearMethod(LinearMethodBase):
  62. """Linear method for FP8.
  63. Supports loading FP8 checkpoints with static weight scale and
  64. dynamic/static activation scale.
  65. Also supports loading quantized FP16/BF16 model checkpoints with dynamic
  66. activation scaling. The weight scaling factor will be initialized after
  67. the model weights are loaded.
  68. Limitations:
  69. 1. Only support per-tensor quantization due to torch._scaled_mm support.
  70. 2. Only support float8_e4m3fn data type due to the limitation of
  71. torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
  72. Args:
  73. quant_config: The quantization config.
  74. """
  75. def __init__(self, quant_config: Fp8Config):
  76. self.quant_config = quant_config
  77. self.cutlass_fp8_supported = cutlass_fp8_supported()
  78. def _create_scale_param(
  79. self,
  80. scale_name: str,
  81. layer: torch.nn.Module,
  82. output_partition_sizes: List[int],
  83. **extra_weight_attrs,
  84. ) -> None:
  85. scale = Parameter(torch.empty(len(output_partition_sizes),
  86. dtype=torch.float32),
  87. requires_grad=False)
  88. layer.register_parameter(scale_name, scale)
  89. set_weight_attrs(
  90. scale, {
  91. **extra_weight_attrs,
  92. "fp8_scales_shard_indexer":
  93. self.scales_shard_indexer,
  94. })
  95. def create_weights(
  96. self,
  97. layer: torch.nn.Module,
  98. input_size_per_partition: int,
  99. output_partition_sizes: List[int],
  100. input_size: int,
  101. output_size: int,
  102. params_dtype: torch.dtype,
  103. **extra_weight_attrs,
  104. ):
  105. del input_size, output_size
  106. output_size_per_partition = sum(output_partition_sizes)
  107. layer.process_after_load = True
  108. layer.logical_widths = output_partition_sizes
  109. # WEIGHT
  110. weight_dtype = (torch.float8_e4m3fn
  111. if self.quant_config.is_checkpoint_fp8_serialized else
  112. params_dtype)
  113. weight = Parameter(torch.empty(output_size_per_partition,
  114. input_size_per_partition,
  115. dtype=weight_dtype),
  116. requires_grad=False)
  117. layer.register_parameter("weight", weight)
  118. set_weight_attrs(weight, {
  119. **extra_weight_attrs,
  120. "input_dim": 1,
  121. "output_dim": 0,
  122. })
  123. # If checkpoint is serialized fp8, load them.
  124. # Otherwise, wait until process_weights_after_loading.
  125. if self.quant_config.is_checkpoint_fp8_serialized:
  126. # WEIGHT SCALE
  127. self._create_scale_param(
  128. scale_name="weight_scale",
  129. layer=layer,
  130. output_partition_sizes=output_partition_sizes,
  131. **extra_weight_attrs)
  132. # INPUT ACTIVATION SCALE
  133. if self.quant_config.activation_scheme == "static":
  134. self._create_scale_param(
  135. scale_name="input_scale",
  136. layer=layer,
  137. output_partition_sizes=output_partition_sizes,
  138. **extra_weight_attrs)
  139. def scales_shard_indexer(
  140. self, param: torch.Tensor, loaded_weight: torch.Tensor,
  141. shard_id: Union[str, int]) -> Tuple[torch.Tensor, torch.Tensor]:
  142. qkv_idxs = {"q": 0, "k": 1, "v": 2}
  143. if isinstance(shard_id, int):
  144. pass
  145. elif isinstance(shard_id, str):
  146. if shard_id not in qkv_idxs:
  147. raise ValueError(f"Unknown shard_id: {shard_id}")
  148. shard_id = qkv_idxs[shard_id]
  149. else:
  150. ValueError(f"Shard id must be int or str but got {type(shard_id)}")
  151. return param[shard_id], loaded_weight
  152. def process_weights_after_loading(self, layer: Module) -> None:
  153. if (not hasattr(layer, "process_after_load")
  154. or not layer.process_after_load):
  155. return
  156. # If checkpoint is fp/bf16 (not serialized fp8), quantize the weights.
  157. if not self.quant_config.is_checkpoint_fp8_serialized:
  158. qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
  159. scale=None)
  160. layer.weight = Parameter(qweight.t(), requires_grad=False)
  161. layer.weight_scale = Parameter(weight_scale, requires_grad=False)
  162. layer.logical_widths = None
  163. layer.input_scale = None
  164. return
  165. # If checkpoint is fp8, requantize the separately quantized logical
  166. # weights into a single fp8 weight with a single weight scale.
  167. else:
  168. # WEIGHT_SCALE / WEIGHT
  169. # Loop over logical weights, requantizing with single scale.
  170. max_w_scale = layer.weight_scale.max()
  171. start = 0
  172. for idx, logical_width in enumerate(layer.logical_widths):
  173. end = start + logical_width
  174. weight_dq = per_tensor_dequantize(layer.weight[start:end, :],
  175. layer.weight_scale[idx])
  176. layer.weight[start:end, :] = per_tensor_quantize(
  177. weight_dq, layer.weight_scale.max())
  178. start = end
  179. layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
  180. # WEIGHT
  181. # Transpose weight for passing to torch._scaled_mm
  182. weight = layer.weight
  183. layer.weight = Parameter(weight.t(), requires_grad=False)
  184. # INPUT ACTIVATION SCALE
  185. # Dynamic: set to None (required input to ops.scaled_fp8_quant).
  186. # Static: set to max of the input_scales (since they are equal).
  187. if self.quant_config.activation_scheme == "dynamic":
  188. layer.input_scale = None
  189. elif self.quant_config.activation_scheme == "static":
  190. if not all_close_1d(layer.input_scale):
  191. raise ValueError(
  192. "All the input_scales for the logical weights of a "
  193. f"layer must be equal. But got {layer.input_scale}")
  194. layer.input_scale = Parameter(layer.input_scale.max(),
  195. requires_grad=False)
  196. else:
  197. raise ValueError(
  198. f"Unknown scheme {self.quant_config.activation_scheme}")
  199. def apply(self,
  200. layer: torch.nn.Module,
  201. x: torch.Tensor,
  202. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  203. # ops.scaled_fp8_quant supports both dynamic and static quant.
  204. # If dynamic, layer.input_scale is None and x_scale computed from x.
  205. # If static, layer.input_scale is scalar and x_scale is input_scale.
  206. if bias is None and self.cutlass_fp8_supported:
  207. qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale)
  208. # Fused GEMM_DQ
  209. output = ops.cutlass_scaled_mm(
  210. qinput,
  211. layer.weight,
  212. out_dtype=x.dtype,
  213. scale_a=x_scale,
  214. scale_b=layer.weight_scale,
  215. )
  216. else:
  217. qinput, x_scale = ops.scaled_fp8_quant(x,
  218. layer.input_scale,
  219. batch_dim_padding=17)
  220. # Fused GEMM_DQ -- note we padded the input above because
  221. # torch._scaled_mm is more performant for matrices with
  222. # batch dimension > 16. Note that this could change
  223. # in the future.
  224. output, _ = torch._scaled_mm(
  225. qinput,
  226. layer.weight,
  227. out_dtype=x.dtype,
  228. scale_a=x_scale,
  229. scale_b=layer.weight_scale,
  230. bias=bias,
  231. )
  232. return torch.narrow(output, 0, 0, x.shape[0])
  233. class Fp8KVCacheMethod(QuantizeMethodBase):
  234. """Supports loading kv-cache scaling factors from FP8 checkpoints.
  235. """
  236. def __init__(self, quant_config: Fp8Config):
  237. self.quant_config = quant_config
  238. def create_weights(self, layer: torch.nn.Module):
  239. """Create "weight" (aka kv_scale) for an attention layer.
  240. Args:
  241. layer: The layer that is using the QuantizeMethodBase factory.
  242. """
  243. # Initialize the KV cache scale to 1.0 as the default value.
  244. # If the kv_scale appears in the checkpoint, it will be
  245. # overwritten when loading weights.
  246. layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False)
  247. def apply(self, layer: torch.nn.Module) -> torch.Tensor:
  248. raise RuntimeError("Fp8KVCacheMethod.apply should not be called.")
  249. def process_weights_after_loading(self, layer: Module) -> None:
  250. # If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0
  251. # regardless whether the kv-scale is available in the checkpoint.
  252. if layer.kv_cache_dtype != "auto":
  253. kv_scale = layer.kv_scale.to("cpu").tolist()
  254. if not isinstance(kv_scale, float):
  255. raise ValueError("Only support per-tensor scaling factor "
  256. "for fp8 KV cache")
  257. layer._kv_scale = kv_scale
  258. if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
  259. print_warning_once(
  260. "Using KV cache scaling factor 1.0 for fp8_e4m3. This may "
  261. "cause accuracy issues. Please make sure kv-cache scaling "
  262. "factor is available in the fp8 checkpoint.")
  263. del layer.kv_scale
  264. def all_close_1d(x: torch.Tensor) -> bool:
  265. assert len(x.shape) == 1
  266. return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
  267. def per_tensor_quantize(tensor: torch.Tensor,
  268. inv_scale: Union[float, torch.Tensor]) -> torch.Tensor:
  269. finfo = torch.finfo(torch.float8_e4m3fn)
  270. qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
  271. return qweight.to(torch.float8_e4m3fn)
  272. def per_tensor_dequantize(
  273. tensor: torch.Tensor, inv_scale: Union[float,
  274. torch.Tensor]) -> torch.Tensor:
  275. fake_qweight = tensor.to(torch.float16)
  276. dq_weight = fake_qweight * inv_scale
  277. return dq_weight