1
0

w8a8_utils.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. from typing import List, Optional, Tuple, Union
  2. import torch
  3. from torch.nn import Parameter
  4. from aphrodite import _custom_ops as ops
  5. from aphrodite.common.utils import is_hip
  6. from aphrodite.modeling.utils import set_weight_attrs
  7. from aphrodite.platforms import current_platform
  8. # scaled_mm in pytorch on rocm has a bug that requires always
  9. # providing scaling factor for result. This value is created
  10. # as global value to avoid multiple tensor allocations, and
  11. # can be removed once pytorch fixes the bug.
  12. TORCH_SCALED_MM_SCALE_RESULT = torch.ones(1).cuda() if is_hip() else None
  13. def cutlass_fp8_supported() -> bool:
  14. # cutlass is not supported on Rocm
  15. if is_hip():
  16. return False
  17. capability = current_platform.get_device_capability()
  18. capability = capability[0] * 10 + capability[1]
  19. return ops.cutlass_scaled_mm_supports_fp8(capability)
  20. def per_tensor_dequantize(
  21. tensor: torch.Tensor, inv_scale: Union[float,
  22. torch.Tensor]) -> torch.Tensor:
  23. fake_qweight = tensor.to(torch.float16)
  24. dq_weight = fake_qweight * inv_scale
  25. return dq_weight
  26. def all_close_1d(x: torch.Tensor) -> bool:
  27. assert len(x.shape) == 1
  28. return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
  29. def create_per_tensor_scale_param(
  30. output_partition_sizes: List[int],
  31. **extra_weight_attrs,
  32. ) -> Parameter:
  33. scale = Parameter(torch.empty(len(output_partition_sizes),
  34. dtype=torch.float32),
  35. requires_grad=False)
  36. scale[:] = torch.finfo(torch.float32).min
  37. set_weight_attrs(scale, {
  38. "needs_scalar_to_array": True,
  39. **extra_weight_attrs
  40. })
  41. return scale
  42. def create_per_channel_scale_param(output_partition_sizes: List[int],
  43. **extra_weight_attrs) -> Parameter:
  44. scale = Parameter(torch.empty((sum(output_partition_sizes), 1),
  45. dtype=torch.float32),
  46. requires_grad=False)
  47. scale[:] = torch.finfo(torch.float32).min
  48. set_weight_attrs(scale, {"output_dim": 0, **extra_weight_attrs})
  49. return scale
  50. def convert_to_channelwise(
  51. weight_scale: torch.Tensor,
  52. logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
  53. # Create channelwise buffer
  54. weight_scale_channel = torch.empty((sum(logical_widths), 1),
  55. dtype=torch.float32,
  56. device=weight_scale.device)
  57. # Expand each scale to match the size of each logical matrix.
  58. start = 0
  59. for idx, logical_width in enumerate(logical_widths):
  60. end = start + logical_width
  61. weight_scale_channel[start:end, :] = weight_scale[idx]
  62. start = end
  63. return weight_scale_channel
  64. def requantize_with_max_scale(
  65. weight: torch.Tensor, weight_scale: torch.Tensor,
  66. logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
  67. # Max scale to be used for requanitzation.
  68. max_w_scale = weight_scale.max()
  69. # QKV / MLP is fused in the on disk checkpoint if any of the
  70. # weight scales are still set to the default since we initialize
  71. # N weight scales for N shards but we only load 1 weight scale
  72. # from disk in this case. Skip requantization in this case (since)
  73. # we already are quantized with the single scale.
  74. # * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
  75. unfused_module_in_checkpoint = (weight_scale[-1] > torch.finfo(
  76. torch.float8_e4m3fn).min)
  77. # If unfused checkpoint, need requanize with the single scale.
  78. if unfused_module_in_checkpoint:
  79. start = 0
  80. for idx, logical_width in enumerate(logical_widths):
  81. end = start + logical_width
  82. weight_dq = per_tensor_dequantize(weight[start:end, :],
  83. weight_scale[idx])
  84. weight[start:end, :], _ = ops.scaled_fp8_quant(
  85. weight_dq, max_w_scale)
  86. start = end
  87. return max_w_scale, weight
  88. def apply_fp8_linear(
  89. input: torch.Tensor,
  90. weight: torch.Tensor,
  91. weight_scale: torch.Tensor,
  92. input_scale: Optional[torch.Tensor] = None,
  93. input_scale_ub: Optional[torch.Tensor] = None,
  94. bias: Optional[torch.Tensor] = None,
  95. cutlass_fp8_supported: bool = True,
  96. use_per_token_if_dynamic: bool = False,
  97. ) -> torch.Tensor:
  98. # ops.scaled_fp8_quant supports both dynamic and static quant.
  99. # If dynamic, layer.input_scale is None and x_scale computed from x.
  100. # If static, layer.input_scale is scalar and x_scale is input_scale.
  101. # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
  102. if cutlass_fp8_supported:
  103. qinput, x_scale = ops.scaled_fp8_quant(
  104. input,
  105. input_scale,
  106. scale_ub=input_scale_ub,
  107. use_per_token_if_dynamic=use_per_token_if_dynamic)
  108. # Fused GEMM_DQ
  109. return ops.cutlass_scaled_mm(qinput,
  110. weight,
  111. out_dtype=input.dtype,
  112. scale_a=x_scale,
  113. scale_b=weight_scale,
  114. bias=bias)
  115. # torch.scaled_mm supports per tensor weights + activations only
  116. # so fallback to naive if per channel or per token
  117. else:
  118. # Note: we pad the input because torch._scaled_mm is more performant
  119. # for matrices with batch dimension > 16.
  120. # This could change in the future.
  121. qinput, x_scale = ops.scaled_fp8_quant(
  122. input,
  123. input_scale,
  124. num_token_padding=17,
  125. use_per_token_if_dynamic=use_per_token_if_dynamic)
  126. per_tensor_weights = (weight_scale.numel() == 1)
  127. per_tensor_activations = (x_scale.numel() == 1)
  128. if per_tensor_weights and per_tensor_activations:
  129. # Fused GEMM_DQ
  130. output = torch._scaled_mm(
  131. qinput,
  132. weight,
  133. out_dtype=input.dtype,
  134. scale_a=x_scale,
  135. scale_b=weight_scale,
  136. scale_result=TORCH_SCALED_MM_SCALE_RESULT,
  137. bias=bias)
  138. # Since in torch 2.5, scaled_mm only returns single value
  139. # This should be removed when vllm-nvidia also moves to 2.5
  140. if is_hip():
  141. return torch.narrow(output, 0, 0, input.shape[0])
  142. return torch.narrow(output[0], 0, 0, input.shape[0])
  143. else:
  144. # Fallback for channelwise case, where we use unfused DQ
  145. # due to limitations with scaled_mm
  146. # Symmetric quantized GEMM by definition computes the following:
  147. # C = (s_x * X) (s_w * W) + bias
  148. # This is equivalent to dequantizing the weights and activations
  149. # before applying a GEMM.
  150. #
  151. # In order to compute quantized operands, a quantized kernel
  152. # will rewrite the above like so:
  153. # C = s_w * s_x * (X * W) + bias
  154. #
  155. # For the scaled_mm fallback case, we break this down, since it
  156. # does not support s_w being a vector.
  157. # GEMM
  158. # This computes C = (X * W).
  159. # Output in fp32 to allow subsequent ops to happen in-place
  160. output, _ = torch._scaled_mm(qinput,
  161. weight,
  162. out_dtype=torch.float32)
  163. # Unpad (undo num_token_padding)
  164. output = torch.narrow(output, 0, 0, input.shape[0])
  165. x_scale = torch.narrow(x_scale, 0, 0, input.shape[0])
  166. # DQ
  167. # C = sw * sx * (X * W) + bias
  168. output = output * x_scale * weight_scale.t()
  169. if bias is not None:
  170. output = output + bias
  171. return output.to(dtype=input.dtype)
  172. def apply_int8_linear(
  173. input: torch.Tensor,
  174. weight: torch.Tensor,
  175. weight_scale: torch.Tensor,
  176. input_scale: Optional[torch.Tensor] = None,
  177. bias: Optional[torch.Tensor] = None,
  178. ):
  179. # ops.scaled_int8_quant supports both dynamic and static quant.
  180. # * dynamic, layer.input_scale is None and x_scale computed from x.
  181. # * static, layer.input_scale is scalar and x_scale is input_scale.
  182. x_q, x_scale = ops.scaled_int8_quant(input, input_scale)
  183. return ops.cutlass_scaled_mm(x_q,
  184. weight,
  185. scale_a=x_scale,
  186. scale_b=weight_scale,
  187. out_dtype=input.dtype,
  188. bias=bias)
  189. def normalize_e4m3fn_to_e4m3fnuz(
  190. weight: torch.Tensor,
  191. weight_scale: torch.Tensor,
  192. input_scale: Optional[torch.Tensor] = None
  193. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  194. assert weight.dtype == torch.float8_e4m3fn
  195. # The bits pattern 10000000(-128) represents zero in e4m3fn
  196. # but NaN in e4m3fnuz. So here we set it to 0.
  197. # https://onnx.ai/onnx/technical/float8.html
  198. weight_as_int8 = weight.view(torch.int8)
  199. ROCM_FP8_NAN_AS_INT = -128
  200. weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
  201. weight = weight_as_int8.view(torch.float8_e4m3fnuz)
  202. # For the same bits representation, e4m3fnuz value is half of
  203. # the e4m3fn value, so we should double the scaling factor to
  204. # get the same dequantized value.
  205. # https://onnx.ai/onnx/technical/float8.html
  206. weight_scale = weight_scale * 2.0
  207. if input_scale is not None:
  208. input_scale = input_scale * 2.0
  209. return weight, weight_scale, input_scale