1
0

fp8.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. from contextlib import suppress
  2. from typing import Any, Dict, List, Optional, Tuple, Union
  3. import torch
  4. from loguru import logger
  5. from torch.nn import Module
  6. from torch.nn.parameter import Parameter
  7. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  8. from aphrodite.modeling.utils import set_weight_attrs
  9. from aphrodite.quantization.base_config import (QuantizationConfig,
  10. QuantizeMethodBase)
  11. from aphrodite.common.utils import print_warning_once
  12. HAS_QUANTS = False
  13. with suppress(ImportError):
  14. from aphrodite._quant_C import quant_ops as ops
  15. HAS_QUANTS = True
  16. ACTIVATION_SCHEMES = ["static", "dynamic"]
  17. def scaled_fp8_quant(
  18. input: torch.Tensor,
  19. scale: Optional[torch.Tensor] = None,
  20. batch_dim_padding: Optional[int] = None,
  21. ) -> Tuple[torch.Tensor, torch.Tensor]:
  22. """
  23. Quantize input tensor to FP8 and return quantized tensor and scale.
  24. This function supports both static and dynamic quantization: If you
  25. provide the scale, it will use static scaling and if you omit it,
  26. the scale will be determined dynamically. The function also allows
  27. optional padding of the output tensor for downstream kernels that
  28. will benefit from padding.
  29. Args:
  30. input: The input tensor to be quantized to FP8
  31. scale: Optional scaling factor for the FP8 quantization
  32. batch_dim_padding: If specified, pad the first dimension
  33. of the output to at least this value.
  34. Returns:
  35. Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
  36. scaling factor.
  37. """
  38. if batch_dim_padding:
  39. shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:])
  40. output = torch.empty(shape,
  41. device=input.device,
  42. dtype=torch.float8_e4m3fn)
  43. else:
  44. output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
  45. if scale is None:
  46. scale = torch.zeros(1, device=input.device, dtype=torch.float32)
  47. ops.dynamic_scaled_fp8_quant(output, input, scale)
  48. else:
  49. ops.static_scaled_fp8_quant(output, input, scale)
  50. return output, scale
  51. class Fp8Config(QuantizationConfig):
  52. """Config class for FP8."""
  53. def __init__(
  54. self,
  55. is_checkpoint_fp8_serialized: bool = False,
  56. activation_scheme: str = "dynamic",
  57. ) -> None:
  58. self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
  59. if is_checkpoint_fp8_serialized:
  60. logger.warning("Detected fp8 checkpoint. Please note that the "
  61. "format is experimental and subject to change.")
  62. if activation_scheme not in ACTIVATION_SCHEMES:
  63. raise ValueError(
  64. f"Unsupported activation scheme {activation_scheme}")
  65. self.activation_scheme = activation_scheme
  66. @classmethod
  67. def get_name(cls) -> str:
  68. return "fp8"
  69. @classmethod
  70. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  71. return [torch.bfloat16, torch.half]
  72. @classmethod
  73. def get_min_capability(cls) -> int:
  74. return 89
  75. @classmethod
  76. def get_config_filenames(cls) -> List[str]:
  77. return []
  78. @classmethod
  79. def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
  80. quant_method = cls.get_from_keys(config, ["quant_method"])
  81. is_checkpoint_fp8_serialized = ("fp8" in quant_method)
  82. activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
  83. return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
  84. activation_scheme=activation_scheme)
  85. def get_quant_method(
  86. self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
  87. from aphrodite.attention.layer import Attention # Avoid circular import
  88. if isinstance(layer, LinearBase):
  89. return Fp8LinearMethod(self)
  90. if isinstance(layer, Attention):
  91. return Fp8KVCacheMethod(self)
  92. return None
  93. def get_scaled_act_names(self) -> List[str]:
  94. return []
  95. class Fp8LinearMethod(LinearMethodBase):
  96. """Linear method for FP8.
  97. Supports loading FP8 checkpoints with static weight scale and
  98. dynamic/static activation scale.
  99. Also supports loading quantized FP16/BF16 model checkpoints with dynamic
  100. activation scaling. The weight scaling factor will be initialized after
  101. the model weights are loaded.
  102. Limitations:
  103. 1. Only support per-tensor quantization due to torch._scaled_mm support.
  104. 2. Only support float8_e4m3fn data type due to the limitation of
  105. torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
  106. Args:
  107. quant_config: The quantization config.
  108. """
  109. def __init__(self, quant_config: Fp8Config):
  110. if not HAS_QUANTS:
  111. raise ImportError("Could not find the quantization kernels.")
  112. self.quant_config = quant_config
  113. def _create_scale_param(
  114. self,
  115. scale_name: str,
  116. layer: torch.nn.Module,
  117. output_partition_sizes: List[int],
  118. **extra_weight_attrs,
  119. ) -> None:
  120. scale = Parameter(torch.empty(len(output_partition_sizes),
  121. dtype=torch.float32),
  122. requires_grad=False)
  123. layer.register_parameter(scale_name, scale)
  124. set_weight_attrs(
  125. scale, {
  126. **extra_weight_attrs,
  127. "fp8_scales_shard_indexer":
  128. self.scales_shard_indexer,
  129. })
  130. def create_weights(
  131. self,
  132. layer: torch.nn.Module,
  133. input_size_per_partition: int,
  134. output_partition_sizes: List[int],
  135. input_size: int,
  136. output_size: int,
  137. params_dtype: torch.dtype,
  138. **extra_weight_attrs,
  139. ):
  140. del input_size, output_size
  141. output_size_per_partition = sum(output_partition_sizes)
  142. layer.process_after_load = True
  143. layer.logical_widths = output_partition_sizes
  144. # WEIGHT
  145. weight_dtype = (torch.float8_e4m3fn
  146. if self.quant_config.is_checkpoint_fp8_serialized else
  147. params_dtype)
  148. weight = Parameter(torch.empty(output_size_per_partition,
  149. input_size_per_partition,
  150. dtype=weight_dtype),
  151. requires_grad=False)
  152. layer.register_parameter("weight", weight)
  153. set_weight_attrs(weight, {
  154. **extra_weight_attrs,
  155. "input_dim": 1,
  156. "output_dim": 0,
  157. })
  158. # If checkpoint is serialized fp8, load them.
  159. # Otherwise, wait until process_weights_after_loading.
  160. if self.quant_config.is_checkpoint_fp8_serialized:
  161. # WEIGHT SCALE
  162. self._create_scale_param(
  163. scale_name="weight_scale",
  164. layer=layer,
  165. output_partition_sizes=output_partition_sizes,
  166. **extra_weight_attrs)
  167. # ACTIVATION SCALE
  168. if self.quant_config.activation_scheme == "static":
  169. self._create_scale_param(
  170. scale_name="act_scale",
  171. layer=layer,
  172. output_partition_sizes=output_partition_sizes,
  173. **extra_weight_attrs)
  174. def scales_shard_indexer(
  175. self, param: torch.Tensor, loaded_weight: torch.Tensor,
  176. shard_id: Union[str, int]) -> Tuple[torch.Tensor, torch.Tensor]:
  177. qkv_idxs = {"q": 0, "k": 1, "v": 2}
  178. if isinstance(shard_id, int):
  179. pass
  180. elif isinstance(shard_id, str):
  181. if shard_id not in qkv_idxs:
  182. raise ValueError(f"Unknown shard_id: {shard_id}")
  183. shard_id = qkv_idxs[shard_id]
  184. else:
  185. ValueError(f"Shard id must be int or str but got {type(shard_id)}")
  186. return param[shard_id], loaded_weight
  187. def process_weights_after_loading(self, layer: Module) -> None:
  188. if (not hasattr(layer, "process_after_load")
  189. or not layer.process_after_load):
  190. return
  191. # If checkpoint is fp/bf16 (not serialized fp8), quantize the weights.
  192. if not self.quant_config.is_checkpoint_fp8_serialized:
  193. qweight, weight_scale = scaled_fp8_quant(layer.weight, scale=None)
  194. layer.weight = Parameter(qweight.t(), requires_grad=False)
  195. layer.weight_scale = Parameter(weight_scale, requires_grad=False)
  196. layer.logical_widths = None
  197. layer.act_scale = None
  198. return
  199. # If checkpoint is fp8, requantize the separately quantized logical
  200. # weights into a single fp8 weight with a single weight scale.
  201. else:
  202. # WEIGHT_SCALE / WEIGHT
  203. # Loop over logical weights, requantizing with single scale.
  204. max_w_scale = layer.weight_scale.max()
  205. start = 0
  206. for idx, logical_width in enumerate(layer.logical_widths):
  207. end = start + logical_width
  208. weight_dq = per_tensor_dequantize(layer.weight[start:end, :],
  209. layer.weight_scale[idx])
  210. layer.weight[start:end, :] = per_tensor_quantize(
  211. weight_dq, layer.weight_scale.max())
  212. start = end
  213. layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
  214. # WEIGHT
  215. # Transpose weight for passing to torch._scaled_mm
  216. weight = layer.weight
  217. layer.weight = Parameter(weight.t(), requires_grad=False)
  218. # ACT_SCALE
  219. # Dynamic: set to None (required input to ops.scaled_fp8_quant).
  220. # Static: set to max of the act_scales (since they are equal).
  221. if self.quant_config.activation_scheme == "dynamic":
  222. layer.act_scale = None
  223. elif self.quant_config.activation_scheme == "static":
  224. if not all_close_1d(layer.act_scale):
  225. raise ValueError(
  226. "All the act_scales for the logical weights of a layer "
  227. f"must be equal. But got {layer.act_scale}")
  228. layer.act_scale = Parameter(layer.act_scale.max(),
  229. requires_grad=False)
  230. else:
  231. raise ValueError(
  232. f"Unknown scheme {self.quant_config.activation_scheme}")
  233. def apply(self,
  234. layer: torch.nn.Module,
  235. x: torch.Tensor,
  236. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  237. # ops.scaled_fp8_quant supports both dynamic and static quant.
  238. # If dynamic, layer.act_scale is None and x_scale computed from x.
  239. # If static, layer.act_scale is scalar and x_scale set to act_scale.
  240. qinput, x_scale = scaled_fp8_quant(x,
  241. layer.act_scale,
  242. batch_dim_padding=17)
  243. # Fused GEMM_DQ -- note we padded the input above because
  244. # torch._scaled_mm is more performant for matrices with
  245. # batch dimension > 16. Note that this could change
  246. # in the future.
  247. output, _ = torch._scaled_mm(
  248. qinput,
  249. layer.weight,
  250. out_dtype=x.dtype,
  251. scale_a=x_scale,
  252. scale_b=layer.weight_scale,
  253. bias=bias,
  254. )
  255. return torch.narrow(output, 0, 0, x.shape[0])
  256. class Fp8KVCacheMethod(QuantizeMethodBase):
  257. """Supports loading kv-cache scaling factors from FP8 checkpoints.
  258. """
  259. def __init__(self, quant_config: Fp8Config):
  260. self.quant_config = quant_config
  261. def create_weights(self, layer: torch.nn.Module):
  262. """Create "weight" (aka kv_scale) for an attention layer.
  263. Args:
  264. layer: The layer that is using the QuantizeMethodBase factory.
  265. """
  266. # Initialize the KV cache scale to 1.0 as the default value.
  267. # If the kv_scale appears in the checkpoint, it will be
  268. # overwritten when loading weights.
  269. layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False)
  270. def apply(self, layer: torch.nn.Module) -> torch.Tensor:
  271. raise RuntimeError("Fp8KVCacheMethod.apply should not be called.")
  272. def process_weights_after_loading(self, layer: Module) -> None:
  273. # If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0
  274. # regardless whether the kv-scale is available in the checkpoint.
  275. if layer.kv_cache_dtype != "auto":
  276. kv_scale = layer.kv_scale.to("cpu").tolist()
  277. if not isinstance(kv_scale, float):
  278. raise ValueError("Only support per-tensor scaling factor "
  279. "for fp8 KV cache")
  280. layer._kv_scale = kv_scale
  281. if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
  282. print_warning_once(
  283. "Using KV cache scaling factor 1.0 for fp8_e4m3. This may "
  284. "cause accuracy issues. Please make sure kv-cache scaling "
  285. "factor is available in the fp8 checkpoint.")
  286. del layer.kv_scale
  287. def all_close_1d(x: torch.Tensor) -> bool:
  288. assert len(x.shape) == 1
  289. return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
  290. def per_tensor_quantize(tensor: torch.Tensor,
  291. inv_scale: float) -> torch.Tensor:
  292. finfo = torch.finfo(torch.float8_e4m3fn)
  293. qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
  294. return qweight.to(torch.float8_e4m3fn)
  295. def per_tensor_dequantize(tensor: torch.Tensor,
  296. inv_scale: float) -> torch.Tensor:
  297. fake_qweight = tensor.to(torch.float16)
  298. dq_weight = fake_qweight * inv_scale
  299. return dq_weight