1
0

test_pipeline_parallel.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. """
  2. WARNING: This test runs in both single-node (4 GPUs) and multi-node
  3. (2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is
  4. important to set the distributed backend to "mp" to avoid Ray scheduling
  5. all workers in a node other than the head node, which can cause the test
  6. to fail.
  7. """
  8. import os
  9. import pytest
  10. from loguru import logger
  11. from packaging import version
  12. from transformers import __version__ as transformers_version
  13. from ..utils import compare_two_settings, fork_new_process_for_each_test
  14. APHRODITE_MULTI_NODE = os.getenv("APHRODITE_MULTI_NODE", "0") == "1"
  15. @pytest.mark.parametrize(
  16. ("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, TRUST_REMOTE_CODE, "
  17. "MODEL_NAME, DIST_BACKEND"),
  18. [
  19. (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
  20. (2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
  21. (1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
  22. (1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
  23. (1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
  24. (1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
  25. (1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
  26. (1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
  27. (2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
  28. (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
  29. (2, 2, 1, 1, 1, "internlm/internlm2_5-7b-chat", "ray"),
  30. (1, 2, 0, 1, 0, "Qwen/Qwen2-VL-2B-Instruct", "mp")
  31. ],
  32. )
  33. @fork_new_process_for_each_test
  34. def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL,
  35. TRUST_REMOTE_CODE, MODEL_NAME, DIST_BACKEND):
  36. if APHRODITE_MULTI_NODE and DIST_BACKEND == "mp":
  37. pytest.skip("Skipping multi-node pipeline parallel test for "
  38. "multiprocessing distributed backend")
  39. # Skip tests that require transformers>=4.45.0
  40. if "Qwen2-VL" in MODEL_NAME and version.parse(
  41. transformers_version) < version.parse("4.45.0.dev0"):
  42. pytest.skip("This test requires transformers>=4.45.0")
  43. pp_args = [
  44. # use half precision for speed and memory savings in CI environment
  45. "--dtype",
  46. "float16",
  47. "--pipeline-parallel-size",
  48. str(PP_SIZE),
  49. "--tensor-parallel-size",
  50. str(TP_SIZE),
  51. "--distributed-executor-backend",
  52. DIST_BACKEND,
  53. ]
  54. # compare without pipeline parallelism
  55. # NOTE: use mp backend for TP
  56. # PP tests might involve multiple nodes, and ray might
  57. # schedule all workers in a node other than the head node,
  58. # which can cause the test to fail.
  59. tp_args = [
  60. # use half precision for speed and memory savings in CI environment
  61. "--dtype",
  62. "bfloat16",
  63. "--tensor-parallel-size",
  64. str(max(TP_SIZE, 2)), # We only use 2 GPUs in the CI.
  65. "--distributed-executor-backend",
  66. "mp",
  67. ]
  68. if CHUNKED_PREFILL:
  69. pp_args.append("--enable-chunked-prefill")
  70. tp_args.append("--enable-chunked-prefill")
  71. if EAGER_MODE:
  72. pp_args.append("--enforce-eager")
  73. tp_args.append("--enforce-eager")
  74. if TRUST_REMOTE_CODE:
  75. pp_args.append("--trust-remote-code")
  76. tp_args.append("--trust-remote-code")
  77. pp_env = None
  78. if (DIST_BACKEND == "ray" and TP_SIZE == 2 and PP_SIZE == 2
  79. and CHUNKED_PREFILL):
  80. # Test Ray ADAG for a subset of the tests
  81. pp_env = {
  82. "APHRODITE_USE_RAY_COMPILED_DAG": "1",
  83. "APHRODITE_USE_RAY_SPMD_WORKER": "1",
  84. "APHRODITE_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1",
  85. }
  86. # Temporary. Currently when zeromq + SPMD is used, it does not properly
  87. # terminate because of aDAG issue.
  88. pp_args.append("--disable-frontend-multiprocessing")
  89. tp_args.append("--disable-frontend-multiprocessing")
  90. try:
  91. compare_two_settings(MODEL_NAME, pp_args, tp_args, pp_env)
  92. except Exception:
  93. if pp_env is None:
  94. raise
  95. else:
  96. # Ray ADAG tests are flaky, so we don't want to fail the test
  97. logger.exception("Ray ADAG tests failed")