import asyncio import multiprocessing from typing import Callable, Tuple, Union from aphrodite import SamplingParams from aphrodite.common.outputs import RequestOutput from aphrodite.engine.args_tools import AsyncEngineArgs from aphrodite.engine.multiprocessing.client import MQAphroditeEngineClient from aphrodite.engine.multiprocessing.engine import MQAphroditeEngine async def generate( client: MQAphroditeEngineClient, request_id: str, num_tokens: int, return_output: bool = False) -> Union[RequestOutput, Tuple[int, str]]: final_output = None count = 0 async for out in client.generate( request_id=request_id, prompt="Hello my name is Robert and", sampling_params=SamplingParams(max_tokens=num_tokens, temperature=0)): count += 1 final_output = out await asyncio.sleep(0.) if return_output: return final_output # Confirm we generated all the tokens we expected. return count, request_id def run_normal(engine_args: AsyncEngineArgs, ipc_path: str): # Make engine. engine = MQAphroditeEngine.from_engine_args( engine_args=engine_args, ipc_path=ipc_path) # Run engine. engine.start() class RemoteMQAphroditeEngine: def __init__(self, engine_args: AsyncEngineArgs, ipc_path: str, run_fn: Callable = run_normal) -> None: self.engine_args = engine_args self.ipc_path = ipc_path context = multiprocessing.get_context("spawn") self.proc = context.Process(target=run_fn, args=(engine_args, ipc_path)) self.proc.start() def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.proc.kill() async def make_client(self) -> MQAphroditeEngineClient: engine_config = self.engine_args.create_engine_config() client = MQAphroditeEngineClient(self.ipc_path, engine_config) while True: try: await client.setup() break except TimeoutError: assert self.proc.is_alive() return client