1
0

test_compatibility.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import pytest
  2. from aphrodite import SamplingParams
  3. from .conftest import get_output_from_llm_generator
  4. @pytest.mark.parametrize(
  5. "common_llm_kwargs",
  6. [{
  7. "model": "JackFram/llama-68m",
  8. "speculative_model": "JackFram/llama-68m",
  9. "num_speculative_tokens": 5,
  10. # Required for spec decode.
  11. "use_v2_block_manager": True
  12. }])
  13. @pytest.mark.parametrize("per_test_common_llm_kwargs", [
  14. {
  15. "enable_chunked_prefill": True,
  16. },
  17. ])
  18. @pytest.mark.parametrize("test_llm_kwargs", [{}])
  19. @pytest.mark.parametrize("seed", [1])
  20. def test_spec_decode_xfail_chunked_prefill(test_llm_generator):
  21. """Verify that speculative decoding with chunked prefill fails.
  22. """
  23. output_len = 128
  24. temperature = 0.0
  25. prompts = [
  26. "Hello, my name is",
  27. ]
  28. sampling_params = SamplingParams(
  29. max_tokens=output_len,
  30. ignore_eos=True,
  31. temperature=temperature,
  32. )
  33. with pytest.raises(ValueError,
  34. match="Speculative decoding and chunked prefill"):
  35. get_output_from_llm_generator(test_llm_generator, prompts,
  36. sampling_params)
  37. @pytest.mark.parametrize(
  38. "common_llm_kwargs",
  39. [{
  40. "model": "meta-llama/Llama-2-7b-chat-hf",
  41. "speculative_model": "JackFram/llama-68m",
  42. "num_speculative_tokens": 5,
  43. # Required for spec decode.
  44. "use_v2_block_manager": True
  45. }])
  46. @pytest.mark.parametrize(
  47. "per_test_common_llm_kwargs",
  48. [
  49. {
  50. # Speculative max model len > overridden max model len should raise.
  51. "max_model_len": 128,
  52. "speculative_max_model_len": 129,
  53. },
  54. {
  55. # Speculative max model len > draft max model len should raise.
  56. # https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12
  57. "speculative_max_model_len": 2048 + 1,
  58. },
  59. {
  60. # Speculative max model len > target max model len should raise.
  61. # https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/f5db02db724555f92da89c216ac04704f23d4590/config.json#L12
  62. "speculative_max_model_len": 4096 + 1,
  63. },
  64. ])
  65. @pytest.mark.parametrize("test_llm_kwargs", [{}])
  66. @pytest.mark.parametrize("seed", [1])
  67. def test_spec_decode_xfail_spec_max_model_len(test_llm_generator):
  68. """Verify that speculative decoding validates speculative_max_model_len.
  69. """
  70. output_len = 128
  71. temperature = 0.0
  72. prompts = [
  73. "Hello, my name is",
  74. ]
  75. sampling_params = SamplingParams(
  76. max_tokens=output_len,
  77. ignore_eos=True,
  78. temperature=temperature,
  79. )
  80. with pytest.raises(ValueError, match="cannot be larger than"):
  81. get_output_from_llm_generator(test_llm_generator, prompts,
  82. sampling_params)
  83. @pytest.mark.parametrize("common_llm_kwargs", [{
  84. "model": "JackFram/llama-68m",
  85. "speculative_model": "JackFram/llama-68m",
  86. "num_speculative_tokens": 5,
  87. }])
  88. @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
  89. @pytest.mark.parametrize("test_llm_kwargs", [{}])
  90. @pytest.mark.parametrize("seed", [1])
  91. def test_spec_decode_xfail_block_manager_v1(test_llm_generator):
  92. """Verify that speculative decoding with block manager v1 fails.
  93. """
  94. output_len = 128
  95. temperature = 0.0
  96. prompts = [
  97. "Hello, my name is",
  98. ]
  99. sampling_params = SamplingParams(
  100. max_tokens=output_len,
  101. ignore_eos=True,
  102. temperature=temperature,
  103. )
  104. with pytest.raises(ValueError,
  105. match="Speculative decoding requires usage of the V2"):
  106. get_output_from_llm_generator(test_llm_generator, prompts,
  107. sampling_params)