test_gptq_marlin.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. """Compares the outputs of gptq vs gptq_marlin
  2. Note: GPTQ and Marlin do not have bitwise correctness.
  3. As a result, in this test, we just confirm that the top selected tokens of the
  4. Marlin/GPTQ models are in the top 5 selections of each other.
  5. Note: Marlin internally uses locks to synchronize the threads. This can
  6. result in very slight nondeterminism for Marlin. As a result, we re-run the test
  7. up to 3 times to see if we pass.
  8. Run `pytest tests/models/test_gptq_marlin.py`.
  9. """
  10. import os
  11. import pytest
  12. from aphrodite.modeling.layers.rotary_embedding import _ROPE_DICT
  13. from tests.quantization.utils import is_quant_method_supported
  14. from .utils import check_logprobs_close
  15. os.environ["TOKENIZERS_PARALLELISM"] = "true"
  16. MAX_MODEL_LEN = 1024
  17. MODELS = [
  18. # act_order==False, group_size=channelwise
  19. ("robertgshaw2/zephyr-7b-beta-channelwise-gptq", "main"),
  20. # act_order==False, group_size=128
  21. ("TheBloke/Llama-2-7B-GPTQ", "main"),
  22. # act_order==True, group_size=128
  23. ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "main"),
  24. # act_order==True, group_size=64
  25. ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-64g-actorder_True"),
  26. # act_order==True, group_size=32
  27. ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-32g-actorder_True"),
  28. # 8-bit, act_order==True, group_size=channelwise
  29. ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit--1g-actorder_True"),
  30. # 8-bit, act_order==True, group_size=128
  31. ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit-128g-actorder_True"),
  32. # 8-bit, act_order==True, group_size=32
  33. ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit-32g-actorder_True"),
  34. # 4-bit, act_order==True, group_size=128
  35. ("TechxGenus/gemma-1.1-2b-it-GPTQ", "main")
  36. ]
  37. @pytest.mark.flaky(reruns=3)
  38. @pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
  39. reason="gptq_marlin is not supported on this GPU type.")
  40. @pytest.mark.parametrize("model", MODELS)
  41. @pytest.mark.parametrize("dtype", ["half", "bfloat16"])
  42. @pytest.mark.parametrize("max_tokens", [32])
  43. @pytest.mark.parametrize("num_logprobs", [5])
  44. def test_models(
  45. aphrodite_runner,
  46. example_prompts,
  47. model,
  48. dtype: str,
  49. max_tokens: int,
  50. num_logprobs: int,
  51. ) -> None:
  52. model_name, revision = model
  53. # Run marlin.
  54. with aphrodite_runner(model_name=model_name,
  55. revision=revision,
  56. dtype=dtype,
  57. quantization="marlin",
  58. max_model_len=MAX_MODEL_LEN,
  59. tensor_parallel_size=1) as gptq_marlin_model:
  60. gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs(
  61. example_prompts[:-1], max_tokens, num_logprobs)
  62. _ROPE_DICT.clear() # clear rope cache to avoid rope dtype error
  63. # Run gptq.
  64. # The naive gptq kernel doesn't support bf16 yet.
  65. # Here we always compare fp16/bf16 gpt marlin kernel
  66. # to fp16 gptq kernel.
  67. with aphrodite_runner(model_name=model_name,
  68. revision=revision,
  69. dtype="half",
  70. quantization="gptq",
  71. max_model_len=MAX_MODEL_LEN,
  72. tensor_parallel_size=1) as gptq_model:
  73. gptq_outputs = gptq_model.generate_greedy_logprobs(
  74. example_prompts[:-1], max_tokens, num_logprobs)
  75. check_logprobs_close(
  76. outputs_0_lst=gptq_outputs,
  77. outputs_1_lst=gptq_marlin_outputs,
  78. name_0="gptq",
  79. name_1="gptq_marlin",
  80. )