test_fp8.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. """Tests whether FP8 computation is enabled correctly.
  2. Run `pytest tests/quantization/test_fp8.py --forked`.
  3. """
  4. import pytest
  5. import torch
  6. from aphrodite import _custom_ops as ops
  7. from aphrodite.platforms import current_platform
  8. from aphrodite.quantization.fp8 import Fp8KVCacheMethod, Fp8LinearMethod
  9. from tests.quantization.utils import is_quant_method_supported
  10. MODELS = [
  11. "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
  12. "nm-testing/Phi-3-mini-128k-instruct-FP8",
  13. "nm-testing/Qwen2-0.5B-Instruct-FP8-SkipQKV",
  14. ]
  15. @pytest.mark.skipif(not is_quant_method_supported("fp8"),
  16. reason="FP8 is not supported on this GPU type.")
  17. @pytest.mark.parametrize("model_id", MODELS)
  18. @pytest.mark.parametrize("force_marlin", [False, True])
  19. def test_model_load_and_run(aphrodite_runner, model_id: str, force_marlin: bool,
  20. monkeypatch) -> None:
  21. if force_marlin:
  22. monkeypatch.setenv("APHRODITE_TEST_FORCE_FP8_MARLIN", "1")
  23. with aphrodite_runner(model_id) as llm:
  24. # note: this does not test accuracy, just that we can run through
  25. # see lm-eval tests for accuracy
  26. outputs = llm.generate_greedy(prompts=["Hello my name is"],
  27. max_tokens=10)
  28. print(outputs[0][1])
  29. KV_CACHE_MODELS = [
  30. # Deprecated AutoFP8 format using .kv_scale
  31. "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
  32. # AutoFP8 format using separate .k_scale and .v_scale
  33. "nm-testing/Qwen2-1.5B-Instruct-FP8-K-V",
  34. ]
  35. @pytest.mark.skipif(not is_quant_method_supported("fp8"),
  36. reason="FP8 is not supported on this GPU type.")
  37. @pytest.mark.parametrize("model_id", KV_CACHE_MODELS)
  38. def test_kv_cache_model_load_and_run(aphrodite_runner, model_id: str):
  39. with aphrodite_runner(model_id, kv_cache_dtype="fp8") as llm:
  40. model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
  41. attn = model.model.layers[0].self_attn.attn
  42. assert isinstance(attn.quant_method, Fp8KVCacheMethod)
  43. # NOTE: it is valid for scales to be 1.0 (default value), but we know
  44. # these checkpoints have scales < 1.0
  45. assert 0.0 < attn._k_scale < 1.0
  46. assert 0.0 < attn._v_scale < 1.0
  47. # note: this does not test accuracy, just that we can run through
  48. # see lm-eval tests for accuracy
  49. outputs = llm.generate_greedy(prompts=["Hello my name is"],
  50. max_tokens=10)
  51. print(outputs[0][1])
  52. @pytest.mark.skipif(not is_quant_method_supported("fp8"),
  53. reason="FP8 is not supported on this GPU type.")
  54. @pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
  55. @pytest.mark.parametrize("force_marlin", [False, True])
  56. def test_load_fp16_model(aphrodite_runner, kv_cache_dtype: str,
  57. force_marlin: bool,
  58. monkeypatch) -> None:
  59. if force_marlin:
  60. monkeypatch.setenv("APHRODITE_TEST_FORCE_FP8_MARLIN", "1")
  61. with aphrodite_runner("facebook/opt-125m",
  62. quantization="fp8",
  63. kv_cache_dtype=kv_cache_dtype) as llm:
  64. model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
  65. fc1 = model.model.decoder.layers[0].fc1
  66. assert isinstance(fc1.quant_method, Fp8LinearMethod)
  67. if kv_cache_dtype == "fp8":
  68. attn = model.model.decoder.layers[0].self_attn.attn
  69. assert isinstance(attn.quant_method, Fp8KVCacheMethod)
  70. assert attn._k_scale == 1.0
  71. assert attn._v_scale == 1.0
  72. capability = current_platform.get_device_capability()
  73. capability = capability[0] * 10 + capability[1]
  74. if capability >= 89 and not force_marlin:
  75. # For GPUs with hardware support, we keep weights in fp8
  76. assert fc1.weight.dtype == torch.float8_e4m3fn
  77. else:
  78. # For GPUs without hardware support, we pack the fp8 weights
  79. # for weight-only quantization using Marlin kernels
  80. assert fc1.weight.dtype == torch.int32
  81. @pytest.mark.skipif(not is_quant_method_supported("fp8"),
  82. reason="FP8 is not supported on this GPU type.")
  83. @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
  84. def test_scaled_fp8_quant(dtype) -> None:
  85. def quantize_ref(tensor, inv_scale):
  86. # The reference implementation that fully aligns to
  87. # the kernel being tested.
  88. finfo = torch.finfo(torch.float8_e4m3fn)
  89. scale = inv_scale.reciprocal()
  90. qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min,
  91. max=finfo.max)
  92. qweight = qweight.to(torch.float8_e4m3fn)
  93. return qweight
  94. def per_tensor_dequantize(tensor, inv_scale, dtype):
  95. fake_qweight = tensor.to(dtype)
  96. dq_weight = fake_qweight * inv_scale
  97. return dq_weight
  98. # Note that we use a shape % 4 != 0 to cover edge cases,
  99. # because scaled_fp8_quant is vectorized by 4.
  100. x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype)
  101. # Dynamic quantization
  102. ref_y, inv_scale = ops.scaled_fp8_quant(x, None)
  103. ref_y = per_tensor_dequantize(ref_y, inv_scale, dtype)
  104. # Reference dynamic quantizaton
  105. y = quantize_ref(x, inv_scale)
  106. torch.testing.assert_close(ref_y,
  107. per_tensor_dequantize(y, inv_scale, dtype))
  108. # Static quantization
  109. y, _ = ops.scaled_fp8_quant(x, inv_scale)
  110. torch.testing.assert_close(ref_y,
  111. per_tensor_dequantize(y, inv_scale, dtype))
  112. # Padding
  113. y, _ = ops.scaled_fp8_quant(x, inv_scale, num_token_padding=17)
  114. assert y.shape[0] == 17
  115. torch.testing.assert_close(
  116. ref_y,
  117. per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale,
  118. dtype))