1
0

test_granite.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. """Compare the outputs of HF and Aphrodite for Granite models using greedy sampling.
  2. Run `pytest tests/models/test_granite.py`.
  3. """
  4. import pytest
  5. import transformers
  6. from ...utils import check_logprobs_close
  7. MODELS = [
  8. "ibm/PowerLM-3b",
  9. ]
  10. # GraniteForCausalLM will be in transformers >= 4.45
  11. @pytest.mark.skipif(transformers.__version__ < "4.45",
  12. reason="granite model test requires transformers >= 4.45")
  13. @pytest.mark.parametrize("model", MODELS)
  14. @pytest.mark.parametrize("dtype", ["bfloat16"])
  15. @pytest.mark.parametrize("max_tokens", [64])
  16. @pytest.mark.parametrize("num_logprobs", [5])
  17. def test_models(
  18. hf_runner,
  19. aphrodite_runner,
  20. example_prompts,
  21. model: str,
  22. dtype: str,
  23. max_tokens: int,
  24. num_logprobs: int,
  25. ) -> None:
  26. # TODO(sang): Sliding window should be tested separately.
  27. with hf_runner(model, dtype=dtype) as hf_model:
  28. hf_outputs = hf_model.generate_greedy_logprobs_limit(
  29. example_prompts, max_tokens, num_logprobs)
  30. with aphrodite_runner(model, dtype=dtype) as aphrodite_model:
  31. aphrodite_outputs = aphrodite_model.generate_greedy_logprobs(
  32. example_prompts, max_tokens, num_logprobs)
  33. check_logprobs_close(
  34. outputs_0_lst=hf_outputs,
  35. outputs_1_lst=aphrodite_outputs,
  36. name_0="hf",
  37. name_1="aphrodite",
  38. )