test_correctness.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. # Test the AsyncLLMEngine with multi-step-decoding
  2. from typing import List
  3. import pytest
  4. from ..utils import RemoteOpenAIServer
  5. MODELS = [
  6. "JackFram/llama-160m",
  7. ]
  8. NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps
  9. NUM_PROMPTS = [10]
  10. DEFAULT_SERVER_ARGS: List[str] = [
  11. "--disable-log-requests",
  12. "--use-v2-block-manager",
  13. "--worker-use-ray",
  14. "--gpu-memory-utilization",
  15. "0.85",
  16. "--swap-space",
  17. "16",
  18. ]
  19. async def completions_with_server_args(
  20. prompts: List[str], model_name: str, server_cli_args: List[str]
  21. ):
  22. outputs = None
  23. with RemoteOpenAIServer(model_name, server_cli_args) as server:
  24. client = server.get_async_client()
  25. outputs = await client.completions.create(
  26. model=model_name,
  27. prompt=prompts,
  28. temperature=0,
  29. stream=False,
  30. max_tokens=5,
  31. )
  32. assert outputs is not None
  33. return outputs
  34. @pytest.mark.parametrize("model", MODELS)
  35. @pytest.mark.parametrize(
  36. ("tp_size, pp_size"),
  37. [
  38. (1, 1),
  39. (2, 2),
  40. ],
  41. )
  42. @pytest.mark.parametrize("eager_mode", [False, True])
  43. @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
  44. @pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
  45. @pytest.mark.asyncio
  46. async def test_multi_step(
  47. example_prompts,
  48. model: str,
  49. tp_size: int,
  50. pp_size: int,
  51. eager_mode: int,
  52. num_scheduler_steps: int,
  53. num_prompts: int,
  54. ):
  55. prompts = example_prompts
  56. if len(prompts) < num_prompts:
  57. prompts = prompts * ((num_prompts // len(prompts)) + 1)
  58. prompts = prompts[:num_prompts]
  59. assert len(prompts) == num_prompts
  60. server_args = DEFAULT_SERVER_ARGS + ["--enforce-eager"]
  61. ms_server_args = DEFAULT_SERVER_ARGS + [
  62. "--num-scheduler-steps",
  63. f"{num_scheduler_steps}",
  64. ]
  65. if eager_mode:
  66. ms_server_args.append("--enforce-eager")
  67. distributed_args = [
  68. "--tensor-parallel-size",
  69. str(tp_size),
  70. "--pipeline-parallel-size",
  71. str(pp_size),
  72. ]
  73. ref_completions = await completions_with_server_args(
  74. prompts, model, server_args + distributed_args
  75. )
  76. test_completions = await completions_with_server_args(
  77. prompts, model, ms_server_args + distributed_args
  78. )
  79. def get_text_generations(completions):
  80. return [x.text for x in completions.choices]
  81. ref_generations = get_text_generations(ref_completions)
  82. test_generations = get_text_generations(test_completions)
  83. assert ref_generations == test_generations