test_openai_server.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. from argparse import Namespace
  2. from dataclasses import dataclass
  3. import pytest
  4. from fastapi.testclient import TestClient
  5. from aphrodite.endpoints.openai.api_server import *
  6. # Define models, templates, and their corresponding expected outputs
  7. MODEL_TEMPLATE_GENERATON_OUTPUT = [
  8. ("EleutherAI/pythia-70m-deduped", None, True,
  9. "Hello</s>Hi there!</s>What is the capital of</s>"),
  10. ("EleutherAI/pythia-70m-deduped", None, False,
  11. "Hello</s>Hi there!</s>What is the capital of</s>"),
  12. ("EleutherAI/pythia-70m-deduped", "../../examples/template_chatml.jinja",
  13. True, """<|im_start|>user
  14. Hello<|im_end|>
  15. <|im_start|>assistant
  16. Hi there!<|im_end|>
  17. <|im_start|>user
  18. What is the capital of<|im_end|>
  19. <|im_start|>assistant
  20. """),
  21. ("EleutherAI/pythia-70m-deduped", "../../examples/template_chatml.jinja",
  22. False, """<|im_start|>user
  23. Hello<|im_end|>
  24. <|im_start|>assistant
  25. Hi there!<|im_end|>
  26. <|im_start|>user
  27. What is the capital of""")
  28. ]
  29. TEST_MESSAGES = [
  30. {
  31. 'role': 'user',
  32. 'content': 'Hello'
  33. },
  34. {
  35. 'role': 'assistant',
  36. 'content': 'Hi there!'
  37. },
  38. {
  39. 'role': 'user',
  40. 'content': 'What is the capital of'
  41. },
  42. ]
  43. client = TestClient(app)
  44. @dataclass
  45. class MockTokenizer:
  46. chat_template = None
  47. def test_load_chat_template():
  48. # Testing chatml template
  49. template = "../../examples/chatml_template.jinja"
  50. mock_args = Namespace(chat_template=template)
  51. tokenizer = MockTokenizer()
  52. # Call the function with the mocked args
  53. load_chat_template(mock_args, tokenizer)
  54. template_content = tokenizer.chat_template
  55. # Test assertions
  56. assert template_content is not None
  57. # Hard coded value for chatml_template.jinja
  58. assert template_content == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
  59. {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}"""
  60. def test_no_load_chat_template():
  61. # Testing chatml template
  62. template = "../../examples/does_not_exist"
  63. mock_args = Namespace(chat_template=template)
  64. tokenizer = MockTokenizer()
  65. # Call the function with the mocked args
  66. load_chat_template(mock_args, tokenizer=tokenizer)
  67. template_content = tokenizer.chat_template
  68. # Test assertions
  69. assert template_content is not None
  70. # Hard coded value for chatml_template.jinja
  71. assert template_content == """../../examples/does_not_exist"""
  72. @pytest.mark.asyncio
  73. @pytest.mark.parametrize(
  74. "model,template,add_generation_prompt,expected_output",
  75. MODEL_TEMPLATE_GENERATON_OUTPUT)
  76. async def test_get_gen_prompt(model, template, add_generation_prompt,
  77. expected_output):
  78. # Initialize the tokenizer
  79. tokenizer = get_tokenizer(tokenizer_name=model)
  80. mock_args = Namespace(chat_template=template)
  81. load_chat_template(mock_args, tokenizer)
  82. # Create a mock request object using keyword arguments
  83. mock_request = ChatCompletionRequest(
  84. model=model,
  85. messages=TEST_MESSAGES,
  86. add_generation_prompt=add_generation_prompt)
  87. # Call the function and get the result
  88. result = tokenizer.apply_chat_template(
  89. conversation=mock_request.messages,
  90. tokenize=False,
  91. add_generation_prompt=mock_request.add_generation_prompt)
  92. # Test assertion
  93. assert result == expected_output, f"The generated prompt does not match the expected output for model {model} and template {template}"
  94. def test_health_endpoint():
  95. response = client.get("/health")
  96. assert response.status_code == 200