test_int8_quant.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import pytest
  2. import torch
  3. from aphrodite._custom_ops import scaled_int8_quant
  4. from tests.kernels.quant_utils import ref_dynamic_per_token_quant
  5. from tests.kernels.utils import opcheck
  6. DTYPES = [torch.half, torch.bfloat16, torch.float]
  7. HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192,
  8. 8193] # Arbitrary values for testing
  9. NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
  10. SEEDS = [0]
  11. SCALE = [0.1, 0.5, 0.8, 1.2, 2.1]
  12. def opcheck_int8_quant_static(output, input, scale, azp=None):
  13. if azp is None:
  14. opcheck(torch.ops._C.static_scaled_int8_quant,
  15. (output, input, scale, None))
  16. else:
  17. opcheck(torch.ops._C.static_scaled_int8_quant,
  18. (output, input, scale, azp))
  19. def opcheck_int8_quant_dynamic(output, input, symmetric=True):
  20. scale = torch.empty((input.numel() // input.shape[-1], 1),
  21. device=input.device,
  22. dtype=torch.float32)
  23. if symmetric:
  24. opcheck(torch.ops._C.dynamic_scaled_int8_quant,
  25. (output, input, scale, None))
  26. else:
  27. azp = torch.empty((input.numel() // input.shape[-1], 1),
  28. device=input.device,
  29. dtype=torch.int32)
  30. opcheck(torch.ops._C.dynamic_scaled_int8_quant,
  31. (output, input, scale, azp))
  32. @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
  33. @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
  34. @pytest.mark.parametrize("dtype", DTYPES)
  35. @pytest.mark.parametrize("seed", SEEDS)
  36. @torch.inference_mode()
  37. def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
  38. dtype: torch.dtype, seed: int) -> None:
  39. torch.random.manual_seed(seed)
  40. torch.cuda.manual_seed(seed)
  41. x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
  42. # reference
  43. ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.int8)
  44. # kernel
  45. ops_out, ops_scales, _ = scaled_int8_quant(x)
  46. torch.testing.assert_close(ops_scales, ref_scales)
  47. # big atol to account for rounding errors
  48. torch.testing.assert_close(ops_out, ref_out, atol=1, rtol=0.0)
  49. opcheck_int8_quant_dynamic(ops_out, x)
  50. @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
  51. @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
  52. @pytest.mark.parametrize("dtype", DTYPES)
  53. @pytest.mark.parametrize("seed", SEEDS)
  54. @torch.inference_mode()
  55. def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
  56. dtype: torch.dtype, seed: int) -> None:
  57. torch.random.manual_seed(seed)
  58. torch.cuda.manual_seed(seed)
  59. int8_traits = torch.iinfo(torch.int8)
  60. x = torch.rand(num_tokens, hidden_size, dtype=dtype,
  61. device="cuda") * 1000 - 300
  62. x_token_max, _ = x.to(dtype=torch.float32).max(dim=1, keepdim=True)
  63. x_token_min, _ = x.to(dtype=torch.float32).min(dim=1, keepdim=True)
  64. # calculate scale and azp, and adjust the range
  65. scales = (x_token_max - x_token_min) / torch.tensor(255.0)
  66. azps = torch.round(torch.tensor(-128.0) - x_token_min / scales).to(
  67. torch.int32)
  68. torch_out = ((x / scales).round() + azps).clamp(
  69. int8_traits.min, int8_traits.max).to(torch.int8)
  70. assert torch_out.min() >= int8_traits.min and torch_out.max(
  71. ) <= int8_traits.max
  72. ops_out = torch.empty_like(x, dtype=torch.int8)
  73. scales_out = torch.empty_like(scales, dtype=torch.float32)
  74. azp_out = torch.empty_like(azps, dtype=torch.int32)
  75. torch.ops._C.dynamic_scaled_int8_quant(ops_out, x, scales_out, azp_out)
  76. if (not torch.allclose(scales_out, scales)):
  77. print(torch.argmax(torch.abs(scales_out - scales)))
  78. torch.testing.assert_close(scales_out, scales)
  79. # big atol to account for rounding errors
  80. torch.testing.assert_close(azp_out, azps, atol=1, rtol=0.0)
  81. # if AZP is off by 1, after rounding-to-even, the output may be off by 2
  82. torch.testing.assert_close(ops_out, torch_out, atol=2, rtol=0.0)
  83. opcheck_int8_quant_dynamic(ops_out, x, False)
  84. @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
  85. @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
  86. @pytest.mark.parametrize("dtype", DTYPES)
  87. @pytest.mark.parametrize("seed", SEEDS)
  88. @pytest.mark.parametrize("scale", SCALE)
  89. @torch.inference_mode()
  90. def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
  91. dtype: torch.dtype, seed: int,
  92. scale: float) -> None:
  93. torch.random.manual_seed(seed)
  94. torch.cuda.manual_seed(seed)
  95. int8_traits = torch.iinfo(torch.int8)
  96. x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
  97. scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda")
  98. out1 = (x / scale_arg).round().clamp(int8_traits.min,
  99. int8_traits.max).to(torch.int8)
  100. out2, _, _ = scaled_int8_quant(x, scale_arg)
  101. # big atol to account for rounding errors
  102. torch.testing.assert_close(out1, out2, atol=1, rtol=0.0)
  103. opcheck_int8_quant_static(out2, x, scale_arg)
  104. @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
  105. @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
  106. @pytest.mark.parametrize("dtype", DTYPES)
  107. @pytest.mark.parametrize("seed", SEEDS)
  108. @pytest.mark.parametrize("scale", SCALE[2:]) # Reduce test time
  109. @pytest.mark.parametrize("azp", [-255, 54])
  110. @torch.inference_mode()
  111. def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
  112. dtype: torch.dtype, seed: int,
  113. scale: float, azp: int) -> None:
  114. torch.random.manual_seed(seed)
  115. torch.cuda.manual_seed(seed)
  116. int8_traits = torch.iinfo(torch.int8)
  117. x = torch.rand(num_tokens, hidden_size, dtype=dtype,
  118. device="cuda") * 1000 - 300
  119. out1 = ((x / scale).round() + azp).clamp(int8_traits.min,
  120. int8_traits.max).to(torch.int8)
  121. out2 = torch.empty_like(x, dtype=torch.int8)
  122. scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda")
  123. azp_arg = torch.tensor([azp], dtype=torch.int32, device="cuda")
  124. torch.ops._C.static_scaled_int8_quant(out2, x, scale_arg, azp_arg)
  125. # big atol to account for rounding errors
  126. torch.testing.assert_close(out1, out2, atol=1, rtol=0.0)
  127. opcheck_int8_quant_static(out2, x, scale_arg, azp_arg)
  128. @pytest.mark.parametrize("is_max", [True, False])
  129. @torch.inference_mode()
  130. def test_static_scaled_int8_azp_quant_saturating_cast(is_max: bool) -> None:
  131. # Test that the saturating cast works correctly for values near i32 max/min
  132. from numpy import inf, nextafter
  133. int32_traits = torch.iinfo(torch.int32)
  134. val = float(int32_traits.max if is_max else int32_traits.min)
  135. x_vals = [[
  136. nextafter(val, inf), val + 1, val, val - 1,
  137. nextafter(val, -inf)
  138. ]]
  139. x = torch.tensor(x_vals, dtype=torch.float32, device="cuda")
  140. # The calculation in the kernel is: cast<int8>(cast<int32>(x / scale) + azp)
  141. # where cast<T> is a saturating cast to type T.
  142. # Scale is set to 1.0 so that the input values are the ones that are cast.
  143. # AZP is set to 0 to make sure the int8 saturating cast is tested as well.
  144. scale = torch.scalar_tensor(1.0, dtype=torch.float32, device="cuda")
  145. azp = torch.scalar_tensor(0, dtype=torch.int32, device="cuda")
  146. int8_traits = torch.iinfo(torch.int8)
  147. val_i8 = int8_traits.max if is_max else int8_traits.min
  148. expected = torch.full((1, 5), val_i8, dtype=torch.int8, device="cuda")
  149. out = torch.empty_like(expected)
  150. torch.ops._C.static_scaled_int8_quant(out, x, scale, azp)
  151. torch.testing.assert_close(expected, out, atol=0, rtol=0)