123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566 |
- """Compare the outputs of HF and distributed Aphrodite when using greedy
- sampling. Aphrodite will allocate all the available memory, so we need to
- run the tests one by one. The solution is to pass arguments (model name)
- by environment variables.
- Run:
- ```sh
- TEST_DIST_MODEL=alpindale/gemma-2b pytest \
- test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=mistralai/Mistral-7B-Instruct-v0.2 \
- test_chunked_prefill_distributed.py
- ```
- """
- import os
- import pytest
- import torch
- MODELS = [
- os.environ["TEST_DIST_MODEL"],
- ]
- @pytest.mark.skipif(torch.cuda.device_count() < 2,
- reason="Need at least 2 GPUs to run the test.")
- @pytest.mark.parametrize("model", MODELS)
- @pytest.mark.parametrize("dtype", ["half"])
- @pytest.mark.parametrize("max_tokens", [5])
- @pytest.mark.parametrize("chunked_prefill_token_size", [16])
- def test_models(
- hf_runner,
- vllm_runner,
- example_prompts,
- model: str,
- dtype: str,
- max_tokens: int,
- chunked_prefill_token_size: int,
- ) -> None:
- # Add a chunked prefill config.
- max_num_seqs = min(chunked_prefill_token_size, 256)
- assert chunked_prefill_token_size != -1
- enable_chunked_prefill = True
- max_num_batched_tokens = chunked_prefill_token_size
- hf_model = hf_runner(model, dtype=dtype)
- hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
- del hf_model
- vllm_model = vllm_runner(
- model,
- dtype=dtype,
- tensor_parallel_size=2,
- max_num_seqs=max_num_seqs,
- enable_chunked_prefill=enable_chunked_prefill,
- max_num_batched_tokens=max_num_batched_tokens,
- )
- vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
- del vllm_model
- for i in range(len(example_prompts)):
- hf_output_ids, hf_output_str = hf_outputs[i]
- vllm_output_ids, vllm_output_str = vllm_outputs[i]
- assert hf_output_str == vllm_output_str, (
- f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
- assert hf_output_ids == vllm_output_ids, (
- f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|