1
0

test_gptq_marlin_24.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. """Compare the outputs of a GPTQ model to a Marlin_24 model.
  2. Note: GPTQ and Marlin_24 do not have bitwise correctness.
  3. As a result, in this test, we just confirm that the top selected tokens of the
  4. Marlin/GPTQ models are in the top 3 selections of each other.
  5. Run `pytest tests/models/test_marlin_24.py`.
  6. """
  7. from dataclasses import dataclass
  8. import pytest
  9. from tests.models.utils import check_logprobs_close
  10. from tests.quantization.utils import is_quant_method_supported
  11. @dataclass
  12. class ModelPair:
  13. model_marlin: str
  14. model_gptq: str
  15. model_pairs = [
  16. # 4-bit, group_size == 128
  17. ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-4bit-g128",
  18. model_gptq="alexm-nm/tinyllama-24-gptq-4bit-g128"),
  19. # 4-bit, group_size == channelwise
  20. ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-4bit-channelwise",
  21. model_gptq="alexm-nm/tinyllama-24-gptq-4bit-channelwise"),
  22. # 8-bit, group_size == 128
  23. ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-8bit-g128",
  24. model_gptq="alexm-nm/tinyllama-24-gptq-8bit-g128"),
  25. # 8-bit, group_size == channelwise
  26. ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-8bit-channelwise",
  27. model_gptq="alexm-nm/tinyllama-24-gptq-8bit-channelwise"),
  28. ]
  29. @pytest.mark.flaky(reruns=2)
  30. @pytest.mark.skipif(not is_quant_method_supported("gptq_marlin_24"),
  31. reason="Marlin24 is not supported on this GPU type.")
  32. @pytest.mark.parametrize("model_pair", model_pairs)
  33. @pytest.mark.parametrize("dtype", ["half"])
  34. @pytest.mark.parametrize("max_tokens", [8])
  35. @pytest.mark.parametrize("num_logprobs", [5])
  36. def test_models(
  37. aphrodite_runner,
  38. example_prompts,
  39. model_pair: ModelPair,
  40. dtype: str,
  41. max_tokens: int,
  42. num_logprobs: int,
  43. ) -> None:
  44. with aphrodite_runner(model_pair.model_marlin,
  45. dtype=dtype,
  46. quantization="gptq_marlin_24") as marlin_24_model:
  47. marlin_24_outputs = marlin_24_model.generate_greedy_logprobs(
  48. example_prompts, max_tokens, num_logprobs)
  49. with aphrodite_runner(model_pair.model_gptq, dtype=dtype,
  50. quantization="gptq") as gptq_model:
  51. gptq_outputs = gptq_model.generate_greedy_logprobs(
  52. example_prompts, max_tokens, num_logprobs)
  53. check_logprobs_close(
  54. outputs_0_lst=gptq_outputs,
  55. outputs_1_lst=marlin_24_outputs,
  56. name_0="gptq",
  57. name_1="marlin_24",
  58. )