test_configs.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. """Tests whether Marlin models can be loaded from the autogptq config.
  2. Run `pytest tests/quantization/test_configs.py --forked`.
  3. """
  4. from dataclasses import dataclass
  5. from typing import Tuple
  6. import pytest
  7. from aphrodite.common.config import ModelConfig
  8. @dataclass
  9. class ModelPair:
  10. model_marlin: str
  11. model_gptq: str
  12. # Model Id // Quantization Arg // Expected Type
  13. MODEL_ARG_EXPTYPES = [
  14. # AUTOGPTQ
  15. # compat: autogptq <=0.7.1 is_marlin_format: bool
  16. # Model Serialized in Marlin Format should always use Marlin kernel.
  17. ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", None, "marlin"),
  18. ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "marlin", "marlin"),
  19. ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "gptq", "marlin"),
  20. ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "awq", "ERROR"),
  21. # Model Serialized in Exllama Format.
  22. ("TheBloke/Llama-2-7B-Chat-GPTQ", None, "gptq_marlin"),
  23. ("TheBloke/Llama-2-7B-Chat-GPTQ", "marlin", "gptq_marlin"),
  24. ("TheBloke/Llama-2-7B-Chat-GPTQ", "gptq", "gptq"),
  25. ("TheBloke/Llama-2-7B-Chat-GPTQ", "awq", "ERROR"),
  26. # compat: autogptq >=0.8.0 use checkpoint_format: str
  27. # Model Serialized in Marlin Format should always use Marlin kernel.
  28. ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", None, "marlin"),
  29. ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "marlin", "marlin"),
  30. ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "gptq", "marlin"),
  31. ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "awq", "ERROR"),
  32. # Model Serialized in Exllama Format.
  33. ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", None, "gptq_marlin"),
  34. ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "marlin", "gptq_marlin"),
  35. ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "gptq", "gptq"),
  36. ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "awq", "ERROR"),
  37. # AUTOAWQ
  38. ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", None, "awq_marlin"),
  39. ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "awq", "awq"),
  40. ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "marlin", "awq_marlin"),
  41. ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "gptq", "ERROR"),
  42. ]
  43. @pytest.mark.parametrize("model_arg_exptype", MODEL_ARG_EXPTYPES)
  44. def test_auto_gptq(model_arg_exptype: Tuple[str, None, str]) -> None:
  45. model_path, quantization_arg, expected_type = model_arg_exptype
  46. try:
  47. model_config = ModelConfig(model_path,
  48. model_path,
  49. tokenizer_mode="auto",
  50. trust_remote_code=False,
  51. seed=0,
  52. dtype="float16",
  53. revision=None,
  54. quantization=quantization_arg)
  55. found_quantization_type = model_config.quantization
  56. except ValueError:
  57. found_quantization_type = "ERROR"
  58. assert found_quantization_type == expected_type, (
  59. f"Expected quant_type == {expected_type} for {model_path}, "
  60. f"but found {found_quantization_type} "
  61. f"for no --quantization {quantization_arg} case")