test_multimodal_broadcast.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. """Compare the outputs of HF and distributed Aphrodite when using greedy
  2. sampling.
  3. Run:
  4. ```sh
  5. pytest -s -v test_multimodal_broadcast.py
  6. ```
  7. """
  8. import pytest
  9. from aphrodite.common.utils import cuda_device_count_stateless
  10. from ..utils import fork_new_process_for_each_test
  11. @pytest.mark.skipif(cuda_device_count_stateless() < 2,
  12. reason="Need at least 2 GPUs to run the test.")
  13. @pytest.mark.parametrize("model, distributed_executor_backend", [
  14. ("llava-hf/llava-1.5-7b-hf", "ray"),
  15. ("llava-hf/llava-v1.6-mistral-7b-hf", "ray"),
  16. ("facebook/chameleon-7b", "ray"),
  17. ("llava-hf/llava-1.5-7b-hf", "mp"),
  18. ("llava-hf/llava-v1.6-mistral-7b-hf", "mp"),
  19. ("facebook/chameleon-7b", "mp"),
  20. ])
  21. @fork_new_process_for_each_test
  22. def test_models(hf_runner, aphrodite_runner, image_assets, model: str,
  23. distributed_executor_backend: str) -> None:
  24. dtype = "half"
  25. max_tokens = 5
  26. num_logprobs = 5
  27. tensor_parallel_size = 2
  28. if model.startswith("llava-hf/llava-1.5"):
  29. from ..models.test_llava import models, run_test
  30. elif model.startswith("llava-hf/llava-v1.6"):
  31. from ..models.test_llava_next import run_test # type: ignore[no-redef]
  32. from ..models.test_llava_next import models
  33. elif model.startswith("facebook/chameleon"):
  34. from ..models.test_chameleon import run_test # type: ignore[no-redef]
  35. from ..models.test_chameleon import models
  36. else:
  37. raise NotImplementedError(f"Unsupported model: {model}")
  38. run_test(
  39. hf_runner,
  40. aphrodite_runner,
  41. image_assets,
  42. model=models[0],
  43. # So that LLaVA-NeXT processor may return nested list
  44. size_factors=[0.25, 0.5, 1.0],
  45. dtype=dtype,
  46. max_tokens=max_tokens,
  47. num_logprobs=num_logprobs,
  48. tensor_parallel_size=tensor_parallel_size,
  49. distributed_executor_backend=distributed_executor_backend,
  50. )