api_server.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750
  1. import asyncio
  2. import importlib
  3. import inspect
  4. import json
  5. import multiprocessing
  6. import os
  7. import re
  8. import tempfile
  9. from argparse import Namespace
  10. from contextlib import asynccontextmanager
  11. from http import HTTPStatus
  12. from typing import AsyncGenerator, AsyncIterator, List, Set, Tuple
  13. import uvloop
  14. from fastapi import APIRouter, FastAPI, Request
  15. from fastapi.exceptions import RequestValidationError
  16. from fastapi.middleware.cors import CORSMiddleware
  17. from fastapi.responses import (HTMLResponse, JSONResponse, Response,
  18. StreamingResponse)
  19. from loguru import logger
  20. from starlette.routing import Mount
  21. from aphrodite.common.config import ModelConfig
  22. from aphrodite.common.outputs import RequestOutput
  23. from aphrodite.common.sampling_params import _SAMPLING_EPS, SamplingParams
  24. from aphrodite.common.utils import (FlexibleArgumentParser,
  25. get_open_zmq_ipc_path, random_uuid)
  26. from aphrodite.endpoints.logger import RequestLogger
  27. from aphrodite.endpoints.openai.args import make_arg_parser
  28. # yapf: disable
  29. from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest,
  30. ChatCompletionResponse,
  31. CompletionRequest,
  32. DetokenizeRequest,
  33. DetokenizeResponse,
  34. EmbeddingRequest,
  35. ErrorResponse,
  36. KAIGenerationInputSchema,
  37. TokenizeRequest,
  38. TokenizeResponse)
  39. from aphrodite.endpoints.openai.rpc.client import AsyncEngineRPCClient
  40. from aphrodite.endpoints.openai.rpc.server import run_rpc_server
  41. # yapf: enable
  42. from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
  43. from aphrodite.endpoints.openai.serving_completions import (
  44. OpenAIServingCompletion)
  45. from aphrodite.endpoints.openai.serving_embedding import OpenAIServingEmbedding
  46. from aphrodite.endpoints.openai.serving_engine import (LoRAModulePath,
  47. PromptAdapterPath)
  48. from aphrodite.endpoints.openai.serving_tokenization import (
  49. OpenAIServingTokenization)
  50. from aphrodite.engine.args_tools import AsyncEngineArgs
  51. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  52. from aphrodite.engine.protocol import AsyncEngineClient
  53. from aphrodite.server import serve_http
  54. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  55. from aphrodite.version import __version__ as APHRODITE_VERSION
  56. TIMEOUT_KEEP_ALIVE = 5 # seconds
  57. async_engine_client: AsyncEngineClient
  58. engine_args: AsyncEngineArgs
  59. openai_serving_chat: OpenAIServingChat
  60. openai_serving_completion: OpenAIServingCompletion
  61. openai_serving_embedding: OpenAIServingEmbedding
  62. openai_serving_tokenization: OpenAIServingTokenization
  63. router = APIRouter()
  64. kai_api = APIRouter()
  65. extra_api = APIRouter()
  66. kobold_lite_ui = ""
  67. sampler_json = ""
  68. gen_cache: dict = {}
  69. prometheus_multiproc_dir: tempfile.TemporaryDirectory
  70. _running_tasks: Set[asyncio.Task] = set()
  71. def model_is_embedding(model_name: str, trust_remote_code: bool) -> bool:
  72. return ModelConfig(model=model_name,
  73. tokenizer=model_name,
  74. tokenizer_mode="auto",
  75. trust_remote_code=trust_remote_code,
  76. seed=0,
  77. dtype="auto").embedding_mode
  78. @asynccontextmanager
  79. async def lifespan(app: FastAPI):
  80. async def _force_log():
  81. while True:
  82. await asyncio.sleep(10)
  83. await async_engine_client.do_log_stats()
  84. if not engine_args.disable_log_stats:
  85. task = asyncio.create_task(_force_log())
  86. _running_tasks.add(task)
  87. task.add_done_callback(_running_tasks.remove)
  88. yield
  89. @asynccontextmanager
  90. async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
  91. # Context manager to handle async_engine_client lifecycle
  92. # Ensures everything is shutdown and cleaned up on error/exit
  93. global engine_args
  94. engine_args = AsyncEngineArgs.from_cli_args(args)
  95. # Backend itself still global for the silly lil' health handler
  96. global async_engine_client
  97. # If manually triggered or embedding model, use AsyncAphrodite in process.
  98. # TODO: support embedding model via RPC.
  99. if (model_is_embedding(args.model, args.trust_remote_code)
  100. or args.disable_frontend_multiprocessing):
  101. async_engine_client = AsyncAphrodite.from_engine_args(engine_args)
  102. yield async_engine_client
  103. return
  104. # Otherwise, use the multiprocessing AsyncAphrodite.
  105. else:
  106. if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
  107. # Make TemporaryDirectory for prometheus multiprocessing
  108. # Note: global TemporaryDirectory will be automatically
  109. # cleaned up upon exit.
  110. global prometheus_multiproc_dir
  111. prometheus_multiproc_dir = tempfile.TemporaryDirectory()
  112. os.environ[
  113. "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
  114. else:
  115. logger.warning(
  116. "Found PROMETHEUS_MULTIPROC_DIR was set by user. "
  117. "This directory must be wiped between Aphrodite runs or "
  118. "you will find inaccurate metrics. Unset the variable "
  119. "and Aphrodite will properly handle cleanup.")
  120. # Select random path for IPC.
  121. rpc_path = get_open_zmq_ipc_path()
  122. logger.info(f"Multiprocessing frontend to use {rpc_path} for RPC Path."
  123. )
  124. # Start RPCServer in separate process (holds the AsyncAphrodite).
  125. context = multiprocessing.get_context("spawn")
  126. # the current process might have CUDA context,
  127. # so we need to spawn a new process
  128. rpc_server_process = context.Process(
  129. target=run_rpc_server,
  130. args=(engine_args, rpc_path))
  131. rpc_server_process.start()
  132. logger.info(
  133. f"Started engine process with PID {rpc_server_process.pid}")
  134. # Build RPCClient, which conforms to AsyncEngineClient Protocol.
  135. async_engine_client = AsyncEngineRPCClient(rpc_path)
  136. try:
  137. while True:
  138. try:
  139. await async_engine_client.setup()
  140. break
  141. except TimeoutError as e:
  142. if not rpc_server_process.is_alive():
  143. raise RuntimeError(
  144. "The server process died before "
  145. "responding to the readiness probe") from e
  146. yield async_engine_client
  147. finally:
  148. # Ensure rpc server process was terminated
  149. rpc_server_process.terminate()
  150. # Close all open connections to the backend
  151. async_engine_client.close()
  152. # Wait for server process to join
  153. rpc_server_process.join()
  154. # Lazy import for prometheus multiprocessing.
  155. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
  156. # before prometheus_client is imported.
  157. # See https://prometheus.github.io/client_python/multiprocess/
  158. from prometheus_client import multiprocess
  159. multiprocess.mark_process_dead(rpc_server_process.pid)
  160. def mount_metrics(app: FastAPI):
  161. # Lazy import for prometheus multiprocessing.
  162. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
  163. # before prometheus_client is imported.
  164. # See https://prometheus.github.io/client_python/multiprocess/
  165. from prometheus_client import (CollectorRegistry, make_asgi_app,
  166. multiprocess)
  167. prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
  168. if prometheus_multiproc_dir_path is not None:
  169. logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
  170. prometheus_multiproc_dir_path)
  171. registry = CollectorRegistry()
  172. multiprocess.MultiProcessCollector(registry)
  173. # Add prometheus asgi middleware to route /metrics requests
  174. metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
  175. else:
  176. # Add prometheus asgi middleware to route /metrics requests
  177. metrics_route = Mount("/metrics", make_asgi_app())
  178. # Workaround for 307 Redirect for /metrics
  179. metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
  180. app.routes.append(metrics_route)
  181. @router.get("/health")
  182. async def health() -> Response:
  183. """Health check."""
  184. await async_engine_client.check_health()
  185. return Response(status_code=200)
  186. @router.post("/v1/tokenize")
  187. async def tokenize(request: TokenizeRequest):
  188. generator = await openai_serving_tokenization.create_tokenize(request)
  189. if isinstance(generator, ErrorResponse):
  190. return JSONResponse(content=generator.model_dump(),
  191. status_code=generator.code)
  192. else:
  193. assert isinstance(generator, TokenizeResponse)
  194. return JSONResponse(content=generator.model_dump())
  195. @router.post("/v1/detokenize")
  196. async def detokenize(request: DetokenizeRequest):
  197. generator = await openai_serving_tokenization.create_detokenize(request)
  198. if isinstance(generator, ErrorResponse):
  199. return JSONResponse(content=generator.model_dump(),
  200. status_code=generator.code)
  201. else:
  202. assert isinstance(generator, DetokenizeResponse)
  203. return JSONResponse(content=generator.model_dump())
  204. @router.get("/v1/models")
  205. async def show_available_models():
  206. models = await openai_serving_completion.show_available_models()
  207. return JSONResponse(content=models.model_dump())
  208. @router.get("/version")
  209. async def show_version():
  210. ver = {"version": APHRODITE_VERSION}
  211. return JSONResponse(content=ver)
  212. @router.post("/v1/chat/completions")
  213. async def create_chat_completion(request: ChatCompletionRequest,
  214. raw_request: Request):
  215. generator = await openai_serving_chat.create_chat_completion(
  216. request, raw_request)
  217. if isinstance(generator, ErrorResponse):
  218. return JSONResponse(content=generator.model_dump(),
  219. status_code=generator.code)
  220. if request.stream:
  221. return StreamingResponse(content=generator,
  222. media_type="text/event-stream")
  223. else:
  224. assert isinstance(generator, ChatCompletionResponse)
  225. return JSONResponse(content=generator.model_dump())
  226. @router.post("/v1/completions")
  227. async def create_completion(request: CompletionRequest, raw_request: Request):
  228. generator = await openai_serving_completion.create_completion(
  229. request, raw_request)
  230. if isinstance(generator, ErrorResponse):
  231. return JSONResponse(content=generator.model_dump(),
  232. status_code=generator.code)
  233. if request.stream:
  234. return StreamingResponse(content=generator,
  235. media_type="text/event-stream")
  236. else:
  237. return JSONResponse(content=generator.model_dump())
  238. @router.post("/v1/embeddings")
  239. async def create_embedding(request: EmbeddingRequest, raw_request: Request):
  240. generator = await openai_serving_embedding.create_embedding(
  241. request, raw_request)
  242. if isinstance(generator, ErrorResponse):
  243. return JSONResponse(content=generator.model_dump(),
  244. status_code=generator.code)
  245. else:
  246. return JSONResponse(content=generator.model_dump())
  247. @router.post("/v1/lora/load")
  248. async def load_lora(lora: LoRAModulePath):
  249. openai_serving_completion.add_lora(lora)
  250. if engine_args.enable_lora is False:
  251. logger.error("LoRA is not enabled in the engine. "
  252. "Please start the server with the "
  253. "--enable-lora flag!")
  254. return JSONResponse(content={"status": "success"})
  255. @router.delete("/v1/lora/unload")
  256. async def unload_lora(lora_name: str):
  257. openai_serving_completion.remove_lora(lora_name)
  258. return JSONResponse(content={"status": "success"})
  259. @router.post("/v1/soft_prompt/load")
  260. async def load_soft_prompt(soft_prompt: PromptAdapterPath):
  261. openai_serving_completion.add_prompt_adapter(soft_prompt)
  262. if engine_args.enable_prompt_adapter is False:
  263. logger.error("Prompt Adapter is not enabled in the engine. "
  264. "Please start the server with the "
  265. "--enable-prompt-adapter flag!")
  266. return JSONResponse(content={"status": "success"})
  267. @router.delete("/v1/soft_prompt/unload")
  268. async def unload_soft_prompt(soft_prompt_name: str):
  269. openai_serving_completion.remove_prompt_adapter(soft_prompt_name)
  270. return JSONResponse(content={"status": "success"})
  271. # ============ KoboldAI API ============ #
  272. def _set_badwords(tokenizer, hf_config): # pylint: disable=redefined-outer-name
  273. # pylint: disable=global-variable-undefined
  274. global badwordsids
  275. if hf_config.bad_words_ids is not None:
  276. badwordsids = hf_config.bad_words_ids
  277. return
  278. badwordsids = [
  279. v for k, v in tokenizer.get_vocab().items()
  280. if any(c in str(k) for c in "[]")
  281. ]
  282. if tokenizer.pad_token_id in badwordsids:
  283. badwordsids.remove(tokenizer.pad_token_id)
  284. badwordsids.append(tokenizer.eos_token_id)
  285. def prepare_engine_payload(
  286. kai_payload: KAIGenerationInputSchema
  287. ) -> Tuple[SamplingParams, List[int]]:
  288. """Create SamplingParams and truncated input tokens for AsyncEngine"""
  289. if not kai_payload.genkey:
  290. kai_payload.genkey = f"kai-{random_uuid()}"
  291. # if kai_payload.max_context_length > engine_args.max_model_len:
  292. # raise ValueError(
  293. # f"max_context_length ({kai_payload.max_context_length}) "
  294. # "must be less than or equal to "
  295. # f"max_model_len ({engine_args.max_model_len})")
  296. kai_payload.top_k = kai_payload.top_k if kai_payload.top_k != 0.0 else -1
  297. kai_payload.tfs = max(_SAMPLING_EPS, kai_payload.tfs)
  298. if kai_payload.temperature < _SAMPLING_EPS:
  299. kai_payload.n = 1
  300. kai_payload.top_p = 1.0
  301. kai_payload.top_k = -1
  302. sampling_params = SamplingParams(
  303. n=kai_payload.n,
  304. best_of=kai_payload.n,
  305. repetition_penalty=kai_payload.rep_pen,
  306. temperature=kai_payload.temperature,
  307. smoothing_factor=kai_payload.smoothing_factor,
  308. smoothing_curve=kai_payload.smoothing_curve,
  309. tfs=kai_payload.tfs,
  310. top_p=kai_payload.top_p,
  311. top_k=kai_payload.top_k,
  312. top_a=kai_payload.top_a,
  313. min_p=kai_payload.min_p,
  314. typical_p=kai_payload.typical,
  315. eta_cutoff=kai_payload.eta_cutoff,
  316. epsilon_cutoff=kai_payload.eps_cutoff,
  317. stop=kai_payload.stop_sequence,
  318. include_stop_str_in_output=kai_payload.include_stop_str_in_output,
  319. custom_token_bans=badwordsids
  320. if kai_payload.use_default_badwordsids else [],
  321. max_tokens=kai_payload.max_length,
  322. seed=kai_payload.sampler_seed,
  323. )
  324. max_input_tokens = max(
  325. 1, kai_payload.max_context_length - kai_payload.max_length)
  326. input_tokens = tokenizer(kai_payload.prompt).input_ids[-max_input_tokens:]
  327. return sampling_params, input_tokens
  328. @kai_api.post("/generate")
  329. async def generate(kai_payload: KAIGenerationInputSchema) -> JSONResponse:
  330. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  331. result_generator = async_engine_client.generate(
  332. {
  333. "prompt": kai_payload.prompt,
  334. "prompt_token_ids": input_tokens,
  335. },
  336. sampling_params,
  337. kai_payload.genkey,
  338. )
  339. final_res: RequestOutput = None
  340. previous_output = ""
  341. async for res in result_generator:
  342. final_res = res
  343. new_chunk = res.outputs[0].text[len(previous_output):]
  344. previous_output += new_chunk
  345. gen_cache[kai_payload.genkey] = previous_output
  346. assert final_res is not None
  347. del gen_cache[kai_payload.genkey]
  348. return JSONResponse(
  349. {"results": [{
  350. "text": output.text
  351. } for output in final_res.outputs]})
  352. @extra_api.post("/generate/stream")
  353. async def generate_stream(
  354. kai_payload: KAIGenerationInputSchema) -> StreamingResponse:
  355. sampling_params, input_tokens = prepare_engine_payload(kai_payload)
  356. results_generator = async_engine_client.generate(
  357. {
  358. "prompt": kai_payload.prompt,
  359. "prompt_token_ids": input_tokens,
  360. },
  361. sampling_params,
  362. kai_payload.genkey,
  363. )
  364. async def stream_kobold() -> AsyncGenerator[bytes, None]:
  365. previous_output = ""
  366. async for res in results_generator:
  367. new_chunk = res.outputs[0].text[len(previous_output):]
  368. previous_output += new_chunk
  369. yield b"event: message\n"
  370. yield f"data: {json.dumps({'token': new_chunk})}\n\n".encode()
  371. return StreamingResponse(stream_kobold(),
  372. headers={
  373. "Cache-Control": "no-cache",
  374. "Connection": "keep-alive",
  375. },
  376. media_type="text/event-stream")
  377. @extra_api.post("/generate/check")
  378. @extra_api.get("/generate/check")
  379. async def check_generation(request: Request):
  380. text = ""
  381. try:
  382. request_dict = await request.json()
  383. if "genkey" in request_dict and request_dict["genkey"] in gen_cache:
  384. text = gen_cache[request_dict["genkey"]]
  385. except json.JSONDecodeError:
  386. pass
  387. return JSONResponse({"results": [{"text": text}]})
  388. @extra_api.post("/abort")
  389. async def abort_generation(request: Request):
  390. try:
  391. request_dict = await request.json()
  392. if "genkey" in request_dict:
  393. await async_engine_client.abort(request_dict["genkey"])
  394. except json.JSONDecodeError:
  395. pass
  396. return JSONResponse({})
  397. @extra_api.post("/tokencount")
  398. async def count_tokens(request: TokenizeRequest):
  399. """Tokenize string and return token count"""
  400. generator = await openai_serving_tokenization.create_tokenize(request)
  401. return JSONResponse({"value": generator.model_dump()["tokens"]})
  402. @kai_api.get("/info/version")
  403. async def get_version():
  404. """Impersonate KAI"""
  405. return JSONResponse({"result": "1.2.4"})
  406. @kai_api.get("/model")
  407. async def get_model():
  408. return JSONResponse({"result": f"aphrodite/{served_model_names[0]}"})
  409. @kai_api.get("/config/soft_prompts_list")
  410. async def get_available_softprompts():
  411. """Stub for compatibility"""
  412. return JSONResponse({"values": []})
  413. @kai_api.get("/config/soft_prompt")
  414. async def get_current_softprompt():
  415. """Stub for compatibility"""
  416. return JSONResponse({"value": ""})
  417. @kai_api.put("/config/soft_prompt")
  418. async def set_current_softprompt():
  419. """Stub for compatibility"""
  420. return JSONResponse({})
  421. @kai_api.get("/config/max_length")
  422. async def get_max_length() -> JSONResponse:
  423. max_length = args.max_length
  424. return JSONResponse({"value": max_length})
  425. @kai_api.get("/config/max_context_length")
  426. @extra_api.get("/true_max_context_length")
  427. async def get_max_context_length() -> JSONResponse:
  428. max_context_length = engine_args.max_model_len
  429. return JSONResponse({"value": max_context_length})
  430. @extra_api.get("/preloadstory")
  431. async def get_preloaded_story() -> JSONResponse:
  432. """Stub for compatibility"""
  433. return JSONResponse({})
  434. @extra_api.get("/version")
  435. async def get_extra_version():
  436. """Impersonate KoboldCpp"""
  437. return JSONResponse({"result": "KoboldCpp", "version": "1.63"})
  438. @router.get("/")
  439. async def get_kobold_lite_ui():
  440. """Serves a cached copy of the Kobold Lite UI, loading it from disk
  441. on demand if needed."""
  442. global kobold_lite_ui
  443. if kobold_lite_ui == "":
  444. scriptpath = os.path.dirname(os.path.abspath(__file__))
  445. klitepath = os.path.join(scriptpath, "../kobold/klite.embd")
  446. klitepath = os.path.normpath(klitepath) # Normalize the path
  447. if os.path.exists(klitepath):
  448. with open(klitepath, "r") as f:
  449. kobold_lite_ui = f.read()
  450. else:
  451. logger.error("Kobold Lite UI not found at " + klitepath)
  452. return HTMLResponse(content=kobold_lite_ui)
  453. # ============ KoboldAI API ============ #
  454. def build_app(args: Namespace) -> FastAPI:
  455. app = FastAPI(lifespan=lifespan)
  456. app.include_router(router)
  457. app.root_path = args.root_path
  458. if args.launch_kobold_api:
  459. logger.warning("Launching Kobold API server in addition to OpenAI. "
  460. "Keep in mind that the Kobold API routes are NOT "
  461. "protected via the API key.")
  462. app.include_router(kai_api, prefix="/api/v1")
  463. app.include_router(kai_api,
  464. prefix="/api/latest",
  465. include_in_schema=False)
  466. app.include_router(extra_api, prefix="/api/extra")
  467. mount_metrics(app)
  468. app.add_middleware(
  469. CORSMiddleware,
  470. allow_origins=args.allowed_origins,
  471. allow_credentials=args.allow_credentials,
  472. allow_methods=args.allowed_methods,
  473. allow_headers=args.allowed_headers,
  474. )
  475. @app.exception_handler(RequestValidationError)
  476. async def validation_exception_handler(_, exc):
  477. err = openai_serving_completion.create_error_response(message=str(exc))
  478. return JSONResponse(err.model_dump(),
  479. status_code=HTTPStatus.BAD_REQUEST)
  480. if token := os.environ.get("APHRODITE_API_KEY") or args.api_keys:
  481. admin_key = os.environ.get("APHRODITE_ADMIN_KEY") or args.admin_key
  482. if admin_key is None:
  483. logger.warning("Admin key not provided. Admin operations will "
  484. "be disabled.")
  485. @app.middleware("http")
  486. async def authentication(request: Request, call_next):
  487. excluded_paths = ["/api"]
  488. if any(
  489. request.url.path.startswith(path)
  490. for path in excluded_paths):
  491. return await call_next(request)
  492. if not request.url.path.startswith("/v1"):
  493. return await call_next(request)
  494. # Browsers may send OPTIONS requests to check CORS headers
  495. # before sending the actual request. We should allow these
  496. # requests to pass through without authentication.
  497. # See https://github.com/PygmalionAI/aphrodite-engine/issues/434
  498. if request.method == "OPTIONS":
  499. return await call_next(request)
  500. auth_header = request.headers.get("Authorization")
  501. api_key_header = request.headers.get("x-api-key")
  502. if request.url.path.startswith(("/v1/lora", "/v1/soft_prompt")):
  503. if admin_key is not None and api_key_header == admin_key:
  504. return await call_next(request)
  505. return JSONResponse(content={"error": "Unauthorized"},
  506. status_code=401)
  507. if auth_header != "Bearer " + token and api_key_header != token:
  508. return JSONResponse(content={"error": "Unauthorized"},
  509. status_code=401)
  510. return await call_next(request)
  511. for middleware in args.middleware:
  512. module_path, object_name = middleware.rsplit(".", 1)
  513. imported = getattr(importlib.import_module(module_path), object_name)
  514. if inspect.isclass(imported):
  515. app.add_middleware(imported)
  516. elif inspect.iscoroutinefunction(imported):
  517. app.middleware("http")(imported)
  518. else:
  519. raise ValueError(f"Invalid middleware {middleware}. "
  520. f"Must be a function or a class.")
  521. return app
  522. async def init_app(
  523. async_engine_client: AsyncEngineClient,
  524. args: Namespace,
  525. ) -> FastAPI:
  526. app = build_app(args)
  527. logger.debug(f"args: {args}")
  528. global served_model_names
  529. if args.served_model_name is not None:
  530. served_model_names = args.served_model_name
  531. else:
  532. served_model_names = [args.model]
  533. if args.uvloop:
  534. uvloop.install()
  535. global tokenizer
  536. model_config = await async_engine_client.get_model_config()
  537. if args.disable_log_requests:
  538. request_logger = None
  539. else:
  540. request_logger = RequestLogger(max_log_len=args.max_log_len)
  541. global openai_serving_chat
  542. global openai_serving_completion
  543. global openai_serving_embedding
  544. global openai_serving_tokenization
  545. openai_serving_chat = OpenAIServingChat(
  546. async_engine_client,
  547. model_config,
  548. served_model_names,
  549. args.response_role,
  550. lora_modules=args.lora_modules,
  551. prompt_adapters=args.prompt_adapters,
  552. request_logger=request_logger,
  553. chat_template=args.chat_template,
  554. return_tokens_as_token_ids=args.return_tokens_as_token_ids,
  555. )
  556. openai_serving_completion = OpenAIServingCompletion(
  557. async_engine_client,
  558. model_config,
  559. served_model_names,
  560. lora_modules=args.lora_modules,
  561. prompt_adapters=args.prompt_adapters,
  562. request_logger=request_logger,
  563. return_tokens_as_token_ids=args.return_tokens_as_token_ids,
  564. )
  565. openai_serving_embedding = OpenAIServingEmbedding(
  566. async_engine_client,
  567. model_config,
  568. served_model_names,
  569. request_logger=request_logger,
  570. )
  571. openai_serving_tokenization = OpenAIServingTokenization(
  572. async_engine_client,
  573. model_config,
  574. served_model_names,
  575. lora_modules=args.lora_modules,
  576. request_logger=request_logger,
  577. chat_template=args.chat_template,
  578. )
  579. app.root_path = args.root_path
  580. tokenizer = get_tokenizer(
  581. tokenizer_name=engine_args.tokenizer,
  582. tokenizer_mode=engine_args.tokenizer_mode,
  583. trust_remote_code=engine_args.trust_remote_code,
  584. revision=engine_args.revision,
  585. )
  586. if args.launch_kobold_api:
  587. _set_badwords(tokenizer, model_config.hf_config)
  588. return app
  589. async def run_server(args, **uvicorn_kwargs) -> None:
  590. async with build_async_engine_client(args) as async_engine_client:
  591. app = await init_app(async_engine_client, args)
  592. shutdown_task = await serve_http(
  593. app,
  594. engine=async_engine_client,
  595. host=args.host,
  596. port=args.port,
  597. log_level=args.uvicorn_log_level,
  598. timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
  599. ssl_keyfile=args.ssl_keyfile,
  600. ssl_certfile=args.ssl_certfile,
  601. ssl_ca_certs=args.ssl_ca_certs,
  602. ssl_cert_reqs=args.ssl_cert_reqs,
  603. **uvicorn_kwargs,
  604. )
  605. # NB: Await server shutdown only after the backend context is exited
  606. await shutdown_task
  607. if __name__ == "__main__":
  608. # NOTE:
  609. # This section should be in sync with aphrodite/endpoints/cli.py
  610. # for CLI entrypoints.
  611. parser = FlexibleArgumentParser(
  612. description="Aphrodite OpenAI-Compatible RESTful API Server")
  613. parser = make_arg_parser(parser)
  614. args = parser.parse_args()
  615. asyncio.run(run_server(args))