1
0

w8a8_utils.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. from typing import List, Optional, Tuple, Union
  2. import torch
  3. from torch.nn import Parameter
  4. from aphrodite import _custom_ops as ops
  5. from aphrodite.modeling.utils import set_weight_attrs
  6. from aphrodite.platforms import current_platform
  7. def cutlass_fp8_supported() -> bool:
  8. capability = current_platform.get_device_capability()
  9. capability = capability[0] * 10 + capability[1]
  10. return ops.cutlass_scaled_mm_supports_fp8(capability)
  11. def per_tensor_dequantize(
  12. tensor: torch.Tensor, inv_scale: Union[float,
  13. torch.Tensor]) -> torch.Tensor:
  14. fake_qweight = tensor.to(torch.float16)
  15. dq_weight = fake_qweight * inv_scale
  16. return dq_weight
  17. def all_close_1d(x: torch.Tensor) -> bool:
  18. assert len(x.shape) == 1
  19. return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
  20. def create_per_tensor_scale_param(
  21. output_partition_sizes: List[int],
  22. **extra_weight_attrs,
  23. ) -> Parameter:
  24. scale = Parameter(torch.empty(len(output_partition_sizes),
  25. dtype=torch.float32),
  26. requires_grad=False)
  27. scale[:] = torch.finfo(torch.float32).min
  28. set_weight_attrs(scale, {
  29. "needs_scalar_to_array": True,
  30. **extra_weight_attrs
  31. })
  32. return scale
  33. def create_per_channel_scale_param(output_partition_sizes: List[int],
  34. **extra_weight_attrs) -> Parameter:
  35. scale = Parameter(torch.empty((sum(output_partition_sizes), 1),
  36. dtype=torch.float32),
  37. requires_grad=False)
  38. scale[:] = torch.finfo(torch.float32).min
  39. set_weight_attrs(scale, {"output_dim": 0, **extra_weight_attrs})
  40. return scale
  41. def convert_to_channelwise(
  42. weight_scale: torch.Tensor,
  43. logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
  44. # Create channelwise buffer
  45. weight_scale_channel = torch.empty((sum(logical_widths), 1),
  46. dtype=torch.float32,
  47. device=weight_scale.device)
  48. # Expand each scale to match the size of each logical matrix.
  49. start = 0
  50. for idx, logical_width in enumerate(logical_widths):
  51. end = start + logical_width
  52. weight_scale_channel[start:end, :] = weight_scale[idx]
  53. start = end
  54. return weight_scale_channel
  55. def requantize_with_max_scale(
  56. weight: torch.Tensor, weight_scale: torch.Tensor,
  57. logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
  58. # Max scale to be used for requanitzation.
  59. max_w_scale = weight_scale.max()
  60. # QKV / MLP is fused in the on disk checkpoint if any of the
  61. # weight scales are still set to the default since we initialize
  62. # N weight scales for N shards but we only load 1 weight scale
  63. # from disk in this case. Skip requantization in this case (since)
  64. # we already are quantized with the single scale.
  65. # * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
  66. unfused_module_in_checkpoint = (weight_scale[-1] > torch.finfo(
  67. torch.float8_e4m3fn).min)
  68. # If unfused checkpoint, need requanize with the single scale.
  69. if unfused_module_in_checkpoint:
  70. start = 0
  71. for idx, logical_width in enumerate(logical_widths):
  72. end = start + logical_width
  73. weight_dq = per_tensor_dequantize(weight[start:end, :],
  74. weight_scale[idx])
  75. weight[start:end, :], _ = ops.scaled_fp8_quant(
  76. weight_dq, max_w_scale)
  77. start = end
  78. return max_w_scale, weight
  79. def apply_fp8_linear(
  80. input: torch.Tensor,
  81. weight: torch.Tensor,
  82. weight_scale: torch.Tensor,
  83. input_scale: Optional[torch.Tensor] = None,
  84. input_scale_ub: Optional[torch.Tensor] = None,
  85. bias: Optional[torch.Tensor] = None,
  86. cutlass_fp8_supported: bool = True,
  87. use_per_token_if_dynamic: bool = False,
  88. ) -> torch.Tensor:
  89. # ops.scaled_fp8_quant supports both dynamic and static quant.
  90. # If dynamic, layer.input_scale is None and x_scale computed from x.
  91. # If static, layer.input_scale is scalar and x_scale is input_scale.
  92. # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
  93. if cutlass_fp8_supported:
  94. qinput, x_scale = ops.scaled_fp8_quant(
  95. input,
  96. input_scale,
  97. scale_ub=input_scale_ub,
  98. use_per_token_if_dynamic=use_per_token_if_dynamic)
  99. # Fused GEMM_DQ
  100. return ops.cutlass_scaled_mm(qinput,
  101. weight,
  102. out_dtype=input.dtype,
  103. scale_a=x_scale,
  104. scale_b=weight_scale,
  105. bias=bias)
  106. # torch.scaled_mm supports per tensor weights + activations only
  107. # so fallback to naive if per channel or per token
  108. else:
  109. # Note: we pad the input because torch._scaled_mm is more performant
  110. # for matrices with batch dimension > 16.
  111. # This could change in the future.
  112. qinput, x_scale = ops.scaled_fp8_quant(
  113. input,
  114. input_scale,
  115. num_token_padding=17,
  116. use_per_token_if_dynamic=use_per_token_if_dynamic)
  117. per_tensor_weights = (weight_scale.numel() == 1)
  118. per_tensor_activations = (x_scale.numel() == 1)
  119. if per_tensor_weights and per_tensor_activations:
  120. # Fused GEMM_DQ
  121. output, _ = torch._scaled_mm(qinput,
  122. weight,
  123. out_dtype=input.dtype,
  124. scale_a=x_scale,
  125. scale_b=weight_scale,
  126. bias=bias)
  127. return torch.narrow(output, 0, 0, input.shape[0])
  128. else:
  129. # Fallback for channelwise case, where we use unfused DQ
  130. # due to limitations with scaled_mm
  131. # Symmetric quantized GEMM by definition computes the following:
  132. # C = (s_x * X) (s_w * W) + bias
  133. # This is equivalent to dequantizing the weights and activations
  134. # before applying a GEMM.
  135. #
  136. # In order to compute quantized operands, a quantized kernel
  137. # will rewrite the above like so:
  138. # C = s_w * s_x * (X * W) + bias
  139. #
  140. # For the scaled_mm fallback case, we break this down, since it
  141. # does not support s_w being a vector.
  142. # GEMM
  143. # This computes C = (X * W).
  144. # Output in fp32 to allow subsequent ops to happen in-place
  145. output, _ = torch._scaled_mm(qinput,
  146. weight,
  147. out_dtype=torch.float32)
  148. # Unpad (undo num_token_padding)
  149. output = torch.narrow(output, 0, 0, input.shape[0])
  150. x_scale = torch.narrow(x_scale, 0, 0, input.shape[0])
  151. # DQ
  152. # C = sw * sx * (X * W) + bias
  153. output = output * x_scale * weight_scale.t()
  154. if bias is not None:
  155. output = output + bias
  156. return output.to(dtype=input.dtype)
  157. def apply_int8_linear(
  158. input: torch.Tensor,
  159. weight: torch.Tensor,
  160. weight_scale: torch.Tensor,
  161. input_scale: Optional[torch.Tensor] = None,
  162. bias: Optional[torch.Tensor] = None,
  163. ):
  164. # ops.scaled_int8_quant supports both dynamic and static quant.
  165. # * dynamic, layer.input_scale is None and x_scale computed from x.
  166. # * static, layer.input_scale is scalar and x_scale is input_scale.
  167. x_q, x_scale = ops.scaled_int8_quant(input, input_scale)
  168. return ops.cutlass_scaled_mm(x_q,
  169. weight,
  170. scale_a=x_scale,
  171. scale_b=weight_scale,
  172. out_dtype=input.dtype,
  173. bias=bias)