test_lm_head.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. """Tests whether gptq models with quantized lm_head can be loaded.
  2. Run `pytest tests/quantization/test_quant_lm_head_true.py --forked`.
  3. """
  4. from typing import Tuple
  5. import pytest
  6. import torch
  7. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  8. UnquantizedEmbeddingMethod)
  9. from aphrodite.quantization.gptq import GPTQLinearMethod
  10. from aphrodite.quantization.gptq_marlin import GPTQMarlinLinearMethod
  11. from aphrodite.quantization.marlin import MarlinLinearMethod
  12. PROMPT = "On the surface of Mars, we found"
  13. MODELS_QUANT = [(
  14. "LnL-AI/TinyLlama-1.1B-intermediate-step-1341k-3T-autoround-lm_head-symFalse",
  15. True), ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", False),
  16. ("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", False)]
  17. @pytest.mark.parametrize("model_lm_head_quant", MODELS_QUANT)
  18. def test_lm_head(
  19. aphrodite_runner,
  20. model_lm_head_quant: Tuple[str, bool],
  21. ) -> None:
  22. model, lm_head_quantized = model_lm_head_quant
  23. aphrodite_model = aphrodite_runner(model, dtype=torch.float16,
  24. max_model_len=2048)
  25. lm_head_layer = (
  26. aphrodite_model.model.llm_engine.model_executor.driver_worker.
  27. model_runner.model.lm_head)
  28. if lm_head_quantized:
  29. assert isinstance(
  30. lm_head_layer.linear_method,
  31. (GPTQLinearMethod, GPTQMarlinLinearMethod, MarlinLinearMethod))
  32. else:
  33. assert isinstance(lm_head_layer.linear_method,
  34. UnquantizedEmbeddingMethod)
  35. print(
  36. aphrodite_model.generate_greedy(prompts=["Hello my name is"],
  37. max_tokens=10)[0][1])
  38. del aphrodite_model