test_fp8_quant.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import pytest
  2. import torch
  3. import aphrodite._custom_ops as ops
  4. from tests.kernels.quant_utils import (FP8_DTYPE,
  5. ref_dynamic_per_tensor_fp8_quant,
  6. ref_dynamic_per_token_quant)
  7. DTYPES = [torch.half, torch.bfloat16, torch.float]
  8. HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192,
  9. 8193] # Arbitrary values for testing
  10. HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases
  11. NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
  12. SCALE_UBS = [True, False]
  13. SEEDS = [0]
  14. @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
  15. @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
  16. @pytest.mark.parametrize("dtype", DTYPES)
  17. @pytest.mark.parametrize("scale_ub", SCALE_UBS)
  18. @pytest.mark.parametrize("seed", SEEDS)
  19. @torch.inference_mode()
  20. def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
  21. dtype: torch.dtype, scale_ub: bool,
  22. seed: int) -> None:
  23. torch.random.manual_seed(seed)
  24. torch.cuda.manual_seed(seed)
  25. x = torch.rand(num_tokens, hidden_size, dtype=dtype,
  26. device="cuda") + 1e-6 # avoid nans
  27. scale_ub = torch.mean(x).to(dtype=torch.float32, device='cuda') \
  28. if scale_ub else None
  29. ref_out, ref_scales = ref_dynamic_per_token_quant(x, FP8_DTYPE, scale_ub)
  30. ops_out, ops_scales = ops.scaled_fp8_quant(x,
  31. scale_ub=scale_ub,
  32. use_per_token_if_dynamic=True)
  33. torch.testing.assert_close(ref_scales, ops_scales)
  34. torch.testing.assert_close(ref_out.to(dtype=torch.float32),
  35. ops_out.to(dtype=torch.float32))
  36. @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
  37. @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
  38. @pytest.mark.parametrize("dtype", DTYPES)
  39. @pytest.mark.parametrize("seed", SEEDS)
  40. @torch.inference_mode()
  41. def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
  42. dtype: torch.dtype, seed: int) -> None:
  43. torch.random.manual_seed(seed)
  44. torch.cuda.manual_seed(seed)
  45. x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
  46. ref_out, ref_scale = ref_dynamic_per_tensor_fp8_quant(x)
  47. ops_out, ops_scale = ops.scaled_fp8_quant(x)
  48. torch.testing.assert_close(ref_scale, ops_scale)
  49. torch.testing.assert_close(ref_out.to(dtype=torch.float32),
  50. ops_out.to(dtype=torch.float32))
  51. # Regression test for a case with large activations where an int32 index cannot
  52. # represent the number of elements.
  53. @torch.inference_mode()
  54. @pytest.mark.parametrize("seed", SEEDS)
  55. def test_fp8_quant_large(seed: int) -> None:
  56. torch.random.manual_seed(seed)
  57. torch.cuda.manual_seed(seed)
  58. num_tokens = 1024000 # Mistral-Nemo's max_position_embeddings
  59. hidden_size = 1152 # Smallest hidden_size to reproduce the error
  60. dtype = torch.bfloat16
  61. x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
  62. ref_out, scale = ref_dynamic_per_tensor_fp8_quant(x)
  63. ops_out, _ = ops.scaled_fp8_quant(x, scale)
  64. # Minimize memory footprint in this test by freeing x and upconverting
  65. # the outputs in place. (torch.allclose does not support fp8)
  66. del x
  67. ref_out = ref_out.to(dtype=dtype)
  68. ops_out = ops_out.to(dtype=dtype)
  69. torch.testing.assert_close(ref_out, ops_out)