test_openai_server.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599
  1. import os
  2. import subprocess
  3. import time
  4. import sys
  5. import pytest
  6. import requests
  7. import ray # using Ray for overall ease of process management, parallel requests, and debugging.
  8. import openai # use the official client for correctness check
  9. from huggingface_hub import snapshot_download # downloading lora to test lora requests
  10. # imports for guided decoding tests
  11. import json
  12. import jsonschema
  13. import re
  14. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  15. MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
  16. MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here
  17. LORA_NAME = "typeof/zephyr-7b-beta-lora" # technically this needs Mistral-7B-v0.1 as base, but we're not testing generation quality here
  18. TEST_SCHEMA = {
  19. "type": "object",
  20. "properties": {
  21. "name": {
  22. "type": "string"
  23. },
  24. "age": {
  25. "type": "integer"
  26. },
  27. "skills": {
  28. "type": "array",
  29. "items": {
  30. "type": "string",
  31. "maxLength": 10
  32. },
  33. "minItems": 3
  34. },
  35. "work history": {
  36. "type": "array",
  37. "items": {
  38. "type": "object",
  39. "properties": {
  40. "company": {
  41. "type": "string"
  42. },
  43. "duration": {
  44. "type": "string"
  45. },
  46. "position": {
  47. "type": "string"
  48. }
  49. },
  50. "required": ["company", "position"]
  51. }
  52. }
  53. },
  54. "required": ["name", "age", "skills", "work history"]
  55. }
  56. TEST_REGEX = r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + \
  57. r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)"
  58. TEST_CHOICE = [
  59. "Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", "Ruby",
  60. "Swift", "Kotlin"
  61. ]
  62. pytestmark = pytest.mark.asyncio
  63. @ray.remote(num_gpus=1)
  64. class ServerRunner:
  65. def __init__(self, args):
  66. env = os.environ.copy()
  67. env["PYTHONUNBUFFERED"] = "1"
  68. self.proc = subprocess.Popen(
  69. ["python3", "-m", "aphrodite.endpoints.openai.api_server"] + args,
  70. env=env,
  71. stdout=sys.stdout,
  72. stderr=sys.stderr,
  73. )
  74. self._wait_for_server()
  75. def ready(self):
  76. return True
  77. def _wait_for_server(self):
  78. # run health check
  79. start = time.time()
  80. while True:
  81. try:
  82. if requests.get(
  83. "http://localhost:2242/health").status_code == 200:
  84. break
  85. except Exception as err:
  86. if self.proc.poll() is not None:
  87. raise RuntimeError("Server exited unexpectedly.") from err
  88. time.sleep(0.5)
  89. if time.time() - start > MAX_SERVER_START_WAIT_S:
  90. raise RuntimeError(
  91. "Server failed to start in time.") from err
  92. def __del__(self):
  93. if hasattr(self, "proc"):
  94. self.proc.terminate()
  95. @pytest.fixture(scope="session")
  96. def zephyr_lora_files():
  97. return snapshot_download(repo_id=LORA_NAME)
  98. @pytest.fixture(scope="session")
  99. def server(zephyr_lora_files):
  100. ray.init()
  101. server_runner = ServerRunner.remote([
  102. "--model",
  103. MODEL_NAME,
  104. "--dtype",
  105. "bfloat16", # use half precision for speed and memory savings in CI environment
  106. "--max-model-len",
  107. "8192",
  108. "--enforce-eager",
  109. # lora config below
  110. "--enable-lora",
  111. "--lora-modules",
  112. f"zephyr-lora={zephyr_lora_files}",
  113. f"zephyr-lora2={zephyr_lora_files}",
  114. "--max-lora-rank",
  115. "64",
  116. "--max-cpu-loras",
  117. "2",
  118. "--max-num-seqs",
  119. "128"
  120. ])
  121. ray.get(server_runner.ready.remote())
  122. yield server_runner
  123. ray.shutdown()
  124. @pytest.fixture(scope="session")
  125. def client():
  126. client = openai.AsyncOpenAI(
  127. base_url="http://localhost:2242/v1",
  128. api_key="",
  129. )
  130. yield client
  131. async def test_check_models(server, client: openai.AsyncOpenAI):
  132. models = await client.models.list()
  133. models = models.data
  134. served_model = models[0]
  135. lora_models = models[1:]
  136. assert served_model.id == MODEL_NAME
  137. assert all(model.root == MODEL_NAME for model in models)
  138. assert lora_models[0].id == "zephyr-lora"
  139. assert lora_models[1].id == "zephyr-lora2"
  140. @pytest.mark.parametrize(
  141. # first test base model, then test loras
  142. "model_name",
  143. [MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
  144. )
  145. async def test_single_completion(server, client: openai.AsyncOpenAI,
  146. model_name: str):
  147. completion = await client.completions.create(model=model_name,
  148. prompt="Hello, my name is",
  149. max_tokens=5,
  150. temperature=0.0)
  151. assert completion.id is not None
  152. assert completion.choices is not None and len(completion.choices) == 1
  153. assert completion.choices[0].text is not None and len(
  154. completion.choices[0].text) >= 5
  155. assert completion.choices[0].finish_reason == "length"
  156. assert completion.usage == openai.types.CompletionUsage(
  157. completion_tokens=5, prompt_tokens=6, total_tokens=11)
  158. # test using token IDs
  159. completion = await client.completions.create(
  160. model=MODEL_NAME,
  161. prompt=[0, 0, 0, 0, 0],
  162. max_tokens=5,
  163. temperature=0.0,
  164. )
  165. assert completion.choices[0].text is not None and len(
  166. completion.choices[0].text) >= 5
  167. @pytest.mark.parametrize(
  168. # just test 1 lora hereafter
  169. "model_name",
  170. [MODEL_NAME, "zephyr-lora"],
  171. )
  172. async def test_single_chat_session(server, client: openai.AsyncOpenAI,
  173. model_name: str):
  174. messages = [{
  175. "role": "system",
  176. "content": "you are a helpful assistant"
  177. }, {
  178. "role": "user",
  179. "content": "what is 1+1?"
  180. }]
  181. # test single completion
  182. chat_completion = await client.chat.completions.create(model=model_name,
  183. messages=messages,
  184. max_tokens=10,
  185. logprobs=True,
  186. top_logprobs=10)
  187. assert chat_completion.id is not None
  188. assert chat_completion.choices is not None and len(
  189. chat_completion.choices) == 1
  190. assert chat_completion.choices[0].message is not None
  191. assert chat_completion.choices[0].logprobs is not None
  192. assert chat_completion.choices[0].logprobs.top_logprobs is not None
  193. assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 10
  194. message = chat_completion.choices[0].message
  195. assert message.content is not None and len(message.content) >= 10
  196. assert message.role == "assistant"
  197. messages.append({"role": "assistant", "content": message.content})
  198. # test multi-turn dialogue
  199. messages.append({"role": "user", "content": "express your result in json"})
  200. chat_completion = await client.chat.completions.create(
  201. model=MODEL_NAME,
  202. messages=messages,
  203. max_tokens=10,
  204. )
  205. message = chat_completion.choices[0].message
  206. assert message.content is not None and len(message.content) >= 0
  207. @pytest.mark.parametrize(
  208. # just test 1 lora hereafter
  209. "model_name",
  210. [MODEL_NAME, "zephyr-lora"],
  211. )
  212. async def test_completion_streaming(server, client: openai.AsyncOpenAI,
  213. model_name: str):
  214. prompt = "What is an LLM?"
  215. single_completion = await client.completions.create(
  216. model=model_name,
  217. prompt=prompt,
  218. max_tokens=5,
  219. temperature=0.0,
  220. )
  221. single_output = single_completion.choices[0].text
  222. single_usage = single_completion.usage
  223. stream = await client.completions.create(model=model_name,
  224. prompt=prompt,
  225. max_tokens=5,
  226. temperature=0.0,
  227. stream=True)
  228. chunks = []
  229. async for chunk in stream:
  230. chunks.append(chunk.choices[0].text)
  231. assert chunk.choices[0].finish_reason == "length"
  232. assert chunk.usage == single_usage
  233. assert "".join(chunks) == single_output
  234. @pytest.mark.parametrize(
  235. # just test 1 lora hereafter
  236. "model_name",
  237. [MODEL_NAME, "zephyr-lora"],
  238. )
  239. async def test_chat_streaming(server, client: openai.AsyncOpenAI,
  240. model_name: str):
  241. messages = [{
  242. "role": "system",
  243. "content": "you are a helpful assistant"
  244. }, {
  245. "role": "user",
  246. "content": "what is 1+1?"
  247. }]
  248. # test single completion
  249. chat_completion = await client.chat.completions.create(
  250. model=model_name,
  251. messages=messages,
  252. max_tokens=10,
  253. temperature=0.0,
  254. )
  255. output = chat_completion.choices[0].message.content
  256. stop_reason = chat_completion.choices[0].finish_reason
  257. # test streaming
  258. stream = await client.chat.completions.create(
  259. model=model_name,
  260. messages=messages,
  261. max_tokens=10,
  262. temperature=0.0,
  263. stream=True,
  264. )
  265. chunks = []
  266. async for chunk in stream:
  267. delta = chunk.choices[0].delta
  268. if delta.role:
  269. assert delta.role == "assistant"
  270. if delta.content:
  271. chunks.append(delta.content)
  272. assert chunk.choices[0].finish_reason == stop_reason
  273. assert "".join(chunks) == output
  274. @pytest.mark.parametrize(
  275. # just test 1 lora hereafter
  276. "model_name",
  277. [MODEL_NAME, "zephyr-lora"],
  278. )
  279. async def test_batch_completions(server, client: openai.AsyncOpenAI,
  280. model_name: str):
  281. # test simple list
  282. batch = await client.completions.create(
  283. model=model_name,
  284. prompt=["Hello, my name is", "Hello, my name is"],
  285. max_tokens=5,
  286. temperature=0.0,
  287. )
  288. assert len(batch.choices) == 2
  289. assert batch.choices[0].text == batch.choices[1].text
  290. # test n = 2
  291. batch = await client.completions.create(
  292. model=model_name,
  293. prompt=["Hello, my name is", "Hello, my name is"],
  294. n=2,
  295. max_tokens=5,
  296. temperature=0.0,
  297. extra_body=dict(
  298. # NOTE: this has to be true for n > 1 in Aphrodite, but not necessary for official client.
  299. use_beam_search=True),
  300. )
  301. assert len(batch.choices) == 4
  302. assert batch.choices[0].text != batch.choices[
  303. 1].text, "beam search should be different"
  304. assert batch.choices[0].text == batch.choices[
  305. 2].text, "two copies of the same prompt should be the same"
  306. assert batch.choices[1].text == batch.choices[
  307. 3].text, "two copies of the same prompt should be the same"
  308. # test streaming
  309. batch = await client.completions.create(
  310. model=model_name,
  311. prompt=["Hello, my name is", "Hello, my name is"],
  312. max_tokens=5,
  313. temperature=0.0,
  314. stream=True,
  315. )
  316. texts = [""] * 2
  317. async for chunk in batch:
  318. assert len(chunk.choices) == 1
  319. choice = chunk.choices[0]
  320. texts[choice.index] += choice.text
  321. assert texts[0] == texts[1]
  322. async def test_logits_bias(server, client: openai.AsyncOpenAI):
  323. prompt = "Hello, my name is"
  324. max_tokens = 5
  325. tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
  326. # Test exclusive selection
  327. token_id = 1000
  328. completion = await client.completions.create(
  329. model=MODEL_NAME,
  330. prompt=prompt,
  331. max_tokens=max_tokens,
  332. temperature=0.0,
  333. logit_bias={str(token_id): 100},
  334. seed=42,
  335. )
  336. assert completion.choices[0].text is not None and len(
  337. completion.choices[0].text) >= 5
  338. response_tokens = tokenizer(completion.choices[0].text,
  339. add_special_tokens=False)["input_ids"]
  340. expected_tokens = tokenizer(tokenizer.decode([token_id] * 5),
  341. add_special_tokens=False)["input_ids"]
  342. assert all([
  343. response == expected
  344. for response, expected in zip(response_tokens, expected_tokens)
  345. ])
  346. # Test ban
  347. completion = await client.completions.create(
  348. model=MODEL_NAME,
  349. prompt=prompt,
  350. max_tokens=max_tokens,
  351. temperature=0.0,
  352. )
  353. response_tokens = tokenizer(completion.choices[0].text,
  354. add_special_tokens=False)["input_ids"]
  355. first_response = completion.choices[0].text
  356. completion = await client.completions.create(
  357. model=MODEL_NAME,
  358. prompt=prompt,
  359. max_tokens=max_tokens,
  360. temperature=0.0,
  361. logit_bias={str(token): -100
  362. for token in response_tokens},
  363. )
  364. assert first_response != completion.choices[0].text
  365. async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
  366. completion = await client.completions.create(
  367. model=MODEL_NAME,
  368. prompt=
  369. f"Give an example JSON for an employee profile that fits this schema: {TEST_SCHEMA}",
  370. n=3,
  371. temperature=1.0,
  372. max_tokens=500,
  373. extra_body=dict(guided_json=TEST_SCHEMA))
  374. assert completion.id is not None
  375. assert completion.choices is not None and len(completion.choices) == 3
  376. for i in range(3):
  377. assert completion.choices[i].text is not None
  378. output_json = json.loads(completion.choices[i].text)
  379. jsonschema.validate(instance=output_json, schema=TEST_SCHEMA)
  380. async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
  381. messages = [{
  382. "role": "system",
  383. "content": "you are a helpful assistant"
  384. }, {
  385. "role": "user",
  386. "content": "Give an example JSON for an employee profile that " + \
  387. f"fits this schema: {TEST_SCHEMA}"
  388. }]
  389. chat_completion = await client.chat.completions.create(
  390. model=MODEL_NAME,
  391. messages=messages,
  392. max_tokens=500,
  393. extra_body=dict(guided_json=TEST_SCHEMA))
  394. message = chat_completion.choices[0].message
  395. assert message.content is not None
  396. json1 = json.loads(message.content)
  397. jsonschema.validate(instance=json1, schema=TEST_SCHEMA)
  398. messages.append({"role": "assistant", "content": message.content})
  399. messages.append({
  400. "role":
  401. "user",
  402. "content":
  403. "Give me another one with a different name and age"
  404. })
  405. chat_completion = await client.chat.completions.create(
  406. model=MODEL_NAME,
  407. messages=messages,
  408. max_tokens=500,
  409. extra_body=dict(guided_json=TEST_SCHEMA))
  410. message = chat_completion.choices[0].message
  411. assert message.content is not None
  412. json2 = json.loads(message.content)
  413. jsonschema.validate(instance=json2, schema=TEST_SCHEMA)
  414. assert json1["name"] != json2["name"]
  415. assert json1["age"] != json2["age"]
  416. async def test_guided_regex_completion(server, client: openai.AsyncOpenAI):
  417. completion = await client.completions.create(
  418. model=MODEL_NAME,
  419. prompt=f"Give an example IPv4 address with this regex: {TEST_REGEX}",
  420. n=3,
  421. temperature=1.0,
  422. max_tokens=20,
  423. extra_body=dict(guided_regex=TEST_REGEX))
  424. assert completion.id is not None
  425. assert completion.choices is not None and len(completion.choices) == 3
  426. for i in range(3):
  427. assert completion.choices[i].text is not None
  428. assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None
  429. async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
  430. messages = [{
  431. "role": "system",
  432. "content": "you are a helpful assistant"
  433. }, {
  434. "role":
  435. "user",
  436. "content":
  437. f"Give an example IP address with this regex: {TEST_REGEX}"
  438. }]
  439. chat_completion = await client.chat.completions.create(
  440. model=MODEL_NAME,
  441. messages=messages,
  442. max_tokens=20,
  443. extra_body=dict(guided_regex=TEST_REGEX))
  444. ip1 = chat_completion.choices[0].message.content
  445. assert ip1 is not None
  446. assert re.fullmatch(TEST_REGEX, ip1) is not None
  447. messages.append({"role": "assistant", "content": ip1})
  448. messages.append({"role": "user", "content": "Give me a different one"})
  449. chat_completion = await client.chat.completions.create(
  450. model=MODEL_NAME,
  451. messages=messages,
  452. max_tokens=20,
  453. extra_body=dict(guided_regex=TEST_REGEX))
  454. ip2 = chat_completion.choices[0].message.content
  455. assert ip2 is not None
  456. assert re.fullmatch(TEST_REGEX, ip2) is not None
  457. assert ip1 != ip2
  458. async def test_guided_choice_completion(server, client: openai.AsyncOpenAI):
  459. completion = await client.completions.create(
  460. model=MODEL_NAME,
  461. prompt="The best language for type-safe systems programming is ",
  462. n=2,
  463. temperature=1.0,
  464. max_tokens=10,
  465. extra_body=dict(guided_choice=TEST_CHOICE))
  466. assert completion.id is not None
  467. assert completion.choices is not None and len(completion.choices) == 2
  468. for i in range(2):
  469. assert completion.choices[i].text in TEST_CHOICE
  470. async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
  471. messages = [{
  472. "role": "system",
  473. "content": "you are a helpful assistant"
  474. }, {
  475. "role":
  476. "user",
  477. "content":
  478. "The best language for type-safe systems programming is "
  479. }]
  480. chat_completion = await client.chat.completions.create(
  481. model=MODEL_NAME,
  482. messages=messages,
  483. max_tokens=10,
  484. extra_body=dict(guided_choice=TEST_CHOICE))
  485. choice1 = chat_completion.choices[0].message.content
  486. assert choice1 in TEST_CHOICE
  487. messages.append({"role": "assistant", "content": choice1})
  488. messages.append({
  489. "role": "user",
  490. "content": "I disagree, pick another one"
  491. })
  492. chat_completion = await client.chat.completions.create(
  493. model=MODEL_NAME,
  494. messages=messages,
  495. max_tokens=10,
  496. extra_body=dict(guided_choice=TEST_CHOICE))
  497. choice2 = chat_completion.choices[0].message.content
  498. assert choice2 in TEST_CHOICE
  499. assert choice1 != choice2
  500. async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI):
  501. with pytest.raises(openai.BadRequestError):
  502. _ = await client.completions.create(
  503. model=MODEL_NAME,
  504. prompt="Give an example JSON that fits this schema: 42",
  505. extra_body=dict(guided_json=42))
  506. messages = [{
  507. "role": "system",
  508. "content": "you are a helpful assistant"
  509. }, {
  510. "role":
  511. "user",
  512. "content":
  513. "The best language for type-safe systems programming is "
  514. }]
  515. with pytest.raises(openai.BadRequestError):
  516. _ = await client.chat.completions.create(model=MODEL_NAME,
  517. messages=messages,
  518. extra_body=dict(guided_regex={
  519. 1: "Python",
  520. 2: "C++"
  521. }))
  522. with pytest.raises(openai.BadRequestError):
  523. _ = await client.completions.create(
  524. model=MODEL_NAME,
  525. prompt="Give an example string that fits this regex",
  526. extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA))
  527. if __name__ == "__main__":
  528. pytest.main([__file__])