123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869 |
- import argparse
- import json
- from typing import Iterable, List
- import requests
- def clear_line(n: int = 1) -> None:
- LINE_UP = '\033[1A'
- LINE_CLEAR = '\x1b[2K'
- for _ in range(n):
- print(LINE_UP, end=LINE_CLEAR, flush=True)
- def post_http_request(prompt: str, api_url: str, n: int = 1,
- stream: bool = False) -> requests.Response:
- headers = {"User-Agent": "Test Client"}
- pload = {
- "prompt": prompt,
- "n": n,
- "use_beam_search": True,
- "temperature": 0.0,
- "max_tokens": 28,
- "stream": stream,
- }
- response = requests.post(api_url, headers=headers, json=pload, stream=True)
- return response
- def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
- for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
- if chunk:
- data = json.loads(chunk.decode("utf-8"))
- output = data["text"]
- yield output
- def get_response(response: requests.Response) -> List[str]:
- data = json.loads(response.content)
- output = data["text"]
- return output
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("--host", type=str, default="localhost")
- parser.add_argument("--port", type=int, default=8000)
- parser.add_argument("--n", type=int, default=4)
- parser.add_argument("--prompt", type=str, default="What is a man? A")
- parser.add_argument("--stream", action="store_true")
- args = parser.parse_args()
- prompt = args.prompt
- api_url = f"http://{args.host}:{args.port}/generate"
- n = args
- stream = args.stream
- print(f"Prompt: {prompt!r}\n", flush=True)
- response = post_http_request(prompt, api_url, n, stream)
- if stream:
- num_printed_lines = 0
- for h in get_streaming_response(response):
- clear_line(num_printed_lines)
- num_printed_lines = 0
- for i, line in enumerate(h):
- num_printed_lines += 1
- print(f"Beam candidate {i}: {line!r}", flush=True)
- else:
- output = get_response(response)
- for i, line in enumerate(output):
- print(f"Beam candidate {i}: {line!r}", flush=True)
|