test_basic_correctness.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. """Compare the short outputs of HF and Aphrodite when using greedy sampling.
  2. Run `pytest tests/basic_correctness/test_basic_correctness.py`.
  3. """
  4. import os
  5. import pickle
  6. import re
  7. import weakref
  8. from unittest.mock import patch
  9. import pytest
  10. from aphrodite import LLM
  11. from aphrodite.common.utils import is_hip
  12. from aphrodite.worker.model_runner import ModelInputForGPUWithSamplingMetadata
  13. from ..models.utils import check_outputs_equal
  14. MODELS = [
  15. "facebook/opt-125m",
  16. "meta-llama/Llama-2-7b-hf",
  17. ]
  18. def test_aphrodite_gc_ed():
  19. """Verify aphrodite instance is GC'ed when it is deleted"""
  20. llm = LLM("facebook/opt-125m")
  21. weak_llm = weakref.ref(llm)
  22. del llm
  23. # If there's any circular reference to aphrodite, this fails
  24. # because llm instance is not GC'ed.
  25. assert weak_llm() is None
  26. @pytest.mark.parametrize("model", MODELS)
  27. @pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"])
  28. @pytest.mark.parametrize("dtype", ["half"])
  29. @pytest.mark.parametrize("max_tokens", [5])
  30. @pytest.mark.parametrize("enforce_eager", [False, True])
  31. def test_models(
  32. hf_runner,
  33. aphrodite_runner,
  34. example_prompts,
  35. model: str,
  36. backend: str,
  37. dtype: str,
  38. max_tokens: int,
  39. enforce_eager: bool,
  40. ) -> None:
  41. if backend == "FLASHINFER" and is_hip():
  42. pytest.skip("Flashinfer does not support ROCm/HIP.")
  43. os.environ["APHRODITE_ATTENTION_BACKEND"] = backend
  44. with hf_runner(model, dtype=dtype) as hf_model:
  45. hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
  46. with aphrodite_runner(model,
  47. dtype=dtype,
  48. enforce_eager=enforce_eager,
  49. gpu_memory_utilization=0.7) as aphrodite_model:
  50. aphrodite_outputs = aphrodite_model.generate_greedy(example_prompts,
  51. max_tokens)
  52. check_outputs_equal(
  53. outputs_0_lst=hf_outputs,
  54. outputs_1_lst=aphrodite_outputs,
  55. name_0="hf",
  56. name_1="aphrodite",
  57. )
  58. def test_model_with_failure(aphrodite_runner) -> None:
  59. try:
  60. with patch("aphrodite.modeling.models.opt.OPTForCausalLM.forward",
  61. side_effect=ValueError()):
  62. with pytest.raises(ValueError) as exc_info:
  63. aphrodite_runner("facebook/opt-125m",
  64. dtype="half",
  65. enforce_eager=False,
  66. gpu_memory_utilization=0.7)
  67. matches = re.search(r"input dumped to (.+).pkl",
  68. str(exc_info.value))
  69. assert matches is not None
  70. filename = f"{matches.group(1)}.pkl"
  71. with open(filename, "rb") as filep:
  72. inputs = pickle.load(filep)
  73. if any(key not in inputs for key in ("arg_1", "arg_2", "arg_3")):
  74. raise AssertionError("Missing keys in dumped inputs. Dumped keys: "
  75. f"{list(inputs.keys())}")
  76. assert isinstance(inputs["arg_1"],
  77. ModelInputForGPUWithSamplingMetadata)
  78. finally:
  79. os.remove(filename)