fp8.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from loguru import logger
  4. from torch.nn import Module
  5. from torch.nn.parameter import Parameter
  6. from aphrodite import _custom_ops as ops
  7. from aphrodite.common.utils import print_warning_once
  8. from aphrodite.modeling.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
  9. fused_moe)
  10. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  11. from aphrodite.modeling.utils import set_weight_attrs
  12. from aphrodite.platforms import current_platform
  13. from aphrodite.quantization.base_config import (QuantizationConfig,
  14. QuantizeMethodBase)
  15. from aphrodite.quantization.utils.marlin_utils_fp8 import (
  16. apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
  17. from aphrodite.quantization.utils.w8a8_utils import (
  18. all_close_1d, apply_fp8_linear, create_per_tensor_scale_param,
  19. cutlass_fp8_supported, per_tensor_dequantize, requantize_with_max_scale)
  20. ACTIVATION_SCHEMES = ["static", "dynamic"]
  21. class Fp8Config(QuantizationConfig):
  22. """Config class for FP8."""
  23. def __init__(
  24. self,
  25. is_checkpoint_fp8_serialized: bool = False,
  26. activation_scheme: str = "dynamic",
  27. ) -> None:
  28. self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
  29. if is_checkpoint_fp8_serialized:
  30. logger.warning("Detected fp8 checkpoint. Please note that the "
  31. "format is experimental and subject to change.")
  32. if activation_scheme not in ACTIVATION_SCHEMES:
  33. raise ValueError(
  34. f"Unsupported activation scheme {activation_scheme}")
  35. self.activation_scheme = activation_scheme
  36. @classmethod
  37. def get_name(cls) -> str:
  38. return "fp8"
  39. @classmethod
  40. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  41. return [torch.bfloat16, torch.half]
  42. @classmethod
  43. def get_min_capability(cls) -> int:
  44. return 80
  45. @classmethod
  46. def get_config_filenames(cls) -> List[str]:
  47. return []
  48. @classmethod
  49. def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
  50. quant_method = cls.get_from_keys(config, ["quant_method"])
  51. is_checkpoint_fp8_serialized = ("fp8" in quant_method)
  52. activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
  53. return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
  54. activation_scheme=activation_scheme)
  55. def get_quant_method(self, layer: torch.nn.Module,
  56. prefix: str) -> Optional["QuantizeMethodBase"]:
  57. from aphrodite.attention.layer import \
  58. Attention # Avoid circular import
  59. if isinstance(layer, LinearBase):
  60. return Fp8LinearMethod(self)
  61. elif isinstance(layer, FusedMoE):
  62. return Fp8MoEMethod(self)
  63. elif isinstance(layer, Attention):
  64. return Fp8KVCacheMethod(self)
  65. return None
  66. def get_scaled_act_names(self) -> List[str]:
  67. return []
  68. class Fp8LinearMethod(LinearMethodBase):
  69. """Linear method for FP8.
  70. Supports loading FP8 checkpoints with static weight scale and
  71. dynamic/static activation scale.
  72. Also supports loading quantized FP16/BF16 model checkpoints with dynamic
  73. activation scaling. The weight scaling factor will be initialized after
  74. the model weights are loaded.
  75. Limitations:
  76. 1. Only support per-tensor quantization due to torch._scaled_mm support.
  77. 2. Only support float8_e4m3fn data type due to the limitation of
  78. torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
  79. Args:
  80. quant_config: The quantization config.
  81. """
  82. def __init__(self, quant_config: Fp8Config):
  83. self.quant_config = quant_config
  84. self.cutlass_fp8_supported = cutlass_fp8_supported()
  85. # For GPUs that lack FP8 hardware support, we can leverage the Marlin
  86. # kernel for fast weight-only FP8 quantization
  87. capability = current_platform.get_device_capability()
  88. capability = capability[0] * 10 + capability[1]
  89. self.use_marlin = capability < 89
  90. def create_weights(
  91. self,
  92. layer: torch.nn.Module,
  93. input_size_per_partition: int,
  94. output_partition_sizes: List[int],
  95. input_size: int,
  96. output_size: int,
  97. params_dtype: torch.dtype,
  98. **extra_weight_attrs,
  99. ):
  100. del input_size, output_size
  101. output_size_per_partition = sum(output_partition_sizes)
  102. layer.logical_widths = output_partition_sizes
  103. layer.input_size_per_partition = input_size_per_partition
  104. layer.output_size_per_partition = output_size_per_partition
  105. layer.orig_dtype = params_dtype
  106. # WEIGHT
  107. weight_dtype = (torch.float8_e4m3fn
  108. if self.quant_config.is_checkpoint_fp8_serialized else
  109. params_dtype)
  110. weight = Parameter(torch.empty(output_size_per_partition,
  111. input_size_per_partition,
  112. dtype=weight_dtype),
  113. requires_grad=False)
  114. layer.register_parameter("weight", weight)
  115. set_weight_attrs(weight, {
  116. **extra_weight_attrs,
  117. "input_dim": 1,
  118. "output_dim": 0,
  119. })
  120. # If checkpoint is serialized fp8, load them.
  121. # Otherwise, wait until process_weights_after_loading.
  122. if self.quant_config.is_checkpoint_fp8_serialized:
  123. # WEIGHT SCALE
  124. scale = create_per_tensor_scale_param(output_partition_sizes,
  125. **extra_weight_attrs)
  126. layer.register_parameter("weight_scale", scale)
  127. # INPUT ACTIVATION SCALE
  128. if self.quant_config.activation_scheme == "static":
  129. scale = create_per_tensor_scale_param(output_partition_sizes,
  130. **extra_weight_attrs)
  131. layer.register_parameter("input_scale", scale)
  132. def process_weights_after_loading(self, layer: Module) -> None:
  133. # If checkpoint not serialized fp8, quantize the weights.
  134. if not self.quant_config.is_checkpoint_fp8_serialized:
  135. qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
  136. scale=None)
  137. # Update the layer with the new values.
  138. layer.weight = Parameter(qweight.t(), requires_grad=False)
  139. layer.weight_scale = Parameter(weight_scale, requires_grad=False)
  140. layer.input_scale = None
  141. # If checkpoint is fp8, requantize the separately quantized logical
  142. # weights into a single fp8 weight with a single weight scale.
  143. else:
  144. # Dequant -> Quant with max scale.
  145. max_w_scale, weight = requantize_with_max_scale(
  146. weight=layer.weight,
  147. weight_scale=layer.weight_scale,
  148. logical_widths=layer.logical_widths,
  149. )
  150. # Update layer with new values.
  151. layer.weight = Parameter(weight.t(), requires_grad=False)
  152. layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
  153. if self.quant_config.activation_scheme == "static":
  154. layer.input_scale = Parameter(layer.input_scale.max(),
  155. requires_grad=False)
  156. else:
  157. layer.input_scale = None
  158. if self.use_marlin:
  159. prepare_fp8_layer_for_marlin(layer)
  160. # Activations not quantized for marlin.
  161. del layer.input_scale
  162. def apply(self,
  163. layer: torch.nn.Module,
  164. x: torch.Tensor,
  165. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  166. if self.use_marlin:
  167. return apply_fp8_marlin_linear(
  168. input=x,
  169. weight=layer.weight,
  170. weight_scale=layer.weight_scale,
  171. workspace=layer.workspace,
  172. size_n=layer.output_size_per_partition,
  173. size_k=layer.input_size_per_partition,
  174. bias=bias)
  175. return apply_fp8_linear(
  176. input=x,
  177. weight=layer.weight,
  178. weight_scale=layer.weight_scale,
  179. input_scale=layer.input_scale,
  180. bias=bias,
  181. cutlass_fp8_supported=self.cutlass_fp8_supported,
  182. use_per_token_if_dynamic=False)
  183. class Fp8MoEMethod(FusedMoEMethodBase):
  184. """MoE method for FP8.
  185. Supports loading FP8 checkpoints with static weight scale and
  186. dynamic/static activation scale.
  187. Also supports loading quantized FP16/BF16 model checkpoints with dynamic
  188. activation scaling. The weight scaling factor will be initialized after
  189. the model weights are loaded.
  190. Args:
  191. quant_config: The quantization config.
  192. """
  193. def __init__(self, quant_config: Fp8Config):
  194. self.quant_config = quant_config
  195. def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
  196. intermediate_size: int, params_dtype: torch.dtype,
  197. **extra_weight_attrs):
  198. if self.quant_config.is_checkpoint_fp8_serialized:
  199. params_dtype = torch.float8_e4m3fn
  200. # WEIGHTS
  201. w13_weight = torch.nn.Parameter(torch.empty(num_experts,
  202. 2 * intermediate_size,
  203. hidden_size,
  204. dtype=params_dtype),
  205. requires_grad=False)
  206. layer.register_parameter("w13_weight", w13_weight)
  207. set_weight_attrs(w13_weight, extra_weight_attrs)
  208. w2_weight = torch.nn.Parameter(torch.empty(num_experts,
  209. hidden_size,
  210. intermediate_size,
  211. dtype=params_dtype),
  212. requires_grad=False)
  213. layer.register_parameter("w2_weight", w2_weight)
  214. set_weight_attrs(w2_weight, extra_weight_attrs)
  215. # WEIGHT_SCALES
  216. # Allocate 2 scales for w1 and w3 respectively.
  217. # They will be combined to a single scale after weight loading.
  218. w13_scale = torch.nn.Parameter(torch.ones(num_experts,
  219. 2,
  220. dtype=torch.float32),
  221. requires_grad=False)
  222. layer.register_parameter("w13_scale", w13_scale)
  223. w2_scale = torch.nn.Parameter(torch.ones(num_experts,
  224. dtype=torch.float32),
  225. requires_grad=False)
  226. layer.register_parameter("w2_scale", w2_scale)
  227. # If loading fp8 checkpoint, pass the weight loaders.
  228. # If loading an fp16 checkpoint, do not (we will quantize in
  229. # process_weights_after_loading()
  230. if self.quant_config.is_checkpoint_fp8_serialized:
  231. set_weight_attrs(w13_scale, extra_weight_attrs)
  232. set_weight_attrs(w2_scale, extra_weight_attrs)
  233. # INPUT_SCALES
  234. if self.quant_config.activation_scheme == "static":
  235. if not self.quant_config.is_checkpoint_fp8_serialized:
  236. raise ValueError(
  237. "Found static activation scheme for checkpoint that "
  238. "was not serialized fp8.")
  239. a13_scale = torch.nn.Parameter(torch.ones(num_experts,
  240. dtype=torch.float32),
  241. requires_grad=False)
  242. layer.register_parameter("a13_scale", a13_scale)
  243. set_weight_attrs(a13_scale, extra_weight_attrs)
  244. a2_scale = torch.nn.Parameter(torch.ones(num_experts,
  245. dtype=torch.float32),
  246. requires_grad=False)
  247. layer.register_parameter("a2_scale", a2_scale)
  248. set_weight_attrs(a2_scale, extra_weight_attrs)
  249. else:
  250. layer.a13_scale = None
  251. layer.a2_scale = None
  252. def process_weights_after_loading(self, layer: Module) -> None:
  253. # If checkpoint is fp16, quantize in place.
  254. if not self.quant_config.is_checkpoint_fp8_serialized:
  255. w13_weight = torch.empty_like(layer.w13_weight.data,
  256. dtype=torch.float8_e4m3fn)
  257. w2_weight = torch.empty_like(layer.w2_weight.data,
  258. dtype=torch.float8_e4m3fn)
  259. # Re-initialize w13_scale because we directly quantize
  260. # merged w13 weights and generate a single scaling factor.
  261. layer.w13_scale = torch.nn.Parameter(torch.ones(
  262. layer.num_experts,
  263. dtype=torch.float32,
  264. device=w13_weight.device),
  265. requires_grad=False)
  266. for expert in range(layer.num_experts):
  267. w13_weight[expert, :, :], layer.w13_scale[
  268. expert] = ops.scaled_fp8_quant(
  269. layer.w13_weight.data[expert, :, :])
  270. w2_weight[expert, :, :], layer.w2_scale[
  271. expert] = ops.scaled_fp8_quant(
  272. layer.w2_weight.data[expert, :, :])
  273. layer.w13_weight = torch.nn.Parameter(w13_weight,
  274. requires_grad=False)
  275. layer.w2_weight = torch.nn.Parameter(w2_weight,
  276. requires_grad=False)
  277. return
  278. # If checkpoint is fp8, we need to handle that the
  279. # MoE kernels require single activation scale and single weight
  280. # scale for w13 per expert.
  281. else:
  282. # Fp8 moe kernels require a single activation scale.
  283. # We take the max of all the scales in case they differ.
  284. if self.quant_config.activation_scheme == "static":
  285. if layer.a13_scale is None or layer.a2_scale is None:
  286. raise ValueError(
  287. "QuantConfig has static quantization, but found "
  288. "activation scales are None.")
  289. if (not all_close_1d(layer.a13_scale)
  290. or not all_close_1d(layer.a2_scale)):
  291. print_warning_once(
  292. "Found input_scales that are not equal for "
  293. "fp8 MoE layer. Using the maximum across experts "
  294. "for each layer. ")
  295. layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(),
  296. requires_grad=False)
  297. layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(),
  298. requires_grad=False)
  299. # Fp8 moe kernel needs single weight scale for w13 per expert.
  300. # We take the max then dequant and requant each expert.
  301. assert layer.w13_scale is not None
  302. shard_size = layer.intermediate_size_per_partition
  303. max_w13_scales = layer.w13_scale.max(dim=1).values
  304. for expert_id in range(layer.num_experts):
  305. start = 0
  306. for shard_id in range(2):
  307. dq_weight = per_tensor_dequantize(
  308. layer.w13_weight[expert_id][start:start +
  309. shard_size, :],
  310. layer.w13_scale[expert_id][shard_id])
  311. layer.w13_weight[expert_id][
  312. start:start + shard_size, :], _ = ops.scaled_fp8_quant(
  313. dq_weight, max_w13_scales[expert_id])
  314. start += shard_size
  315. layer.w13_scale = torch.nn.Parameter(max_w13_scales,
  316. requires_grad=False)
  317. return
  318. def apply(self,
  319. layer: torch.nn.Module,
  320. x: torch.Tensor,
  321. router_logits: torch.Tensor,
  322. top_k: int,
  323. renormalize: bool = True,
  324. use_grouped_topk: bool = False,
  325. num_expert_group: Optional[int] = None,
  326. topk_group: Optional[int] = None) -> torch.Tensor:
  327. return fused_moe(x,
  328. layer.w13_weight,
  329. layer.w2_weight,
  330. router_logits,
  331. top_k,
  332. renormalize=renormalize,
  333. inplace=True,
  334. use_fp8=True,
  335. w1_scale=layer.w13_scale,
  336. w2_scale=layer.w2_scale,
  337. a1_scale=layer.a13_scale,
  338. a2_scale=layer.a2_scale,
  339. use_grouped_topk=use_grouped_topk,
  340. num_expert_group=num_expert_group,
  341. topk_group=topk_group)
  342. class Fp8KVCacheMethod(QuantizeMethodBase):
  343. """Supports loading kv-cache scaling factors from FP8 checkpoints.
  344. """
  345. def __init__(self, quant_config: Fp8Config):
  346. self.quant_config = quant_config
  347. def create_weights(self, layer: torch.nn.Module):
  348. """Create "weight" (aka k_scale and v_scale) for an attention layer.
  349. Args:
  350. layer: The layer that is using the QuantizeMethodBase factory.
  351. """
  352. # Initialize the KV cache scales to -1.0, which is an invalid value.
  353. # If the k/v_scale appears in the checkpoint, it will be
  354. # overwritten when loading weights.
  355. layer.k_scale = Parameter(torch.tensor(-1.0), requires_grad=False)
  356. layer.v_scale = Parameter(torch.tensor(-1.0), requires_grad=False)
  357. def apply(self, layer: torch.nn.Module) -> torch.Tensor:
  358. raise RuntimeError("Fp8KVCacheMethod.apply should not be called.")
  359. def process_weights_after_loading(self, layer: Module) -> None:
  360. # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
  361. # regardless whether the kv-scale is available in the checkpoint.
  362. if layer.kv_cache_dtype != "auto":
  363. if layer.k_scale > 0.0 and layer.v_scale > 0.0:
  364. # We prefer to use separate k_scale and v_scale if present
  365. k_scale = layer.k_scale.to("cpu").tolist()
  366. v_scale = layer.v_scale.to("cpu").tolist()
  367. elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
  368. # If no scales were loaded (both scales are invalid negative
  369. # values), use the default value of 1.0
  370. k_scale = Parameter(torch.tensor(1.0), requires_grad=False)
  371. v_scale = Parameter(torch.tensor(1.0), requires_grad=False)
  372. else:
  373. # If we find a single kv_scale in the checkpoint, we remap
  374. # kv_scale to k_scale during weight loading, and duplicate
  375. # k_scale to v_scale here
  376. assert layer.k_scale > 0.0
  377. scale_to_duplicate = max(layer.k_scale, layer.v_scale)
  378. k_scale = scale_to_duplicate.to("cpu").tolist()
  379. v_scale = scale_to_duplicate.to("cpu").tolist()
  380. if not isinstance(k_scale, float) or not isinstance(
  381. v_scale, float):
  382. raise ValueError("Only support per-tensor scaling factor "
  383. "for fp8 KV cache")
  384. # These are used in the final Attention.forward()
  385. layer._k_scale = k_scale
  386. layer._v_scale = v_scale
  387. if (layer._k_scale == 1.0 and layer._v_scale == 1.0
  388. and "e5m2" not in layer.kv_cache_dtype):
  389. print_warning_once(
  390. "Using KV cache scaling factor 1.0 for fp8_e4m3. This "
  391. "may cause accuracy issues. Please make sure k/v_scale "
  392. "scaling factors are available in the fp8 checkpoint.")
  393. del layer.k_scale
  394. del layer.v_scale