test_granite.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. """Compare the outputs of HF and Aphrodite for Granite models using greedy
  2. sampling.
  3. Run `pytest tests/models/test_granite.py`.
  4. """
  5. import pytest
  6. import transformers
  7. from ...utils import check_logprobs_close
  8. MODELS = [
  9. "ibm/PowerLM-3b",
  10. ]
  11. # GraniteForCausalLM will be in transformers >= 4.45
  12. @pytest.mark.skipif(transformers.__version__ < "4.45",
  13. reason="granite model test requires transformers >= 4.45")
  14. @pytest.mark.parametrize("model", MODELS)
  15. @pytest.mark.parametrize("dtype", ["bfloat16"])
  16. @pytest.mark.parametrize("max_tokens", [64])
  17. @pytest.mark.parametrize("num_logprobs", [5])
  18. def test_models(
  19. hf_runner,
  20. aphrodite_runner,
  21. example_prompts,
  22. model: str,
  23. dtype: str,
  24. max_tokens: int,
  25. num_logprobs: int,
  26. ) -> None:
  27. # TODO(sang): Sliding window should be tested separately.
  28. with hf_runner(model, dtype=dtype) as hf_model:
  29. hf_outputs = hf_model.generate_greedy_logprobs_limit(
  30. example_prompts, max_tokens, num_logprobs)
  31. with aphrodite_runner(model, dtype=dtype) as aphrodite_model:
  32. aphrodite_outputs = aphrodite_model.generate_greedy_logprobs(
  33. example_prompts, max_tokens, num_logprobs)
  34. check_logprobs_close(
  35. outputs_0_lst=hf_outputs,
  36. outputs_1_lst=aphrodite_outputs,
  37. name_0="hf",
  38. name_1="aphrodite",
  39. )