quant_utils.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. from typing import Optional, Tuple, Union
  2. import torch
  3. from aphrodite.common.utils import is_hip
  4. # Using the default value (240.0) from pytorch will cause accuracy
  5. # issue on dynamic quantization models. Here use 224.0 for rocm.
  6. ROCM_FP8_MAX = 224.0
  7. FP8_DTYPE = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
  8. def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
  9. return torch.as_tensor(x, dtype=torch.float32, device='cuda')
  10. def ref_dynamic_per_token_quant(x: torch.tensor,
  11. quant_dtype: torch.dtype,
  12. scale_ub: Optional[torch.tensor] = None) \
  13. -> Tuple[torch.tensor, torch.tensor]:
  14. assert quant_dtype in [torch.int8, FP8_DTYPE]
  15. if scale_ub is not None:
  16. assert quant_dtype == FP8_DTYPE
  17. qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
  18. else torch.finfo(quant_dtype)
  19. qtype_traits_max = ROCM_FP8_MAX if is_hip() else qtype_traits.max
  20. qtype_traits_min = -ROCM_FP8_MAX if is_hip() else qtype_traits.min
  21. qtype_max = as_float32_tensor(qtype_traits_max)
  22. s_1 = as_float32_tensor(1.0)
  23. s_512 = as_float32_tensor(512.0)
  24. # For fp8, in order to match the cuda kernel output, we have to do exactly
  25. # the same operations as in the corresponding fp8 kernel to prevent
  26. # rounding errors.
  27. # Compute scales
  28. x_token_max, _ = x.abs().max(dim=-1)
  29. x_token_max = as_float32_tensor(x_token_max)
  30. if scale_ub is not None:
  31. x_token_max = x_token_max.clamp(max=scale_ub)
  32. scales = (x_token_max / qtype_max)[:, None]
  33. # Quant
  34. if quant_dtype == torch.int8:
  35. iscales = as_float32_tensor(s_1 / scales)
  36. torch_out = as_float32_tensor(x) * iscales
  37. torch_out = torch_out.round()
  38. torch_out = torch_out.clamp(qtype_traits_min,
  39. qtype_traits_max).to(quant_dtype)
  40. else:
  41. assert quant_dtype == FP8_DTYPE
  42. min_scaling_factor = s_1 / (qtype_max * s_512)
  43. scales = scales.clamp(min=min_scaling_factor)
  44. torch_out = as_float32_tensor(x) / scales
  45. torch_out = torch_out.clamp(qtype_traits_min,
  46. qtype_traits_max).to(quant_dtype)
  47. return torch_out, scales
  48. # The int8 version is very similar. Incorporate the int8 version, like in
  49. # ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant
  50. # kernel
  51. def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
  52. -> Tuple[torch.tensor, torch.tensor]:
  53. fp8_traits = torch.finfo(FP8_DTYPE)
  54. fp8_traits_max = ROCM_FP8_MAX if is_hip() else fp8_traits.max
  55. fp8_traits_min = -ROCM_FP8_MAX if is_hip() else fp8_traits.min
  56. fp8_max = as_float32_tensor(fp8_traits_max)
  57. one = as_float32_tensor(1.0)
  58. # For fp8, in order to match the cuda kernel output, we have to do exactly
  59. # the same operations as in the corresponding fp8 kernel to prevent
  60. # rounding errors.
  61. x_max = as_float32_tensor(x.abs().max())
  62. ref_scale = x_max / fp8_max
  63. ref_iscale = one / ref_scale
  64. ref_out = (as_float32_tensor(x) * ref_iscale).clamp(
  65. fp8_traits_min, fp8_traits_max).to(FP8_DTYPE)
  66. return ref_out, ref_scale.view((1, ))