backend_request_func.py 11 KB


  1. import json
  2. import os
  3. import time
  4. from dataclasses import dataclass
  5. from typing import Optional
  6. import aiohttp
  7. from tqdm.asyncio import tqdm
  8. AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
  9. @dataclass
  10. class RequestFuncInput:
  11. prompt: str
  12. api_url: str
  13. prompt_len: int
  14. output_len: int
  15. model: str
  16. best_of: int = 1
  17. use_beam_search: bool = False
  18. @dataclass
  19. class RequestFuncOutput:
  20. generated_text: str = ""
  21. success: bool = False
  22. latency: float = 0
  23. ttft: float = 0
  24. prompt_len: int = 0
  25. async def async_request_tgi(
  26. request_func_input: RequestFuncInput,
  27. pbar: Optional[tqdm] = None,
  28. ) -> RequestFuncOutput:
  29. api_url = request_func_input.api_url
  30. assert api_url.endswith("generate_stream")
  31. async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
  32. assert not request_func_input.use_beam_search
  33. params = {
  34. "best_of": request_func_input.best_of,
  35. "max_new_tokens": request_func_input.output_len,
  36. "do_sample": True,
  37. "temperature": 0.01, # TGI does not accept 0.0 temperature.
  38. "top_p": 0.99, # TGI does not accept 1.0 top_p.
  39. }
  40. payload = {
  41. "inputs": request_func_input.prompt,
  42. "parameters": params,
  43. }
  44. output = RequestFuncOutput()
  45. output.prompt_len = request_func_input.prompt_len
  46. ttft = 0
  47. st = time.perf_counter()
  48. try:
  49. async with session.post(url=api_url, json=payload) as response:
  50. if response.status == 200:
  51. async for data in response.content.iter_any():
  52. if ttft == 0:
  53. ttft = time.perf_counter() - st
  54. output.ttft = ttft
  55. output.latency = time.perf_counter() - st
  56. body = data.decode("utf-8").lstrip("data:")
  57. output.generated_text = json.loads(body)["generated_text"]
  58. output.success = True
  59. else:
  60. output.success = False
  61. except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
  62. output.success = False
  63. if pbar:
  64. pbar.update(1)
  65. return output
  66. async def async_request_aphrodite(
  67. request_func_input: RequestFuncInput,
  68. pbar: Optional[tqdm] = None,
  69. ) -> RequestFuncOutput:
  70. api_url = request_func_input.api_url
  71. assert api_url.endswith("generate")
  72. async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
  73. payload = {
  74. "prompt": request_func_input.prompt,
  75. "n": 1,
  76. "best_of": request_func_input.best_of,
  77. "temperature": 0.0 if request_func_input.use_beam_search else 1.0,
  78. "top_p": 1.0,
  79. "min_p": 0.06,
  80. "seed": 42,
  81. "max_tokens": request_func_input.output_len,
  82. "ignore_eos": True,
  83. "stream": True,
  84. }
  85. output = RequestFuncOutput()
  86. output.prompt_len = request_func_input.prompt_len
  87. ttft = 0
  88. st = time.perf_counter()
  89. try:
  90. async with session.post(url=api_url, json=payload) as response:
  91. if response.status == 200:
  92. async for data in response.content.iter_any():
  93. if ttft == 0:
  94. ttft = time.perf_counter() - st
  95. output.ttft = ttft
  96. output.latency = time.perf_counter() - st
  97. # When streaming, '\0' is appended to the end of the
  98. # response.
  99. body = data.decode("utf-8").strip("\0")
  100. output.generated_text = json.loads(
  101. body)["text"][0][len(request_func_input.prompt):]
  102. output.success = True
  103. else:
  104. output.success = False
  105. except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
  106. output.success = False
  107. if pbar:
  108. pbar.update(1)
  109. return output
  110. async def async_request_vllm(
  111. request_func_input: RequestFuncInput,
  112. pbar: Optional[tqdm] = None,
  113. ) -> RequestFuncOutput:
  114. api_url = request_func_input.api_url
  115. assert api_url.endswith("generate")
  116. async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
  117. payload = {
  118. "prompt": request_func_input.prompt,
  119. "n": 1,
  120. "best_of": request_func_input.best_of,
  121. "use_beam_search": request_func_input.use_beam_search,
  122. "temperature": 0.0 if request_func_input.use_beam_search else 1.0,
  123. "top_p": 1.0,
  124. "max_tokens": request_func_input.output_len,
  125. "ignore_eos": True,
  126. "stream": True,
  127. }
  128. output = RequestFuncOutput()
  129. output.prompt_len = request_func_input.prompt_len
  130. ttft = 0
  131. st = time.perf_counter()
  132. try:
  133. async with session.post(url=api_url, json=payload) as response:
  134. if response.status == 200:
  135. async for data in response.content.iter_any():
  136. if ttft == 0:
  137. ttft = time.perf_counter() - st
  138. output.ttft = ttft
  139. output.latency = time.perf_counter() - st
  140. # When streaming, '\0' is appended to the end of the
  141. # response.
  142. body = data.decode("utf-8").strip("\0")
  143. output.generated_text = json.loads(
  144. body)["text"][0][len(request_func_input.prompt):]
  145. output.success = True
  146. else:
  147. output.success = False
  148. except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
  149. output.success = False
  150. if pbar:
  151. pbar.update(1)
  152. return output
  153. async def async_request_trt_llm(
  154. request_func_input: RequestFuncInput,
  155. pbar: Optional[tqdm] = None,
  156. ) -> RequestFuncOutput:
  157. api_url = request_func_input.api_url
  158. assert api_url.endswith("generate_stream")
  159. async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
  160. assert not request_func_input.use_beam_search
  161. assert request_func_input.best_of == 1
  162. payload = {
  163. "accumulate_tokens": True,
  164. "text_input": request_func_input.prompt,
  165. "temperature": 0.0,
  166. "top_p": 1.0,
  167. "max_tokens": request_func_input.output_len,
  168. "stream": True,
  169. }
  170. output = RequestFuncOutput()
  171. output.prompt_len = request_func_input.prompt_len
  172. ttft = 0
  173. st = time.perf_counter()
  174. try:
  175. async with session.post(url=api_url, json=payload) as resp:
  176. if resp.status == 200:
  177. async for data in resp.content.iter_any():
  178. if ttft == 0:
  179. ttft = time.perf_counter() - st
  180. output.ttft = ttft
  181. output.latency = time.perf_counter() - st
  182. body = data.decode("utf-8").lstrip("data:")
  183. output.generated_text = json.loads(body)["text_output"]
  184. output.success = True
  185. else:
  186. output.success = False
  187. except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
  188. output.success = False
  189. if pbar:
  190. pbar.update(1)
  191. return output
  192. async def async_request_deepspeed_mii(
  193. request_func_input: RequestFuncInput,
  194. pbar: Optional[tqdm] = None,
  195. ) -> RequestFuncOutput:
  196. async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
  197. assert request_func_input.best_of == 1
  198. assert not request_func_input.use_beam_search
  199. payload = {
  200. "prompts": request_func_input.prompt,
  201. "max_new_tokens": request_func_input.output_len,
  202. "ignore_eos": True,
  203. "do_sample": True,
  204. "temperature":
  205. 0.01, # deepspeed-mii does not accept 0.0 temperature.
  206. "top_p": 1.0,
  207. }
  208. output = RequestFuncOutput()
  209. output.prompt_len = request_func_input.prompt_len
  210. # DeepSpeed-MII doesn't support streaming as of Jan 28 2024, will use
  211. # 0 as placeholder.
  212. # https://github.com/microsoft/DeepSpeed-MII/pull/311
  213. output.ttft = 0
  214. st = time.perf_counter()
  215. try:
  216. async with session.post(url=request_func_input.api_url,
  217. json=payload) as resp:
  218. if resp.status == 200:
  219. parsed_resp = await resp.json()
  220. output.latency = time.perf_counter() - st
  221. output.generated_text = parsed_resp[0]["generated_text"]
  222. output.success = True
  223. else:
  224. output.success = False
  225. except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
  226. output.success = False
  227. if pbar:
  228. pbar.update(1)
  229. return output
  230. async def async_request_openai_completions(
  231. request_func_input: RequestFuncInput,
  232. pbar: Optional[tqdm] = None,
  233. ) -> RequestFuncOutput:
  234. api_url = request_func_input.api_url
  235. assert api_url.endswith("v1/completions")
  236. async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
  237. assert not request_func_input.use_beam_search
  238. payload = {
  239. "model": request_func_input.model,
  240. "prompt": request_func_input.prompt,
  241. "temperature": 0.0,
  242. "best_of": request_func_input.best_of,
  243. "max_tokens": request_func_input.output_len,
  244. "stream": True,
  245. }
  246. headers = {
  247. "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
  248. }
  249. output = RequestFuncOutput()
  250. output.prompt_len = request_func_input.prompt_len
  251. generated_text = ""
  252. ttft = 0
  253. st = time.perf_counter()
  254. try:
  255. async with session.post(url=api_url, json=payload,
  256. headers=headers) as response:
  257. if response.status == 200:
  258. async for chunk in response.content:
  259. if ttft == 0:
  260. ttft = time.perf_counter() - st
  261. output.ttft = ttft
  262. chunk = chunk.strip()
  263. if not chunk:
  264. continue
  265. chunk = chunk.decode("utf-8").lstrip("data: ")
  266. if chunk == "[DONE]":
  267. latency = time.perf_counter() - st
  268. else:
  269. body = json.loads(chunk)
  270. generated_text += body["choices"][0]["text"]
  271. output.generated_text = generated_text
  272. output.success = True
  273. output.latency = latency
  274. else:
  275. output.success = False
  276. except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
  277. output.success = False
  278. if pbar:
  279. pbar.update(1)
  280. return output
  281. ASYNC_REQUEST_FUNCS = {
  282. "tgi": async_request_tgi,
  283. "aphrodite": async_request_aphrodite,
  284. "vllm": async_request_vllm,
  285. "deepspeed-mii": async_request_deepspeed_mii,
  286. "openai": async_request_openai_completions,
  287. "tensorrt-llm": async_request_trt_llm,
  288. }