test_api_server.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import os
  2. import subprocess
  3. import sys
  4. import time
  5. from multiprocessing import Pool
  6. from pathlib import Path
  7. import pytest
  8. import requests
  9. def _query_server(prompt: str, max_tokens: int = 5) -> dict:
  10. response = requests.post("http://localhost:2242/generate",
  11. json={
  12. "prompt": prompt,
  13. "max_tokens": max_tokens,
  14. "temperature": 0,
  15. "ignore_eos": True
  16. })
  17. response.raise_for_status()
  18. return response.json()
  19. def _query_server_long(prompt: str) -> dict:
  20. return _query_server(prompt, max_tokens=500)
  21. @pytest.fixture
  22. def api_server(tokenizer_pool_size: int, engine_use_ray: bool,
  23. worker_use_ray: bool):
  24. script_path = Path(__file__).parent.joinpath(
  25. "api_server_async_engine.py").absolute()
  26. commands = [
  27. sys.executable, "-u",
  28. str(script_path), "--model", "facebook/opt-125m", "--host",
  29. "127.0.0.1", "--tokenizer-pool-size",
  30. str(tokenizer_pool_size)
  31. ]
  32. # Copy the environment variables and append `APHRODITE_ALLOW_ENGINE_USE_RAY=1`
  33. # to prevent `--engine-use-ray` raises an exception due to it deprecation
  34. env_vars = os.environ.copy()
  35. env_vars["APHRODITE_ALLOW_ENGINE_USE_RAY"] = "1"
  36. if engine_use_ray:
  37. commands.append("--engine-use-ray")
  38. if worker_use_ray:
  39. commands.append("--worker-use-ray")
  40. uvicorn_process = subprocess.Popen(commands, env=env_vars)
  41. yield
  42. uvicorn_process.terminate()
  43. @pytest.mark.parametrize("tokenizer_pool_size", [0, 2])
  44. @pytest.mark.parametrize("worker_use_ray", [False, True])
  45. @pytest.mark.parametrize("engine_use_ray", [False, True])
  46. def test_api_server(api_server, tokenizer_pool_size: int, worker_use_ray: bool,
  47. engine_use_ray: bool):
  48. """
  49. Run the API server and test it.
  50. We run both the server and requests in separate processes.
  51. We test that the server can handle incoming requests, including
  52. multiple requests at the same time, and that it can handle requests
  53. being cancelled without crashing.
  54. """
  55. with Pool(32) as pool:
  56. # Wait until the server is ready
  57. prompts = ["warm up"] * 1
  58. result = None
  59. while not result:
  60. try:
  61. for r in pool.map(_query_server, prompts):
  62. result = r
  63. break
  64. except requests.exceptions.ConnectionError:
  65. time.sleep(1)
  66. # Actual tests start here
  67. # Try with 1 prompt
  68. for result in pool.map(_query_server, prompts):
  69. assert result
  70. num_aborted_requests = requests.get(
  71. "http://localhost:2242/stats").json()["num_aborted_requests"]
  72. assert num_aborted_requests == 0
  73. # Try with 100 prompts
  74. prompts = ["test prompt"] * 100
  75. for result in pool.map(_query_server, prompts):
  76. assert result
  77. with Pool(32) as pool:
  78. # Cancel requests
  79. prompts = ["canceled requests"] * 100
  80. pool.map_async(_query_server_long, prompts)
  81. time.sleep(0.01)
  82. pool.terminate()
  83. pool.join()
  84. # check cancellation stats
  85. # give it some times to update the stats
  86. time.sleep(1)
  87. num_aborted_requests = requests.get(
  88. "http://localhost:2242/stats").json()["num_aborted_requests"]
  89. assert num_aborted_requests > 0
  90. # check that server still runs after cancellations
  91. with Pool(32) as pool:
  92. # Try with 100 prompts
  93. prompts = ["test prompt after canceled"] * 100
  94. for result in pool.map(_query_server, prompts):
  95. assert result