backend_request_func.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. import json
  2. import os
  3. import time
  4. from dataclasses import dataclass
  5. from typing import Optional
  6. import aiohttp
  7. from tqdm.asyncio import tqdm
  8. AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
  9. @dataclass
  10. class RequestFuncInput:
  11. prompt: str
  12. api_url: str
  13. prompt_len: int
  14. output_len: int
  15. model: str
  16. best_of: int = 1
  17. use_beam_search: bool = False
  18. @dataclass
  19. class RequestFuncOutput:
  20. generated_text: str = ""
  21. success: bool = False
  22. latency: float = 0
  23. ttft: float = 0
  24. prompt_len: int = 0
  25. async def async_request_tgi(
  26. request_func_input: RequestFuncInput,
  27. pbar: Optional[tqdm] = None,
  28. ) -> RequestFuncOutput:
  29. api_url = request_func_input.api_url
  30. assert api_url.endswith("generate_stream")
  31. async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
  32. assert not request_func_input.use_beam_search
  33. params = {
  34. "best_of": request_func_input.best_of,
  35. "max_new_tokens": request_func_input.output_len,
  36. "do_sample": True,
  37. "temperature": 0.01, # TGI does not accept 0.0 temperature.
  38. "top_p": 0.99, # TGI does not accept 1.0 top_p.
  39. }
  40. payload = {
  41. "inputs": request_func_input.prompt,
  42. "parameters": params,
  43. }
  44. output = RequestFuncOutput()
  45. output.prompt_len = request_func_input.prompt_len
  46. ttft = 0
  47. st = time.perf_counter()
  48. try:
  49. async with session.post(url=api_url, json=payload) as response:
  50. if response.status == 200:
  51. async for data in response.content.iter_any():
  52. if ttft == 0:
  53. ttft = time.perf_counter() - st
  54. output.ttft = ttft
  55. output.latency = time.perf_counter() - st
  56. body = data.decode("utf-8").lstrip("data:")
  57. output.generated_text = json.loads(body)["generated_text"]
  58. output.success = True
  59. else:
  60. output.success = False
  61. except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
  62. output.success = False
  63. if pbar:
  64. pbar.update(1)
  65. return output
  66. async def async_request_aphrodite(
  67. request_func_input: RequestFuncInput,
  68. pbar: Optional[tqdm] = None,
  69. ) -> RequestFuncOutput:
  70. api_url = request_func_input.api_url
  71. assert api_url.endswith("generate")
  72. async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
  73. payload = {
  74. "prompt": request_func_input.prompt,
  75. "n": 1,
  76. "best_of": request_func_input.best_of,
  77. "temperature": 0.0 if request_func_input.use_beam_search else 1.0,
  78. "top_p": 1.0,
  79. "min_p": 0.06,
  80. "seed": 42,
  81. "max_tokens": request_func_input.output_len,
  82. "ignore_eos": True,
  83. "stream": True,
  84. }
  85. output = RequestFuncOutput()
  86. output.prompt_len = request_func_input.prompt_len
  87. ttft = 0
  88. st = time.perf_counter()
  89. try:
  90. async with session.post(url=api_url, json=payload) as response:
  91. if response.status == 200:
  92. async for data in response.content.iter_any():
  93. if ttft == 0:
  94. ttft = time.perf_counter() - st
  95. output.ttft = ttft
  96. output.latency = time.perf_counter() - st
  97. # When streaming, '\0' is appended to the end of the response.
  98. body = data.decode("utf-8").strip("\0")
  99. output.generated_text = json.loads(
  100. body)["text"][0][len(request_func_input.prompt):]
  101. output.success = True
  102. else:
  103. output.success = False
  104. except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
  105. output.success = False
  106. if pbar:
  107. pbar.update(1)
  108. return output
  109. async def async_request_vllm(
  110. request_func_input: RequestFuncInput,
  111. pbar: Optional[tqdm] = None,
  112. ) -> RequestFuncOutput:
  113. api_url = request_func_input.api_url
  114. assert api_url.endswith("generate")
  115. async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
  116. payload = {
  117. "prompt": request_func_input.prompt,
  118. "n": 1,
  119. "best_of": request_func_input.best_of,
  120. "use_beam_search": request_func_input.use_beam_search,
  121. "temperature": 0.0 if request_func_input.use_beam_search else 1.0,
  122. "top_p": 1.0,
  123. "max_tokens": request_func_input.output_len,
  124. "ignore_eos": True,
  125. "stream": True,
  126. }
  127. output = RequestFuncOutput()
  128. output.prompt_len = request_func_input.prompt_len
  129. ttft = 0
  130. st = time.perf_counter()
  131. try:
  132. async with session.post(url=api_url, json=payload) as response:
  133. if response.status == 200:
  134. async for data in response.content.iter_any():
  135. if ttft == 0:
  136. ttft = time.perf_counter() - st
  137. output.ttft = ttft
  138. output.latency = time.perf_counter() - st
  139. # When streaming, '\0' is appended to the end of the response.
  140. body = data.decode("utf-8").strip("\0")
  141. output.generated_text = json.loads(
  142. body)["text"][0][len(request_func_input.prompt):]
  143. output.success = True
  144. else:
  145. output.success = False
  146. except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
  147. output.success = False
  148. if pbar:
  149. pbar.update(1)
  150. return output
  151. async def async_request_trt_llm(
  152. request_func_input: RequestFuncInput,
  153. pbar: Optional[tqdm] = None,
  154. ) -> RequestFuncOutput:
  155. api_url = request_func_input.api_url
  156. assert api_url.endswith("generate_stream")
  157. async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
  158. assert not request_func_input.use_beam_search
  159. assert request_func_input.best_of == 1
  160. payload = {
  161. "accumulate_tokens": True,
  162. "text_input": request_func_input.prompt,
  163. "temperature": 0.0,
  164. "top_p": 1.0,
  165. "max_tokens": request_func_input.output_len,
  166. "stream": True,
  167. }
  168. output = RequestFuncOutput()
  169. output.prompt_len = request_func_input.prompt_len
  170. ttft = 0
  171. st = time.perf_counter()
  172. try:
  173. async with session.post(url=api_url, json=payload) as resp:
  174. if resp.status == 200:
  175. async for data in resp.content.iter_any():
  176. if ttft == 0:
  177. ttft = time.perf_counter() - st
  178. output.ttft = ttft
  179. output.latency = time.perf_counter() - st
  180. body = data.decode("utf-8").lstrip("data:")
  181. output.generated_text = json.loads(body)["text_output"]
  182. output.success = True
  183. else:
  184. output.success = False
  185. except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
  186. output.success = False
  187. if pbar:
  188. pbar.update(1)
  189. return output
  190. async def async_request_deepspeed_mii(
  191. request_func_input: RequestFuncInput,
  192. pbar: Optional[tqdm] = None,
  193. ) -> RequestFuncOutput:
  194. async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
  195. assert request_func_input.best_of == 1
  196. assert not request_func_input.use_beam_search
  197. payload = {
  198. "prompts": request_func_input.prompt,
  199. "max_new_tokens": request_func_input.output_len,
  200. "ignore_eos": True,
  201. "do_sample": True,
  202. "temperature":
  203. 0.01, # deepspeed-mii does not accept 0.0 temperature.
  204. "top_p": 1.0,
  205. }
  206. output = RequestFuncOutput()
  207. output.prompt_len = request_func_input.prompt_len
  208. # DeepSpeed-MII doesn't support streaming as of Jan 28 2024, will use 0 as placeholder.
  209. # https://github.com/microsoft/DeepSpeed-MII/pull/311
  210. output.ttft = 0
  211. st = time.perf_counter()
  212. try:
  213. async with session.post(url=request_func_input.api_url,
  214. json=payload) as resp:
  215. if resp.status == 200:
  216. parsed_resp = await resp.json()
  217. output.latency = time.perf_counter() - st
  218. output.generated_text = parsed_resp[0]["generated_text"]
  219. output.success = True
  220. else:
  221. output.success = False
  222. except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
  223. output.success = False
  224. if pbar:
  225. pbar.update(1)
  226. return output
  227. async def async_request_openai_completions(
  228. request_func_input: RequestFuncInput,
  229. pbar: Optional[tqdm] = None,
  230. ) -> RequestFuncOutput:
  231. api_url = request_func_input.api_url
  232. assert api_url.endswith("v1/completions")
  233. async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
  234. assert not request_func_input.use_beam_search
  235. payload = {
  236. "model": request_func_input.model,
  237. "prompt": request_func_input.prompt,
  238. "temperature": 0.0,
  239. "best_of": request_func_input.best_of,
  240. "max_tokens": request_func_input.output_len,
  241. "stream": True,
  242. }
  243. headers = {
  244. "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
  245. }
  246. output = RequestFuncOutput()
  247. output.prompt_len = request_func_input.prompt_len
  248. generated_text = ""
  249. ttft = 0
  250. st = time.perf_counter()
  251. try:
  252. async with session.post(url=api_url, json=payload,
  253. headers=headers) as response:
  254. if response.status == 200:
  255. async for chunk in response.content:
  256. if ttft == 0:
  257. ttft = time.perf_counter() - st
  258. output.ttft = ttft
  259. chunk = chunk.strip()
  260. if not chunk:
  261. continue
  262. chunk = chunk.decode("utf-8").lstrip("data: ")
  263. if chunk == "[DONE]":
  264. latency = time.perf_counter() - st
  265. else:
  266. body = json.loads(chunk)
  267. generated_text += body["choices"][0]["text"]
  268. output.generated_text = generated_text
  269. output.success = True
  270. output.latency = latency
  271. else:
  272. output.success = False
  273. except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
  274. output.success = False
  275. if pbar:
  276. pbar.update(1)
  277. return output
  278. ASYNC_REQUEST_FUNCS = {
  279. "tgi": async_request_tgi,
  280. "aphrodite": async_request_aphrodite,
  281. "vllm": async_request_vllm,
  282. "deepspeed-mii": async_request_deepspeed_mii,
  283. "openai": async_request_openai_completions,
  284. "tensorrt-llm": async_request_trt_llm,
  285. }