import json import re import weakref import jsonschema import pytest from aphrodite.common.outputs import RequestOutput from aphrodite.common.sampling_params import SamplingParams from aphrodite.endpoints.llm import LLM from ...conftest import cleanup MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" @pytest.fixture(scope="module") def llm(): # pytest caches the fixture so we use weakref.proxy to # enable garbage collection llm = LLM(model=MODEL_NAME, max_model_len=1024) with llm.deprecate_legacy_api(): yield weakref.proxy(llm) del llm cleanup() @pytest.mark.skip_global_cleanup def test_guided_regex(sample_regex, llm): sampling_params = SamplingParams( temperature=0.8, top_p=0.95, ) outputs = llm.generate( prompts=[ f"Give an example IPv4 address with this regex: {sample_regex}" ] * 2, sampling_params=sampling_params, use_tqdm=True, guided_options_request=dict(guided_regex=sample_regex)) assert outputs is not None for output in outputs: assert output is not None assert isinstance(output, RequestOutput) prompt = output.prompt generated_text = output.outputs[0].text print(generated_text) assert generated_text is not None assert re.fullmatch(sample_regex, generated_text) is not None print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") @pytest.mark.skip_global_cleanup def test_guided_json_completion(sample_json_schema, llm): sampling_params = SamplingParams( temperature=1.0, max_tokens=1000, ) outputs = llm.generate( prompts=[ f"Give an example JSON for an employee profile " f"that fits this schema: {sample_json_schema}" ] * 2, sampling_params=sampling_params, use_tqdm=True, guided_options_request=dict(guided_json=sample_json_schema)) assert outputs is not None for output in outputs: assert output is not None assert isinstance(output, RequestOutput) prompt = output.prompt generated_text = output.outputs[0].text assert generated_text is not None print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") output_json = json.loads(generated_text) jsonschema.validate(instance=output_json, schema=sample_json_schema) @pytest.mark.skip_global_cleanup def test_guided_choice_completion(sample_guided_choice, llm): sampling_params = SamplingParams( temperature=0.8, top_p=0.95, ) outputs = llm.generate( prompts="The best language for type-safe systems programming is ", sampling_params=sampling_params, use_tqdm=True, guided_options_request=dict(guided_choice=sample_guided_choice)) assert outputs is not None for output in outputs: assert output is not None assert isinstance(output, RequestOutput) prompt = output.prompt generated_text = output.outputs[0].text print(generated_text) assert generated_text is not None assert generated_text in sample_guided_choice print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") @pytest.mark.skip_global_cleanup def test_guided_grammar(sample_sql_statements, llm): sampling_params = SamplingParams( temperature=0.8, top_p=0.95, max_tokens=1000, ) outputs = llm.generate( prompts=("Generate a sql state that select col_1 from " "table_1 where it is equals to 1"), sampling_params=sampling_params, use_tqdm=True, guided_options_request=dict(guided_grammar=sample_sql_statements)) assert outputs is not None for output in outputs: assert output is not None assert isinstance(output, RequestOutput) prompt = output.prompt generated_text = output.outputs[0].text assert generated_text is not None # use Lark to parse the output, and make sure it's a valid parse tree from lark import Lark parser = Lark(sample_sql_statements) parser.parse(generated_text) # remove spaces for comparison b/c we removed them in the grammar ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace( " ", "") assert generated_text.strip() == ground_truth print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")