1
0

test_guided_generate.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import json
  2. import re
  3. import weakref
  4. import jsonschema
  5. import pytest
  6. from aphrodite.common.outputs import RequestOutput
  7. from aphrodite.common.sampling_params import SamplingParams
  8. from aphrodite.endpoints.llm import LLM
  9. from ...conftest import cleanup
  10. MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
  11. @pytest.fixture(scope="module")
  12. def llm():
  13. # pytest caches the fixture so we use weakref.proxy to
  14. # enable garbage collection
  15. llm = LLM(model=MODEL_NAME, max_model_len=1024)
  16. with llm.deprecate_legacy_api():
  17. yield weakref.proxy(llm)
  18. del llm
  19. cleanup()
  20. @pytest.mark.skip_global_cleanup
  21. def test_guided_regex(sample_regex, llm):
  22. sampling_params = SamplingParams(
  23. temperature=0.8,
  24. top_p=0.95,
  25. )
  26. outputs = llm.generate(
  27. prompts=[
  28. f"Give an example IPv4 address with this regex: {sample_regex}"
  29. ] * 2,
  30. sampling_params=sampling_params,
  31. use_tqdm=True,
  32. guided_options_request=dict(guided_regex=sample_regex))
  33. assert outputs is not None
  34. for output in outputs:
  35. assert output is not None
  36. assert isinstance(output, RequestOutput)
  37. prompt = output.prompt
  38. generated_text = output.outputs[0].text
  39. print(generated_text)
  40. assert generated_text is not None
  41. assert re.fullmatch(sample_regex, generated_text) is not None
  42. print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
  43. @pytest.mark.skip_global_cleanup
  44. def test_guided_json_completion(sample_json_schema, llm):
  45. sampling_params = SamplingParams(
  46. temperature=1.0,
  47. max_tokens=1000,
  48. )
  49. outputs = llm.generate(
  50. prompts=[
  51. f"Give an example JSON for an employee profile "
  52. f"that fits this schema: {sample_json_schema}"
  53. ] * 2,
  54. sampling_params=sampling_params,
  55. use_tqdm=True,
  56. guided_options_request=dict(guided_json=sample_json_schema))
  57. assert outputs is not None
  58. for output in outputs:
  59. assert output is not None
  60. assert isinstance(output, RequestOutput)
  61. prompt = output.prompt
  62. generated_text = output.outputs[0].text
  63. assert generated_text is not None
  64. print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
  65. output_json = json.loads(generated_text)
  66. jsonschema.validate(instance=output_json, schema=sample_json_schema)
  67. @pytest.mark.skip_global_cleanup
  68. def test_guided_choice_completion(sample_guided_choice, llm):
  69. sampling_params = SamplingParams(
  70. temperature=0.8,
  71. top_p=0.95,
  72. )
  73. outputs = llm.generate(
  74. prompts="The best language for type-safe systems programming is ",
  75. sampling_params=sampling_params,
  76. use_tqdm=True,
  77. guided_options_request=dict(guided_choice=sample_guided_choice))
  78. assert outputs is not None
  79. for output in outputs:
  80. assert output is not None
  81. assert isinstance(output, RequestOutput)
  82. prompt = output.prompt
  83. generated_text = output.outputs[0].text
  84. print(generated_text)
  85. assert generated_text is not None
  86. assert generated_text in sample_guided_choice
  87. print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
  88. @pytest.mark.skip_global_cleanup
  89. def test_guided_grammar(sample_sql_statements, llm):
  90. sampling_params = SamplingParams(
  91. temperature=0.8,
  92. top_p=0.95,
  93. max_tokens=1000,
  94. )
  95. outputs = llm.generate(
  96. prompts=("Generate a sql state that select col_1 from "
  97. "table_1 where it is equals to 1"),
  98. sampling_params=sampling_params,
  99. use_tqdm=True,
  100. guided_options_request=dict(guided_grammar=sample_sql_statements))
  101. assert outputs is not None
  102. for output in outputs:
  103. assert output is not None
  104. assert isinstance(output, RequestOutput)
  105. prompt = output.prompt
  106. generated_text = output.outputs[0].text
  107. assert generated_text is not None
  108. # use Lark to parse the output, and make sure it's a valid parse tree
  109. from lark import Lark
  110. parser = Lark(sample_sql_statements)
  111. parser.parse(generated_text)
  112. # remove spaces for comparison b/c we removed them in the grammar
  113. ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
  114. " ", "")
  115. assert generated_text.strip() == ground_truth
  116. print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")