run_batch.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. import asyncio
  2. from io import StringIO
  3. from typing import Awaitable, List
  4. import aiohttp
  5. from loguru import logger
  6. from aphrodite.common.utils import FlexibleArgumentParser, random_uuid
  7. from aphrodite.endpoints.logger import RequestLogger
  8. from aphrodite.endpoints.openai.protocol import (BatchRequestInput,
  9. BatchRequestOutput,
  10. BatchResponseData,
  11. ChatCompletionResponse,
  12. ErrorResponse)
  13. from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
  14. from aphrodite.engine.args_tools import AsyncEngineArgs
  15. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  16. from aphrodite.version import __version__ as APHRODITE_VERSION
  17. def parse_args():
  18. parser = FlexibleArgumentParser(
  19. description="Aphrodite OpenAI-Compatible batch runner.")
  20. parser.add_argument(
  21. "-i",
  22. "--input-file",
  23. required=True,
  24. type=str,
  25. help=
  26. "The path or url to a single input file. Currently supports local file "
  27. "paths, or the http protocol (http or https). If a URL is specified, "
  28. "the file should be available via HTTP GET.")
  29. parser.add_argument(
  30. "-o",
  31. "--output-file",
  32. required=True,
  33. type=str,
  34. help="The path or url to a single output file. Currently supports "
  35. "local file paths, or web (http or https) urls. If a URL is specified,"
  36. " the file should be available via HTTP PUT.")
  37. parser.add_argument("--response-role",
  38. type=str,
  39. default="assistant",
  40. help="The role name to return if "
  41. "`request.add_generation_prompt=true`.")
  42. parser.add_argument("--max-log-len",
  43. type=int,
  44. default=0,
  45. help="Max number of prompt characters or prompt "
  46. "ID numbers being printed in log."
  47. "\n\nDefault: 0")
  48. parser = AsyncEngineArgs.add_cli_args(parser)
  49. return parser.parse_args()
  50. async def read_file(path_or_url: str) -> str:
  51. if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
  52. async with aiohttp.ClientSession() as session, \
  53. session.get(path_or_url) as resp:
  54. return await resp.text()
  55. else:
  56. with open(path_or_url, "r", encoding="utf-8") as f:
  57. return f.read()
  58. async def write_file(path_or_url: str, data: str) -> None:
  59. if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
  60. async with aiohttp.ClientSession() as session, \
  61. session.put(path_or_url, data=data.encode("utf-8")):
  62. pass
  63. else:
  64. # We should make this async, but as long as this is always run as a
  65. # standalone program, blocking the event loop won't effect performance
  66. # in this particular case.
  67. with open(path_or_url, "w", encoding="utf-8") as f:
  68. f.write(data)
  69. async def run_request(chat_serving: OpenAIServingChat,
  70. request: BatchRequestInput) -> BatchRequestOutput:
  71. chat_request = request.body
  72. chat_response = await chat_serving.create_chat_completion(chat_request)
  73. if isinstance(chat_response, ChatCompletionResponse):
  74. batch_output = BatchRequestOutput(
  75. id=f"aphrodite-{random_uuid()}",
  76. custom_id=request.custom_id,
  77. response=BatchResponseData(
  78. body=chat_response,
  79. request_id=f"aphrodite-batch-{random_uuid()}"),
  80. error=None,
  81. )
  82. elif isinstance(chat_response, ErrorResponse):
  83. batch_output = BatchRequestOutput(
  84. id=f"aphrodite-{random_uuid()}",
  85. custom_id=request.custom_id,
  86. response=BatchResponseData(
  87. status_code=chat_response.code,
  88. request_id=f"aphrodite-batch-{random_uuid()}"),
  89. error=chat_response,
  90. )
  91. else:
  92. raise ValueError("Request must not be sent in stream mode")
  93. return batch_output
  94. async def main(args):
  95. if args.served_model_name is not None:
  96. served_model_names = args.served_model_name
  97. else:
  98. served_model_names = [args.model]
  99. engine_args = AsyncEngineArgs.from_cli_args(args)
  100. engine = AsyncAphrodite.from_engine_args(engine_args)
  101. # When using single Aphrodite without engine_use_ray
  102. model_config = await engine.get_model_config()
  103. if args.disable_log_requests:
  104. request_logger = None
  105. else:
  106. request_logger = RequestLogger(max_log_len=args.max_log_len)
  107. openai_serving_chat = OpenAIServingChat(
  108. engine,
  109. model_config,
  110. served_model_names,
  111. args.response_role,
  112. lora_modules=None,
  113. prompt_adapters=None,
  114. request_logger=request_logger,
  115. chat_template=None,
  116. )
  117. # Submit all requests in the file to the engine "concurrently".
  118. response_futures: List[Awaitable[BatchRequestOutput]] = []
  119. for request_json in (await read_file(args.input_file)).strip().split("\n"):
  120. request = BatchRequestInput.model_validate_json(request_json)
  121. response_futures.append(run_request(openai_serving_chat, request))
  122. responses = await asyncio.gather(*response_futures)
  123. output_buffer = StringIO()
  124. for response in responses:
  125. print(response.model_dump_json(), file=output_buffer)
  126. output_buffer.seek(0)
  127. await write_file(args.output_file, output_buffer.read().strip())
  128. if __name__ == "__main__":
  129. args = parse_args()
  130. logger.info(f"Aphrodite API server version {APHRODITE_VERSION}")
  131. logger.info(f"args: {args}")
  132. asyncio.run(main(args))