test_preemption.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. """Compare the short outputs of HF and Aphrodite when using greedy sampling.
  2. APHRODITE_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 has to be set before running this
  3. test.
  4. Run `APHRODITE_TEST_ENABLE_ARTIFICIAL_PREEMPT=1
  5. pytest tests/basic_correctness/test_preemption.py`.
  6. """
  7. import pytest
  8. from prometheus_client import REGISTRY
  9. import aphrodite.common.envs as envs
  10. from aphrodite import SamplingParams
  11. from aphrodite.processing.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT,
  12. ENABLE_ARTIFICIAL_PREEMPT)
  13. from ..models.utils import check_outputs_equal
  14. MODELS = [
  15. "facebook/opt-125m",
  16. ]
  17. @pytest.fixture(scope="module", autouse=True)
  18. def check_settings():
  19. assert ENABLE_ARTIFICIAL_PREEMPT is True, (
  20. "Use an env var APHRODITE_TEST_ENABLE_ARTIFICIAL_PREEMPT=1. "
  21. "`APHRODITE_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest "
  22. "tests/basic_correctness/test_preemption.py`")
  23. @pytest.fixture
  24. def worker_use_ray() -> bool:
  25. # When SPMD worker is used, use ray_use_worker=True
  26. # to test delta input optimization works with preemption.
  27. return envs.APHRODITE_USE_RAY_SPMD_WORKER
  28. @pytest.mark.parametrize("model", MODELS)
  29. @pytest.mark.parametrize("dtype", ["half"])
  30. @pytest.mark.parametrize("max_tokens", [96])
  31. @pytest.mark.parametrize("chunked_prefill_token_size", [16])
  32. def test_chunked_prefill_recompute(
  33. hf_runner,
  34. aphrodite_runner,
  35. example_prompts,
  36. model: str,
  37. dtype: str,
  38. max_tokens: int,
  39. chunked_prefill_token_size: int,
  40. worker_use_ray: bool,
  41. ) -> None:
  42. """Ensure that chunked prefill works with preemption."""
  43. max_num_seqs = min(chunked_prefill_token_size, 256)
  44. enable_chunked_prefill = False
  45. max_num_batched_tokens = None
  46. if chunked_prefill_token_size != -1:
  47. enable_chunked_prefill = True
  48. max_num_batched_tokens = chunked_prefill_token_size
  49. with hf_runner(model, dtype=dtype) as hf_model:
  50. hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
  51. with aphrodite_runner(
  52. model,
  53. dtype=dtype,
  54. max_num_batched_tokens=max_num_batched_tokens,
  55. enable_chunked_prefill=enable_chunked_prefill,
  56. max_num_seqs=max_num_seqs,
  57. worker_use_ray=worker_use_ray,
  58. disable_log_stats=False,
  59. ) as aphrodite_model:
  60. aphrodite_outputs = aphrodite_model.generate_greedy(
  61. example_prompts, max_tokens)
  62. assert (
  63. aphrodite_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
  64. < ARTIFICIAL_PREEMPTION_MAX_CNT)
  65. for i in range(len(example_prompts)):
  66. hf_output_ids, hf_output_str = hf_outputs[i]
  67. aphrodite_output_ids, aphrodite_output_str = aphrodite_outputs[i]
  68. assert hf_output_str == aphrodite_output_str, (
  69. f"Test{i}:\nHF: {hf_output_str!r}\nAphrodite: "
  70. f"{aphrodite_output_str!r}")
  71. assert hf_output_ids == aphrodite_output_ids, (
  72. f"Test{i}:\nHF: {hf_output_ids}\nAphrodite: {aphrodite_output_ids}")
  73. @pytest.mark.parametrize("model", MODELS)
  74. @pytest.mark.parametrize("dtype", ["float"])
  75. @pytest.mark.parametrize("max_tokens", [96])
  76. def test_preemption(
  77. caplog_aphrodite,
  78. hf_runner,
  79. aphrodite_runner,
  80. example_prompts,
  81. model: str,
  82. dtype: str,
  83. max_tokens: int,
  84. worker_use_ray: bool,
  85. ) -> None:
  86. """By default, recompute preemption is enabled"""
  87. with hf_runner(model, dtype=dtype) as hf_model:
  88. hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
  89. with aphrodite_runner(
  90. model,
  91. dtype=dtype,
  92. disable_log_stats=False,
  93. worker_use_ray=worker_use_ray,
  94. ) as aphrodite_model:
  95. aphrodite_outputs = aphrodite_model.generate_greedy(
  96. example_prompts, max_tokens)
  97. assert (
  98. aphrodite_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
  99. < ARTIFICIAL_PREEMPTION_MAX_CNT)
  100. total_preemption = (
  101. aphrodite_model.model.llm_engine.scheduler[0].num_cumulative_preemption)
  102. check_outputs_equal(
  103. outputs_0_lst=hf_outputs,
  104. outputs_1_lst=aphrodite_outputs,
  105. name_0="hf",
  106. name_1="aphrodite",
  107. )
  108. assert ("is preempted by PreemptionMode.RECOMPUTE mode because there "
  109. "is not enough KV cache space." in caplog_aphrodite.text)
  110. # Ensure the count bucket of request-level histogram metrics matches
  111. # the number of requests as a simple sanity check to ensure metrics are
  112. # generated
  113. preemption_metrics = None
  114. for m in REGISTRY.collect():
  115. if m.name == "aphrodite:num_preemptions":
  116. preemption_metrics = m
  117. assert preemption_metrics is not None
  118. total_recorded_preemption = 0
  119. for sample in preemption_metrics.samples:
  120. total_recorded_preemption += sample.value
  121. assert total_preemption == total_recorded_preemption
  122. @pytest.mark.parametrize("model", MODELS)
  123. @pytest.mark.parametrize("dtype", ["float"])
  124. @pytest.mark.parametrize("max_tokens", [96])
  125. @pytest.mark.parametrize("beam_width", [4])
  126. def test_swap(
  127. caplog_aphrodite,
  128. hf_runner,
  129. aphrodite_runner,
  130. example_prompts,
  131. model: str,
  132. dtype: str,
  133. max_tokens: int,
  134. beam_width: int,
  135. worker_use_ray: bool,
  136. ) -> None:
  137. """Use beam search enables swapping."""
  138. example_prompts = example_prompts[:1]
  139. with hf_runner(model, dtype=dtype) as hf_model:
  140. hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width,
  141. max_tokens)
  142. with aphrodite_runner(
  143. model,
  144. dtype=dtype,
  145. swap_space=10,
  146. disable_log_stats=False,
  147. worker_use_ray=worker_use_ray,
  148. ) as aphrodite_model:
  149. aphrodite_outputs = aphrodite_model.generate_beam_search(
  150. example_prompts, beam_width, max_tokens)
  151. assert (
  152. aphrodite_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
  153. < ARTIFICIAL_PREEMPTION_MAX_CNT)
  154. total_preemption = (
  155. aphrodite_model.model.llm_engine.scheduler[0].num_cumulative_preemption)
  156. for i in range(len(example_prompts)):
  157. hf_output_ids, _ = hf_outputs[i]
  158. aphrodite_output_ids, _ = aphrodite_outputs[i]
  159. assert len(hf_output_ids) == len(aphrodite_output_ids)
  160. for j in range(len(hf_output_ids)):
  161. assert hf_output_ids[j] == aphrodite_output_ids[j], (
  162. f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
  163. f"Aphrodite: {aphrodite_output_ids}")
  164. assert ("is preempted by PreemptionMode.SWAP mode because there "
  165. "is not enough KV cache space." in caplog_aphrodite.text)
  166. # Ensure the count bucket of request-level histogram metrics matches
  167. # the number of requests as a simple sanity check to ensure metrics are
  168. # generated
  169. preemption_metrics = None
  170. for m in REGISTRY.collect():
  171. if m.name == "aphrodite:num_preemptions":
  172. preemption_metrics = m
  173. assert preemption_metrics is not None
  174. total_recorded_preemption = 0
  175. for sample in preemption_metrics.samples:
  176. total_recorded_preemption += sample.value
  177. assert total_preemption == total_recorded_preemption
  178. @pytest.mark.parametrize("model", MODELS)
  179. @pytest.mark.parametrize("dtype", ["float"])
  180. @pytest.mark.parametrize("max_tokens", [96])
  181. @pytest.mark.parametrize("beam_width", [4])
  182. def test_swap_infeasible(
  183. aphrodite_runner,
  184. example_prompts,
  185. model: str,
  186. dtype: str,
  187. max_tokens: int,
  188. beam_width: int,
  189. worker_use_ray: bool,
  190. ) -> None:
  191. """Verify infeasible swap request will be ignored."""
  192. BLOCK_SIZE = 16
  193. prefill_blocks = 2
  194. decode_blocks = max_tokens // BLOCK_SIZE
  195. example_prompts = example_prompts[:1]
  196. with aphrodite_runner(
  197. model,
  198. dtype=dtype,
  199. swap_space=10,
  200. block_size=BLOCK_SIZE,
  201. # Since beam search have more than 1 sequence, prefill +
  202. # decode blocks are not enough to finish.
  203. num_gpu_blocks_override=prefill_blocks + decode_blocks,
  204. max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE,
  205. worker_use_ray=worker_use_ray,
  206. ) as aphrodite_model:
  207. sampling_params = SamplingParams(n=beam_width,
  208. use_beam_search=True,
  209. temperature=0.0,
  210. max_tokens=max_tokens,
  211. ignore_eos=True)
  212. req_outputs = aphrodite_model.model.generate(
  213. example_prompts,
  214. sampling_params=sampling_params,
  215. )
  216. assert (
  217. aphrodite_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
  218. < ARTIFICIAL_PREEMPTION_MAX_CNT)
  219. # Verify the request is ignored and not hang.
  220. assert req_outputs[0].outputs[0].finish_reason == "length"
  221. @pytest.mark.parametrize("model", MODELS)
  222. @pytest.mark.parametrize("dtype", ["float"])
  223. @pytest.mark.parametrize("max_tokens", [96])
  224. def test_preemption_infeasible(
  225. aphrodite_runner,
  226. example_prompts,
  227. model: str,
  228. dtype: str,
  229. max_tokens: int,
  230. worker_use_ray: bool,
  231. ) -> None:
  232. """Verify infeasible preemption request will be ignored."""
  233. BLOCK_SIZE = 16
  234. prefill_blocks = 2
  235. decode_blocks = max_tokens // BLOCK_SIZE
  236. with aphrodite_runner(
  237. model,
  238. dtype=dtype,
  239. block_size=BLOCK_SIZE,
  240. # Not enough gpu blocks to complete a single sequence.
  241. # preemption should happen, and the sequence should be
  242. # ignored instead of hanging forever.
  243. num_gpu_blocks_override=prefill_blocks + decode_blocks // 2,
  244. max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE),
  245. worker_use_ray=worker_use_ray,
  246. ) as aphrodite_model:
  247. sampling_params = SamplingParams(max_tokens=max_tokens,
  248. ignore_eos=True)
  249. req_outputs = aphrodite_model.model.generate(
  250. example_prompts,
  251. sampling_params=sampling_params,
  252. )
  253. assert (
  254. aphrodite_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
  255. < ARTIFICIAL_PREEMPTION_MAX_CNT)
  256. # Verify the request is ignored and not hang.
  257. for req_output in req_outputs:
  258. outputs = req_output.outputs
  259. assert len(outputs) == 1
  260. assert outputs[0].finish_reason == "length"