run_batch.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. import asyncio
  2. from http import HTTPStatus
  3. from io import StringIO
  4. from typing import Awaitable, Callable, List, Optional
  5. import aiohttp
  6. import torch
  7. from prometheus_client import start_http_server
  8. from tqdm import tqdm
  9. from aphrodite.common.utils import FlexibleArgumentParser, random_uuid
  10. from aphrodite.endpoints.logger import RequestLogger, logger
  11. # yapf: disable
  12. from aphrodite.endpoints.openai.protocol import (BatchRequestInput,
  13. BatchRequestOutput,
  14. BatchResponseData,
  15. ChatCompletionResponse,
  16. EmbeddingResponse,
  17. ErrorResponse)
  18. # yapf: enable
  19. from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
  20. from aphrodite.endpoints.openai.serving_embedding import OpenAIServingEmbedding
  21. from aphrodite.endpoints.openai.serving_engine import BaseModelPath
  22. from aphrodite.engine.args_tools import AsyncEngineArgs
  23. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  24. from aphrodite.version import __version__ as APHRODITE_VERSION
  25. def parse_args():
  26. parser = FlexibleArgumentParser(
  27. description="Aphrodite OpenAI-Compatible batch runner.")
  28. parser.add_argument(
  29. "-i",
  30. "--input-file",
  31. required=True,
  32. type=str,
  33. help=
  34. "The path or url to a single input file. Currently supports local file "
  35. "paths, or the http protocol (http or https). If a URL is specified, "
  36. "the file should be available via HTTP GET.")
  37. parser.add_argument(
  38. "-o",
  39. "--output-file",
  40. required=True,
  41. type=str,
  42. help="The path or url to a single output file. Currently supports "
  43. "local file paths, or web (http or https) urls. If a URL is specified,"
  44. " the file should be available via HTTP PUT.")
  45. parser.add_argument("--response-role",
  46. type=str,
  47. default="assistant",
  48. help="The role name to return if "
  49. "`request.add_generation_prompt=True`.")
  50. parser = AsyncEngineArgs.add_cli_args(parser)
  51. parser.add_argument('--max-log-len',
  52. type=int,
  53. default=None,
  54. help='Max number of prompt characters or prompt '
  55. 'ID numbers being printed in log.'
  56. '\n\nDefault: Unlimited')
  57. parser.add_argument("--enable-metrics",
  58. action="store_true",
  59. help="Enable Prometheus metrics")
  60. parser.add_argument(
  61. "--url",
  62. type=str,
  63. default="0.0.0.0",
  64. help="URL to the Prometheus metrics server "
  65. "(only needed if enable-metrics is set).",
  66. )
  67. parser.add_argument(
  68. "--port",
  69. type=int,
  70. default=8000,
  71. help="Port number for the Prometheus metrics server "
  72. "(only needed if enable-metrics is set).",
  73. )
  74. return parser.parse_args()
  75. # explicitly use pure text format, with a newline at the end
  76. # this makes it impossible to see the animation in the progress bar
  77. # but will avoid messing up with ray or multiprocessing, which wraps
  78. # each line of output with some prefix.
  79. _BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
  80. class BatchProgressTracker:
  81. def __init__(self):
  82. self._total = 0
  83. self._pbar: Optional[tqdm] = None
  84. def submitted(self):
  85. self._total += 1
  86. def completed(self):
  87. if self._pbar:
  88. self._pbar.update()
  89. def pbar(self) -> tqdm:
  90. enable_tqdm = not torch.distributed.is_initialized(
  91. ) or torch.distributed.get_rank() == 0
  92. self._pbar = tqdm(total=self._total,
  93. unit="req",
  94. desc="Running batch",
  95. mininterval=5,
  96. disable=not enable_tqdm,
  97. bar_format=_BAR_FORMAT)
  98. return self._pbar
  99. async def read_file(path_or_url: str) -> str:
  100. if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
  101. async with aiohttp.ClientSession() as session, \
  102. session.get(path_or_url) as resp:
  103. return await resp.text()
  104. else:
  105. with open(path_or_url, "r", encoding="utf-8") as f:
  106. return f.read()
  107. async def write_file(path_or_url: str, data: str) -> None:
  108. if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
  109. async with aiohttp.ClientSession() as session, \
  110. session.put(path_or_url, data=data.encode("utf-8")):
  111. pass
  112. else:
  113. # We should make this async, but as long as this is always run as a
  114. # standalone program, blocking the event loop won't effect performance
  115. # in this particular case.
  116. with open(path_or_url, "w", encoding="utf-8") as f:
  117. f.write(data)
  118. def make_error_request_output(request: BatchRequestInput,
  119. error_msg: str) -> BatchRequestOutput:
  120. batch_output = BatchRequestOutput(
  121. id=f"aphrodite-{random_uuid()}",
  122. custom_id=request.custom_id,
  123. response=BatchResponseData(
  124. status_code=HTTPStatus.BAD_REQUEST,
  125. request_id=f"aphrodite-batch-{random_uuid()}",
  126. ),
  127. error=error_msg,
  128. )
  129. return batch_output
  130. async def make_async_error_request_output(
  131. request: BatchRequestInput, error_msg: str) -> BatchRequestOutput:
  132. return make_error_request_output(request, error_msg)
  133. async def run_request(serving_engine_func: Callable,
  134. request: BatchRequestInput,
  135. tracker: BatchProgressTracker) -> BatchRequestOutput:
  136. response = await serving_engine_func(request.body)
  137. if isinstance(response, (ChatCompletionResponse, EmbeddingResponse)):
  138. batch_output = BatchRequestOutput(
  139. id=f"aphrodite-{random_uuid()}",
  140. custom_id=request.custom_id,
  141. response=BatchResponseData(
  142. body=response, request_id=f"aphrodite-batch-{random_uuid()}"),
  143. error=None,
  144. )
  145. elif isinstance(response, ErrorResponse):
  146. batch_output = BatchRequestOutput(
  147. id=f"aphrodite-{random_uuid()}",
  148. custom_id=request.custom_id,
  149. response=BatchResponseData(
  150. status_code=response.code,
  151. request_id=f"aphrodite-batch-{random_uuid()}"),
  152. error=response,
  153. )
  154. else:
  155. batch_output = make_error_request_output(
  156. request, error_msg="Request must not be sent in stream mode")
  157. tracker.completed()
  158. return batch_output
  159. async def main(args):
  160. if args.served_model_name is not None:
  161. served_model_names = args.served_model_name
  162. else:
  163. served_model_names = [args.model]
  164. engine_args = AsyncEngineArgs.from_cli_args(args)
  165. engine = AsyncAphrodite.from_engine_args(engine_args)
  166. model_config = await engine.get_model_config()
  167. base_model_paths = [
  168. BaseModelPath(name=name, model_path=args.model)
  169. for name in served_model_names
  170. ]
  171. if args.disable_log_requests:
  172. request_logger = None
  173. else:
  174. request_logger = RequestLogger(max_log_len=args.max_log_len)
  175. # Create the openai serving objects.
  176. openai_serving_chat = OpenAIServingChat(
  177. engine,
  178. model_config,
  179. base_model_paths,
  180. args.response_role,
  181. lora_modules=None,
  182. prompt_adapters=None,
  183. request_logger=request_logger,
  184. chat_template=None,
  185. )
  186. openai_serving_embedding = OpenAIServingEmbedding(
  187. engine,
  188. model_config,
  189. base_model_paths,
  190. request_logger=request_logger,
  191. )
  192. tracker = BatchProgressTracker()
  193. logger.info(f"Reading batch from {args.input_file}...")
  194. # Submit all requests in the file to the engine "concurrently".
  195. response_futures: List[Awaitable[BatchRequestOutput]] = []
  196. for request_json in (await read_file(args.input_file)).strip().split("\n"):
  197. # Skip empty lines.
  198. request_json = request_json.strip()
  199. if not request_json:
  200. continue
  201. request = BatchRequestInput.model_validate_json(request_json)
  202. # Determine the type of request and run it.
  203. if request.url == "/v1/chat/completions":
  204. response_futures.append(
  205. run_request(openai_serving_chat.create_chat_completion,
  206. request, tracker))
  207. tracker.submitted()
  208. elif request.url == "/v1/embeddings":
  209. response_futures.append(
  210. run_request(openai_serving_embedding.create_embedding, request,
  211. tracker))
  212. tracker.submitted()
  213. else:
  214. response_futures.append(
  215. make_async_error_request_output(
  216. request,
  217. error_msg="Only /v1/chat/completions and "
  218. "/v1/embeddings are supported in the batch endpoint.",
  219. ))
  220. with tracker.pbar():
  221. responses = await asyncio.gather(*response_futures)
  222. output_buffer = StringIO()
  223. for response in responses:
  224. print(response.model_dump_json(), file=output_buffer)
  225. output_buffer.seek(0)
  226. await write_file(args.output_file, output_buffer.read().strip())
  227. if __name__ == "__main__":
  228. args = parse_args()
  229. logger.info(f"Aphrodite batch processing API version {APHRODITE_VERSION}")
  230. logger.debug(f"args: {args}")
  231. # Start the Prometheus metrics server. LLMEngine uses the Prometheus client
  232. # to publish metrics at the /metrics endpoint.
  233. if args.enable_metrics:
  234. logger.info("Prometheus metrics enabled")
  235. start_http_server(port=args.port, addr=args.url)
  236. else:
  237. logger.info("Prometheus metrics disabled")
  238. asyncio.run(main(args))