test_bart.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. """Compare the outputs of HF and Aphrodite for BART models using greedy
  2. sampling.
  3. Run `pytest tests/models/test_bart.py`.
  4. """
  5. from typing import List, Optional, Tuple
  6. from aphrodite.common.utils import is_cpu
  7. if not is_cpu():
  8. # CPU backend is not currently supported with encoder/decoder models
  9. # skip test definitions entirely to avoid importing GPU kernel libs
  10. # (xFormers, etc.)
  11. import pytest
  12. from transformers import AutoModelForSeq2SeqLM
  13. from aphrodite.common.sequence import SampleLogprobs
  14. from ..conftest import DecoderPromptType
  15. from .utils import check_logprobs_close
  16. MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"]
  17. def aphrodite_to_hf_output(
  18. aphrodite_output: Tuple[List[int], str, Optional[SampleLogprobs]],
  19. decoder_prompt_type: DecoderPromptType,
  20. ):
  21. """Sanitize aphrodite output to be comparable with hf output."""
  22. output_ids, output_str, out_logprobs = aphrodite_output
  23. hf_output_str = output_str + "</s>"
  24. if decoder_prompt_type == DecoderPromptType.NONE:
  25. hf_output_str = "<s>" + hf_output_str
  26. return output_ids, hf_output_str, out_logprobs
  27. @pytest.mark.parametrize("model", MODELS)
  28. @pytest.mark.parametrize("dtype", ["float", "bfloat16"])
  29. @pytest.mark.parametrize("max_tokens", [64])
  30. @pytest.mark.parametrize("num_logprobs", [5])
  31. @pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
  32. def test_models(
  33. hf_runner,
  34. aphrodite_runner,
  35. example_encoder_decoder_prompts,
  36. model: str,
  37. dtype: str,
  38. max_tokens: int,
  39. num_logprobs: int,
  40. decoder_prompt_type: DecoderPromptType,
  41. ) -> None:
  42. '''
  43. Test the Aphrodite BART model for a variety of encoder/decoder
  44. input prompts, by validating it against HuggingFace (HF) BART.
  45. Arguments:
  46. * hf_runner: HuggingFace (HF) test model runner
  47. * aphrodite_runner: Aphrodite test model runner
  48. * example_encoder_decoder_prompts: test fixture which provides a
  49. dictionary of dummy prompts
  50. * model: the HF ID of the specific BART variant under test
  51. * dtype: the tensor datatype to employ
  52. * max_tokens
  53. * num_logprobs
  54. * decoder_prompt_type: key into the example_encoder_decoder_prompts
  55. dictionary; selects specific encoder/decoder
  56. prompt scenarios to test
  57. A note on using HF BART as a baseline for validating Aphrodite BART,
  58. specifically when the decoder prompt is None.
  59. The HF GenerationMixin's default behavior is to force the first
  60. decoded token to be <BOS> if the prompt does not already contain
  61. <BOS> (this is accomplished using a logit
  62. processor setting.)
  63. So when we use HF BART as our baseline for comparison, note that
  64. when the user provides a request with a None decoder prompt
  65. (i.e. a singleton encoder prompt, or else an explicit encoder/
  66. decoder prompt with the decoder sub-prompt set to None), HF and
  67. Aphrodite handle this in different ways:
  68. * HF will (1) tokenize the None prompt as an empty token-list,
  69. (2) append <decoder-start-token> to the beginning, yielding
  70. [<decoder-start-token>], (3) pass this token list to the model, and
  71. then (4) after computing logits during prefill, override the model
  72. logits & force <BOS> to be the first generated token.
  73. * Aphrodite will (1) tokenize the None prompt as [<BOS>], (2) append
  74. decoder-start-token to the beginning, yielding
  75. [<decoder-start-token><BOS>], (3) pass these tokens to the model &
  76. proceed with generation.
  77. The net effect is that compared to Aphrodite, the list of HF
  78. *decoded* tokens will contain one more initial <BOS> than the
  79. Aphrodite generated tokens, because Aphrodite's <BOS> token is
  80. injected into the prompt rather than into the generated output.
  81. This is in spite of the fact that overall, the complete sequences
  82. (prompt + decoded tokens) produced by Aphrodite will match HF.
  83. So when we use HF decoded token output to validate Aphrodite's decoded
  84. token output, the testing process must account for the difference in
  85. decoded token sequences between Aphrodite and HF specifically in the
  86. decoder-prompt-is-None case.
  87. One option is to disable the logit processor feature that forces the
  88. <BOS> token to be decoded (forced_bos_token_id = None), eliminating
  89. the problem entirely. However this is not "normal" BART usage.
  90. The other option is - only in the decoder-prompt-is-None case - to
  91. discard the first decoded token from the HF output before comparing it
  92. to Aphrodite.
  93. To that end, when testing the scenario where the decoder prompt is None
  94. (and only in that one scenario), this test skips the first HF decoded
  95. token during the process of validating the Aphrodite decoded output.
  96. '''
  97. test_case_prompts = example_encoder_decoder_prompts[
  98. decoder_prompt_type]
  99. # Configuration settings for HF baseline
  100. hf_kwargs = {
  101. "top_k": None,
  102. "num_beams": 1,
  103. "repetition_penalty": 1.0,
  104. "top_p": 1.0,
  105. "length_penalty": 1.0,
  106. "early_stopping": False,
  107. "no_repeat_ngram_size": None,
  108. "min_length": 0
  109. }
  110. with hf_runner(model, dtype=dtype,
  111. auto_cls=AutoModelForSeq2SeqLM) as hf_model:
  112. hf_outputs = (
  113. hf_model.generate_encoder_decoder_greedy_logprobs_limit(
  114. test_case_prompts,
  115. max_tokens,
  116. num_logprobs,
  117. **hf_kwargs,
  118. ))
  119. # Note: currently encoder/decoder models are only compatible with
  120. # enforce_eager=True. Normally this is not a problem because
  121. # for encoder/decoder models Aphrodite will
  122. # default to enforce_eager=True if enforce_eager
  123. # is left unspecified. However, the
  124. # AphroditeRunner test fixture (which wraps around the LLM class)
  125. # defaults to enforce_eager=False (a behavior which a number of
  126. # already-exisitng decoder-only unit tests expect), so when testing
  127. # an encoder/decoder model we must explicitly specify enforce_eager=True
  128. # in the AphroditeRunner constructor.
  129. with aphrodite_runner(model, dtype=dtype,
  130. enforce_eager=True) as aphrodite_model:
  131. aphrodite_outputs = (
  132. aphrodite_model.generate_encoder_decoder_greedy_logprobs(
  133. test_case_prompts, max_tokens, num_logprobs))
  134. hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE
  135. else 0)
  136. check_logprobs_close(
  137. outputs_0_lst=hf_outputs,
  138. outputs_1_lst=[
  139. aphrodite_to_hf_output(aphrodite_output, decoder_prompt_type)
  140. for aphrodite_output in aphrodite_outputs
  141. ],
  142. name_0="hf",
  143. name_1="aphrodite",
  144. num_outputs_0_skip_tokens=hf_skip_tokens,
  145. )