w8a8_utils.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. from typing import List, Optional, Tuple, Union
  2. import torch
  3. from aphrodite import _custom_ops as ops
  4. from aphrodite.common.utils import is_hip
  5. from aphrodite.platforms import current_platform
  6. # Input scaling factors are no longer optional in _scaled_mm starting
  7. # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
  8. TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() if is_hip() else None
  9. def cutlass_fp8_supported() -> bool:
  10. # cutlass is not supported on Rocm
  11. if is_hip():
  12. return False
  13. capability = current_platform.get_device_capability()
  14. capability = capability[0] * 10 + capability[1]
  15. return ops.cutlass_scaled_mm_supports_fp8(capability)
  16. def per_tensor_dequantize(
  17. tensor: torch.Tensor, inv_scale: Union[float,
  18. torch.Tensor]) -> torch.Tensor:
  19. fake_qweight = tensor.to(torch.float16)
  20. dq_weight = fake_qweight * inv_scale
  21. return dq_weight
  22. def all_close_1d(x: torch.Tensor) -> bool:
  23. assert len(x.shape) == 1
  24. return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
  25. def convert_to_channelwise(
  26. weight_scale: torch.Tensor,
  27. logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
  28. # Create channelwise buffer
  29. weight_scale_channel = torch.empty((sum(logical_widths), 1),
  30. dtype=torch.float32,
  31. device=weight_scale.device)
  32. # Expand each scale to match the size of each logical matrix.
  33. start = 0
  34. for idx, logical_width in enumerate(logical_widths):
  35. end = start + logical_width
  36. weight_scale_channel[start:end, :] = weight_scale[idx]
  37. start = end
  38. return weight_scale_channel
  39. def requantize_with_max_scale(
  40. weight: torch.Tensor, weight_scale: torch.Tensor,
  41. logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
  42. # Max scale to be used for requanitzation.
  43. max_w_scale = weight_scale.max()
  44. # QKV / MLP is fused in the on disk checkpoint if any of the
  45. # weight scales are still set to the default since we initialize
  46. # N weight scales for N shards but we only load 1 weight scale
  47. # from disk in this case. Skip requantization in this case (since)
  48. # we already are quantized with the single scale.
  49. # * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
  50. unfused_module_in_checkpoint = (weight_scale[-1] > torch.finfo(
  51. torch.float8_e4m3fn).min)
  52. # If unfused checkpoint, need requanize with the single scale.
  53. if unfused_module_in_checkpoint:
  54. start = 0
  55. for idx, logical_width in enumerate(logical_widths):
  56. end = start + logical_width
  57. weight_dq = per_tensor_dequantize(weight[start:end, :],
  58. weight_scale[idx])
  59. weight[start:end, :], _ = ops.scaled_fp8_quant(
  60. weight_dq, max_w_scale)
  61. start = end
  62. return max_w_scale, weight
  63. def apply_fp8_linear(
  64. input: torch.Tensor,
  65. weight: torch.Tensor,
  66. weight_scale: torch.Tensor,
  67. input_scale: Optional[torch.Tensor] = None,
  68. input_scale_ub: Optional[torch.Tensor] = None,
  69. bias: Optional[torch.Tensor] = None,
  70. cutlass_fp8_supported: bool = True,
  71. use_per_token_if_dynamic: bool = False,
  72. ) -> torch.Tensor:
  73. # ops.scaled_fp8_quant supports both dynamic and static quant.
  74. # If dynamic, layer.input_scale is None and x_scale computed from x.
  75. # If static, layer.input_scale is scalar and x_scale is input_scale.
  76. # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
  77. if cutlass_fp8_supported:
  78. qinput, x_scale = ops.scaled_fp8_quant(
  79. input,
  80. input_scale,
  81. scale_ub=input_scale_ub,
  82. use_per_token_if_dynamic=use_per_token_if_dynamic)
  83. # Fused GEMM_DQ
  84. return ops.cutlass_scaled_mm(qinput,
  85. weight,
  86. out_dtype=input.dtype,
  87. scale_a=x_scale,
  88. scale_b=weight_scale,
  89. bias=bias)
  90. # torch.scaled_mm supports per tensor weights + activations only
  91. # so fallback to naive if per channel or per token
  92. else:
  93. # Note: we pad the input because torch._scaled_mm is more performant
  94. # for matrices with batch dimension > 16.
  95. # This could change in the future.
  96. qinput, x_scale = ops.scaled_fp8_quant(
  97. input,
  98. input_scale,
  99. num_token_padding=17,
  100. use_per_token_if_dynamic=use_per_token_if_dynamic)
  101. per_tensor_weights = (weight_scale.numel() == 1)
  102. per_tensor_activations = (x_scale.numel() == 1)
  103. if per_tensor_weights and per_tensor_activations:
  104. # Fused GEMM_DQ
  105. output = torch._scaled_mm(qinput,
  106. weight,
  107. out_dtype=input.dtype,
  108. scale_a=x_scale,
  109. scale_b=weight_scale,
  110. bias=bias)
  111. # A fix for discrepancy in scaled_mm which returns tuple
  112. # for torch < 2.5 and a single value in torch >= 2.5
  113. if type(output) is tuple and len(output) == 2:
  114. return torch.narrow(output[0], 0, 0, input.shape[0])
  115. return torch.narrow(output, 0, 0, input.shape[0])
  116. else:
  117. # Fallback for channelwise case, where we use unfused DQ
  118. # due to limitations with scaled_mm
  119. # Symmetric quantized GEMM by definition computes the following:
  120. # C = (s_x * X) (s_w * W) + bias
  121. # This is equivalent to dequantizing the weights and activations
  122. # before applying a GEMM.
  123. #
  124. # In order to compute quantized operands, a quantized kernel
  125. # will rewrite the above like so:
  126. # C = s_w * s_x * (X * W) + bias
  127. #
  128. # For the scaled_mm fallback case, we break this down, since it
  129. # does not support s_w being a vector.
  130. # Making sure the dummy tensor is on the same device as the weight
  131. global TORCH_DEVICE_IDENTITY
  132. if TORCH_DEVICE_IDENTITY.device != weight.device:
  133. TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
  134. # GEMM
  135. # This computes C = (X * W).
  136. # Output in fp32 to allow subsequent ops to happen in-place
  137. output = torch._scaled_mm(qinput,
  138. weight,
  139. scale_a=TORCH_DEVICE_IDENTITY,
  140. scale_b=TORCH_DEVICE_IDENTITY,
  141. out_dtype=torch.float32)
  142. # A fix for discrepancy in scaled_mm which returns tuple
  143. # for torch < 2.5 and a single value in torch >= 2.5
  144. if type(output) is tuple and len(output) == 2:
  145. output = output[0]
  146. # Unpad (undo num_token_padding)
  147. output = torch.narrow(output, 0, 0, input.shape[0])
  148. x_scale = torch.narrow(x_scale, 0, 0, input.shape[0])
  149. # DQ
  150. # C = sw * sx * (X * W) + bias
  151. output = output * x_scale * weight_scale.t()
  152. if bias is not None:
  153. output = output + bias
  154. return output.to(dtype=input.dtype)
  155. def apply_int8_linear(
  156. input: torch.Tensor,
  157. weight: torch.Tensor,
  158. weight_scale: torch.Tensor,
  159. input_scale: Optional[torch.Tensor] = None,
  160. bias: Optional[torch.Tensor] = None,
  161. ):
  162. # ops.scaled_int8_quant supports both dynamic and static quant.
  163. # * dynamic, layer.input_scale is None and x_scale computed from x.
  164. # * static, layer.input_scale is scalar and x_scale is input_scale.
  165. x_q, x_scale, _ = ops.scaled_int8_quant(input, input_scale)
  166. return ops.cutlass_scaled_mm(x_q,
  167. weight,
  168. scale_a=x_scale,
  169. scale_b=weight_scale,
  170. out_dtype=input.dtype,
  171. bias=bias)
  172. def normalize_e4m3fn_to_e4m3fnuz(
  173. weight: torch.Tensor,
  174. weight_scale: torch.Tensor,
  175. input_scale: Optional[torch.Tensor] = None
  176. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  177. assert weight.dtype == torch.float8_e4m3fn
  178. # The bits pattern 10000000(-128) represents zero in e4m3fn
  179. # but NaN in e4m3fnuz. So here we set it to 0.
  180. # https://onnx.ai/onnx/technical/float8.html
  181. weight_as_int8 = weight.view(torch.int8)
  182. ROCM_FP8_NAN_AS_INT = -128
  183. weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
  184. weight = weight_as_int8.view(torch.float8_e4m3fnuz)
  185. # For the same bits representation, e4m3fnuz value is half of
  186. # the e4m3fn value, so we should double the scaling factor to
  187. # get the same dequantized value.
  188. # https://onnx.ai/onnx/technical/float8.html
  189. weight_scale = weight_scale * 2.0
  190. if input_scale is not None:
  191. input_scale = input_scale * 2.0
  192. return weight, weight_scale, input_scale