1
0

utils.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import asyncio
  2. import multiprocessing
  3. from typing import Callable, Tuple, Union
  4. from aphrodite import SamplingParams
  5. from aphrodite.common.outputs import RequestOutput
  6. from aphrodite.engine.args_tools import AsyncEngineArgs
  7. from aphrodite.engine.multiprocessing.client import MQAphroditeEngineClient
  8. from aphrodite.engine.multiprocessing.engine import MQAphroditeEngine
  9. async def generate(
  10. client: MQAphroditeEngineClient,
  11. request_id: str,
  12. num_tokens: int,
  13. return_output: bool = False) -> Union[RequestOutput, Tuple[int, str]]:
  14. final_output = None
  15. count = 0
  16. async for out in client.generate(
  17. request_id=request_id,
  18. prompt="Hello my name is Robert and",
  19. sampling_params=SamplingParams(max_tokens=num_tokens,
  20. temperature=0)):
  21. count += 1
  22. final_output = out
  23. await asyncio.sleep(0.)
  24. if return_output:
  25. return final_output
  26. # Confirm we generated all the tokens we expected.
  27. return count, request_id
  28. def run_normal(engine_args: AsyncEngineArgs, ipc_path: str):
  29. # Make engine.
  30. engine = MQAphroditeEngine.from_engine_args(
  31. engine_args=engine_args,
  32. ipc_path=ipc_path)
  33. # Run engine.
  34. engine.start()
  35. class RemoteMQAphroditeEngine:
  36. def __init__(self,
  37. engine_args: AsyncEngineArgs,
  38. ipc_path: str,
  39. run_fn: Callable = run_normal) -> None:
  40. self.engine_args = engine_args
  41. self.ipc_path = ipc_path
  42. context = multiprocessing.get_context("spawn")
  43. self.proc = context.Process(target=run_fn,
  44. args=(engine_args, ipc_path))
  45. self.proc.start()
  46. def __enter__(self):
  47. return self
  48. def __exit__(self, exc_type, exc_value, traceback):
  49. self.proc.kill()
  50. async def make_client(self) -> MQAphroditeEngineClient:
  51. engine_config = self.engine_args.create_engine_config()
  52. client = MQAphroditeEngineClient(self.ipc_path, engine_config)
  53. while True:
  54. try:
  55. await client.setup()
  56. break
  57. except TimeoutError:
  58. assert self.proc.is_alive()
  59. return client