test_completion.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832
  1. # imports for guided decoding tests
  2. import json
  3. import re
  4. import shutil
  5. from tempfile import TemporaryDirectory
  6. from typing import Dict, List
  7. import jsonschema
  8. import openai # use the official client for correctness check
  9. import pytest
  10. # downloading lora to test lora requests
  11. from huggingface_hub import snapshot_download
  12. from openai import BadRequestError
  13. from transformers import AutoTokenizer
  14. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  15. from ...utils import RemoteOpenAIServer
  16. # any model with a chat template should work here
  17. MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
  18. # technically these adapters use a different base model,
  19. # but we're not testing generation quality here
  20. LORA_NAME = "typeof/zephyr-7b-beta-lora"
  21. PA_NAME = "swapnilbp/llama_tweet_ptune"
  22. # if PA_NAME changes, PA_NUM_VIRTUAL_TOKENS might also
  23. # need to change to match the prompt adapter
  24. PA_NUM_VIRTUAL_TOKENS = 8
  25. @pytest.fixture(scope="module")
  26. def zephyr_lora_files():
  27. return snapshot_download(repo_id=LORA_NAME)
  28. @pytest.fixture(scope="module")
  29. def zephyr_lora_added_tokens_files(zephyr_lora_files):
  30. tmp_dir = TemporaryDirectory()
  31. tmp_model_dir = f"{tmp_dir.name}/zephyr"
  32. shutil.copytree(zephyr_lora_files, tmp_model_dir)
  33. tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
  34. # Copy tokenizer to adapter and add some unique tokens
  35. # 32000, 32001, 32002
  36. added = tokenizer.add_tokens(["aphrodite1", "aphrodite2", "aphrodite3"],
  37. special_tokens=True)
  38. assert added == 3
  39. tokenizer.save_pretrained(tmp_model_dir)
  40. yield tmp_model_dir
  41. tmp_dir.cleanup()
  42. @pytest.fixture(scope="module")
  43. def zephyr_pa_files():
  44. return snapshot_download(repo_id=PA_NAME)
  45. @pytest.fixture(scope="module")
  46. def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
  47. zephyr_pa_files):
  48. return [
  49. # use half precision for speed and memory savings in CI environment
  50. "--dtype",
  51. "bfloat16",
  52. "--max-model-len",
  53. "8192",
  54. "--max-num-seqs",
  55. "128",
  56. "--enforce-eager",
  57. # lora config
  58. "--enable-lora",
  59. "--lora-modules",
  60. f"zephyr-lora={zephyr_lora_files}",
  61. f"zephyr-lora2={zephyr_lora_added_tokens_files}",
  62. "--max-lora-rank",
  63. "64",
  64. "--max-cpu-loras",
  65. "2",
  66. # pa config
  67. "--enable-prompt-adapter",
  68. "--prompt-adapters",
  69. f"zephyr-pa={zephyr_pa_files}",
  70. f"zephyr-pa2={zephyr_pa_files}",
  71. "--max-prompt-adapters",
  72. "2",
  73. "--max-prompt-adapter-token",
  74. "128",
  75. ]
  76. @pytest.fixture(scope="module",
  77. params=["", "--disable-frontend-multiprocessing"])
  78. def client(default_server_args, request):
  79. if request.param:
  80. default_server_args.append(request.param)
  81. with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
  82. yield remote_server.get_async_client()
  83. @pytest.mark.asyncio
  84. @pytest.mark.parametrize(
  85. # first test base model, then test loras, then test prompt adapters
  86. "model_name,num_virtual_tokens",
  87. [(MODEL_NAME, 0), ("zephyr-lora", 0), ("zephyr-lora2", 0),
  88. ("zephyr-pa", PA_NUM_VIRTUAL_TOKENS),
  89. ("zephyr-pa2", PA_NUM_VIRTUAL_TOKENS)],
  90. )
  91. async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
  92. num_virtual_tokens: int):
  93. completion = await client.completions.create(model=model_name,
  94. prompt="Hello, my name is",
  95. max_tokens=5,
  96. temperature=0.0)
  97. assert completion.id is not None
  98. assert completion.choices is not None and len(completion.choices) == 1
  99. choice = completion.choices[0]
  100. assert len(choice.text) >= 5
  101. assert choice.finish_reason == "length"
  102. assert completion.usage == openai.types.CompletionUsage(
  103. completion_tokens=5,
  104. prompt_tokens=6 + num_virtual_tokens,
  105. total_tokens=11 + num_virtual_tokens)
  106. # test using token IDs
  107. completion = await client.completions.create(
  108. model=model_name,
  109. prompt=[0, 0, 0, 0, 0],
  110. max_tokens=5,
  111. temperature=0.0,
  112. )
  113. assert len(completion.choices[0].text) >= 1
  114. assert completion.choices[0].prompt_logprobs is None
  115. @pytest.mark.asyncio
  116. async def test_added_lora_tokens(client: openai.AsyncOpenAI):
  117. # test using token IDs
  118. completion = await client.completions.create(
  119. model="zephyr-lora2",
  120. prompt=[0, 0, 32000, 32001, 32002],
  121. echo=True,
  122. max_tokens=5,
  123. temperature=0.0,
  124. )
  125. # Added tokens should appear in tokenized prompt
  126. assert completion.choices[0].text.startswith(
  127. "<unk><unk>aphrodite1aphrodite2aphrodite3")
  128. @pytest.mark.asyncio
  129. async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI):
  130. # test using token IDs
  131. completion = await client.completions.create(
  132. model=MODEL_NAME,
  133. prompt=[0, 0, 32000, 32001, 32002],
  134. echo=True,
  135. max_tokens=5,
  136. temperature=0.0,
  137. )
  138. # Added tokens should not appear in tokenized prompt
  139. assert "aphrodite" not in completion.choices[0].text
  140. @pytest.mark.asyncio
  141. @pytest.mark.parametrize(
  142. # first test base model, then test loras, then test prompt adapters
  143. "model_name",
  144. [MODEL_NAME, "zephyr-lora", "zephyr-lora2", "zephyr-pa", "zephyr-pa2"],
  145. )
  146. async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
  147. # test using token IDs
  148. completion = await client.completions.create(
  149. model=model_name,
  150. prompt=[0, 0, 0, 0, 0],
  151. max_tokens=5,
  152. temperature=0.0,
  153. logprobs=None,
  154. )
  155. choice = completion.choices[0]
  156. assert choice.logprobs is None
  157. @pytest.mark.asyncio
  158. @pytest.mark.parametrize(
  159. # just test 1 lora and 1 pa hereafter
  160. "model_name",
  161. [MODEL_NAME, "zephyr-lora", "zephyr-pa"],
  162. )
  163. async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
  164. # test using token IDs
  165. completion = await client.completions.create(
  166. model=model_name,
  167. prompt=[0, 0, 0, 0, 0],
  168. max_tokens=5,
  169. temperature=0.0,
  170. logprobs=0,
  171. )
  172. choice = completion.choices[0]
  173. assert choice.logprobs is not None
  174. assert choice.logprobs.token_logprobs is not None
  175. assert choice.logprobs.top_logprobs is not None
  176. assert len(choice.logprobs.top_logprobs[0]) == 1
  177. @pytest.mark.asyncio
  178. @pytest.mark.parametrize(
  179. "model_name",
  180. [MODEL_NAME, "zephyr-lora", "zephyr-pa"],
  181. )
  182. async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
  183. # test using token IDs
  184. completion = await client.completions.create(
  185. model=model_name,
  186. prompt=[0, 0, 0, 0, 0],
  187. max_tokens=5,
  188. temperature=0.0,
  189. logprobs=5,
  190. )
  191. choice = completion.choices[0]
  192. assert choice.logprobs is not None
  193. assert choice.logprobs.token_logprobs is not None
  194. assert choice.logprobs.top_logprobs is not None
  195. assert 5 <= len(choice.logprobs.top_logprobs[0]) <= 6
  196. @pytest.mark.asyncio
  197. @pytest.mark.parametrize(
  198. "model_name",
  199. [MODEL_NAME, "zephyr-lora", "zephyr-pa"],
  200. )
  201. async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
  202. model_name: str):
  203. with pytest.raises(
  204. (openai.BadRequestError, openai.APIError)): # test using token IDs
  205. await client.completions.create(
  206. model=model_name,
  207. prompt=[0, 0, 0, 0, 0],
  208. max_tokens=5,
  209. temperature=0.0,
  210. # Aphrodite has higher default max_logprobs (20 instead of 5)
  211. # to support both Completion API and Chat Completion API
  212. logprobs=21,
  213. )
  214. ...
  215. with pytest.raises(
  216. (openai.BadRequestError, openai.APIError)): # test using token IDs
  217. stream = await client.completions.create(
  218. model=model_name,
  219. prompt=[0, 0, 0, 0, 0],
  220. max_tokens=5,
  221. temperature=0.0,
  222. # Aphrodite has higher default max_logprobs (20 instead of 5)
  223. # to support both Completion API and Chat Completion API
  224. logprobs=30,
  225. stream=True,
  226. )
  227. async for chunk in stream:
  228. ...
  229. # the server should still work afterwards
  230. completion = await client.completions.create(
  231. model=model_name,
  232. prompt=[0, 0, 0, 0, 0],
  233. max_tokens=5,
  234. temperature=0.0,
  235. )
  236. assert len(completion.choices[0].text) >= 0
  237. @pytest.mark.asyncio
  238. @pytest.mark.parametrize(
  239. "model_name, prompt_logprobs",
  240. [(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)],
  241. )
  242. async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI,
  243. model_name: str, prompt_logprobs: int):
  244. params: Dict = {
  245. "messages": [{
  246. "role": "system",
  247. "content": "You are a helpful assistant."
  248. }, {
  249. "role": "user",
  250. "content": "Who won the world series in 2020?"
  251. }, {
  252. "role":
  253. "assistant",
  254. "content":
  255. "The Los Angeles Dodgers won the World Series in 2020."
  256. }, {
  257. "role": "user",
  258. "content": "Where was it played?"
  259. }],
  260. "model":
  261. model_name
  262. }
  263. if prompt_logprobs is not None:
  264. params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
  265. if prompt_logprobs and prompt_logprobs < 0:
  266. with pytest.raises(BadRequestError) as err_info:
  267. await client.chat.completions.create(**params)
  268. expected_err_string = (
  269. "Error code: 400 - {'object': 'error', 'message': "
  270. "'Prompt_logprobs set to invalid negative value: -1',"
  271. " 'type': 'BadRequestError', 'param': None, 'code': 400}")
  272. assert str(err_info.value) == expected_err_string
  273. else:
  274. completion = await client.chat.completions.create(**params)
  275. if prompt_logprobs and prompt_logprobs > 0:
  276. assert completion.prompt_logprobs is not None
  277. assert len(completion.prompt_logprobs) > 0
  278. else:
  279. assert completion.prompt_logprobs is None
  280. @pytest.mark.asyncio
  281. @pytest.mark.parametrize(
  282. "model_name",
  283. [MODEL_NAME],
  284. )
  285. async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
  286. model_name: str):
  287. params: Dict = {
  288. "messages": [{
  289. "role": "system",
  290. "content": "You are a helpful assistant."
  291. }, {
  292. "role": "user",
  293. "content": "Who won the world series in 2020?"
  294. }, {
  295. "role":
  296. "assistant",
  297. "content":
  298. "The Los Angeles Dodgers won the World Series in 2020."
  299. }, {
  300. "role": "user",
  301. "content": "Where was it played?"
  302. }],
  303. "model":
  304. model_name,
  305. "extra_body": {
  306. "prompt_logprobs": 1
  307. }
  308. }
  309. completion_1 = await client.chat.completions.create(**params)
  310. params["extra_body"] = {"prompt_logprobs": 2}
  311. completion_2 = await client.chat.completions.create(**params)
  312. assert len(completion_1.prompt_logprobs[3]) == 1
  313. assert len(completion_2.prompt_logprobs[3]) == 2
  314. @pytest.mark.asyncio
  315. @pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1),
  316. (MODEL_NAME, 0),
  317. (MODEL_NAME, 1),
  318. (MODEL_NAME, None)])
  319. async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
  320. model_name: str,
  321. prompt_logprobs: int):
  322. params: Dict = {
  323. "prompt": ["A robot may not injure another robot", "My name is"],
  324. "model": model_name,
  325. }
  326. if prompt_logprobs is not None:
  327. params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
  328. if prompt_logprobs and prompt_logprobs < 0:
  329. with pytest.raises(BadRequestError) as err_info:
  330. await client.completions.create(**params)
  331. expected_err_string = (
  332. "Error code: 400 - {'object': 'error', 'message': "
  333. "'Prompt_logprobs set to invalid negative value: -1',"
  334. " 'type': 'BadRequestError', 'param': None, 'code': 400}")
  335. assert str(err_info.value) == expected_err_string
  336. else:
  337. completion = await client.completions.create(**params)
  338. if prompt_logprobs and prompt_logprobs > 0:
  339. assert completion.choices[0].prompt_logprobs is not None
  340. assert len(completion.choices[0].prompt_logprobs) > 0
  341. assert completion.choices[1].prompt_logprobs is not None
  342. assert len(completion.choices[1].prompt_logprobs) > 0
  343. else:
  344. assert completion.choices[0].prompt_logprobs is None
  345. @pytest.mark.asyncio
  346. @pytest.mark.parametrize(
  347. "model_name",
  348. [MODEL_NAME, "zephyr-lora", "zephyr-pa"],
  349. )
  350. async def test_completion_streaming(client: openai.AsyncOpenAI,
  351. model_name: str):
  352. prompt = "What is an LLM?"
  353. single_completion = await client.completions.create(
  354. model=model_name,
  355. prompt=prompt,
  356. max_tokens=5,
  357. temperature=0.0,
  358. )
  359. single_output = single_completion.choices[0].text
  360. stream = await client.completions.create(model=model_name,
  361. prompt=prompt,
  362. max_tokens=5,
  363. temperature=0.0,
  364. stream=True)
  365. chunks: List[str] = []
  366. finish_reason_count = 0
  367. async for chunk in stream:
  368. chunks.append(chunk.choices[0].text)
  369. if chunk.choices[0].finish_reason is not None:
  370. finish_reason_count += 1
  371. # finish reason should only return in last block
  372. assert finish_reason_count == 1
  373. assert chunk.choices[0].finish_reason == "length"
  374. assert chunk.choices[0].text
  375. assert "".join(chunks) == single_output
  376. @pytest.mark.asyncio
  377. @pytest.mark.parametrize(
  378. "model_name",
  379. [MODEL_NAME, "zephyr-lora", "zephyr-pa"],
  380. )
  381. async def test_completion_stream_options(client: openai.AsyncOpenAI,
  382. model_name: str):
  383. prompt = "What is the capital of France?"
  384. # Test stream=True, stream_options=
  385. # {"include_usage": False, "continuous_usage_stats": False}
  386. stream = await client.completions.create(model=model_name,
  387. prompt=prompt,
  388. max_tokens=5,
  389. temperature=0.0,
  390. stream=True,
  391. stream_options={
  392. "include_usage": False,
  393. "continuous_usage_stats":
  394. False,
  395. })
  396. async for chunk in stream:
  397. assert chunk.usage is None
  398. # Test stream=True, stream_options=
  399. # {"include_usage": False, "continuous_usage_stats": True}
  400. stream = await client.completions.create(model=model_name,
  401. prompt=prompt,
  402. max_tokens=5,
  403. temperature=0.0,
  404. stream=True,
  405. stream_options={
  406. "include_usage": False,
  407. "continuous_usage_stats":
  408. True,
  409. })
  410. async for chunk in stream:
  411. assert chunk.usage is None
  412. # Test stream=True, stream_options=
  413. # {"include_usage": True, "continuous_usage_stats": False}
  414. stream = await client.completions.create(model=model_name,
  415. prompt=prompt,
  416. max_tokens=5,
  417. temperature=0.0,
  418. stream=True,
  419. stream_options={
  420. "include_usage": True,
  421. "continuous_usage_stats":
  422. False,
  423. })
  424. async for chunk in stream:
  425. if chunk.choices[0].finish_reason is None:
  426. assert chunk.usage is None
  427. else:
  428. assert chunk.usage is None
  429. final_chunk = await stream.__anext__()
  430. assert final_chunk.usage is not None
  431. assert final_chunk.usage.prompt_tokens > 0
  432. assert final_chunk.usage.completion_tokens > 0
  433. assert final_chunk.usage.total_tokens == (
  434. final_chunk.usage.prompt_tokens +
  435. final_chunk.usage.completion_tokens)
  436. assert final_chunk.choices == []
  437. # Test stream=True, stream_options=
  438. # {"include_usage": True, "continuous_usage_stats": True}
  439. stream = await client.completions.create(model=model_name,
  440. prompt=prompt,
  441. max_tokens=5,
  442. temperature=0.0,
  443. stream=True,
  444. stream_options={
  445. "include_usage": True,
  446. "continuous_usage_stats":
  447. True,
  448. })
  449. async for chunk in stream:
  450. assert chunk.usage is not None
  451. assert chunk.usage.prompt_tokens > 0
  452. assert chunk.usage.completion_tokens > 0
  453. assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens +
  454. chunk.usage.completion_tokens)
  455. if chunk.choices[0].finish_reason is not None:
  456. final_chunk = await stream.__anext__()
  457. assert final_chunk.usage is not None
  458. assert final_chunk.usage.prompt_tokens > 0
  459. assert final_chunk.usage.completion_tokens > 0
  460. assert final_chunk.usage.total_tokens == (
  461. final_chunk.usage.prompt_tokens +
  462. final_chunk.usage.completion_tokens)
  463. assert final_chunk.choices == []
  464. # Test stream=False, stream_options=
  465. # {"include_usage": None}
  466. with pytest.raises(BadRequestError):
  467. await client.completions.create(model=model_name,
  468. prompt=prompt,
  469. max_tokens=5,
  470. temperature=0.0,
  471. stream=False,
  472. stream_options={"include_usage": None})
  473. # Test stream=False, stream_options=
  474. # {"include_usage": True}
  475. with pytest.raises(BadRequestError):
  476. await client.completions.create(model=model_name,
  477. prompt=prompt,
  478. max_tokens=5,
  479. temperature=0.0,
  480. stream=False,
  481. stream_options={"include_usage": True})
  482. # Test stream=False, stream_options=
  483. # {"continuous_usage_stats": None}
  484. with pytest.raises(BadRequestError):
  485. await client.completions.create(
  486. model=model_name,
  487. prompt=prompt,
  488. max_tokens=5,
  489. temperature=0.0,
  490. stream=False,
  491. stream_options={"continuous_usage_stats": None})
  492. # Test stream=False, stream_options=
  493. # {"continuous_usage_stats": True}
  494. with pytest.raises(BadRequestError):
  495. await client.completions.create(
  496. model=model_name,
  497. prompt=prompt,
  498. max_tokens=5,
  499. temperature=0.0,
  500. stream=False,
  501. stream_options={"continuous_usage_stats": True})
  502. @pytest.mark.asyncio
  503. @pytest.mark.parametrize(
  504. "model_name",
  505. [MODEL_NAME, "zephyr-lora", "zephyr-pa"],
  506. )
  507. async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
  508. # test both text and token IDs
  509. for prompts in (["Hello, my name is"] * 2, [[0, 0, 0, 0, 0]] * 2):
  510. # test simple list
  511. batch = await client.completions.create(
  512. model=model_name,
  513. prompt=prompts,
  514. max_tokens=5,
  515. temperature=0.0,
  516. )
  517. assert len(batch.choices) == 2
  518. assert batch.choices[0].text == batch.choices[1].text
  519. # test n = 2
  520. batch = await client.completions.create(
  521. model=model_name,
  522. prompt=prompts,
  523. n=2,
  524. max_tokens=5,
  525. temperature=0.0,
  526. extra_body=dict(
  527. # NOTE: this has to be true for n > 1 in Aphrodite, but
  528. # not necessary for official client.
  529. use_beam_search=True),
  530. )
  531. assert len(batch.choices) == 4
  532. assert batch.choices[0].text != batch.choices[
  533. 1].text, "beam search should be different"
  534. assert batch.choices[0].text == batch.choices[
  535. 2].text, "two copies of the same prompt should be the same"
  536. assert batch.choices[1].text == batch.choices[
  537. 3].text, "two copies of the same prompt should be the same"
  538. # test streaming
  539. batch = await client.completions.create(
  540. model=model_name,
  541. prompt=prompts,
  542. max_tokens=5,
  543. temperature=0.0,
  544. stream=True,
  545. )
  546. texts = [""] * 2
  547. async for chunk in batch:
  548. assert len(chunk.choices) == 1
  549. choice = chunk.choices[0]
  550. texts[choice.index] += choice.text
  551. assert texts[0] == texts[1]
  552. @pytest.mark.asyncio
  553. async def test_logits_bias(client: openai.AsyncOpenAI):
  554. prompt = "Hello, my name is"
  555. max_tokens = 5
  556. tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
  557. # Test exclusive selection
  558. token_id = 1000
  559. completion = await client.completions.create(
  560. model=MODEL_NAME,
  561. prompt=prompt,
  562. max_tokens=max_tokens,
  563. temperature=0.0,
  564. logit_bias={str(token_id): 100},
  565. seed=42,
  566. )
  567. assert len(completion.choices[0].text) >= 5
  568. response_tokens = tokenizer(completion.choices[0].text,
  569. add_special_tokens=False)["input_ids"]
  570. expected_tokens = tokenizer(tokenizer.decode([token_id] * 5),
  571. add_special_tokens=False)["input_ids"]
  572. assert all([
  573. response == expected
  574. for response, expected in zip(response_tokens, expected_tokens)
  575. ])
  576. # Test ban
  577. completion = await client.completions.create(
  578. model=MODEL_NAME,
  579. prompt=prompt,
  580. max_tokens=max_tokens,
  581. temperature=0.0,
  582. )
  583. response_tokens = tokenizer(completion.choices[0].text,
  584. add_special_tokens=False)["input_ids"]
  585. first_response = completion.choices[0].text
  586. completion = await client.completions.create(
  587. model=MODEL_NAME,
  588. prompt=prompt,
  589. max_tokens=max_tokens,
  590. temperature=0.0,
  591. logit_bias={str(token): -100
  592. for token in response_tokens},
  593. )
  594. assert first_response != completion.choices[0].text
  595. @pytest.mark.asyncio
  596. async def test_allowed_token_ids(client: openai.AsyncOpenAI):
  597. prompt = "Hello, my name is"
  598. max_tokens = 1
  599. tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
  600. # Test exclusive selection
  601. allowed_ids = [21555, 21557, 21558]
  602. completion = await client.completions.create(
  603. model=MODEL_NAME,
  604. prompt=prompt,
  605. max_tokens=max_tokens,
  606. temperature=0.0,
  607. seed=42,
  608. extra_body=dict(allowed_token_ids=allowed_ids),
  609. logprobs=1,
  610. )
  611. response_tokens = completion.choices[0].logprobs.tokens
  612. assert len(response_tokens) == 1
  613. assert tokenizer.convert_tokens_to_ids(response_tokens)[0] in allowed_ids
  614. @pytest.mark.asyncio
  615. @pytest.mark.parametrize("guided_decoding_backend",
  616. ["outlines", "lm-format-enforcer"])
  617. async def test_guided_json_completion(client: openai.AsyncOpenAI,
  618. guided_decoding_backend: str,
  619. sample_json_schema):
  620. completion = await client.completions.create(
  621. model=MODEL_NAME,
  622. prompt=f"Give an example JSON for an employee profile "
  623. f"that fits this schema: {sample_json_schema}",
  624. n=3,
  625. temperature=1.0,
  626. max_tokens=500,
  627. extra_body=dict(guided_json=sample_json_schema,
  628. guided_decoding_backend=guided_decoding_backend))
  629. assert completion.id is not None
  630. assert len(completion.choices) == 3
  631. for i in range(3):
  632. output_json = json.loads(completion.choices[i].text)
  633. jsonschema.validate(instance=output_json, schema=sample_json_schema)
  634. @pytest.mark.asyncio
  635. @pytest.mark.parametrize("guided_decoding_backend",
  636. ["outlines", "lm-format-enforcer"])
  637. async def test_guided_regex_completion(client: openai.AsyncOpenAI,
  638. guided_decoding_backend: str,
  639. sample_regex):
  640. completion = await client.completions.create(
  641. model=MODEL_NAME,
  642. prompt=f"Give an example IPv4 address with this regex: {sample_regex}",
  643. n=3,
  644. temperature=1.0,
  645. max_tokens=20,
  646. extra_body=dict(guided_regex=sample_regex,
  647. guided_decoding_backend=guided_decoding_backend))
  648. assert completion.id is not None
  649. assert len(completion.choices) == 3
  650. for i in range(3):
  651. assert re.fullmatch(sample_regex,
  652. completion.choices[i].text) is not None
  653. @pytest.mark.asyncio
  654. @pytest.mark.parametrize("guided_decoding_backend",
  655. ["outlines", "lm-format-enforcer"])
  656. async def test_guided_choice_completion(client: openai.AsyncOpenAI,
  657. guided_decoding_backend: str,
  658. sample_guided_choice):
  659. completion = await client.completions.create(
  660. model=MODEL_NAME,
  661. prompt="The best language for type-safe systems programming is ",
  662. n=2,
  663. temperature=1.0,
  664. max_tokens=10,
  665. extra_body=dict(guided_choice=sample_guided_choice,
  666. guided_decoding_backend=guided_decoding_backend))
  667. assert completion.id is not None
  668. assert len(completion.choices) == 2
  669. for i in range(2):
  670. assert completion.choices[i].text in sample_guided_choice
  671. @pytest.mark.asyncio
  672. async def test_guided_grammar(client: openai.AsyncOpenAI,
  673. sample_sql_statements):
  674. completion = await client.completions.create(
  675. model=MODEL_NAME,
  676. prompt=("Generate a sql state that select col_1 from "
  677. "table_1 where it is equals to 1"),
  678. temperature=1.0,
  679. max_tokens=500,
  680. extra_body=dict(guided_grammar=sample_sql_statements))
  681. content = completion.choices[0].text
  682. # use Lark to parse the output, and make sure it's a valid parse tree
  683. from lark import Lark
  684. parser = Lark(sample_sql_statements)
  685. parser.parse(content)
  686. # remove spaces for comparison b/c we removed them in the grammar
  687. ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "")
  688. assert content.strip() == ground_truth
  689. @pytest.mark.asyncio
  690. @pytest.mark.parametrize(
  691. # first test base model, then test loras
  692. "model_name",
  693. [MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
  694. )
  695. @pytest.mark.parametrize("logprobs_arg", [1, 0])
  696. async def test_echo_logprob_completion(client: openai.AsyncOpenAI,
  697. model_name: str, logprobs_arg: int):
  698. tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
  699. # test using text and token IDs
  700. for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]):
  701. completion = await client.completions.create(model=model_name,
  702. prompt=prompt,
  703. max_tokens=5,
  704. temperature=0.0,
  705. echo=True,
  706. logprobs=logprobs_arg)
  707. prompt_text = tokenizer.decode(prompt) if isinstance(prompt,
  708. list) else prompt
  709. assert re.search(r"^" + prompt_text, completion.choices[0].text)
  710. logprobs = completion.choices[0].logprobs
  711. assert logprobs is not None
  712. assert len(logprobs.text_offset) > 5
  713. assert (len(logprobs.token_logprobs) > 5
  714. and logprobs.token_logprobs[0] is None)
  715. assert (len(logprobs.top_logprobs) > 5
  716. and logprobs.top_logprobs[0] is None)
  717. for top_logprobs in logprobs.top_logprobs[1:]:
  718. assert max(logprobs_arg,
  719. 1) <= len(top_logprobs) <= logprobs_arg + 1
  720. assert len(logprobs.tokens) > 5
  721. @pytest.mark.asyncio
  722. @pytest.mark.parametrize("guided_decoding_backend",
  723. ["outlines", "lm-format-enforcer"])
  724. async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
  725. guided_decoding_backend: str,
  726. sample_json_schema, sample_regex):
  727. with pytest.raises(openai.BadRequestError):
  728. _ = await client.completions.create(
  729. model=MODEL_NAME,
  730. prompt="Give an example JSON that fits this schema: 42",
  731. extra_body=dict(guided_json=42,
  732. guided_decoding_backend=guided_decoding_backend))
  733. with pytest.raises(openai.BadRequestError):
  734. _ = await client.completions.create(
  735. model=MODEL_NAME,
  736. prompt="Give an example string that fits this regex",
  737. extra_body=dict(guided_regex=sample_regex,
  738. guided_json=sample_json_schema))