1
0

backend_request_func.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. import json
  2. import os
  3. import sys
  4. import time
  5. import traceback
  6. from dataclasses import dataclass, field
  7. from typing import List, Optional, Union
  8. import aiohttp
  9. import huggingface_hub.constants
  10. from tqdm.asyncio import tqdm
  11. from transformers import (AutoTokenizer, PreTrainedTokenizer,
  12. PreTrainedTokenizerFast)
  13. AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
  14. @dataclass
  15. class RequestFuncInput:
  16. prompt: str
  17. api_url: str
  18. prompt_len: int
  19. output_len: int
  20. model: str
  21. best_of: int = 1
  22. use_beam_search: bool = False
  23. @dataclass
  24. class RequestFuncOutput:
  25. generated_text: str = ""
  26. success: bool = False
  27. latency: float = 0.0
  28. ttft: float = 0.0 # Time to first token
  29. itl: List[float] = field(
  30. default_factory=list) # List of inter-token latencies
  31. prompt_len: int = 0
  32. error: str = ""
  33. async def async_request_tgi(
  34. request_func_input: RequestFuncInput,
  35. pbar: Optional[tqdm] = None,
  36. ) -> RequestFuncOutput:
  37. api_url = request_func_input.api_url
  38. assert api_url.endswith("generate_stream")
  39. async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
  40. assert not request_func_input.use_beam_search
  41. params = {
  42. "best_of": request_func_input.best_of,
  43. "max_new_tokens": request_func_input.output_len,
  44. "do_sample": True,
  45. "temperature": 0.01, # TGI does not accept 0.0 temperature.
  46. "top_p": 0.99, # TGI does not accept 1.0 top_p.
  47. }
  48. payload = {
  49. "inputs": request_func_input.prompt,
  50. "parameters": params,
  51. }
  52. output = RequestFuncOutput()
  53. output.prompt_len = request_func_input.prompt_len
  54. ttft = 0.0
  55. st = time.perf_counter()
  56. most_recent_timestamp = st
  57. try:
  58. async with session.post(url=api_url, json=payload) as response:
  59. if response.status == 200:
  60. async for chunk_bytes in response.content:
  61. chunk_bytes = chunk_bytes.strip()
  62. if not chunk_bytes:
  63. continue
  64. chunk_bytes = chunk_bytes.decode("utf-8")
  65. #NOTE: Sometimes TGI returns a ping response without
  66. # any data, we should skip it.
  67. if chunk_bytes.startswith(":"):
  68. continue
  69. chunk = remove_prefix(chunk_bytes, "data:")
  70. data = json.loads(chunk)
  71. timestamp = time.perf_counter()
  72. # First token
  73. if ttft == 0.0:
  74. ttft = time.perf_counter() - st
  75. output.ttft = ttft
  76. # Decoding phase
  77. else:
  78. output.itl.append(timestamp -
  79. most_recent_timestamp)
  80. most_recent_timestamp = timestamp
  81. output.latency = most_recent_timestamp - st
  82. output.success = True
  83. output.generated_text = data["generated_text"]
  84. else:
  85. output.error = response.reason or ""
  86. output.success = False
  87. except Exception:
  88. output.success = False
  89. exc_info = sys.exc_info()
  90. output.error = "".join(traceback.format_exception(*exc_info))
  91. if pbar:
  92. pbar.update(1)
  93. return output
  94. async def async_request_trt_llm(
  95. request_func_input: RequestFuncInput,
  96. pbar: Optional[tqdm] = None,
  97. ) -> RequestFuncOutput:
  98. api_url = request_func_input.api_url
  99. assert api_url.endswith("generate_stream")
  100. async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
  101. assert not request_func_input.use_beam_search
  102. assert request_func_input.best_of == 1
  103. payload = {
  104. "accumulate_tokens": True,
  105. "text_input": request_func_input.prompt,
  106. "temperature": 0.0,
  107. "top_p": 1.0,
  108. "max_tokens": request_func_input.output_len,
  109. "stream": True,
  110. }
  111. output = RequestFuncOutput()
  112. output.prompt_len = request_func_input.prompt_len
  113. ttft = 0.0
  114. st = time.perf_counter()
  115. most_recent_timestamp = st
  116. try:
  117. async with session.post(url=api_url, json=payload) as response:
  118. if response.status == 200:
  119. async for chunk_bytes in response.content:
  120. chunk_bytes = chunk_bytes.strip()
  121. if not chunk_bytes:
  122. continue
  123. chunk = remove_prefix(chunk_bytes.decode("utf-8"),
  124. "data:")
  125. data = json.loads(chunk)
  126. output.generated_text += data["text_output"]
  127. timestamp = time.perf_counter()
  128. # First token
  129. if ttft == 0.0:
  130. ttft = time.perf_counter() - st
  131. output.ttft = ttft
  132. # Decoding phase
  133. else:
  134. output.itl.append(timestamp -
  135. most_recent_timestamp)
  136. most_recent_timestamp = timestamp
  137. output.latency = most_recent_timestamp - st
  138. output.success = True
  139. else:
  140. output.error = response.reason or ""
  141. output.success = False
  142. except Exception:
  143. output.success = False
  144. exc_info = sys.exc_info()
  145. output.error = "".join(traceback.format_exception(*exc_info))
  146. if pbar:
  147. pbar.update(1)
  148. return output
  149. async def async_request_deepspeed_mii(
  150. request_func_input: RequestFuncInput,
  151. pbar: Optional[tqdm] = None,
  152. ) -> RequestFuncOutput:
  153. async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
  154. assert request_func_input.best_of == 1
  155. assert not request_func_input.use_beam_search
  156. payload = {
  157. "prompt": request_func_input.prompt,
  158. "max_tokens": request_func_input.output_len,
  159. "temperature": 0.01, # deepspeed-mii does not accept 0.0 temp.
  160. "top_p": 1.0,
  161. }
  162. output = RequestFuncOutput()
  163. output.prompt_len = request_func_input.prompt_len
  164. # NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
  165. # will use 0 as placeholder.
  166. # See https://github.com/microsoft/DeepSpeed-MII/pull/311
  167. output.ttft = 0
  168. st = time.perf_counter()
  169. try:
  170. async with session.post(url=request_func_input.api_url,
  171. json=payload) as response:
  172. if response.status == 200:
  173. parsed_resp = await response.json()
  174. output.latency = time.perf_counter() - st
  175. output.generated_text = parsed_resp["text"][0]
  176. output.success = True
  177. else:
  178. output.error = response.reason or ""
  179. output.success = False
  180. except Exception:
  181. output.success = False
  182. exc_info = sys.exc_info()
  183. output.error = "".join(traceback.format_exception(*exc_info))
  184. if pbar:
  185. pbar.update(1)
  186. return output
  187. async def async_request_openai_completions(
  188. request_func_input: RequestFuncInput,
  189. pbar: Optional[tqdm] = None,
  190. ) -> RequestFuncOutput:
  191. api_url = request_func_input.api_url
  192. assert api_url.endswith(
  193. "completions"
  194. ), "OpenAI Completions API URL must end with 'completions'."
  195. async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
  196. assert not request_func_input.use_beam_search
  197. payload = {
  198. "model": request_func_input.model,
  199. "prompt": request_func_input.prompt,
  200. "temperature": 0.0,
  201. "best_of": request_func_input.best_of,
  202. "max_tokens": request_func_input.output_len,
  203. "stream": True,
  204. }
  205. headers = {
  206. "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
  207. }
  208. output = RequestFuncOutput()
  209. output.prompt_len = request_func_input.prompt_len
  210. generated_text = ""
  211. ttft = 0.0
  212. st = time.perf_counter()
  213. most_recent_timestamp = st
  214. try:
  215. async with session.post(url=api_url, json=payload,
  216. headers=headers) as response:
  217. if response.status == 200:
  218. async for chunk_bytes in response.content:
  219. chunk_bytes = chunk_bytes.strip()
  220. if not chunk_bytes:
  221. continue
  222. chunk = remove_prefix(chunk_bytes.decode("utf-8"),
  223. "data: ")
  224. if chunk == "[DONE]":
  225. latency = time.perf_counter() - st
  226. else:
  227. data = json.loads(chunk)
  228. # NOTE: Some completion API might have a last
  229. # usage summary response without a token so we
  230. # want to check a token was generated
  231. if data["choices"][0]["text"]:
  232. timestamp = time.perf_counter()
  233. # First token
  234. if ttft == 0.0:
  235. ttft = time.perf_counter() - st
  236. output.ttft = ttft
  237. # Decoding phase
  238. output.itl.append(timestamp -
  239. most_recent_timestamp)
  240. most_recent_timestamp = timestamp
  241. generated_text += data["choices"][0]["text"]
  242. output.generated_text = generated_text
  243. output.success = True
  244. output.latency = latency
  245. else:
  246. output.error = response.reason or ""
  247. output.success = False
  248. except Exception:
  249. output.success = False
  250. exc_info = sys.exc_info()
  251. output.error = "".join(traceback.format_exception(*exc_info))
  252. if pbar:
  253. pbar.update(1)
  254. return output
  255. async def async_request_openai_chat_completions(
  256. request_func_input: RequestFuncInput,
  257. pbar: Optional[tqdm] = None,
  258. ) -> RequestFuncOutput:
  259. api_url = request_func_input.api_url
  260. assert api_url.endswith(
  261. "chat/completions"
  262. ), "OpenAI Chat Completions API URL must end with 'chat/completions'."
  263. async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
  264. assert not request_func_input.use_beam_search
  265. payload = {
  266. "model": request_func_input.model,
  267. "messages": [
  268. {
  269. "role": "user",
  270. "content": request_func_input.prompt,
  271. },
  272. ],
  273. "temperature": 0.0,
  274. "max_tokens": request_func_input.output_len,
  275. "stream": True,
  276. }
  277. headers = {
  278. "Content-Type": "application/json",
  279. "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
  280. }
  281. output = RequestFuncOutput()
  282. output.prompt_len = request_func_input.prompt_len
  283. generated_text = ""
  284. ttft = 0.0
  285. st = time.perf_counter()
  286. most_recent_timestamp = st
  287. try:
  288. async with session.post(url=api_url, json=payload,
  289. headers=headers) as response:
  290. if response.status == 200:
  291. async for chunk_bytes in response.content:
  292. chunk_bytes = chunk_bytes.strip()
  293. if not chunk_bytes:
  294. continue
  295. chunk = remove_prefix(chunk_bytes.decode("utf-8"),
  296. "data: ")
  297. if chunk == "[DONE]":
  298. latency = time.perf_counter() - st
  299. else:
  300. timestamp = time.perf_counter()
  301. data = json.loads(chunk)
  302. delta = data["choices"][0]["delta"]
  303. if delta.get("content", None):
  304. # First token
  305. if ttft == 0.0:
  306. ttft = time.perf_counter() - st
  307. output.ttft = ttft
  308. # Decoding phase
  309. else:
  310. output.itl.append(timestamp -
  311. most_recent_timestamp)
  312. generated_text += delta["content"]
  313. most_recent_timestamp = timestamp
  314. output.generated_text = generated_text
  315. output.success = True
  316. output.latency = latency
  317. else:
  318. output.error = response.reason or ""
  319. output.success = False
  320. except Exception:
  321. output.success = False
  322. exc_info = sys.exc_info()
  323. output.error = "".join(traceback.format_exception(*exc_info))
  324. if pbar:
  325. pbar.update(1)
  326. return output
  327. # Since aphrodite must support Python 3.8, we can't use str.removeprefix(prefix)
  328. # introduced in Python 3.9
  329. def remove_prefix(text: str, prefix: str) -> str:
  330. if text.startswith(prefix):
  331. return text[len(prefix):]
  332. return text
  333. def get_model(pretrained_model_name_or_path: str) -> str:
  334. if os.getenv('APHRODITE_USE_MODELSCOPE', 'False').lower() == 'true':
  335. from modelscope import snapshot_download
  336. model_path = snapshot_download(
  337. model_id=pretrained_model_name_or_path,
  338. local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
  339. ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
  340. return model_path
  341. return pretrained_model_name_or_path
  342. def get_tokenizer(
  343. pretrained_model_name_or_path: str, trust_remote_code: bool
  344. ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
  345. if pretrained_model_name_or_path is not None and not os.path.exists(
  346. pretrained_model_name_or_path):
  347. pretrained_model_name_or_path = get_model(
  348. pretrained_model_name_or_path)
  349. return AutoTokenizer.from_pretrained(pretrained_model_name_or_path,
  350. trust_remote_code=trust_remote_code)
  351. ASYNC_REQUEST_FUNCS = {
  352. "tgi": async_request_tgi,
  353. "aphrodite": async_request_openai_completions,
  354. "vllm": async_request_openai_completions,
  355. "lmdeploy": async_request_openai_completions,
  356. "deepspeed-mii": async_request_deepspeed_mii,
  357. "openai": async_request_openai_completions,
  358. "openai-chat": async_request_openai_chat_completions,
  359. "tensorrt-llm": async_request_trt_llm,
  360. "scalellm": async_request_openai_completions,
  361. }