1
0

test_custom_executor.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import asyncio
  2. import os
  3. import pytest
  4. from aphrodite.common.sampling_params import SamplingParams
  5. from aphrodite.engine.args_tools import AsyncEngineArgs, EngineArgs
  6. from aphrodite.engine.async_aphrodite import AphroditeEngine, AsyncAphrodite
  7. from aphrodite.executor.gpu_executor import GPUExecutor, GPUExecutorAsync
  8. class Mock:
  9. ...
  10. class CustomGPUExecutor(GPUExecutor):
  11. def execute_model(self, *args, **kwargs):
  12. # Drop marker to show that this was ran
  13. with open(".marker", "w"):
  14. ...
  15. return super().execute_model(*args, **kwargs)
  16. class CustomGPUExecutorAsync(GPUExecutorAsync):
  17. async def execute_model_async(self, *args, **kwargs):
  18. with open(".marker", "w"):
  19. ...
  20. return await super().execute_model_async(*args, **kwargs)
  21. @pytest.mark.parametrize("model", ["facebook/opt-125m"])
  22. def test_custom_executor_type_checking(model):
  23. with pytest.raises(ValueError):
  24. engine_args = EngineArgs(model=model,
  25. distributed_executor_backend=Mock)
  26. AphroditeEngine.from_engine_args(engine_args)
  27. with pytest.raises(ValueError):
  28. engine_args = AsyncEngineArgs(model=model,
  29. distributed_executor_backend=Mock)
  30. AsyncAphrodite.from_engine_args(engine_args)
  31. with pytest.raises(TypeError):
  32. engine_args = AsyncEngineArgs(
  33. model=model, distributed_executor_backend=CustomGPUExecutor)
  34. AsyncAphrodite.from_engine_args(engine_args)
  35. @pytest.mark.parametrize("model", ["facebook/opt-125m"])
  36. def test_custom_executor(model, tmpdir):
  37. cwd = os.path.abspath(".")
  38. os.chdir(tmpdir)
  39. try:
  40. assert not os.path.exists(".marker")
  41. engine_args = EngineArgs(
  42. model=model, distributed_executor_backend=CustomGPUExecutor)
  43. engine = AphroditeEngine.from_engine_args(engine_args)
  44. sampling_params = SamplingParams(max_tokens=1)
  45. engine.add_request("0", "foo", sampling_params)
  46. engine.step()
  47. assert os.path.exists(".marker")
  48. finally:
  49. os.chdir(cwd)
  50. @pytest.mark.parametrize("model", ["facebook/opt-125m"])
  51. def test_custom_executor_async(model, tmpdir):
  52. cwd = os.path.abspath(".")
  53. os.chdir(tmpdir)
  54. try:
  55. assert not os.path.exists(".marker")
  56. engine_args = AsyncEngineArgs(
  57. model=model, distributed_executor_backend=CustomGPUExecutorAsync)
  58. engine = AsyncAphrodite.from_engine_args(engine_args)
  59. sampling_params = SamplingParams(max_tokens=1)
  60. async def t():
  61. stream = await engine.add_request("0", "foo", sampling_params)
  62. async for x in stream:
  63. ...
  64. asyncio.run(t())
  65. assert os.path.exists(".marker")
  66. finally:
  67. os.chdir(cwd)