1
0

test_api_server.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import os
  2. import subprocess
  3. import sys
  4. import time
  5. from multiprocessing import Pool
  6. from pathlib import Path
  7. import pytest
  8. import requests
  9. def _query_server(prompt: str, max_tokens: int = 5) -> dict:
  10. response = requests.post("http://localhost:2242/generate",
  11. json={
  12. "prompt": prompt,
  13. "max_tokens": max_tokens,
  14. "temperature": 0,
  15. "ignore_eos": True
  16. })
  17. response.raise_for_status()
  18. return response.json()
  19. def _query_server_long(prompt: str) -> dict:
  20. return _query_server(prompt, max_tokens=500)
  21. @pytest.fixture
  22. def api_server(tokenizer_pool_size: int, engine_use_ray: bool,
  23. worker_use_ray: bool):
  24. script_path = Path(__file__).parent.joinpath(
  25. "api_server_async_engine.py").absolute()
  26. commands = [
  27. sys.executable, "-u",
  28. str(script_path), "--model", "facebook/opt-125m", "--host",
  29. "127.0.0.1", "--tokenizer-pool-size",
  30. str(tokenizer_pool_size)
  31. ]
  32. # Copy the environment variables and append
  33. # `APHRODITE_ALLOW_ENGINE_USE_RAY=1` to prevent
  34. # `--engine-use-ray` raises an exception due to it deprecation
  35. env_vars = os.environ.copy()
  36. env_vars["APHRODITE_ALLOW_ENGINE_USE_RAY"] = "1"
  37. if engine_use_ray:
  38. commands.append("--engine-use-ray")
  39. if worker_use_ray:
  40. commands.append("--worker-use-ray")
  41. uvicorn_process = subprocess.Popen(commands, env=env_vars)
  42. yield
  43. uvicorn_process.terminate()
  44. @pytest.mark.parametrize("tokenizer_pool_size", [0, 2])
  45. @pytest.mark.parametrize("worker_use_ray", [False, True])
  46. @pytest.mark.parametrize("engine_use_ray", [False, True])
  47. def test_api_server(api_server, tokenizer_pool_size: int, worker_use_ray: bool,
  48. engine_use_ray: bool):
  49. """
  50. Run the API server and test it.
  51. We run both the server and requests in separate processes.
  52. We test that the server can handle incoming requests, including
  53. multiple requests at the same time, and that it can handle requests
  54. being cancelled without crashing.
  55. """
  56. with Pool(32) as pool:
  57. # Wait until the server is ready
  58. prompts = ["warm up"] * 1
  59. result = None
  60. while not result:
  61. try:
  62. for r in pool.map(_query_server, prompts):
  63. result = r
  64. break
  65. except requests.exceptions.ConnectionError:
  66. time.sleep(1)
  67. # Actual tests start here
  68. # Try with 1 prompt
  69. for result in pool.map(_query_server, prompts):
  70. assert result
  71. num_aborted_requests = requests.get(
  72. "http://localhost:2242/stats").json()["num_aborted_requests"]
  73. assert num_aborted_requests == 0
  74. # Try with 100 prompts
  75. prompts = ["test prompt"] * 100
  76. for result in pool.map(_query_server, prompts):
  77. assert result
  78. with Pool(32) as pool:
  79. # Cancel requests
  80. prompts = ["canceled requests"] * 100
  81. pool.map_async(_query_server_long, prompts)
  82. time.sleep(0.01)
  83. pool.terminate()
  84. pool.join()
  85. # check cancellation stats
  86. # give it some times to update the stats
  87. time.sleep(1)
  88. num_aborted_requests = requests.get(
  89. "http://localhost:2242/stats").json()["num_aborted_requests"]
  90. assert num_aborted_requests > 0
  91. # check that server still runs after cancellations
  92. with Pool(32) as pool:
  93. # Try with 100 prompts
  94. prompts = ["test prompt after canceled"] * 100
  95. for result in pool.map(_query_server, prompts):
  96. assert result