1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192 |
- # Test the AsyncLLMEngine with multi-step-decoding
- from typing import List
- import pytest
- from ..utils import RemoteOpenAIServer
- MODELS = [
- "JackFram/llama-160m",
- ]
- NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps
- NUM_PROMPTS = [10]
- DEFAULT_SERVER_ARGS: List[str] = [
- "--disable-log-requests",
- "--use-v2-block-manager",
- "--worker-use-ray",
- "--gpu-memory-utilization",
- "0.85",
- "--swap-space",
- "16",
- ]
- async def completions_with_server_args(
- prompts: List[str], model_name: str, server_cli_args: List[str]
- ):
- outputs = None
- with RemoteOpenAIServer(model_name, server_cli_args) as server:
- client = server.get_async_client()
- outputs = await client.completions.create(
- model=model_name,
- prompt=prompts,
- temperature=0,
- stream=False,
- max_tokens=5,
- )
- assert outputs is not None
- return outputs
- @pytest.mark.parametrize("model", MODELS)
- @pytest.mark.parametrize(
- ("tp_size, pp_size"),
- [
- (1, 1),
- (2, 2),
- ],
- )
- @pytest.mark.parametrize("eager_mode", [False, True])
- @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
- @pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
- @pytest.mark.asyncio
- async def test_multi_step(
- example_prompts,
- model: str,
- tp_size: int,
- pp_size: int,
- eager_mode: int,
- num_scheduler_steps: int,
- num_prompts: int,
- ):
- prompts = example_prompts
- if len(prompts) < num_prompts:
- prompts = prompts * ((num_prompts // len(prompts)) + 1)
- prompts = prompts[:num_prompts]
- assert len(prompts) == num_prompts
- server_args = DEFAULT_SERVER_ARGS + ["--enforce-eager"]
- ms_server_args = DEFAULT_SERVER_ARGS + [
- "--num-scheduler-steps",
- f"{num_scheduler_steps}",
- ]
- if eager_mode:
- ms_server_args.append("--enforce-eager")
- distributed_args = [
- "--tensor-parallel-size",
- str(tp_size),
- "--pipeline-parallel-size",
- str(pp_size),
- ]
- ref_completions = await completions_with_server_args(
- prompts, model, server_args + distributed_args
- )
- test_completions = await completions_with_server_args(
- prompts, model, ms_server_args + distributed_args
- )
- def get_text_generations(completions):
- return [x.text for x in completions.choices]
- ref_generations = get_text_generations(ref_completions)
- test_generations = get_text_generations(test_completions)
- assert ref_generations == test_generations
|