test_chunked_prefill_distributed.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. """Compare the outputs of HF and distributed Aphrodite when using greedy
  2. sampling. Aphrodite will allocate all the available memory, so we need to
  3. run the tests one by one. The solution is to pass arguments (model name)
  4. by environment variables.
  5. Run:
  6. ```sh
  7. TEST_DIST_MODEL=alpindale/gemma-2b pytest \
  8. test_chunked_prefill_distributed.py
  9. TEST_DIST_MODEL=mistralai/Mistral-7B-Instruct-v0.2 \
  10. test_chunked_prefill_distributed.py
  11. ```
  12. """
  13. import os
  14. import pytest
  15. import torch
  16. MODELS = [
  17. os.environ["TEST_DIST_MODEL"],
  18. ]
  19. @pytest.mark.skipif(torch.cuda.device_count() < 2,
  20. reason="Need at least 2 GPUs to run the test.")
  21. @pytest.mark.parametrize("model", MODELS)
  22. @pytest.mark.parametrize("dtype", ["half"])
  23. @pytest.mark.parametrize("max_tokens", [5])
  24. @pytest.mark.parametrize("chunked_prefill_token_size", [16])
  25. def test_models(
  26. hf_runner,
  27. vllm_runner,
  28. example_prompts,
  29. model: str,
  30. dtype: str,
  31. max_tokens: int,
  32. chunked_prefill_token_size: int,
  33. ) -> None:
  34. # Add a chunked prefill config.
  35. max_num_seqs = min(chunked_prefill_token_size, 256)
  36. assert chunked_prefill_token_size != -1
  37. enable_chunked_prefill = True
  38. max_num_batched_tokens = chunked_prefill_token_size
  39. hf_model = hf_runner(model, dtype=dtype)
  40. hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
  41. del hf_model
  42. vllm_model = vllm_runner(
  43. model,
  44. dtype=dtype,
  45. tensor_parallel_size=2,
  46. max_num_seqs=max_num_seqs,
  47. enable_chunked_prefill=enable_chunked_prefill,
  48. max_num_batched_tokens=max_num_batched_tokens,
  49. )
  50. vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
  51. del vllm_model
  52. for i in range(len(example_prompts)):
  53. hf_output_ids, hf_output_str = hf_outputs[i]
  54. vllm_output_ids, vllm_output_str = vllm_outputs[i]
  55. assert hf_output_str == vllm_output_str, (
  56. f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
  57. assert hf_output_ids == vllm_output_ids, (
  58. f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")