123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- from argparse import Namespace
- from dataclasses import dataclass
- import pytest
- from fastapi.testclient import TestClient
- from aphrodite.endpoints.openai.api_server import *
- # Define models, templates, and their corresponding expected outputs
- MODEL_TEMPLATE_GENERATON_OUTPUT = [
- ("EleutherAI/pythia-70m-deduped", None, True,
- "Hello</s>Hi there!</s>What is the capital of</s>"),
- ("EleutherAI/pythia-70m-deduped", None, False,
- "Hello</s>Hi there!</s>What is the capital of</s>"),
- ("EleutherAI/pythia-70m-deduped", "../../examples/template_chatml.jinja",
- True, """<|im_start|>user
- Hello<|im_end|>
- <|im_start|>assistant
- Hi there!<|im_end|>
- <|im_start|>user
- What is the capital of<|im_end|>
- <|im_start|>assistant
- """),
- ("EleutherAI/pythia-70m-deduped", "../../examples/template_chatml.jinja",
- False, """<|im_start|>user
- Hello<|im_end|>
- <|im_start|>assistant
- Hi there!<|im_end|>
- <|im_start|>user
- What is the capital of""")
- ]
- TEST_MESSAGES = [
- {
- 'role': 'user',
- 'content': 'Hello'
- },
- {
- 'role': 'assistant',
- 'content': 'Hi there!'
- },
- {
- 'role': 'user',
- 'content': 'What is the capital of'
- },
- ]
- client = TestClient(app)
- @dataclass
- class MockTokenizer:
- chat_template = None
- def test_load_chat_template():
- # Testing chatml template
- template = "../../examples/chatml_template.jinja"
- mock_args = Namespace(chat_template=template)
- tokenizer = MockTokenizer()
- # Call the function with the mocked args
- load_chat_template(mock_args, tokenizer)
- template_content = tokenizer.chat_template
- # Test assertions
- assert template_content is not None
- # Hard coded value for chatml_template.jinja
- assert template_content == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
- {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}"""
- def test_no_load_chat_template():
- # Testing chatml template
- template = "../../examples/does_not_exist"
- mock_args = Namespace(chat_template=template)
- tokenizer = MockTokenizer()
- # Call the function with the mocked args
- load_chat_template(mock_args, tokenizer=tokenizer)
- template_content = tokenizer.chat_template
- # Test assertions
- assert template_content is not None
- # Hard coded value for chatml_template.jinja
- assert template_content == """../../examples/does_not_exist"""
- @pytest.mark.asyncio
- @pytest.mark.parametrize(
- "model,template,add_generation_prompt,expected_output",
- MODEL_TEMPLATE_GENERATON_OUTPUT)
- async def test_get_gen_prompt(model, template, add_generation_prompt,
- expected_output):
- # Initialize the tokenizer
- tokenizer = get_tokenizer(tokenizer_name=model)
- mock_args = Namespace(chat_template=template)
- load_chat_template(mock_args, tokenizer)
- # Create a mock request object using keyword arguments
- mock_request = ChatCompletionRequest(
- model=model,
- messages=TEST_MESSAGES,
- add_generation_prompt=add_generation_prompt)
- # Call the function and get the result
- result = tokenizer.apply_chat_template(
- conversation=mock_request.messages,
- tokenize=False,
- add_generation_prompt=mock_request.add_generation_prompt)
- # Test assertion
- assert result == expected_output, f"The generated prompt does not match the expected output for model {model} and template {template}"
- def test_health_endpoint():
- response = client.get("/health")
- assert response.status_code == 200
|