test_api_server.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import subprocess
  2. import sys
  3. import time
  4. from multiprocessing import Pool
  5. from pathlib import Path
  6. import pytest
  7. import requests
  8. def _query_server(prompt: str, max_tokens: int = 5) -> dict:
  9. response = requests.post("http://localhost:2242/generate",
  10. json={
  11. "prompt": prompt,
  12. "max_tokens": max_tokens,
  13. "temperature": 0,
  14. "ignore_eos": True
  15. })
  16. response.raise_for_status()
  17. return response.json()
  18. def _query_server_long(prompt: str) -> dict:
  19. return _query_server(prompt, max_tokens=500)
  20. @pytest.fixture
  21. def api_server(tokenizer_pool_size: int, worker_use_ray: bool):
  22. script_path = Path(__file__).parent.joinpath(
  23. "api_server_async_engine.py").absolute()
  24. commands = [
  25. sys.executable, "-u",
  26. str(script_path), "--model", "facebook/opt-125m", "--host",
  27. "127.0.0.1", "--tokenizer-pool-size",
  28. str(tokenizer_pool_size)
  29. ]
  30. if worker_use_ray:
  31. commands.append("--worker-use-ray")
  32. uvicorn_process = subprocess.Popen(commands)
  33. yield
  34. uvicorn_process.terminate()
  35. @pytest.mark.parametrize("tokenizer_pool_size", [0, 2])
  36. @pytest.mark.parametrize("worker_use_ray", [False, True])
  37. def test_api_server(api_server, tokenizer_pool_size: int,
  38. worker_use_ray: bool):
  39. """
  40. Run the API server and test it.
  41. We run both the server and requests in separate processes.
  42. We test that the server can handle incoming requests, including
  43. multiple requests at the same time, and that it can handle requests
  44. being cancelled without crashing.
  45. """
  46. with Pool(32) as pool:
  47. # Wait until the server is ready
  48. prompts = ["warm up"] * 1
  49. result = None
  50. while not result:
  51. try:
  52. for r in pool.map(_query_server, prompts):
  53. result = r
  54. break
  55. except requests.exceptions.ConnectionError:
  56. time.sleep(1)
  57. # Actual tests start here
  58. # Try with 1 prompt
  59. for result in pool.map(_query_server, prompts):
  60. assert result
  61. num_aborted_requests = requests.get(
  62. "http://localhost:2242/stats").json()["num_aborted_requests"]
  63. assert num_aborted_requests == 0
  64. # Try with 100 prompts
  65. prompts = ["test prompt"] * 100
  66. for result in pool.map(_query_server, prompts):
  67. assert result
  68. with Pool(32) as pool:
  69. # Cancel requests
  70. prompts = ["canceled requests"] * 100
  71. pool.map_async(_query_server_long, prompts)
  72. time.sleep(0.01)
  73. pool.terminate()
  74. pool.join()
  75. # check cancellation stats
  76. # give it some times to update the stats
  77. time.sleep(1)
  78. num_aborted_requests = requests.get(
  79. "http://localhost:2242/stats").json()["num_aborted_requests"]
  80. assert num_aborted_requests > 0
  81. # check that server still runs after cancellations
  82. with Pool(32) as pool:
  83. # Try with 100 prompts
  84. prompts = ["test prompt after canceled"] * 100
  85. for result in pool.map(_query_server, prompts):
  86. assert result