api_client.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import argparse
  2. import json
  3. from typing import Iterable, List
  4. import requests
  5. def clear_line(n: int = 1) -> None:
  6. LINE_UP = '\033[1A'
  7. LINE_CLEAR = '\x1b[2K'
  8. for _ in range(n):
  9. print(LINE_UP, end=LINE_CLEAR, flush=True)
  10. def post_http_request(prompt: str, api_url: str, n: int = 1,
  11. stream: bool = False) -> requests.Response:
  12. headers = {"User-Agent": "Test Client"}
  13. pload = {
  14. "prompt": prompt,
  15. "n": n,
  16. "use_beam_search": True,
  17. "temperature": 0.0,
  18. "max_tokens": 28,
  19. "stream": stream,
  20. }
  21. response = requests.post(api_url, headers=headers, json=pload, stream=True)
  22. return response
  23. def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
  24. for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
  25. if chunk:
  26. data = json.loads(chunk.decode("utf-8"))
  27. output = data["text"]
  28. yield output
  29. def get_response(response: requests.Response) -> List[str]:
  30. data = json.loads(response.content)
  31. output = data["text"]
  32. return output
  33. if __name__ == "__main__":
  34. parser = argparse.ArgumentParser()
  35. parser.add_argument("--host", type=str, default="localhost")
  36. parser.add_argument("--port", type=int, default=8000)
  37. parser.add_argument("--n", type=int, default=4)
  38. parser.add_argument("--prompt", type=str, default="What is a man? A")
  39. parser.add_argument("--stream", action="store_true")
  40. args = parser.parse_args()
  41. prompt = args.prompt
  42. api_url = f"http://{args.host}:{args.port}/generate"
  43. n = args
  44. stream = args.stream
  45. print(f"Prompt: {prompt!r}\n", flush=True)
  46. response = post_http_request(prompt, api_url, n, stream)
  47. if stream:
  48. num_printed_lines = 0
  49. for h in get_streaming_response(response):
  50. clear_line(num_printed_lines)
  51. num_printed_lines = 0
  52. for i, line in enumerate(h):
  53. num_printed_lines += 1
  54. print(f"Beam candidate {i}: {line!r}", flush=True)
  55. else:
  56. output = get_response(response)
  57. for i, line in enumerate(output):
  58. print(f"Beam candidate {i}: {line!r}", flush=True)