test_preemption.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. """Compare the short outputs of HF and Aphrodite when using greedy sampling.
  2. APHRODITE_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 has to be set before running this
  3. test.
  4. Run `APHRODITE_TEST_ENABLE_ARTIFICIAL_PREEMPT=1
  5. pytest tests/basic_correctness/test_preemption.py`.
  6. """
  7. import pytest
  8. from prometheus_client import REGISTRY
  9. from aphrodite import SamplingParams
  10. from aphrodite.executor.ray_gpu_executor import APHRODITE_USE_RAY_SPMD_WORKER
  11. from aphrodite.processing.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT,
  12. ENABLE_ARTIFICIAL_PREEMPT)
  13. from ..models.utils import check_outputs_equal
  14. MODELS = [
  15. "facebook/opt-125m",
  16. ]
  17. assert ENABLE_ARTIFICIAL_PREEMPT is True, (
  18. "Use an env var APHRODITE_TEST_ENABLE_ARTIFICIAL_PREEMPT=1. "
  19. "`APHRODITE_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest "
  20. "tests/basic_correctness/test_preemption.py`")
  21. @pytest.fixture
  22. def worker_use_ray() -> bool:
  23. # When SPMD worker is used, use ray_use_worker=True
  24. # to test delta input optimization works with preemption.
  25. return APHRODITE_USE_RAY_SPMD_WORKER
  26. @pytest.mark.parametrize("model", MODELS)
  27. @pytest.mark.parametrize("dtype", ["half"])
  28. @pytest.mark.parametrize("max_tokens", [96])
  29. @pytest.mark.parametrize("chunked_prefill_token_size", [16])
  30. def test_chunked_prefill_recompute(
  31. hf_runner,
  32. aphrodite_runner,
  33. example_prompts,
  34. model: str,
  35. dtype: str,
  36. max_tokens: int,
  37. chunked_prefill_token_size: int,
  38. worker_use_ray: bool,
  39. ) -> None:
  40. """Ensure that chunked prefill works with preemption."""
  41. max_num_seqs = min(chunked_prefill_token_size, 256)
  42. enable_chunked_prefill = False
  43. max_num_batched_tokens = None
  44. if chunked_prefill_token_size != -1:
  45. enable_chunked_prefill = True
  46. max_num_batched_tokens = chunked_prefill_token_size
  47. with hf_runner(model, dtype=dtype) as hf_model:
  48. hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
  49. with aphrodite_runner(
  50. model,
  51. dtype=dtype,
  52. max_num_batched_tokens=max_num_batched_tokens,
  53. enable_chunked_prefill=enable_chunked_prefill,
  54. max_num_seqs=max_num_seqs,
  55. worker_use_ray=worker_use_ray,
  56. ) as aphrodite_model:
  57. aphrodite_outputs = aphrodite_model.generate_greedy(example_prompts,
  58. max_tokens)
  59. assert (
  60. aphrodite_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
  61. < ARTIFICIAL_PREEMPTION_MAX_CNT)
  62. for i in range(len(example_prompts)):
  63. hf_output_ids, hf_output_str = hf_outputs[i]
  64. aphrodite_output_ids, aphrodite_output_str = aphrodite_outputs[i]
  65. assert hf_output_str == aphrodite_output_str, (
  66. f"Test{i}:\nHF: {hf_output_str!r}\nAphrodite: "
  67. f"{aphrodite_output_str!r}")
  68. assert hf_output_ids == aphrodite_output_ids, (
  69. f"Test{i}:\nHF: {hf_output_ids}\nAphrodite: {aphrodite_output_ids}")
  70. @pytest.mark.parametrize("model", MODELS)
  71. @pytest.mark.parametrize("dtype", ["float"])
  72. @pytest.mark.parametrize("max_tokens", [96])
  73. def test_preemption(
  74. caplog_aphrodite,
  75. hf_runner,
  76. aphrodite_runner,
  77. example_prompts,
  78. model: str,
  79. dtype: str,
  80. max_tokens: int,
  81. worker_use_ray: bool,
  82. ) -> None:
  83. """By default, recompute preemption is enabled"""
  84. with hf_runner(model, dtype=dtype) as hf_model:
  85. hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
  86. with aphrodite_runner(
  87. model,
  88. dtype=dtype,
  89. disable_log_stats=False,
  90. worker_use_ray=worker_use_ray,
  91. ) as aphrodite_model:
  92. aphrodite_outputs = aphrodite_model.generate_greedy(example_prompts,
  93. max_tokens)
  94. assert (
  95. aphrodite_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
  96. < ARTIFICIAL_PREEMPTION_MAX_CNT)
  97. total_preemption = (
  98. aphrodite_model.model.llm_engine.scheduler[0].num_cumulative_preemption)
  99. check_outputs_equal(
  100. outputs_0_lst=hf_outputs,
  101. outputs_1_lst=aphrodite_outputs,
  102. name_0="hf",
  103. name_1="aphrodite",
  104. )
  105. assert ("is preempted by PreemptionMode.RECOMPUTE mode because there "
  106. "is not enough KV cache space." in caplog_aphrodite.text)
  107. # Ensure the count bucket of request-level histogram metrics matches
  108. # the number of requests as a simple sanity check to ensure metrics are
  109. # generated
  110. preemption_metrics = None
  111. for m in REGISTRY.collect():
  112. if m.name == "aphrodite:num_preemptions":
  113. preemption_metrics = m
  114. assert preemption_metrics is not None
  115. total_recorded_preemption = 0
  116. for sample in preemption_metrics.samples:
  117. total_recorded_preemption += sample.value
  118. assert total_preemption == total_recorded_preemption
  119. @pytest.mark.parametrize("model", MODELS)
  120. @pytest.mark.parametrize("dtype", ["float"])
  121. @pytest.mark.parametrize("max_tokens", [96])
  122. @pytest.mark.parametrize("beam_width", [4])
  123. def test_swap(
  124. caplog_aphrodite,
  125. hf_runner,
  126. aphrodite_runner,
  127. example_prompts,
  128. model: str,
  129. dtype: str,
  130. max_tokens: int,
  131. beam_width: int,
  132. worker_use_ray: bool,
  133. ) -> None:
  134. """Use beam search enables swapping."""
  135. example_prompts = example_prompts[:1]
  136. with hf_runner(model, dtype=dtype) as hf_model:
  137. hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width,
  138. max_tokens)
  139. with aphrodite_runner(
  140. model,
  141. dtype=dtype,
  142. swap_space=10,
  143. disable_log_stats=False,
  144. worker_use_ray=worker_use_ray,
  145. ) as aphrodite_model:
  146. aphrodite_outputs = aphrodite_model.generate_beam_search(
  147. example_prompts, beam_width, max_tokens)
  148. assert (
  149. aphrodite_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
  150. < ARTIFICIAL_PREEMPTION_MAX_CNT)
  151. total_preemption = (
  152. aphrodite_model.model.llm_engine.scheduler[0].num_cumulative_preemption)
  153. for i in range(len(example_prompts)):
  154. hf_output_ids, _ = hf_outputs[i]
  155. aphrodite_output_ids, _ = aphrodite_outputs[i]
  156. assert len(hf_output_ids) == len(aphrodite_output_ids)
  157. for j in range(len(hf_output_ids)):
  158. assert hf_output_ids[j] == aphrodite_output_ids[j], (
  159. f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
  160. f"Aphrodite: {aphrodite_output_ids}")
  161. assert ("is preempted by PreemptionMode.SWAP mode because there "
  162. "is not enough KV cache space." in caplog_aphrodite.text)
  163. # Ensure the count bucket of request-level histogram metrics matches
  164. # the number of requests as a simple sanity check to ensure metrics are
  165. # generated
  166. preemption_metrics = None
  167. for m in REGISTRY.collect():
  168. if m.name == "aphrodite:num_preemptions":
  169. preemption_metrics = m
  170. assert preemption_metrics is not None
  171. total_recorded_preemption = 0
  172. for sample in preemption_metrics.samples:
  173. total_recorded_preemption += sample.value
  174. assert total_preemption == total_recorded_preemption
  175. @pytest.mark.parametrize("model", MODELS)
  176. @pytest.mark.parametrize("dtype", ["float"])
  177. @pytest.mark.parametrize("max_tokens", [96])
  178. @pytest.mark.parametrize("beam_width", [4])
  179. def test_swap_infeasible(
  180. aphrodite_runner,
  181. example_prompts,
  182. model: str,
  183. dtype: str,
  184. max_tokens: int,
  185. beam_width: int,
  186. worker_use_ray: bool,
  187. ) -> None:
  188. """Verify infeasible swap request will be ignored."""
  189. BLOCK_SIZE = 16
  190. prefill_blocks = 2
  191. decode_blocks = max_tokens // BLOCK_SIZE
  192. example_prompts = example_prompts[:1]
  193. with aphrodite_runner(
  194. model,
  195. dtype=dtype,
  196. swap_space=10,
  197. block_size=BLOCK_SIZE,
  198. # Since beam search have more than 1 sequence, prefill +
  199. # decode blocks are not enough to finish.
  200. num_gpu_blocks_override=prefill_blocks + decode_blocks,
  201. max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE,
  202. worker_use_ray=worker_use_ray,
  203. ) as aphrodite_model:
  204. sampling_params = SamplingParams(n=beam_width,
  205. use_beam_search=True,
  206. temperature=0.0,
  207. max_tokens=max_tokens,
  208. ignore_eos=True)
  209. req_outputs = aphrodite_model.model.generate(
  210. example_prompts,
  211. sampling_params=sampling_params,
  212. )
  213. assert (
  214. aphrodite_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
  215. < ARTIFICIAL_PREEMPTION_MAX_CNT)
  216. # Verify the request is ignored and not hang.
  217. assert req_outputs[0].outputs[0].finish_reason == "length"
  218. @pytest.mark.parametrize("model", MODELS)
  219. @pytest.mark.parametrize("dtype", ["float"])
  220. @pytest.mark.parametrize("max_tokens", [96])
  221. def test_preemption_infeasible(
  222. aphrodite_runner,
  223. example_prompts,
  224. model: str,
  225. dtype: str,
  226. max_tokens: int,
  227. worker_use_ray: bool,
  228. ) -> None:
  229. """Verify infeasible preemption request will be ignored."""
  230. BLOCK_SIZE = 16
  231. prefill_blocks = 2
  232. decode_blocks = max_tokens // BLOCK_SIZE
  233. with aphrodite_runner(
  234. model,
  235. dtype=dtype,
  236. block_size=BLOCK_SIZE,
  237. # Not enough gpu blocks to complete a single sequence.
  238. # preemption should happen, and the sequence should be
  239. # ignored instead of hanging forever.
  240. num_gpu_blocks_override=prefill_blocks + decode_blocks // 2,
  241. max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE),
  242. worker_use_ray=worker_use_ray,
  243. ) as aphrodite_model:
  244. sampling_params = SamplingParams(max_tokens=max_tokens,
  245. ignore_eos=True)
  246. req_outputs = aphrodite_model.model.generate(
  247. example_prompts,
  248. sampling_params=sampling_params,
  249. )
  250. assert (
  251. aphrodite_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
  252. < ARTIFICIAL_PREEMPTION_MAX_CNT)
  253. # Verify the request is ignored and not hang.
  254. for req_output in req_outputs:
  255. outputs = req_output.outputs
  256. assert len(outputs) == 1
  257. assert outputs[0].finish_reason == "length"