test_async_aphrodite.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import asyncio
  2. from dataclasses import dataclass
  3. import pytest
  4. import torch
  5. from aphrodite import SamplingParams
  6. from aphrodite.common.config import ParallelConfig
  7. from aphrodite.engine.async_aphrodite import AsyncAphrodite, AsyncEngineArgs
  8. from ..utils import wait_for_gpu_memory_to_clear
  9. @dataclass
  10. class RequestOutput:
  11. request_id: int
  12. finished: bool = False
  13. class MockEngine:
  14. def __init__(self):
  15. self.step_calls = 0
  16. self.add_request_calls = 0
  17. self.abort_request_calls = 0
  18. self.request_id = None
  19. # Ugly, remove dependency when possible
  20. self.parallel_config = ParallelConfig(1, 1, False)
  21. async def step_async(self, virtual_engine):
  22. # PP size is 1, ignore virtual engine
  23. self.step_calls += 1
  24. return [RequestOutput(
  25. request_id=self.request_id)] if self.request_id else []
  26. async def process_model_inputs_async(self, *args, **kwargs):
  27. pass
  28. async def stop_remote_worker_execution_loop_async(self):
  29. pass
  30. def generate(self, request_id):
  31. self.request_id = request_id
  32. def stop_generating(self):
  33. self.request_id = None
  34. def add_request(self, **kwargs):
  35. del kwargs # Unused
  36. self.add_request_calls += 1
  37. print(f'Request calls: {self.add_request_calls}')
  38. async def add_request_async(self, **kwargs):
  39. self.add_request_calls += 1
  40. return
  41. def abort_request(self, request_id):
  42. del request_id # Unused
  43. self.abort_request_calls += 1
  44. def has_unfinished_requests(self):
  45. return self.request_id is not None
  46. def has_unfinished_requests_for_virtual_engine(self, virtual_engine):
  47. return self.request_id is not None
  48. class MockAsyncAphrodite(AsyncAphrodite):
  49. _engine_class = MockEngine
  50. @pytest.mark.asyncio
  51. async def test_new_requests_event():
  52. engine = MockAsyncAphrodite(worker_use_ray=False)
  53. engine.start_background_loop()
  54. await asyncio.sleep(0.01)
  55. assert engine.engine.step_calls == 0
  56. await engine.add_request("1", "", None)
  57. await asyncio.sleep(0.01)
  58. assert engine.engine.add_request_calls == 1
  59. assert engine.engine.step_calls == 1
  60. await engine.add_request("2", "", None)
  61. engine.engine.generate("2")
  62. await asyncio.sleep(0)
  63. await asyncio.sleep(0)
  64. await asyncio.sleep(0)
  65. assert engine.engine.add_request_calls == 2
  66. assert engine.engine.step_calls >= 2
  67. await asyncio.sleep(0.001)
  68. assert engine.engine.step_calls >= 3
  69. engine.engine.stop_generating()
  70. await asyncio.sleep(0.001)
  71. old_step_calls = engine.engine.step_calls
  72. await asyncio.sleep(0.001)
  73. assert engine.engine.step_calls == old_step_calls
  74. await engine.add_request("3", "", None)
  75. await asyncio.sleep(0.01)
  76. assert engine.engine.add_request_calls == 3
  77. assert engine.engine.step_calls == old_step_calls + 1
  78. await asyncio.sleep(0.01)
  79. assert engine.engine.add_request_calls == 3
  80. assert engine.engine.step_calls == old_step_calls + 1
  81. engine = MockAsyncAphrodite(worker_use_ray=True)
  82. assert engine.get_model_config() is not None
  83. assert engine.get_tokenizer() is not None
  84. assert engine.get_decoding_config() is not None
  85. def test_asyncio_run():
  86. wait_for_gpu_memory_to_clear(
  87. devices=list(range(torch.cuda.device_count())),
  88. threshold_bytes=2 * 2**30,
  89. timeout_s=60,
  90. )
  91. engine = AsyncAphrodite.from_engine_args(
  92. AsyncEngineArgs(model="facebook/opt-125m"))
  93. async def run(prompt: str):
  94. sampling_params = SamplingParams(
  95. temperature=0,
  96. max_tokens=32,
  97. )
  98. async for output in engine.generate(prompt,
  99. sampling_params,
  100. request_id=prompt):
  101. final_output = output
  102. return final_output
  103. async def generate():
  104. return await asyncio.gather(
  105. run("test0"),
  106. run("test1"),
  107. )
  108. results = asyncio.run(generate())
  109. assert len(results) == 2