test_correctness_async_llm.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # Test the AsyncLLMEngine with multi-step-decoding
  2. from typing import List, Optional
  3. import pytest
  4. import torch
  5. from ..models.utils import check_logprobs_close
  6. from ..utils import (completions_with_server_args, get_client_text_generations,
  7. get_client_text_logprob_generations)
  8. MODELS = [
  9. "JackFram/llama-160m",
  10. ]
  11. NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps
  12. NUM_PROMPTS = [10]
  13. DEFAULT_SERVER_ARGS: List[str] = [
  14. "--disable-log-requests",
  15. "--use-v2-block-manager",
  16. "--worker-use-ray",
  17. "--gpu-memory-utilization",
  18. "0.85",
  19. "--swap-space",
  20. "16",
  21. ]
  22. @pytest.mark.parametrize("model", MODELS)
  23. @pytest.mark.parametrize(
  24. ("tp_size, pp_size"),
  25. [
  26. (1, 1),
  27. (2, 2),
  28. ],
  29. )
  30. @pytest.mark.parametrize("eager_mode", [False, True])
  31. @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
  32. @pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
  33. @pytest.mark.parametrize("num_logprobs", [None, 5])
  34. @pytest.mark.parametrize("is_async", [False, True])
  35. @pytest.mark.asyncio
  36. async def test_multi_step(
  37. example_prompts,
  38. model: str,
  39. tp_size: int,
  40. pp_size: int,
  41. eager_mode: int,
  42. num_scheduler_steps: int,
  43. num_prompts: int,
  44. is_async: bool,
  45. num_logprobs: Optional[int],
  46. ) -> None:
  47. """Test Aphrodite engine with multi-step scheduling in an OpenAI-protocol
  48. client/server environment.
  49. Set up an engine with single-step scheduling as a ground-truth reference.
  50. Send a completions API request to both engines with the same prompts.
  51. Validate:
  52. * Generated tokens match
  53. * Generated logprobs are all very close
  54. Args:
  55. example_prompts: test fixture providing example prompts
  56. model: model under test (same for single- and multi-step engines)
  57. tp_size: degree of tensor-parallelism
  58. pp_size: degree of pipeline-parallelism
  59. eager_mode
  60. num_scheduler_steps: for multi-step scheduling, GPU-side steps per
  61. GPU -> CPU output transfer
  62. num_prompts: number of example prompts under test
  63. num_logprobs: corresponds to the `logprobs` argument to the OpenAI
  64. completions endpoint; `None` -> no logprobs
  65. """
  66. if (tp_size > 1 or pp_size > 1) and torch.cuda.device_count() == 1:
  67. pytest.skip("Skipping multi-GPU tests on single GPU system")
  68. prompts = example_prompts
  69. if len(prompts) < num_prompts:
  70. prompts = prompts * ((num_prompts // len(prompts)) + 1)
  71. prompts = prompts[:num_prompts]
  72. assert len(prompts) == num_prompts
  73. server_args = DEFAULT_SERVER_ARGS + ["--enforce-eager"]
  74. ms_server_args = DEFAULT_SERVER_ARGS + [
  75. "--num-scheduler-steps",
  76. f"{num_scheduler_steps}",
  77. ]
  78. if not is_async:
  79. ms_server_args += ["--disable-async-output-proc"]
  80. if eager_mode:
  81. ms_server_args.append("--enforce-eager")
  82. distributed_args = [
  83. "--tensor-parallel-size",
  84. str(tp_size),
  85. "--pipeline-parallel-size",
  86. str(pp_size),
  87. ]
  88. # Spin up client/server & issue completion API requests.
  89. # Default `max_wait_seconds` is 240 but was empirically
  90. # was raised 3x to 720 *just for this test* due to
  91. # observed timeouts in GHA CI
  92. ref_completions = await completions_with_server_args(
  93. prompts,
  94. model,
  95. server_args + distributed_args,
  96. num_logprobs,
  97. max_wait_seconds=3 * 240)
  98. test_completions = await completions_with_server_args(
  99. prompts,
  100. model,
  101. ms_server_args + distributed_args,
  102. num_logprobs,
  103. max_wait_seconds=3 * 240)
  104. # Assert multi-step scheduling produces identical tokens
  105. # to single-step scheduling.
  106. ref_generations = get_client_text_generations(ref_completions)
  107. test_generations = get_client_text_generations(test_completions)
  108. assert ref_generations == test_generations
  109. # Assert multi-step scheduling produces nearly-identical logprobs
  110. # to single-step scheduling.
  111. ref_text_logprobs = get_client_text_logprob_generations(ref_completions)
  112. test_text_logprobs = get_client_text_logprob_generations(test_completions)
  113. check_logprobs_close(
  114. outputs_0_lst=ref_text_logprobs,
  115. outputs_1_lst=test_text_logprobs,
  116. name_0="hf",
  117. name_1="aphrodite",
  118. )