fp8.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508
  1. from typing import Any, Dict, List, Optional, Union
  2. import torch
  3. from loguru import logger
  4. from torch.nn import Module
  5. from torch.nn.parameter import Parameter
  6. from aphrodite import _custom_ops as ops
  7. from aphrodite.common.utils import print_warning_once
  8. from aphrodite.modeling.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
  9. fused_moe)
  10. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  11. from aphrodite.modeling.utils import set_weight_attrs
  12. from aphrodite.platforms import current_platform
  13. from aphrodite.quantization.base_config import (QuantizationConfig,
  14. QuantizeMethodBase)
  15. ACTIVATION_SCHEMES = ["static", "dynamic"]
  16. def cutlass_fp8_supported() -> bool:
  17. capability = current_platform.get_device_capability()
  18. capability = capability[0] * 10 + capability[1]
  19. return ops.cutlass_scaled_mm_supports_fp8(capability)
  20. class Fp8Config(QuantizationConfig):
  21. """Config class for FP8."""
  22. def __init__(
  23. self,
  24. is_checkpoint_fp8_serialized: bool = False,
  25. activation_scheme: str = "dynamic",
  26. ) -> None:
  27. self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
  28. if is_checkpoint_fp8_serialized:
  29. logger.warning("Detected fp8 checkpoint. Please note that the "
  30. "format is experimental and subject to change.")
  31. if activation_scheme not in ACTIVATION_SCHEMES:
  32. raise ValueError(
  33. f"Unsupported activation scheme {activation_scheme}")
  34. self.activation_scheme = activation_scheme
  35. @classmethod
  36. def get_name(cls) -> str:
  37. return "fp8"
  38. @classmethod
  39. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  40. return [torch.bfloat16, torch.half]
  41. @classmethod
  42. def get_min_capability(cls) -> int:
  43. return 89
  44. @classmethod
  45. def get_config_filenames(cls) -> List[str]:
  46. return []
  47. @classmethod
  48. def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
  49. quant_method = cls.get_from_keys(config, ["quant_method"])
  50. is_checkpoint_fp8_serialized = ("fp8" in quant_method)
  51. activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
  52. return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
  53. activation_scheme=activation_scheme)
  54. def get_quant_method(
  55. self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
  56. from aphrodite.attention.layer import \
  57. Attention # Avoid circular import
  58. if isinstance(layer, LinearBase):
  59. return Fp8LinearMethod(self)
  60. elif isinstance(layer, FusedMoE):
  61. return Fp8MoEMethod(self)
  62. elif isinstance(layer, Attention):
  63. return Fp8KVCacheMethod(self)
  64. return None
  65. def get_scaled_act_names(self) -> List[str]:
  66. return []
  67. class Fp8LinearMethod(LinearMethodBase):
  68. """Linear method for FP8.
  69. Supports loading FP8 checkpoints with static weight scale and
  70. dynamic/static activation scale.
  71. Also supports loading quantized FP16/BF16 model checkpoints with dynamic
  72. activation scaling. The weight scaling factor will be initialized after
  73. the model weights are loaded.
  74. Limitations:
  75. 1. Only support per-tensor quantization due to torch._scaled_mm support.
  76. 2. Only support float8_e4m3fn data type due to the limitation of
  77. torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
  78. Args:
  79. quant_config: The quantization config.
  80. """
  81. def __init__(self, quant_config: Fp8Config):
  82. self.quant_config = quant_config
  83. self.cutlass_fp8_supported = cutlass_fp8_supported()
  84. def _create_scale_param(
  85. self,
  86. scale_name: str,
  87. layer: torch.nn.Module,
  88. output_partition_sizes: List[int],
  89. **extra_weight_attrs,
  90. ) -> None:
  91. scale = Parameter(torch.empty(len(output_partition_sizes),
  92. dtype=torch.float32),
  93. requires_grad=False)
  94. scale[:] = torch.finfo(torch.float8_e4m3fn).min
  95. layer.register_parameter(scale_name, scale)
  96. set_weight_attrs(scale, {
  97. **extra_weight_attrs,
  98. "needs_scalar_to_array": True,
  99. })
  100. def create_weights(
  101. self,
  102. layer: torch.nn.Module,
  103. input_size_per_partition: int,
  104. output_partition_sizes: List[int],
  105. input_size: int,
  106. output_size: int,
  107. params_dtype: torch.dtype,
  108. **extra_weight_attrs,
  109. ):
  110. del input_size, output_size
  111. output_size_per_partition = sum(output_partition_sizes)
  112. layer.process_after_load = True
  113. layer.logical_widths = output_partition_sizes
  114. # WEIGHT
  115. weight_dtype = (torch.float8_e4m3fn
  116. if self.quant_config.is_checkpoint_fp8_serialized else
  117. params_dtype)
  118. weight = Parameter(torch.empty(output_size_per_partition,
  119. input_size_per_partition,
  120. dtype=weight_dtype),
  121. requires_grad=False)
  122. layer.register_parameter("weight", weight)
  123. set_weight_attrs(weight, {
  124. **extra_weight_attrs,
  125. "input_dim": 1,
  126. "output_dim": 0,
  127. })
  128. # If checkpoint is serialized fp8, load them.
  129. # Otherwise, wait until process_weights_after_loading.
  130. if self.quant_config.is_checkpoint_fp8_serialized:
  131. # WEIGHT SCALE
  132. self._create_scale_param(
  133. scale_name="weight_scale",
  134. layer=layer,
  135. output_partition_sizes=output_partition_sizes,
  136. **extra_weight_attrs)
  137. # INPUT ACTIVATION SCALE
  138. if self.quant_config.activation_scheme == "static":
  139. self._create_scale_param(
  140. scale_name="input_scale",
  141. layer=layer,
  142. output_partition_sizes=output_partition_sizes,
  143. **extra_weight_attrs)
  144. def process_weights_after_loading(self, layer: Module) -> None:
  145. if (not hasattr(layer, "process_after_load")
  146. or not layer.process_after_load):
  147. return
  148. # If checkpoint is fp/bf16 (not serialized fp8), quantize the weights.
  149. if not self.quant_config.is_checkpoint_fp8_serialized:
  150. qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
  151. scale=None)
  152. layer.weight = Parameter(qweight.t(), requires_grad=False)
  153. layer.weight_scale = Parameter(weight_scale, requires_grad=False)
  154. layer.logical_widths = None
  155. layer.input_scale = None
  156. return
  157. # If checkpoint is fp8, requantize the separately quantized logical
  158. # weights into a single fp8 weight with a single weight scale.
  159. else:
  160. # WEIGHT_SCALE / WEIGHT
  161. # Loop over logical weights, requantizing with single scale.
  162. max_w_scale = layer.weight_scale.max()
  163. # QKV / MLP is fused in the on disk checkpoint if any of the
  164. # weight scales are still set to the default since we initialize
  165. # N weight scales for N shards but we only load 1 weight scale
  166. # from disk in this case. As a result, we skip dequant -> requant
  167. # since we already have quantized QKV together.
  168. unfused_module_in_checkpoint = (
  169. layer.weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min)
  170. if unfused_module_in_checkpoint:
  171. start = 0
  172. for idx, logical_width in enumerate(layer.logical_widths):
  173. end = start + logical_width
  174. weight_dq = per_tensor_dequantize(
  175. layer.weight[start:end, :], layer.weight_scale[idx])
  176. layer.weight[start:end, :] = per_tensor_quantize(
  177. weight_dq, layer.weight_scale.max())
  178. start = end
  179. layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
  180. # WEIGHT
  181. # Transpose weight for passing to torch._scaled_mm
  182. weight = layer.weight
  183. layer.weight = Parameter(weight.t(), requires_grad=False)
  184. # INPUT ACTIVATION SCALE
  185. # Dynamic: set to None (required input to ops.scaled_fp8_quant).
  186. # Static: set to max of the input_scales (since they are equal).
  187. if self.quant_config.activation_scheme == "dynamic":
  188. layer.input_scale = None
  189. elif self.quant_config.activation_scheme == "static":
  190. layer.input_scale = Parameter(layer.input_scale.max(),
  191. requires_grad=False)
  192. else:
  193. raise ValueError(
  194. f"Unknown scheme {self.quant_config.activation_scheme}")
  195. def apply(self,
  196. layer: torch.nn.Module,
  197. x: torch.Tensor,
  198. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  199. # ops.scaled_fp8_quant supports both dynamic and static quant.
  200. # If dynamic, layer.input_scale is None and x_scale computed from x.
  201. # If static, layer.input_scale is scalar and x_scale is input_scale.
  202. if bias is None and self.cutlass_fp8_supported:
  203. qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale)
  204. # Fused GEMM_DQ
  205. output = ops.cutlass_scaled_mm(
  206. qinput,
  207. layer.weight,
  208. out_dtype=x.dtype,
  209. scale_a=x_scale,
  210. scale_b=layer.weight_scale,
  211. )
  212. else:
  213. qinput, x_scale = ops.scaled_fp8_quant(x,
  214. layer.input_scale,
  215. batch_dim_padding=17)
  216. # Fused GEMM_DQ -- note we padded the input above because
  217. # torch._scaled_mm is more performant for matrices with
  218. # batch dimension > 16. Note that this could change
  219. # in the future.
  220. output, _ = torch._scaled_mm(
  221. qinput,
  222. layer.weight,
  223. out_dtype=x.dtype,
  224. scale_a=x_scale,
  225. scale_b=layer.weight_scale,
  226. bias=bias,
  227. )
  228. return torch.narrow(output, 0, 0, x.shape[0])
  229. class Fp8MoEMethod(FusedMoEMethodBase):
  230. """MoE method for FP8.
  231. Supports loading FP8 checkpoints with static weight scale and
  232. dynamic/static activation scale.
  233. Also supports loading quantized FP16/BF16 model checkpoints with dynamic
  234. activation scaling. The weight scaling factor will be initialized after
  235. the model weights are loaded.
  236. Args:
  237. quant_config: The quantization config.
  238. """
  239. def __init__(self, quant_config: Fp8Config):
  240. self.quant_config = quant_config
  241. def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
  242. intermediate_size: int, params_dtype: torch.dtype,
  243. **extra_weight_attrs):
  244. layer.process_after_load = True
  245. if self.quant_config.is_checkpoint_fp8_serialized:
  246. params_dtype = torch.float8_e4m3fn
  247. # WEIGHTS
  248. w13_weight = torch.nn.Parameter(torch.empty(num_experts,
  249. 2 * intermediate_size,
  250. hidden_size,
  251. dtype=params_dtype),
  252. requires_grad=False)
  253. layer.register_parameter("w13_weight", w13_weight)
  254. set_weight_attrs(w13_weight, extra_weight_attrs)
  255. w2_weight = torch.nn.Parameter(torch.empty(num_experts,
  256. hidden_size,
  257. intermediate_size,
  258. dtype=params_dtype),
  259. requires_grad=False)
  260. layer.register_parameter("w2_weight", w2_weight)
  261. set_weight_attrs(w2_weight, extra_weight_attrs)
  262. # WEIGHT_SCALES
  263. # Allocate 2 scales for w1 and w3 respectively.
  264. # They will be combined to a single scale after weight loading.
  265. w13_scale = torch.nn.Parameter(torch.ones(num_experts,
  266. 2,
  267. dtype=torch.float32),
  268. requires_grad=False)
  269. layer.register_parameter("w13_scale", w13_scale)
  270. w2_scale = torch.nn.Parameter(torch.ones(num_experts,
  271. dtype=torch.float32),
  272. requires_grad=False)
  273. layer.register_parameter("w2_scale", w2_scale)
  274. # If loading fp8 checkpoint, pass the weight loaders.
  275. # If loading an fp16 checkpoint, do not (we will quantize in
  276. # process_weights_after_loading()
  277. if self.quant_config.is_checkpoint_fp8_serialized:
  278. set_weight_attrs(w13_scale, extra_weight_attrs)
  279. set_weight_attrs(w2_scale, extra_weight_attrs)
  280. # INPUT_SCALES
  281. if self.quant_config.activation_scheme == "static":
  282. if not self.quant_config.is_checkpoint_fp8_serialized:
  283. raise ValueError(
  284. "Found static activation scheme for checkpoint that "
  285. "was not serialized fp8.")
  286. a13_scale = torch.nn.Parameter(torch.ones(num_experts,
  287. dtype=torch.float32),
  288. requires_grad=False)
  289. layer.register_parameter("a13_scale", a13_scale)
  290. set_weight_attrs(a13_scale, extra_weight_attrs)
  291. a2_scale = torch.nn.Parameter(torch.ones(num_experts,
  292. dtype=torch.float32),
  293. requires_grad=False)
  294. layer.register_parameter("a2_scale", a2_scale)
  295. set_weight_attrs(a2_scale, extra_weight_attrs)
  296. else:
  297. layer.a13_scale = None
  298. layer.a2_scale = None
  299. def process_weights_after_loading(self, layer: Module) -> None:
  300. if (not hasattr(layer, "process_after_load")
  301. or not layer.process_after_load):
  302. return
  303. # If checkpoint is fp16, quantize in place.
  304. if not self.quant_config.is_checkpoint_fp8_serialized:
  305. w13_weight = torch.empty_like(layer.w13_weight.data,
  306. dtype=torch.float8_e4m3fn)
  307. w2_weight = torch.empty_like(layer.w2_weight.data,
  308. dtype=torch.float8_e4m3fn)
  309. # Re-initialize w13_scale because we directly quantize
  310. # merged w13 weights and generate a single scaling factor.
  311. layer.w13_scale = torch.nn.Parameter(torch.ones(
  312. layer.num_experts,
  313. dtype=torch.float32,
  314. device=w13_weight.device),
  315. requires_grad=False)
  316. for expert in range(layer.num_experts):
  317. w13_weight[expert, :, :], layer.w13_scale[
  318. expert] = ops.scaled_fp8_quant(
  319. layer.w13_weight.data[expert, :, :])
  320. w2_weight[expert, :, :], layer.w2_scale[
  321. expert] = ops.scaled_fp8_quant(
  322. layer.w2_weight.data[expert, :, :])
  323. layer.w13_weight = torch.nn.Parameter(w13_weight,
  324. requires_grad=False)
  325. layer.w2_weight = torch.nn.Parameter(w2_weight,
  326. requires_grad=False)
  327. return
  328. # If checkpoint is fp8, we need to handle that the
  329. # MoE kernels require single activation scale and single weight
  330. # scale for w13 per expert.
  331. else:
  332. # Fp8 moe kernels require a single activation scale.
  333. # We take the max of all the scales in case they differ.
  334. if self.quant_config.activation_scheme == "static":
  335. if layer.a13_scale is None or layer.a2_scale is None:
  336. raise ValueError(
  337. "QuantConfig has static quantization, but found "
  338. "activation scales are None.")
  339. if (not all_close_1d(layer.a13_scale)
  340. or not all_close_1d(layer.a2_scale)):
  341. print_warning_once(
  342. "Found input_scales that are not equal for "
  343. "fp8 MoE layer. Using the maximum across experts "
  344. "for each layer. ")
  345. layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(),
  346. requires_grad=False)
  347. layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(),
  348. requires_grad=False)
  349. # Fp8 moe kernel needs single weight scale for w13 per expert.
  350. # We take the max then dequant and requant each expert.
  351. assert layer.w13_scale is not None
  352. shard_size = layer.intermediate_size_per_partition
  353. max_w13_scales = layer.w13_scale.max(dim=1).values
  354. for expert_id in range(layer.num_experts):
  355. start = 0
  356. for shard_id in range(2):
  357. dq_weight = per_tensor_dequantize(
  358. layer.w13_weight[expert_id][start:start +
  359. shard_size, :],
  360. layer.w13_scale[expert_id][shard_id])
  361. layer.w13_weight[expert_id][
  362. start:start + shard_size, :] = per_tensor_quantize(
  363. dq_weight, max_w13_scales[expert_id])
  364. start += shard_size
  365. layer.w13_scale = torch.nn.Parameter(max_w13_scales,
  366. requires_grad=False)
  367. return
  368. def apply(self,
  369. layer: torch.nn.Module,
  370. x: torch.Tensor,
  371. router_logits: torch.Tensor,
  372. top_k: int,
  373. renormalize: bool = True) -> torch.Tensor:
  374. return fused_moe(x,
  375. layer.w13_weight,
  376. layer.w2_weight,
  377. router_logits,
  378. top_k,
  379. renormalize=renormalize,
  380. inplace=True,
  381. use_fp8=True,
  382. w1_scale=layer.w13_scale,
  383. w2_scale=layer.w2_scale,
  384. a1_scale=layer.a13_scale,
  385. a2_scale=layer.a2_scale)
  386. class Fp8KVCacheMethod(QuantizeMethodBase):
  387. """Supports loading kv-cache scaling factors from FP8 checkpoints.
  388. """
  389. def __init__(self, quant_config: Fp8Config):
  390. self.quant_config = quant_config
  391. def create_weights(self, layer: torch.nn.Module):
  392. """Create "weight" (aka kv_scale) for an attention layer.
  393. Args:
  394. layer: The layer that is using the QuantizeMethodBase factory.
  395. """
  396. # Initialize the KV cache scale to 1.0 as the default value.
  397. # If the kv_scale appears in the checkpoint, it will be
  398. # overwritten when loading weights.
  399. layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False)
  400. def apply(self, layer: torch.nn.Module) -> torch.Tensor:
  401. raise RuntimeError("Fp8KVCacheMethod.apply should not be called.")
  402. def process_weights_after_loading(self, layer: Module) -> None:
  403. # If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0
  404. # regardless whether the kv-scale is available in the checkpoint.
  405. if layer.kv_cache_dtype != "auto":
  406. kv_scale = layer.kv_scale.to("cpu").tolist()
  407. if not isinstance(kv_scale, float):
  408. raise ValueError("Only support per-tensor scaling factor "
  409. "for fp8 KV cache")
  410. layer._kv_scale = kv_scale
  411. if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
  412. print_warning_once(
  413. "Using KV cache scaling factor 1.0 for fp8_e4m3. This may "
  414. "cause accuracy issues. Please make sure kv-cache scaling "
  415. "factor is available in the fp8 checkpoint.")
  416. del layer.kv_scale
  417. def per_tensor_quantize(tensor: torch.Tensor,
  418. inv_scale: Union[float, torch.Tensor]) -> torch.Tensor:
  419. finfo = torch.finfo(torch.float8_e4m3fn)
  420. qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
  421. return qweight.to(torch.float8_e4m3fn)
  422. def per_tensor_dequantize(
  423. tensor: torch.Tensor, inv_scale: Union[float,
  424. torch.Tensor]) -> torch.Tensor:
  425. fake_qweight = tensor.to(torch.float16)
  426. dq_weight = fake_qweight * inv_scale
  427. return dq_weight
  428. def all_close_1d(x: torch.Tensor) -> bool:
  429. assert len(x.shape) == 1
  430. return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))