test_marlin.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. """Compare the outputs of a GPTQ model to a Marlin model.
  2. Note: GPTQ and Marlin do not have bitwise correctness.
  3. As a result, in this test, we just confirm that the top selected tokens of the
  4. Marlin/GPTQ models are in the top 3 selections of each other.
  5. Note: Marlin internally uses locks to synchronize the threads. This can
  6. result in very slight nondeterminism for Marlin. As a result, we re-run the test
  7. up to 3 times to see if we pass.
  8. Run `pytest tests/models/test_marlin.py`.
  9. """
  10. from dataclasses import dataclass
  11. import pytest
  12. from tests.quantization.utils import is_quant_method_supported
  13. from .utils import check_logprobs_close
  14. @dataclass
  15. class ModelPair:
  16. model_marlin: str
  17. model_gptq: str
  18. model_pairs = [
  19. ModelPair(model_marlin="nm-testing/zephyr-beta-7b-marlin-g128",
  20. model_gptq="nm-testing/zephyr-beta-7b-gptq-g128"),
  21. ModelPair(model_marlin="robertgshaw2/zephyr-7b-beta-channelwise-marlin",
  22. model_gptq="robertgshaw2/zephyr-7b-beta-channelwise-gptq"),
  23. ModelPair(model_marlin="robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin",
  24. model_gptq="robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-gptq")
  25. ]
  26. @pytest.mark.flaky(reruns=2)
  27. @pytest.mark.skipif(not is_quant_method_supported("marlin"),
  28. reason="Marlin is not supported on this GPU type.")
  29. @pytest.mark.parametrize("model_pair", model_pairs)
  30. @pytest.mark.parametrize("dtype", ["half"])
  31. @pytest.mark.parametrize("max_tokens", [32])
  32. @pytest.mark.parametrize("num_logprobs", [5])
  33. def test_models(
  34. aphrodite_runner,
  35. example_prompts,
  36. model_pair: ModelPair,
  37. dtype: str,
  38. max_tokens: int,
  39. num_logprobs: int,
  40. ) -> None:
  41. with aphrodite_runner(model_pair.model_marlin,
  42. dtype=dtype,
  43. quantization="marlin") as marlin_model:
  44. marlin_outputs = marlin_model.generate_greedy_logprobs(
  45. example_prompts, max_tokens, num_logprobs)
  46. with aphrodite_runner(model_pair.model_gptq, dtype=dtype,
  47. quantization="gptq") as gptq_model:
  48. gptq_outputs = gptq_model.generate_greedy_logprobs(
  49. example_prompts, max_tokens, num_logprobs)
  50. check_logprobs_close(
  51. outputs_0_lst=gptq_outputs,
  52. outputs_1_lst=marlin_outputs,
  53. name_0="gptq",
  54. name_1="marlin",
  55. )