test_fp8.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. # flake8: noqa
  2. """Tests fp8 models against ground truth generation
  3. Note: these tests will only pass on L4 GPU.
  4. """
  5. import os
  6. from typing import Optional
  7. import pytest
  8. from tests.kernels.utils import override_backend_env_variable
  9. from tests.quantization.utils import is_quant_method_supported
  10. from ...utils import check_logprobs_close
  11. os.environ["TOKENIZERS_PARALLELISM"] = "true"
  12. @pytest.mark.skipif(not is_quant_method_supported("fp8"),
  13. reason="fp8 is not supported on this GPU type.")
  14. @pytest.mark.parametrize(
  15. "kv_cache_dtype,base_model,test_model,scale_path",
  16. [
  17. # Test FP8 checkpoint w. fp8_e4m3 kv-cache scaling factors.
  18. ("fp8_e4m3", "meta-llama/Meta-Llama-3-8B-Instruct",
  19. "nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV", None),
  20. # Test FP16 checkpoint w. fp8_e5m2 kv-cache.
  21. ("fp8_e5m2", "meta-llama/Meta-Llama-3-8B-Instruct",
  22. "meta-llama/Meta-Llama-3-8B-Instruct", None),
  23. # Test FP16 checkpoint w. fp8_e4m3 kv-cache scaling factors in json.
  24. ("fp8_e4m3", "meta-llama/Llama-2-7b-chat-hf",
  25. "meta-llama/Llama-2-7b-chat-hf",
  26. "./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json")
  27. ])
  28. # Due to low-precision numerical divergence, we only test logprob of 4 tokens
  29. @pytest.mark.parametrize("max_tokens", [4])
  30. @pytest.mark.parametrize("enforce_eager", [False, True])
  31. @pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"])
  32. # NOTE: Increasing this in this suite will fail CI because we currently cannot
  33. # reset distributed env properly. Use a value > 1 just when you test.
  34. @pytest.mark.parametrize("tensor_parallel_size", [1])
  35. # Due to low-precision numerical divergence, this test is too sensitive for
  36. # the async postprocessor
  37. @pytest.mark.parametrize("disable_async_output_proc", [True])
  38. def test_models(
  39. aphrodite_runner,
  40. example_prompts,
  41. kv_cache_dtype: str,
  42. base_model: str,
  43. test_model: str,
  44. scale_path: Optional[str],
  45. max_tokens: int,
  46. enforce_eager: bool,
  47. backend: str,
  48. tensor_parallel_size: int,
  49. disable_async_output_proc: bool,
  50. monkeypatch,
  51. ) -> None:
  52. """
  53. Only checks log probs match to cover the discrepancy in
  54. numerical sensitive kernels.
  55. """
  56. override_backend_env_variable(monkeypatch, backend)
  57. MAX_MODEL_LEN = 1024
  58. NUM_LOG_PROBS = 8
  59. with aphrodite_runner(
  60. base_model,
  61. max_model_len=MAX_MODEL_LEN,
  62. tensor_parallel_size=tensor_parallel_size,
  63. enforce_eager=enforce_eager,
  64. kv_cache_dtype="auto",
  65. disable_async_output_proc=disable_async_output_proc,
  66. ) as aphrodite_model:
  67. baseline_outputs = aphrodite_model.generate_greedy_logprobs(
  68. example_prompts, max_tokens, NUM_LOG_PROBS)
  69. extra_kwargs = {}
  70. if scale_path is not None:
  71. extra_kwargs["quantization_param_path"] = scale_path
  72. with aphrodite_runner(
  73. test_model,
  74. max_model_len=MAX_MODEL_LEN,
  75. tensor_parallel_size=tensor_parallel_size,
  76. enforce_eager=enforce_eager,
  77. kv_cache_dtype=kv_cache_dtype,
  78. disable_async_output_proc=disable_async_output_proc,
  79. **extra_kwargs,
  80. ) as aphrodite_model:
  81. test_outputs = aphrodite_model.generate_greedy_logprobs(
  82. example_prompts, max_tokens, NUM_LOG_PROBS)
  83. check_logprobs_close(
  84. outputs_0_lst=baseline_outputs,
  85. outputs_1_lst=test_outputs,
  86. name_0="fp16_kv_cache",
  87. name_1="fp8_kv_cache",
  88. )