test_models.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. """Compare the outputs of HF and Aphrodite when using greedy sampling.
  2. This test only tests small models. Big models such as 7B should be tested from
  3. test_big_models.py because it could use a larger instance to run tests.
  4. Run `pytest tests/models/test_models.py`.
  5. """
  6. import pytest
  7. from .utils import check_outputs_equal
  8. MODELS = [
  9. "facebook/opt-125m",
  10. "gpt2",
  11. "bigcode/tiny_starcoder_py",
  12. "EleutherAI/pythia-70m",
  13. "bigscience/bloom-560m", # Testing alibi slopes.
  14. "microsoft/phi-2",
  15. "stabilityai/stablelm-3b-4e1t",
  16. # "allenai/OLMo-1B", # Broken
  17. "bigcode/starcoder2-3b",
  18. "google/gemma-1.1-2b-it",
  19. ]
  20. @pytest.mark.parametrize("model", MODELS)
  21. @pytest.mark.parametrize("dtype", ["float"])
  22. @pytest.mark.parametrize("max_tokens", [96])
  23. def test_models(
  24. hf_runner,
  25. aphrodite_runner,
  26. example_prompts,
  27. model: str,
  28. dtype: str,
  29. max_tokens: int,
  30. ) -> None:
  31. # To pass the small model tests, we need full precision.
  32. assert dtype == "float"
  33. with hf_runner(model, dtype=dtype) as hf_model:
  34. hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
  35. with aphrodite_runner(model, dtype=dtype) as aphrodite_model:
  36. aphrodite_outputs = aphrodite_model.generate_greedy(
  37. example_prompts, max_tokens)
  38. check_outputs_equal(
  39. outputs_0_lst=hf_outputs,
  40. outputs_1_lst=aphrodite_outputs,
  41. name_0="hf",
  42. name_1="aphrodite",
  43. )
  44. @pytest.mark.parametrize("model", MODELS)
  45. @pytest.mark.parametrize("dtype", ["float"])
  46. def test_model_print(
  47. aphrodite_runner,
  48. model: str,
  49. dtype: str,
  50. ) -> None:
  51. with aphrodite_runner(model, dtype=dtype) as aphrodite_model:
  52. # This test is for verifying whether the model's extra_repr
  53. # can be printed correctly.
  54. print(aphrodite_model.model.llm_engine.model_executor.driver_worker.
  55. model_runner.model)