test_big_models.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. """Compare the outputs of HF and Aphrodite when using greedy sampling.
  2. This tests bigger models and use half precision.
  3. Run `pytest tests/models/test_big_models.py`.
  4. """
  5. import pytest
  6. from aphrodite.platforms import current_platform
  7. from ...utils import check_outputs_equal
  8. MODELS = [
  9. "meta-llama/Llama-2-7b-hf",
  10. # "mistralai/Mistral-7B-v0.1", # Tested by test_mistral.py
  11. # "Deci/DeciLM-7b", # Broken
  12. # "tiiuae/falcon-7b", # Broken
  13. "EleutherAI/gpt-j-6b",
  14. # "mosaicml/mpt-7b", # Broken
  15. # "Qwen/Qwen1.5-0.5B" # Broken,
  16. ]
  17. if not current_platform.is_cpu():
  18. # MiniCPM requires fused_moe which is not supported by CPU
  19. MODELS.append("openbmb/MiniCPM3-4B")
  20. #TODO: remove this after CPU float16 support ready
  21. target_dtype = "float" if current_platform.is_cpu() else "half"
  22. @pytest.mark.parametrize("model", MODELS)
  23. @pytest.mark.parametrize("dtype", [target_dtype])
  24. @pytest.mark.parametrize("max_tokens", [32])
  25. def test_models(
  26. hf_runner,
  27. aphrodite_runner,
  28. example_prompts,
  29. model: str,
  30. dtype: str,
  31. max_tokens: int,
  32. ) -> None:
  33. with hf_runner(model, dtype=dtype) as hf_model:
  34. hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
  35. with aphrodite_runner(model, dtype=dtype,
  36. enforce_eager=True) as aphrodite_model:
  37. aphrodite_outputs = aphrodite_model.generate_greedy(
  38. example_prompts, max_tokens)
  39. check_outputs_equal(
  40. outputs_0_lst=hf_outputs,
  41. outputs_1_lst=aphrodite_outputs,
  42. name_0="hf",
  43. name_1="aphrodite",
  44. )
  45. @pytest.mark.parametrize("model", MODELS)
  46. @pytest.mark.parametrize("dtype", [target_dtype])
  47. def test_model_print(
  48. aphrodite_runner,
  49. model: str,
  50. dtype: str,
  51. ) -> None:
  52. with aphrodite_runner(
  53. model, dtype=dtype, enforce_eager=True) as aphrodite_model:
  54. # This test is for verifying whether the model's extra_repr
  55. # can be printed correctly.
  56. print(aphrodite_model.model.llm_engine.model_executor.driver_worker.
  57. model_runner.model)