test_e2e_correctness.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. """E2E tests to verify the correctness of the encoder-decoder framework
  2. Run `pytest tests/encoder_decoder/test_e2e_correctness.py`.
  3. """
  4. from typing import List, Optional, Tuple
  5. import pytest
  6. from transformers import AutoModelForSeq2SeqLM
  7. from aphrodite.common.sequence import SampleLogprobs
  8. from aphrodite.common.utils import is_cpu
  9. from ..conftest import DecoderPromptType
  10. from ..models.utils import check_logprobs_close
  11. def aphrodite_to_hf_output(
  12. aphrodite_output: Tuple[List[int], str, Optional[SampleLogprobs]],
  13. decoder_prompt_type: DecoderPromptType,
  14. ):
  15. """Sanitize aphrodite output to be comparable with hf output."""
  16. output_ids, output_str, out_logprobs = aphrodite_output
  17. hf_output_str = output_str + "</s>"
  18. if decoder_prompt_type == DecoderPromptType.NONE:
  19. hf_output_str = "<s>" + hf_output_str
  20. return output_ids, hf_output_str, out_logprobs
  21. @pytest.mark.parametrize("model", ["facebook/bart-large-cnn"])
  22. @pytest.mark.parametrize("dtype", ["bfloat16"])
  23. @pytest.mark.parametrize("max_tokens", [128])
  24. @pytest.mark.parametrize("num_logprobs", [5])
  25. @pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
  26. @pytest.mark.parametrize("enforce_eager", [True, False])
  27. @pytest.mark.skipif(
  28. is_cpu(),
  29. reason="CPU backend is not currently supported with encoder/decoder models",
  30. )
  31. def test_encoder_decoder_e2e(
  32. hf_runner,
  33. aphrodite_runner,
  34. example_encoder_decoder_prompts,
  35. model: str,
  36. dtype: str,
  37. max_tokens: int,
  38. num_logprobs: int,
  39. decoder_prompt_type: DecoderPromptType,
  40. enforce_eager: bool,
  41. ) -> None:
  42. """
  43. End-to-End (E2E) test for the encoder-decoder framework.
  44. This test evaluates the encoder-decoder functionality using the BART
  45. model. We compare the outputs of the Hugging Face and Aphrodite
  46. implementations to ensure that both implementations produce consistent
  47. and correct results.
  48. """
  49. test_case_prompts = example_encoder_decoder_prompts[decoder_prompt_type]
  50. # Configuration settings for HF baseline
  51. hf_kwargs = {
  52. "top_k": None,
  53. "num_beams": 1,
  54. "repetition_penalty": 1.0,
  55. "top_p": 1.0,
  56. "length_penalty": 1.0,
  57. "early_stopping": False,
  58. "no_repeat_ngram_size": None,
  59. "min_length": 0,
  60. }
  61. with hf_runner(
  62. model, dtype=dtype, auto_cls=AutoModelForSeq2SeqLM
  63. ) as hf_model:
  64. hf_outputs = hf_model.generate_encoder_decoder_greedy_logprobs_limit(
  65. test_case_prompts,
  66. max_tokens,
  67. num_logprobs,
  68. **hf_kwargs,
  69. )
  70. with aphrodite_runner(
  71. model, dtype=dtype, enforce_eager=enforce_eager
  72. ) as aphrodite_model:
  73. aphrodite_outputs = (
  74. aphrodite_model.generate_encoder_decoder_greedy_logprobs(
  75. test_case_prompts, max_tokens, num_logprobs
  76. )
  77. )
  78. hf_skip_tokens = 1 if decoder_prompt_type == DecoderPromptType.NONE else 0
  79. check_logprobs_close(
  80. outputs_0_lst=hf_outputs,
  81. outputs_1_lst=[
  82. aphrodite_to_hf_output(aphrodite_output, decoder_prompt_type)
  83. for aphrodite_output in aphrodite_outputs
  84. ],
  85. name_0="hf",
  86. name_1="aphrodite",
  87. num_outputs_0_skip_tokens=hf_skip_tokens,
  88. )