1
0

backend_request_func.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. import json
  2. import os
  3. import sys
  4. import time
  5. import traceback
  6. from dataclasses import dataclass, field
  7. from typing import List, Optional, Union
  8. import aiohttp
  9. import huggingface_hub.constants
  10. from tqdm.asyncio import tqdm
  11. from transformers import (AutoTokenizer, PreTrainedTokenizer,
  12. PreTrainedTokenizerFast)
  13. AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
  14. @dataclass
  15. class RequestFuncInput:
  16. prompt: str
  17. api_url: str
  18. prompt_len: int
  19. output_len: int
  20. model: str
  21. best_of: int = 1
  22. use_beam_search: bool = False
  23. @dataclass
  24. class RequestFuncOutput:
  25. generated_text: str = ""
  26. success: bool = False
  27. latency: float = 0.0
  28. ttft: float = 0.0 # Time to first token
  29. itl: List[float] = field(
  30. default_factory=list) # List of inter-token latencies
  31. prompt_len: int = 0
  32. error: str = ""
  33. async def async_request_tgi(
  34. request_func_input: RequestFuncInput,
  35. pbar: Optional[tqdm] = None,
  36. ) -> RequestFuncOutput:
  37. api_url = request_func_input.api_url
  38. assert api_url.endswith("generate_stream")
  39. async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
  40. assert not request_func_input.use_beam_search
  41. params = {
  42. "best_of": request_func_input.best_of,
  43. "max_new_tokens": request_func_input.output_len,
  44. "do_sample": True,
  45. "temperature": 0.01, # TGI does not accept 0.0 temperature.
  46. "top_p": 0.99, # TGI does not accept 1.0 top_p.
  47. }
  48. payload = {
  49. "inputs": request_func_input.prompt,
  50. "parameters": params,
  51. }
  52. output = RequestFuncOutput()
  53. output.prompt_len = request_func_input.prompt_len
  54. ttft = 0.0
  55. st = time.perf_counter()
  56. most_recent_timestamp = st
  57. try:
  58. async with session.post(url=api_url, json=payload) as response:
  59. if response.status == 200:
  60. async for chunk_bytes in response.content:
  61. chunk_bytes = chunk_bytes.strip()
  62. if not chunk_bytes:
  63. continue
  64. chunk_bytes = chunk_bytes.decode("utf-8")
  65. #NOTE: Sometimes TGI returns a ping response without
  66. # any data, we should skip it.
  67. if chunk_bytes.startswith(":"):
  68. continue
  69. chunk = remove_prefix(chunk_bytes, "data:")
  70. data = json.loads(chunk)
  71. timestamp = time.perf_counter()
  72. # First token
  73. if ttft == 0.0:
  74. ttft = time.perf_counter() - st
  75. output.ttft = ttft
  76. # Decoding phase
  77. else:
  78. output.itl.append(timestamp -
  79. most_recent_timestamp)
  80. most_recent_timestamp = timestamp
  81. output.latency = most_recent_timestamp - st
  82. output.success = True
  83. output.generated_text = data["generated_text"]
  84. else:
  85. output.error = response.reason or ""
  86. output.success = False
  87. except Exception:
  88. output.success = False
  89. exc_info = sys.exc_info()
  90. output.error = "".join(traceback.format_exception(*exc_info))
  91. if pbar:
  92. pbar.update(1)
  93. return output
  94. async def async_request_trt_llm(
  95. request_func_input: RequestFuncInput,
  96. pbar: Optional[tqdm] = None,
  97. ) -> RequestFuncOutput:
  98. api_url = request_func_input.api_url
  99. assert api_url.endswith("generate_stream")
  100. async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
  101. assert not request_func_input.use_beam_search
  102. assert request_func_input.best_of == 1
  103. payload = {
  104. "accumulate_tokens": True,
  105. "text_input": request_func_input.prompt,
  106. "temperature": 0.0,
  107. "top_p": 1.0,
  108. "max_tokens": request_func_input.output_len,
  109. "stream": True,
  110. }
  111. output = RequestFuncOutput()
  112. output.prompt_len = request_func_input.prompt_len
  113. ttft = 0.0
  114. st = time.perf_counter()
  115. most_recent_timestamp = st
  116. try:
  117. async with session.post(url=api_url, json=payload) as response:
  118. if response.status == 200:
  119. async for chunk_bytes in response.content:
  120. chunk_bytes = chunk_bytes.strip()
  121. if not chunk_bytes:
  122. continue
  123. chunk = remove_prefix(chunk_bytes.decode("utf-8"),
  124. "data:")
  125. data = json.loads(chunk)
  126. output.generated_text += data["text_output"]
  127. timestamp = time.perf_counter()
  128. # First token
  129. if ttft == 0.0:
  130. ttft = time.perf_counter() - st
  131. output.ttft = ttft
  132. # Decoding phase
  133. else:
  134. output.itl.append(timestamp -
  135. most_recent_timestamp)
  136. most_recent_timestamp = timestamp
  137. output.latency = most_recent_timestamp - st
  138. output.success = True
  139. else:
  140. output.error = response.reason or ""
  141. output.success = False
  142. except Exception:
  143. output.success = False
  144. exc_info = sys.exc_info()
  145. output.error = "".join(traceback.format_exception(*exc_info))
  146. if pbar:
  147. pbar.update(1)
  148. return output
  149. async def async_request_deepspeed_mii(
  150. request_func_input: RequestFuncInput,
  151. pbar: Optional[tqdm] = None,
  152. ) -> RequestFuncOutput:
  153. async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
  154. assert request_func_input.best_of == 1
  155. assert not request_func_input.use_beam_search
  156. payload = {
  157. "prompt": request_func_input.prompt,
  158. "max_tokens": request_func_input.output_len,
  159. "temperature": 0.01, # deepspeed-mii does not accept 0.0 temp.
  160. "top_p": 1.0,
  161. }
  162. output = RequestFuncOutput()
  163. output.prompt_len = request_func_input.prompt_len
  164. # NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
  165. # will use 0 as placeholder.
  166. # See https://github.com/microsoft/DeepSpeed-MII/pull/311
  167. output.ttft = 0
  168. st = time.perf_counter()
  169. try:
  170. async with session.post(url=request_func_input.api_url,
  171. json=payload) as response:
  172. if response.status == 200:
  173. parsed_resp = await response.json()
  174. output.latency = time.perf_counter() - st
  175. output.generated_text = parsed_resp["text"][0]
  176. output.success = True
  177. else:
  178. output.error = response.reason or ""
  179. output.success = False
  180. except Exception:
  181. output.success = False
  182. exc_info = sys.exc_info()
  183. output.error = "".join(traceback.format_exception(*exc_info))
  184. if pbar:
  185. pbar.update(1)
  186. return output
  187. async def async_request_openai_completions(
  188. request_func_input: RequestFuncInput,
  189. pbar: Optional[tqdm] = None,
  190. ) -> RequestFuncOutput:
  191. api_url = request_func_input.api_url
  192. assert api_url.endswith(
  193. "completions"
  194. ), "OpenAI Completions API URL must end with 'completions'."
  195. async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
  196. assert not request_func_input.use_beam_search
  197. payload = {
  198. "model": request_func_input.model,
  199. "prompt": request_func_input.prompt,
  200. "temperature": 0.0,
  201. "best_of": request_func_input.best_of,
  202. "max_tokens": request_func_input.output_len,
  203. "stream": True,
  204. }
  205. headers = {
  206. "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
  207. }
  208. output = RequestFuncOutput()
  209. output.prompt_len = request_func_input.prompt_len
  210. generated_text = ""
  211. ttft = 0.0
  212. st = time.perf_counter()
  213. most_recent_timestamp = st
  214. try:
  215. async with session.post(url=api_url, json=payload,
  216. headers=headers) as response:
  217. if response.status == 200:
  218. async for chunk_bytes in response.content:
  219. chunk_bytes = chunk_bytes.strip()
  220. if not chunk_bytes:
  221. continue
  222. chunk = remove_prefix(chunk_bytes.decode("utf-8"),
  223. "data: ")
  224. if chunk == "[DONE]":
  225. latency = time.perf_counter() - st
  226. else:
  227. data = json.loads(chunk)
  228. # NOTE: Some completion API might have a last
  229. # usage summary response without a token so we
  230. # want to check a token was generated
  231. if data["choices"][0]["text"]:
  232. timestamp = time.perf_counter()
  233. # First token
  234. if ttft == 0.0:
  235. ttft = time.perf_counter() - st
  236. output.ttft = ttft
  237. # Decoding phase
  238. else:
  239. output.itl.append(timestamp -
  240. most_recent_timestamp)
  241. most_recent_timestamp = timestamp
  242. generated_text += data["choices"][0]["text"]
  243. output.generated_text = generated_text
  244. output.success = True
  245. output.latency = latency
  246. else:
  247. output.error = response.reason or ""
  248. output.success = False
  249. except Exception:
  250. output.success = False
  251. exc_info = sys.exc_info()
  252. output.error = "".join(traceback.format_exception(*exc_info))
  253. if pbar:
  254. pbar.update(1)
  255. return output
  256. async def async_request_openai_chat_completions(
  257. request_func_input: RequestFuncInput,
  258. pbar: Optional[tqdm] = None,
  259. ) -> RequestFuncOutput:
  260. api_url = request_func_input.api_url
  261. assert api_url.endswith(
  262. "chat/completions"
  263. ), "OpenAI Chat Completions API URL must end with 'chat/completions'."
  264. async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
  265. assert not request_func_input.use_beam_search
  266. payload = {
  267. "model": request_func_input.model,
  268. "messages": [
  269. {
  270. "role": "user",
  271. "content": request_func_input.prompt,
  272. },
  273. ],
  274. "temperature": 0.0,
  275. "max_tokens": request_func_input.output_len,
  276. "stream": True,
  277. }
  278. headers = {
  279. "Content-Type": "application/json",
  280. "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
  281. }
  282. output = RequestFuncOutput()
  283. output.prompt_len = request_func_input.prompt_len
  284. generated_text = ""
  285. ttft = 0.0
  286. st = time.perf_counter()
  287. most_recent_timestamp = st
  288. try:
  289. async with session.post(url=api_url, json=payload,
  290. headers=headers) as response:
  291. if response.status == 200:
  292. async for chunk_bytes in response.content:
  293. chunk_bytes = chunk_bytes.strip()
  294. if not chunk_bytes:
  295. continue
  296. chunk = remove_prefix(chunk_bytes.decode("utf-8"),
  297. "data: ")
  298. if chunk == "[DONE]":
  299. latency = time.perf_counter() - st
  300. else:
  301. timestamp = time.perf_counter()
  302. data = json.loads(chunk)
  303. delta = data["choices"][0]["delta"]
  304. if delta.get("content", None):
  305. # First token
  306. if ttft == 0.0:
  307. ttft = time.perf_counter() - st
  308. output.ttft = ttft
  309. # Decoding phase
  310. else:
  311. output.itl.append(timestamp -
  312. most_recent_timestamp)
  313. generated_text += delta["content"]
  314. most_recent_timestamp = timestamp
  315. output.generated_text = generated_text
  316. output.success = True
  317. output.latency = latency
  318. else:
  319. output.error = response.reason or ""
  320. output.success = False
  321. except Exception:
  322. output.success = False
  323. exc_info = sys.exc_info()
  324. output.error = "".join(traceback.format_exception(*exc_info))
  325. if pbar:
  326. pbar.update(1)
  327. return output
  328. # Since aphrodite must support Python 3.8, we can't use str.removeprefix(prefix)
  329. # introduced in Python 3.9
  330. def remove_prefix(text: str, prefix: str) -> str:
  331. if text.startswith(prefix):
  332. return text[len(prefix):]
  333. return text
  334. def get_model(pretrained_model_name_or_path: str) -> str:
  335. if os.getenv('APHRODITE_USE_MODELSCOPE', 'False').lower() == 'true':
  336. from modelscope import snapshot_download
  337. model_path = snapshot_download(
  338. model_id=pretrained_model_name_or_path,
  339. local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
  340. ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
  341. return model_path
  342. return pretrained_model_name_or_path
  343. def get_tokenizer(
  344. pretrained_model_name_or_path: str, trust_remote_code: bool
  345. ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
  346. if pretrained_model_name_or_path is not None and not os.path.exists(
  347. pretrained_model_name_or_path):
  348. pretrained_model_name_or_path = get_model(
  349. pretrained_model_name_or_path)
  350. return AutoTokenizer.from_pretrained(pretrained_model_name_or_path,
  351. trust_remote_code=trust_remote_code)
  352. ASYNC_REQUEST_FUNCS = {
  353. "tgi": async_request_tgi,
  354. "aphrodite": async_request_openai_completions,
  355. "vllm": async_request_openai_completions,
  356. "lmdeploy": async_request_openai_completions,
  357. "deepspeed-mii": async_request_deepspeed_mii,
  358. "openai": async_request_openai_completions,
  359. "openai-chat": async_request_openai_chat_completions,
  360. "tensorrt-llm": async_request_trt_llm,
  361. "scalellm": async_request_openai_completions,
  362. }