1
0

test_api_server.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import subprocess
  2. import sys
  3. import time
  4. from multiprocessing import Pool
  5. from pathlib import Path
  6. import pytest
  7. import requests
  8. def _query_server(prompt: str, max_tokens: int = 5) -> dict:
  9. response = requests.post("http://localhost:2242/generate",
  10. json={
  11. "prompt": prompt,
  12. "max_tokens": max_tokens,
  13. "temperature": 0,
  14. "ignore_eos": True
  15. })
  16. response.raise_for_status()
  17. return response.json()
  18. def _query_server_long(prompt: str) -> dict:
  19. return _query_server(prompt, max_tokens=500)
  20. @pytest.fixture
  21. def api_server():
  22. script_path = Path(__file__).parent.joinpath(
  23. "api_server_async_engine.py").absolute()
  24. uvicorn_process = subprocess.Popen([
  25. sys.executable, "-u",
  26. str(script_path), "--model", "EleutherAI/pythia-70m-deduped"
  27. ])
  28. yield
  29. uvicorn_process.terminate()
  30. def test_api_server(api_server):
  31. """
  32. Run the API server and test it.
  33. We run both the server and requests in separate processes.
  34. We test that the server can handle incoming requests, including
  35. multiple requests at the same time, and that it can handle requests
  36. being cancelled without crashing.
  37. """
  38. with Pool(32) as pool:
  39. # Wait until the server is ready
  40. prompts = ["warm up"] * 1
  41. result = None
  42. while not result:
  43. try:
  44. for r in pool.map(_query_server, prompts):
  45. result = r
  46. break
  47. except requests.exceptions.ConnectionError:
  48. time.sleep(1)
  49. # Actual tests start here
  50. # Try with 1 prompt
  51. for result in pool.map(_query_server, prompts):
  52. assert result
  53. num_aborted_requests = requests.get(
  54. "http://localhost:8000/stats").json()["num_aborted_requests"]
  55. assert num_aborted_requests == 0
  56. # Try with 100 prompts
  57. prompts = ["test prompt"] * 100
  58. for result in pool.map(_query_server, prompts):
  59. assert result
  60. with Pool(32) as pool:
  61. # Cancel requests
  62. prompts = ["canceled requests"] * 100
  63. pool.map_async(_query_server_long, prompts)
  64. time.sleep(0.01)
  65. pool.terminate()
  66. pool.join()
  67. # check cancellation stats
  68. num_aborted_requests = requests.get(
  69. "http://localhost:8000/stats").json()["num_aborted_requests"]
  70. assert num_aborted_requests > 0
  71. # check that server still runs after cancellations
  72. with Pool(32) as pool:
  73. # Try with 100 prompts
  74. prompts = ["test prompt after canceled"] * 100
  75. for result in pool.map(_query_server, prompts):
  76. assert result