test_async_aphrodite.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import asyncio
  2. import os
  3. from dataclasses import dataclass
  4. import pytest
  5. import torch
  6. from aphrodite import SamplingParams
  7. from aphrodite.common.config import ParallelConfig
  8. from aphrodite.engine.async_aphrodite import AsyncAphrodite, AsyncEngineArgs
  9. from ..utils import wait_for_gpu_memory_to_clear
  10. @dataclass
  11. class RequestOutput:
  12. request_id: int
  13. finished: bool = False
  14. class MockEngine:
  15. def __init__(self):
  16. self.step_calls = 0
  17. self.add_request_calls = 0
  18. self.abort_request_calls = 0
  19. self.request_id = None
  20. # Ugly, remove dependency when possible
  21. self.parallel_config = ParallelConfig(1, 1, False)
  22. async def step_async(self, virtual_engine):
  23. # PP size is 1, ignore virtual engine
  24. self.step_calls += 1
  25. return [RequestOutput(
  26. request_id=self.request_id)] if self.request_id else []
  27. async def process_model_inputs_async(self, *args, **kwargs):
  28. pass
  29. async def stop_remote_worker_execution_loop_async(self):
  30. pass
  31. def generate(self, request_id):
  32. self.request_id = request_id
  33. def stop_generating(self):
  34. self.request_id = None
  35. def add_request(self, **kwargs):
  36. del kwargs # Unused
  37. self.add_request_calls += 1
  38. print(f'Request calls: {self.add_request_calls}')
  39. async def add_request_async(self, **kwargs):
  40. self.add_request_calls += 1
  41. return
  42. def abort_request(self, request_id):
  43. del request_id # Unused
  44. self.abort_request_calls += 1
  45. def has_unfinished_requests(self):
  46. return self.request_id is not None
  47. def has_unfinished_requests_for_virtual_engine(self, virtual_engine):
  48. return self.request_id is not None
  49. class MockAsyncAphrodite(AsyncAphrodite):
  50. def _init_engine(self, *args, **kwargs):
  51. return MockEngine()
  52. @pytest.mark.asyncio
  53. async def test_new_requests_event():
  54. engine = MockAsyncAphrodite(worker_use_ray=False, engine_use_ray=False)
  55. engine.start_background_loop()
  56. await asyncio.sleep(0.01)
  57. assert engine.engine.step_calls == 0
  58. await engine.add_request("1", "", None)
  59. await asyncio.sleep(0.01)
  60. assert engine.engine.add_request_calls == 1
  61. assert engine.engine.step_calls == 1
  62. await engine.add_request("2", "", None)
  63. engine.engine.generate("2")
  64. await asyncio.sleep(0)
  65. await asyncio.sleep(0)
  66. await asyncio.sleep(0)
  67. assert engine.engine.add_request_calls == 2
  68. assert engine.engine.step_calls >= 2
  69. await asyncio.sleep(0.001)
  70. assert engine.engine.step_calls >= 3
  71. engine.engine.stop_generating()
  72. await asyncio.sleep(0.001)
  73. old_step_calls = engine.engine.step_calls
  74. await asyncio.sleep(0.001)
  75. assert engine.engine.step_calls == old_step_calls
  76. await engine.add_request("3", "", None)
  77. await asyncio.sleep(0.01)
  78. assert engine.engine.add_request_calls == 3
  79. assert engine.engine.step_calls == old_step_calls + 1
  80. await asyncio.sleep(0.01)
  81. assert engine.engine.add_request_calls == 3
  82. assert engine.engine.step_calls == old_step_calls + 1
  83. # Allow deprecated engine_use_ray to not raise exception
  84. os.environ["APHRODITE_ALLOW_ENGINE_USE_RAY"] = "1"
  85. engine = MockAsyncAphrodite(worker_use_ray=True, engine_use_ray=True)
  86. assert engine.get_model_config() is not None
  87. assert engine.get_tokenizer() is not None
  88. assert engine.get_decoding_config() is not None
  89. os.environ.pop("APHRODITE_ALLOW_ENGINE_USE_RAY")
  90. def test_asyncio_run():
  91. wait_for_gpu_memory_to_clear(
  92. devices=list(range(torch.cuda.device_count())),
  93. threshold_bytes=2 * 2**30,
  94. timeout_s=60,
  95. )
  96. engine = AsyncAphrodite.from_engine_args(
  97. AsyncEngineArgs(model="facebook/opt-125m"))
  98. async def run(prompt: str):
  99. sampling_params = SamplingParams(
  100. temperature=0,
  101. max_tokens=32,
  102. )
  103. async for output in engine.generate(prompt,
  104. sampling_params,
  105. request_id=prompt):
  106. final_output = output
  107. return final_output
  108. async def generate():
  109. return await asyncio.gather(
  110. run("test0"),
  111. run("test1"),
  112. )
  113. results = asyncio.run(generate())
  114. assert len(results) == 2