test_big_models.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. """Compare the outputs of HF and Aphrodite when using greedy sampling.
  2. This tests bigger models and use half precision.
  3. Run `pytest tests/models/test_big_models.py`.
  4. """
  5. import pytest
  6. import torch
  7. from .utils import check_outputs_equal
  8. MODELS = [
  9. "meta-llama/Llama-2-7b-hf",
  10. # "mistralai/Mistral-7B-v0.1", # Tested by test_mistral.py
  11. # "Deci/DeciLM-7b", # Broken
  12. # "tiiuae/falcon-7b", # Broken
  13. "EleutherAI/gpt-j-6b",
  14. # "mosaicml/mpt-7b", # Broken
  15. # "Qwen/Qwen1.5-0.5B" # Broken,
  16. ]
  17. #TODO: remove this after CPU float16 support ready
  18. target_dtype = "float"
  19. if torch.cuda.is_available():
  20. target_dtype = "half"
  21. @pytest.mark.parametrize("model", MODELS)
  22. @pytest.mark.parametrize("dtype", [target_dtype])
  23. @pytest.mark.parametrize("max_tokens", [32])
  24. def test_models(
  25. hf_runner,
  26. aphrodite_runner,
  27. example_prompts,
  28. model: str,
  29. dtype: str,
  30. max_tokens: int,
  31. ) -> None:
  32. with hf_runner(model, dtype=dtype) as hf_model:
  33. hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
  34. with aphrodite_runner(model, dtype=dtype) as aphrodite_model:
  35. aphrodite_outputs = aphrodite_model.generate_greedy(
  36. example_prompts, max_tokens)
  37. check_outputs_equal(
  38. outputs_0_lst=hf_outputs,
  39. outputs_1_lst=aphrodite_outputs,
  40. name_0="hf",
  41. name_1="aphrodite",
  42. )
  43. @pytest.mark.parametrize("model", MODELS)
  44. @pytest.mark.parametrize("dtype", [target_dtype])
  45. def test_model_print(
  46. aphrodite_runner,
  47. model: str,
  48. dtype: str,
  49. ) -> None:
  50. with aphrodite_runner(model, dtype=dtype) as aphrodite_model:
  51. # This test is for verifying whether the model's extra_repr
  52. # can be printed correctly.
  53. print(aphrodite_model.model.llm_engine.model_executor.driver_worker.
  54. model_runner.model)