fp8.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from loguru import logger
  4. from torch.nn import Module
  5. from torch.nn.parameter import Parameter
  6. import aphrodite.common.envs as envs
  7. from aphrodite import _custom_ops as ops
  8. from aphrodite.common.utils import is_hip, print_warning_once
  9. from aphrodite.modeling.layers.fused_moe import FusedMoE, FusedMoEMethodBase
  10. from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
  11. UnquantizedLinearMethod)
  12. from aphrodite.modeling.parameter import (ModelWeightParameter,
  13. PerTensorScaleParameter)
  14. from aphrodite.modeling.utils import set_weight_attrs
  15. from aphrodite.platforms import current_platform
  16. from aphrodite.quantization.base_config import (QuantizationConfig,
  17. QuantizeMethodBase)
  18. from aphrodite.quantization.kv_cache import BaseKVCacheMethod
  19. from aphrodite.quantization.utils.marlin_utils_fp8 import (
  20. apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
  21. from aphrodite.quantization.utils.quant_utils import is_layer_skipped
  22. from aphrodite.quantization.utils.w8a8_utils import (
  23. all_close_1d, apply_fp8_linear, convert_to_channelwise,
  24. cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
  25. requantize_with_max_scale)
  26. ACTIVATION_SCHEMES = ["static", "dynamic"]
  27. APHRODITE_TEST_FORCE_FP8_MARLIN = envs.APHRODITE_TEST_FORCE_FP8_MARLIN
  28. class Fp8Config(QuantizationConfig):
  29. """Config class for FP8."""
  30. def __init__(
  31. self,
  32. is_checkpoint_fp8_serialized: bool = False,
  33. activation_scheme: str = "dynamic",
  34. ignored_layers: Optional[List[str]] = None,
  35. ) -> None:
  36. self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
  37. if is_checkpoint_fp8_serialized:
  38. logger.warning("Detected fp8 checkpoint. Please note that the "
  39. "format is experimental and subject to change.")
  40. if activation_scheme not in ACTIVATION_SCHEMES:
  41. raise ValueError(
  42. f"Unsupported activation scheme {activation_scheme}")
  43. self.activation_scheme = activation_scheme
  44. self.ignored_layers = ignored_layers or []
  45. @classmethod
  46. def get_name(cls) -> str:
  47. return "fp8"
  48. @classmethod
  49. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  50. return [torch.bfloat16, torch.half]
  51. @classmethod
  52. def get_min_capability(cls) -> int:
  53. return 80
  54. @classmethod
  55. def get_config_filenames(cls) -> List[str]:
  56. return []
  57. @classmethod
  58. def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
  59. quant_method = cls.get_from_keys(config, ["quant_method"])
  60. is_checkpoint_fp8_serialized = ("fp8" in quant_method)
  61. activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
  62. ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
  63. return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
  64. activation_scheme=activation_scheme,
  65. ignored_layers=ignored_layers)
  66. def get_quant_method(self, layer: torch.nn.Module,
  67. prefix: str) -> Optional["QuantizeMethodBase"]:
  68. from aphrodite.attention.layer import (
  69. Attention) # Avoid circular import
  70. if isinstance(layer, LinearBase):
  71. if is_layer_skipped(prefix, self.ignored_layers):
  72. return UnquantizedLinearMethod()
  73. return Fp8LinearMethod(self)
  74. elif isinstance(layer, FusedMoE):
  75. return Fp8MoEMethod(self)
  76. elif isinstance(layer, Attention):
  77. return Fp8KVCacheMethod(self)
  78. return None
  79. def get_scaled_act_names(self) -> List[str]:
  80. return []
  81. class Fp8LinearMethod(LinearMethodBase):
  82. """Linear method for FP8.
  83. Supports loading FP8 checkpoints with static weight scale and
  84. dynamic/static activation scale.
  85. Also supports loading quantized FP16/BF16 model checkpoints with dynamic
  86. activation scaling. The weight scaling factor will be initialized after
  87. the model weights are loaded.
  88. Limitations:
  89. 1. Only support per-tensor quantization due to torch._scaled_mm support.
  90. 2. Only support float8_e4m3fn data type due to the limitation of
  91. torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
  92. Args:
  93. quant_config: The quantization config.
  94. """
  95. def __init__(self, quant_config: Fp8Config):
  96. self.quant_config = quant_config
  97. self.cutlass_fp8_supported = cutlass_fp8_supported()
  98. # For GPUs that lack FP8 hardware support, we can leverage the Marlin
  99. # kernel for fast weight-only FP8 quantization
  100. capability = current_platform.get_device_capability()
  101. capability = capability[0] * 10 + capability[1]
  102. self.use_marlin = capability < 89 or APHRODITE_TEST_FORCE_FP8_MARLIN
  103. # Disable marlin for rocm
  104. if is_hip():
  105. self.use_marlin = False
  106. def create_weights(
  107. self,
  108. layer: torch.nn.Module,
  109. input_size_per_partition: int,
  110. output_partition_sizes: List[int],
  111. input_size: int,
  112. output_size: int,
  113. params_dtype: torch.dtype,
  114. **extra_weight_attrs,
  115. ):
  116. del input_size, output_size
  117. output_size_per_partition = sum(output_partition_sizes)
  118. weight_loader = extra_weight_attrs.get("weight_loader")
  119. layer.logical_widths = output_partition_sizes
  120. layer.input_size_per_partition = input_size_per_partition
  121. layer.output_size_per_partition = output_size_per_partition
  122. layer.orig_dtype = params_dtype
  123. # WEIGHT
  124. weight_dtype = (torch.float8_e4m3fn
  125. if self.quant_config.is_checkpoint_fp8_serialized else
  126. params_dtype)
  127. weight = ModelWeightParameter(data=torch.empty(
  128. output_size_per_partition,
  129. input_size_per_partition,
  130. dtype=weight_dtype),
  131. input_dim=1,
  132. output_dim=0,
  133. weight_loader=weight_loader)
  134. layer.register_parameter("weight", weight)
  135. # If checkpoint is serialized fp8, load them.
  136. # Otherwise, wait until process_weights_after_loading.
  137. if self.quant_config.is_checkpoint_fp8_serialized:
  138. # WEIGHT SCALE
  139. scale = PerTensorScaleParameter(data=torch.empty(
  140. len(output_partition_sizes), dtype=torch.float32),
  141. weight_loader=weight_loader)
  142. scale[:] = torch.finfo(torch.float32).min
  143. layer.register_parameter("weight_scale", scale)
  144. # INPUT ACTIVATION SCALE
  145. if self.quant_config.activation_scheme == "static":
  146. scale = PerTensorScaleParameter(data=torch.empty(
  147. len(output_partition_sizes), dtype=torch.float32),
  148. weight_loader=weight_loader)
  149. scale[:] = torch.finfo(torch.float32).min
  150. layer.register_parameter("input_scale", scale)
  151. else:
  152. layer.register_parameter("input_scale", None)
  153. def process_weights_after_loading(self, layer: Module) -> None:
  154. layer.weight = torch.nn.Parameter(layer.weight.data,
  155. requires_grad=False)
  156. # If checkpoint not serialized fp8, quantize the weights.
  157. if not self.quant_config.is_checkpoint_fp8_serialized:
  158. qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
  159. scale=None)
  160. # If using marlin (w8a16), kernel uses channelwise weights,
  161. # so extend the weight scales to be channelwise.
  162. if self.use_marlin:
  163. assert weight_scale.numel() == 1
  164. weight_scale = convert_to_channelwise(
  165. weight_scale.expand(len(layer.logical_widths)),
  166. layer.logical_widths)
  167. # Update the layer with the new values.
  168. layer.weight = Parameter(qweight.t(), requires_grad=False)
  169. layer.weight_scale = Parameter(weight_scale, requires_grad=False)
  170. layer.input_scale = None
  171. # If checkpoint is fp8, handle that there are N scales for N
  172. # shards in a fused module
  173. else:
  174. layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
  175. requires_grad=False)
  176. if self.quant_config.activation_scheme == "static":
  177. layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
  178. requires_grad=False)
  179. # If using marlin (w8a16), kernel uses channelwise weights,
  180. # so extend the weight scales to be channelwise.
  181. if self.use_marlin:
  182. weight = layer.weight
  183. weight_scale = convert_to_channelwise(layer.weight_scale,
  184. layer.logical_widths)
  185. # If using w8a8, torch._scaled_mm needs per tensor, so
  186. # requantize the logical shards as a single weight.
  187. else:
  188. # Dequant -> Quant with max scale so we can run per tensor.
  189. weight = layer.weight
  190. weight_scale = layer.weight_scale
  191. # If rocm, use float8_e4m3fnuz.
  192. if is_hip():
  193. weight, weight_scale, input_scale = \
  194. normalize_e4m3fn_to_e4m3fnuz(
  195. weight=weight,
  196. weight_scale=weight_scale,
  197. input_scale=layer.input_scale)
  198. if input_scale is not None:
  199. layer.input_scale = Parameter(input_scale,
  200. requires_grad=False)
  201. weight_scale, weight = requantize_with_max_scale(
  202. weight=weight,
  203. weight_scale=weight_scale,
  204. logical_widths=layer.logical_widths,
  205. )
  206. # Update layer with new values.
  207. layer.weight = Parameter(weight.t(), requires_grad=False)
  208. layer.weight_scale = Parameter(weight_scale, requires_grad=False)
  209. if self.quant_config.activation_scheme == "static":
  210. layer.input_scale = Parameter(layer.input_scale.max(),
  211. requires_grad=False)
  212. if self.use_marlin:
  213. prepare_fp8_layer_for_marlin(layer)
  214. # Activations not quantized for marlin.
  215. del layer.input_scale
  216. def apply(self,
  217. layer: torch.nn.Module,
  218. x: torch.Tensor,
  219. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  220. if self.use_marlin:
  221. return apply_fp8_marlin_linear(
  222. input=x,
  223. weight=layer.weight,
  224. weight_scale=layer.weight_scale,
  225. workspace=layer.workspace,
  226. size_n=layer.output_size_per_partition,
  227. size_k=layer.input_size_per_partition,
  228. bias=bias)
  229. return apply_fp8_linear(
  230. input=x,
  231. weight=layer.weight,
  232. weight_scale=layer.weight_scale,
  233. input_scale=layer.input_scale,
  234. bias=bias,
  235. cutlass_fp8_supported=self.cutlass_fp8_supported,
  236. use_per_token_if_dynamic=False)
  237. class Fp8MoEMethod(FusedMoEMethodBase):
  238. """MoE method for FP8.
  239. Supports loading FP8 checkpoints with static weight scale and
  240. dynamic/static activation scale.
  241. Also supports loading quantized FP16/BF16 model checkpoints with dynamic
  242. activation scaling. The weight scaling factor will be initialized after
  243. the model weights are loaded.
  244. Args:
  245. quant_config: The quantization config.
  246. """
  247. def __init__(self, quant_config: Fp8Config):
  248. self.quant_config = quant_config
  249. def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
  250. intermediate_size: int, params_dtype: torch.dtype,
  251. **extra_weight_attrs):
  252. if self.quant_config.is_checkpoint_fp8_serialized:
  253. params_dtype = torch.float8_e4m3fn
  254. # WEIGHTS
  255. w13_weight = torch.nn.Parameter(torch.empty(num_experts,
  256. 2 * intermediate_size,
  257. hidden_size,
  258. dtype=params_dtype),
  259. requires_grad=False)
  260. layer.register_parameter("w13_weight", w13_weight)
  261. set_weight_attrs(w13_weight, extra_weight_attrs)
  262. w2_weight = torch.nn.Parameter(torch.empty(num_experts,
  263. hidden_size,
  264. intermediate_size,
  265. dtype=params_dtype),
  266. requires_grad=False)
  267. layer.register_parameter("w2_weight", w2_weight)
  268. set_weight_attrs(w2_weight, extra_weight_attrs)
  269. # WEIGHT_SCALES
  270. # Allocate 2 scales for w1 and w3 respectively.
  271. # They will be combined to a single scale after weight loading.
  272. w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
  273. 2,
  274. dtype=torch.float32),
  275. requires_grad=False)
  276. layer.register_parameter("w13_weight_scale", w13_weight_scale)
  277. w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
  278. dtype=torch.float32),
  279. requires_grad=False)
  280. layer.register_parameter("w2_weight_scale", w2_weight_scale)
  281. # If loading fp8 checkpoint, pass the weight loaders.
  282. # If loading an fp16 checkpoint, do not (we will quantize in
  283. # process_weights_after_loading()
  284. if self.quant_config.is_checkpoint_fp8_serialized:
  285. set_weight_attrs(w13_weight_scale, {
  286. "is_fp8_scale": True,
  287. **extra_weight_attrs
  288. })
  289. set_weight_attrs(w2_weight_scale, {
  290. "is_fp8_scale": True,
  291. **extra_weight_attrs
  292. })
  293. # INPUT_SCALES
  294. if self.quant_config.activation_scheme == "static":
  295. if not self.quant_config.is_checkpoint_fp8_serialized:
  296. raise ValueError(
  297. "Found static activation scheme for checkpoint that "
  298. "was not serialized fp8.")
  299. w13_input_scale = torch.nn.Parameter(torch.ones(
  300. num_experts, dtype=torch.float32),
  301. requires_grad=False)
  302. layer.register_parameter("w13_input_scale", w13_input_scale)
  303. set_weight_attrs(w13_input_scale, {
  304. "is_fp8_scale": True,
  305. **extra_weight_attrs
  306. })
  307. w2_input_scale = torch.nn.Parameter(torch.ones(
  308. num_experts, dtype=torch.float32),
  309. requires_grad=False)
  310. layer.register_parameter("w2_input_scale", w2_input_scale)
  311. set_weight_attrs(w2_input_scale, {
  312. "is_fp8_scale": True,
  313. **extra_weight_attrs
  314. })
  315. else:
  316. layer.w13_input_scale = None
  317. layer.w2_input_scale = None
  318. def process_weights_after_loading(self, layer: Module) -> None:
  319. # If checkpoint is fp16, quantize in place.
  320. if not self.quant_config.is_checkpoint_fp8_serialized:
  321. # If rocm, use float8_e4m3fnuz as dtype
  322. fp8_dtype = torch.float8_e4m3fnuz \
  323. if is_hip() else torch.float8_e4m3fn
  324. w13_weight = torch.empty_like(layer.w13_weight.data,
  325. dtype=fp8_dtype)
  326. w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
  327. # Re-initialize w13_scale because we directly quantize
  328. # merged w13 weights and generate a single scaling factor.
  329. layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
  330. layer.num_experts,
  331. dtype=torch.float32,
  332. device=w13_weight.device),
  333. requires_grad=False)
  334. for expert in range(layer.num_experts):
  335. w13_weight[expert, :, :], layer.w13_weight_scale[
  336. expert] = ops.scaled_fp8_quant(
  337. layer.w13_weight.data[expert, :, :])
  338. w2_weight[expert, :, :], layer.w2_weight_scale[
  339. expert] = ops.scaled_fp8_quant(
  340. layer.w2_weight.data[expert, :, :])
  341. layer.w13_weight = torch.nn.Parameter(w13_weight,
  342. requires_grad=False)
  343. layer.w2_weight = torch.nn.Parameter(w2_weight,
  344. requires_grad=False)
  345. return
  346. # If checkpoint is fp8, we need to handle that the
  347. # MoE kernels require single activation scale and single weight
  348. # scale for w13 per expert.
  349. else:
  350. # Fp8 moe kernels require a single activation scale.
  351. # We take the max of all the scales in case they differ.
  352. if self.quant_config.activation_scheme == "static":
  353. if (layer.w13_input_scale is None
  354. or layer.w2_input_scale is None):
  355. raise ValueError(
  356. "QuantConfig has static quantization, but found "
  357. "activation scales are None.")
  358. if (not all_close_1d(layer.w13_input_scale)
  359. or not all_close_1d(layer.w2_input_scale)):
  360. print_warning_once(
  361. "Found input_scales that are not equal for "
  362. "fp8 MoE layer. Using the maximum across experts "
  363. "for each layer. ")
  364. layer.w13_input_scale = torch.nn.Parameter(
  365. layer.w13_input_scale.max(), requires_grad=False)
  366. layer.w2_input_scale = torch.nn.Parameter(
  367. layer.w2_input_scale.max(), requires_grad=False)
  368. # If rocm, normalize the weights and scales to e4m3fnuz
  369. if is_hip():
  370. # Normalize the weights and scales
  371. w13_weight, w13_weight_scale, w13_input_scale = \
  372. normalize_e4m3fn_to_e4m3fnuz(
  373. layer.w13_weight, layer.w13_weight_scale,
  374. layer.w13_input_scale)
  375. w2_weight, w2_weight_scale, w2_input_scale = \
  376. normalize_e4m3fn_to_e4m3fnuz(
  377. layer.w2_weight, layer.w2_weight_scale,
  378. layer.w2_input_scale)
  379. # Reset the parameter
  380. layer.w13_weight = torch.nn.Parameter(w13_weight,
  381. requires_grad=False)
  382. layer.w13_weight_scale = torch.nn.Parameter(
  383. w13_weight_scale, requires_grad=False)
  384. if w13_input_scale is not None:
  385. layer.w13_input_scale = torch.nn.Parameter(
  386. w13_input_scale, requires_grad=False)
  387. layer.w2_weight = torch.nn.Parameter(w2_weight,
  388. requires_grad=False)
  389. layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
  390. requires_grad=False)
  391. if w2_input_scale is not None:
  392. layer.w2_input_scale = torch.nn.Parameter(
  393. w2_input_scale, requires_grad=False)
  394. # Fp8 moe kernel needs single weight scale for w13 per expert.
  395. # We take the max then dequant and requant each expert.
  396. assert layer.w13_weight_scale is not None
  397. shard_size = layer.intermediate_size_per_partition
  398. max_w13_scales = layer.w13_weight_scale.max(dim=1).values
  399. for expert_id in range(layer.num_experts):
  400. start = 0
  401. for shard_id in range(2):
  402. dq_weight = per_tensor_dequantize(
  403. layer.w13_weight[expert_id][start:start +
  404. shard_size, :],
  405. layer.w13_weight_scale[expert_id][shard_id])
  406. layer.w13_weight[expert_id][
  407. start:start + shard_size, :], _ = ops.scaled_fp8_quant(
  408. dq_weight, max_w13_scales[expert_id])
  409. start += shard_size
  410. layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
  411. requires_grad=False)
  412. return
  413. def apply(self,
  414. layer: torch.nn.Module,
  415. x: torch.Tensor,
  416. router_logits: torch.Tensor,
  417. top_k: int,
  418. renormalize: bool,
  419. use_grouped_topk: bool,
  420. topk_group: Optional[int] = None,
  421. num_expert_group: Optional[int] = None) -> torch.Tensor:
  422. from aphrodite.modeling.layers.fused_moe import fused_experts
  423. topk_weights, topk_ids = FusedMoE.select_experts(
  424. hidden_states=x,
  425. router_logits=router_logits,
  426. use_grouped_topk=use_grouped_topk,
  427. top_k=top_k,
  428. renormalize=renormalize,
  429. topk_group=topk_group,
  430. num_expert_group=num_expert_group)
  431. return fused_experts(x,
  432. layer.w13_weight,
  433. layer.w2_weight,
  434. topk_weights=topk_weights,
  435. topk_ids=topk_ids,
  436. inplace=True,
  437. use_fp8_w8a8=True,
  438. w1_scale=layer.w13_weight_scale,
  439. w2_scale=layer.w2_weight_scale,
  440. a1_scale=layer.w13_input_scale,
  441. a2_scale=layer.w2_input_scale)
  442. class Fp8KVCacheMethod(BaseKVCacheMethod):
  443. """
  444. Supports loading kv-cache scaling factors from FP8 checkpoints.
  445. """
  446. def __init__(self, quant_config: Fp8Config):
  447. super().__init__(quant_config)