test_chat_template.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import pytest
  2. from aphrodite.endpoints.chat_utils import apply_chat_template, load_chat_template
  3. from aphrodite.endpoints.openai.protocol import ChatCompletionRequest
  4. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  5. from ..utils import APHRODITE_PATH
  6. chatml_jinja_path = APHRODITE_PATH / "examples/chat_templates/chatml.jinja"
  7. assert chatml_jinja_path.exists()
  8. # Define models, templates, and their corresponding expected outputs
  9. MODEL_TEMPLATE_GENERATON_OUTPUT = [
  10. ("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user
  11. Hello<|im_end|>
  12. <|im_start|>assistant
  13. Hi there!<|im_end|>
  14. <|im_start|>user
  15. What is the capital of<|im_end|>
  16. <|im_start|>assistant
  17. """),
  18. ("facebook/opt-125m", chatml_jinja_path, False, """<|im_start|>user
  19. Hello<|im_end|>
  20. <|im_start|>assistant
  21. Hi there!<|im_end|>
  22. <|im_start|>user
  23. What is the capital of""")
  24. ]
  25. TEST_MESSAGES = [
  26. {
  27. 'role': 'user',
  28. 'content': 'Hello'
  29. },
  30. {
  31. 'role': 'assistant',
  32. 'content': 'Hi there!'
  33. },
  34. {
  35. 'role': 'user',
  36. 'content': 'What is the capital of'
  37. },
  38. ]
  39. def test_load_chat_template():
  40. # Testing chatml template
  41. template_content = load_chat_template(chat_template=chatml_jinja_path)
  42. # Test assertions
  43. assert template_content is not None
  44. # Hard coded value for template_chatml.jinja
  45. assert template_content == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
  46. {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501
  47. def test_no_load_chat_template_filelike():
  48. # Testing chatml template
  49. template = "../../examples/does_not_exist"
  50. with pytest.raises(ValueError, match="looks like a file path"):
  51. load_chat_template(chat_template=template)
  52. def test_no_load_chat_template_literallike():
  53. # Testing chatml template
  54. template = "{{ messages }}"
  55. template_content = load_chat_template(chat_template=template)
  56. assert template_content == template
  57. @pytest.mark.parametrize(
  58. "model,template,add_generation_prompt,expected_output",
  59. MODEL_TEMPLATE_GENERATON_OUTPUT)
  60. def test_get_gen_prompt(model, template, add_generation_prompt,
  61. expected_output):
  62. # Initialize the tokenizer
  63. tokenizer = get_tokenizer(tokenizer_name=model)
  64. template_content = load_chat_template(chat_template=template)
  65. # Create a mock request object using keyword arguments
  66. mock_request = ChatCompletionRequest(
  67. model=model,
  68. messages=TEST_MESSAGES,
  69. add_generation_prompt=add_generation_prompt)
  70. # Call the function and get the result
  71. result = apply_chat_template(
  72. tokenizer,
  73. conversation=mock_request.messages,
  74. chat_template=mock_request.chat_template or template_content,
  75. add_generation_prompt=mock_request.add_generation_prompt,
  76. )
  77. # Test assertion
  78. assert result == expected_output, (
  79. f"The generated prompt does not match the expected output for "
  80. f"model {model} and template {template}")