1
0

run_batch.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. import asyncio
  2. from io import StringIO
  3. from typing import Awaitable, Callable, List
  4. import aiohttp
  5. from loguru import logger
  6. from aphrodite.common.utils import FlexibleArgumentParser, random_uuid
  7. from aphrodite.endpoints.logger import RequestLogger
  8. from aphrodite.endpoints.openai.protocol import (BatchRequestInput,
  9. BatchRequestOutput,
  10. BatchResponseData,
  11. ChatCompletionResponse,
  12. EmbeddingResponse,
  13. ErrorResponse)
  14. from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
  15. from aphrodite.endpoints.openai.serving_embedding import OpenAIServingEmbedding
  16. from aphrodite.engine.args_tools import AsyncEngineArgs
  17. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  18. from aphrodite.version import __version__ as APHRODITE_VERSION
  19. def parse_args():
  20. parser = FlexibleArgumentParser(
  21. description="Aphrodite OpenAI-Compatible batch runner.")
  22. parser.add_argument(
  23. "-i",
  24. "--input-file",
  25. required=True,
  26. type=str,
  27. help=
  28. "The path or url to a single input file. Currently supports local file "
  29. "paths, or the http protocol (http or https). If a URL is specified, "
  30. "the file should be available via HTTP GET.")
  31. parser.add_argument(
  32. "-o",
  33. "--output-file",
  34. required=True,
  35. type=str,
  36. help="The path or url to a single output file. Currently supports "
  37. "local file paths, or web (http or https) urls. If a URL is specified,"
  38. " the file should be available via HTTP PUT.")
  39. parser.add_argument("--response-role",
  40. type=str,
  41. default="assistant",
  42. help="The role name to return if "
  43. "`request.add_generation_prompt=true`.")
  44. parser.add_argument("--max-log-len",
  45. type=int,
  46. default=0,
  47. help="Max number of prompt characters or prompt "
  48. "ID numbers being printed in log."
  49. "\n\nDefault: 0")
  50. parser = AsyncEngineArgs.add_cli_args(parser)
  51. return parser.parse_args()
  52. async def read_file(path_or_url: str) -> str:
  53. if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
  54. async with aiohttp.ClientSession() as session, \
  55. session.get(path_or_url) as resp:
  56. return await resp.text()
  57. else:
  58. with open(path_or_url, "r", encoding="utf-8") as f:
  59. return f.read()
  60. async def write_file(path_or_url: str, data: str) -> None:
  61. if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
  62. async with aiohttp.ClientSession() as session, \
  63. session.put(path_or_url, data=data.encode("utf-8")):
  64. pass
  65. else:
  66. # We should make this async, but as long as this is always run as a
  67. # standalone program, blocking the event loop won't effect performance
  68. # in this particular case.
  69. with open(path_or_url, "w", encoding="utf-8") as f:
  70. f.write(data)
  71. async def run_request(serving_engine_func: Callable,
  72. request: BatchRequestInput) -> BatchRequestOutput:
  73. response = await serving_engine_func(request.body)
  74. if isinstance(response, (ChatCompletionResponse, EmbeddingResponse)):
  75. batch_output = BatchRequestOutput(
  76. id=f"aphrodite-{random_uuid()}",
  77. custom_id=request.custom_id,
  78. response=BatchResponseData(
  79. body=response, request_id=f"aphrodite-batch-{random_uuid()}"),
  80. error=None,
  81. )
  82. elif isinstance(response, ErrorResponse):
  83. batch_output = BatchRequestOutput(
  84. id=f"aphrodite-{random_uuid()}",
  85. custom_id=request.custom_id,
  86. response=BatchResponseData(
  87. status_code=response.code,
  88. request_id=f"aphrodite-batch-{random_uuid()}"),
  89. error=response,
  90. )
  91. else:
  92. raise ValueError("Request must not be sent in stream mode")
  93. return batch_output
  94. async def main(args):
  95. if args.served_model_name is not None:
  96. served_model_names = args.served_model_name
  97. else:
  98. served_model_names = [args.model]
  99. engine_args = AsyncEngineArgs.from_cli_args(args)
  100. engine = AsyncAphrodite.from_engine_args(engine_args)
  101. # When using single Aphrodite without engine_use_ray
  102. model_config = await engine.get_model_config()
  103. if args.disable_log_requests:
  104. request_logger = None
  105. else:
  106. request_logger = RequestLogger(max_log_len=args.max_log_len)
  107. # Create the OpenAI serving objects.
  108. openai_serving_chat = OpenAIServingChat(
  109. engine,
  110. model_config,
  111. served_model_names,
  112. args.response_role,
  113. lora_modules=None,
  114. prompt_adapters=None,
  115. request_logger=request_logger,
  116. chat_template=None,
  117. )
  118. openai_serving_embedding = OpenAIServingEmbedding(
  119. engine,
  120. model_config,
  121. served_model_names,
  122. request_logger=request_logger,
  123. )
  124. # Submit all requests in the file to the engine "concurrently".
  125. response_futures: List[Awaitable[BatchRequestOutput]] = []
  126. for request_json in (await read_file(args.input_file)).strip().split("\n"):
  127. # Skip empty lines.
  128. request_json = request_json.strip()
  129. if not request_json:
  130. continue
  131. request = BatchRequestInput.model_validate_json(request_json)
  132. # Determine the type of request and run it.
  133. if request.url == "/v1/chat/completions":
  134. response_futures.append(
  135. run_request(openai_serving_chat.create_chat_completion,
  136. request))
  137. elif request.url == "/v1/embeddings":
  138. response_futures.append(
  139. run_request(openai_serving_embedding.create_embedding,
  140. request))
  141. else:
  142. raise ValueError("Only /v1/chat/completions and /v1/embeddings are"
  143. "supported in the batch endpoint.")
  144. responses = await asyncio.gather(*response_futures)
  145. output_buffer = StringIO()
  146. for response in responses:
  147. print(response.model_dump_json(), file=output_buffer)
  148. output_buffer.seek(0)
  149. await write_file(args.output_file, output_buffer.read().strip())
  150. if __name__ == "__main__":
  151. args = parse_args()
  152. logger.info(f"Aphrodite API server version {APHRODITE_VERSION}")
  153. logger.info(f"args: {args}")
  154. asyncio.run(main(args))