1
0

test_beam_search.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. """Compare the outputs of HF and Aphrodite Engine when using beam search.
  2. Run `pytest tests/samplers/test_beam_search.py --forked`.
  3. """
  4. import pytest
  5. MAX_TOKENS = [128]
  6. BEAM_WIDTHS = [4]
  7. MODELS = ["EleutherAI/pythia-70m-deduped"]
  8. @pytest.mark.parametrize("model", MODELS)
  9. @pytest.mark.parametrize("dtype", ["half"])
  10. @pytest.mark.parametrize("max_tokens", MAX_TOKENS)
  11. @pytest.mark.parametrize("beam_width", BEAM_WIDTHS)
  12. def test_beam_search_single_input(
  13. hf_runner,
  14. aphrodite_runner,
  15. example_prompts,
  16. model: str,
  17. dtype: str,
  18. max_tokens: int,
  19. beam_width: int,
  20. ) -> None:
  21. hf_model = hf_runner(model, dtype=dtype)
  22. hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width,
  23. max_tokens)
  24. del hf_model
  25. aphrodite_model = aphrodite_runner(model, dtype=dtype)
  26. aphrodite_outputs = aphrodite_model.generate_beam_search(
  27. example_prompts, beam_width, max_tokens)
  28. del aphrodite_model
  29. for i in range(len(example_prompts)):
  30. hf_output_ids, _ = hf_outputs[i]
  31. aphrodite_output_ids, _ = aphrodite_outputs[i]
  32. assert len(hf_output_ids) == len(aphrodite_output_ids)
  33. for j in range(len(hf_output_ids)):
  34. assert hf_output_ids[j] == aphrodite_output_ids[j], (
  35. f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
  36. f"Aphrodite Engine: {aphrodite_output_ids}")